commit 95d560fcbdd3772213b0fdcd84b5a242b528e4c6
parent 48cd056e2a840f80b01fbc881b518df2b49fbef6
Author: Matsuda Kenji <info@mtkn.jp>
Date: Mon, 25 Dec 2023 08:02:37 +0900
check context.Context.Done() channel inside select.
Diffstat:
4 files changed, 98 insertions(+), 33 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -100,8 +100,6 @@ func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan lib9p.Msg
for {
select {
case <-ctx.Done():
- // TODO: should return error via ec??
- // TODO: should close r?
return
default:
done := make(chan struct{})
@@ -116,6 +114,7 @@ func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan lib9p.Msg
select {
case <-done:
case <-ctx.Done():
+ return
}
if err != nil {
c.errc <- fmt.Errorf("recv: %v", err)
diff --git a/client/fs.go b/client/fs.go
@@ -138,3 +138,5 @@ func Mount(r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) {
cfs.c.rootFid = fid
return cfs, nil
}
+
+func (fsys *FS) Unmount() { fsys.c.Stop() }
+\ No newline at end of file
diff --git a/server.go b/server.go
@@ -206,8 +206,10 @@ func sVersion(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TVersion)
version := ifcall.Version
@@ -237,8 +239,10 @@ func rVersion(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
panic(fmt.Errorf("rVersion err: %w", r.err))
@@ -264,8 +268,10 @@ func sAuth(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TAuth)
var err error
@@ -294,8 +300,10 @@ func rAuth(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
r.Srv.fPool.delete(r.Ifcall.(*TAuth).Afid)
@@ -315,8 +323,10 @@ func sAttach(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TAttach)
fid, err := s.fPool.add(ifcall.Fid)
@@ -378,8 +388,10 @@ func rAttach(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
r.Srv.fPool.delete(r.Ifcall.(*TAttach).Fid)
@@ -399,8 +411,10 @@ func sFlush(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
rc <- r
}
@@ -413,8 +427,10 @@ func rFlush(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
panic(fmt.Errorf("err in flush: %v", r.err))
@@ -437,8 +453,10 @@ func sWalk(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TWalk)
oldFid, ok := s.fPool.lookup(ifcall.Fid)
@@ -505,8 +523,10 @@ func rWalk(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TWalk)
if r.Ofcall == nil {
@@ -542,8 +562,10 @@ func sOpen(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TOpen)
var ok bool
@@ -646,8 +668,10 @@ func rOpen(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
setError(r, r.err)
@@ -680,8 +704,10 @@ func sCreate(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TCreate)
var ok bool
@@ -751,8 +777,10 @@ func rCreate(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
setError(r, r.err)
@@ -773,8 +801,10 @@ func sRead(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TRead)
var ok bool
@@ -876,8 +906,10 @@ func rRead(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
setError(r, r.err)
@@ -897,8 +929,10 @@ func sWrite(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TWrite)
var ok bool
@@ -963,8 +997,10 @@ func rWrite(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
setError(r, r.err)
@@ -983,8 +1019,10 @@ func sClunk(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TClunk)
fid, ok := s.fPool.lookup(ifcall.Fid)
@@ -1013,8 +1051,10 @@ func rClunk(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
log.Printf("clunk: %v", r.err)
@@ -1034,8 +1074,10 @@ func sRemove(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TRemove)
var ok bool
@@ -1084,8 +1126,10 @@ func rRemove(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
setError(r, r.err)
@@ -1104,8 +1148,10 @@ func sStat(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TStat)
var ok bool
@@ -1135,8 +1181,10 @@ func rStat(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
setError(r, r.err)
@@ -1155,8 +1203,10 @@ func sWStat(ctx context.Context, s *Server, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
ifcall := r.Ifcall.(*TWStat)
var ok bool
@@ -1302,8 +1352,10 @@ func rWStat(ctx context.Context, c <-chan *Req) {
case <-ctx.Done():
return
case r := <-c:
- if r == nil {
+ select {
+ case <-ctx.Done():
return
+ default:
}
if r.err != nil {
setError(r, r.err)
@@ -1367,6 +1419,11 @@ L:
for {
select {
case r := <-s.listenChan:
+ select {
+ case <-ctx.Done():
+ break L
+ default:
+ }
if r.listenErr != nil {
log.Printf("listen: %v", r.listenErr)
if r.listenErr == io.EOF {
@@ -1412,7 +1469,6 @@ L:
}
}()
case <-ctx.Done():
- //log.Println(ctx.Err())
break L
}
}
@@ -1424,6 +1480,11 @@ func respond(ctx context.Context, s *Server) {
case <-ctx.Done():
return
case r := <-s.respChan:
+ select {
+ case <-ctx.Done():
+ return
+ default: // respChan should be closed after ctx is canceled.
+ }
r.Ofcall.SetTag(r.Tag)
// free tag.
if r.pool == nil && r.err != ErrDupTag {
@@ -1442,6 +1503,11 @@ func respond(ctx context.Context, s *Server) {
}
select {
case err := <-r.speakErrChan:
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
// TODO: handle errors.
if err != nil {
log.Printf("speak: %v", err)
diff --git a/testfs/fs.go b/testfs/fs.go
@@ -5,13 +5,10 @@ import (
"fmt"
"io/fs"
"strings"
- "time"
"git.mtkn.jp/lib9p"
)
-const SleepTime = 10 * time.Millisecond
-
type File struct {
// Fsys is the FS this File belongs to.
Fsys *FS
@@ -178,7 +175,7 @@ func init() {
Content: []byte("b\n"),
St: lib9p.Stat{
Qid: lib9p.Qid{Path: 2, Type: lib9p.QTFILE},
- Mode: lib9p.FileMode(0400),
+ Mode: lib9p.FileMode(0600),
Name: "b",
Uid: "ken",
Gid: "ken",