lib9p

Go 9P library.
Log | Files | Refs | LICENSE

commit 95d560fcbdd3772213b0fdcd84b5a242b528e4c6
parent 48cd056e2a840f80b01fbc881b518df2b49fbef6
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Mon, 25 Dec 2023 08:02:37 +0900

check context.Context.Done() channel inside select.

Diffstat:
Mclient/client.go | 3+--
Mclient/fs.go | 3+++
Mserver.go | 120+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------
Mtestfs/fs.go | 5+----
4 files changed, 98 insertions(+), 33 deletions(-)

diff --git a/client/client.go b/client/client.go @@ -100,8 +100,6 @@ func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan lib9p.Msg for { select { case <-ctx.Done(): - // TODO: should return error via ec?? - // TODO: should close r? return default: done := make(chan struct{}) @@ -116,6 +114,7 @@ func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan lib9p.Msg select { case <-done: case <-ctx.Done(): + return } if err != nil { c.errc <- fmt.Errorf("recv: %v", err) diff --git a/client/fs.go b/client/fs.go @@ -138,3 +138,5 @@ func Mount(r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) { cfs.c.rootFid = fid return cfs, nil } + +func (fsys *FS) Unmount() { fsys.c.Stop() } +\ No newline at end of file diff --git a/server.go b/server.go @@ -206,8 +206,10 @@ func sVersion(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TVersion) version := ifcall.Version @@ -237,8 +239,10 @@ func rVersion(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { panic(fmt.Errorf("rVersion err: %w", r.err)) @@ -264,8 +268,10 @@ func sAuth(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TAuth) var err error @@ -294,8 +300,10 @@ func rAuth(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { r.Srv.fPool.delete(r.Ifcall.(*TAuth).Afid) @@ -315,8 +323,10 @@ func sAttach(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TAttach) fid, err := s.fPool.add(ifcall.Fid) @@ -378,8 +388,10 @@ func rAttach(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { r.Srv.fPool.delete(r.Ifcall.(*TAttach).Fid) @@ -399,8 +411,10 @@ func sFlush(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } rc <- r } @@ -413,8 +427,10 @@ func rFlush(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { panic(fmt.Errorf("err in flush: %v", r.err)) @@ -437,8 +453,10 @@ func sWalk(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TWalk) oldFid, ok := s.fPool.lookup(ifcall.Fid) @@ -505,8 +523,10 @@ func rWalk(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TWalk) if r.Ofcall == nil { @@ -542,8 +562,10 @@ func sOpen(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TOpen) var ok bool @@ -646,8 +668,10 @@ func rOpen(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { setError(r, r.err) @@ -680,8 +704,10 @@ func sCreate(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TCreate) var ok bool @@ -751,8 +777,10 @@ func rCreate(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { setError(r, r.err) @@ -773,8 +801,10 @@ func sRead(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TRead) var ok bool @@ -876,8 +906,10 @@ func rRead(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { setError(r, r.err) @@ -897,8 +929,10 @@ func sWrite(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TWrite) var ok bool @@ -963,8 +997,10 @@ func rWrite(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { setError(r, r.err) @@ -983,8 +1019,10 @@ func sClunk(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TClunk) fid, ok := s.fPool.lookup(ifcall.Fid) @@ -1013,8 +1051,10 @@ func rClunk(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { log.Printf("clunk: %v", r.err) @@ -1034,8 +1074,10 @@ func sRemove(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TRemove) var ok bool @@ -1084,8 +1126,10 @@ func rRemove(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { setError(r, r.err) @@ -1104,8 +1148,10 @@ func sStat(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TStat) var ok bool @@ -1135,8 +1181,10 @@ func rStat(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { setError(r, r.err) @@ -1155,8 +1203,10 @@ func sWStat(ctx context.Context, s *Server, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } ifcall := r.Ifcall.(*TWStat) var ok bool @@ -1302,8 +1352,10 @@ func rWStat(ctx context.Context, c <-chan *Req) { case <-ctx.Done(): return case r := <-c: - if r == nil { + select { + case <-ctx.Done(): return + default: } if r.err != nil { setError(r, r.err) @@ -1367,6 +1419,11 @@ L: for { select { case r := <-s.listenChan: + select { + case <-ctx.Done(): + break L + default: + } if r.listenErr != nil { log.Printf("listen: %v", r.listenErr) if r.listenErr == io.EOF { @@ -1412,7 +1469,6 @@ L: } }() case <-ctx.Done(): - //log.Println(ctx.Err()) break L } } @@ -1424,6 +1480,11 @@ func respond(ctx context.Context, s *Server) { case <-ctx.Done(): return case r := <-s.respChan: + select { + case <-ctx.Done(): + return + default: // respChan should be closed after ctx is canceled. + } r.Ofcall.SetTag(r.Tag) // free tag. if r.pool == nil && r.err != ErrDupTag { @@ -1442,6 +1503,11 @@ func respond(ctx context.Context, s *Server) { } select { case err := <-r.speakErrChan: + select { + case <-ctx.Done(): + return + default: + } // TODO: handle errors. if err != nil { log.Printf("speak: %v", err) diff --git a/testfs/fs.go b/testfs/fs.go @@ -5,13 +5,10 @@ import ( "fmt" "io/fs" "strings" - "time" "git.mtkn.jp/lib9p" ) -const SleepTime = 10 * time.Millisecond - type File struct { // Fsys is the FS this File belongs to. Fsys *FS @@ -178,7 +175,7 @@ func init() { Content: []byte("b\n"), St: lib9p.Stat{ Qid: lib9p.Qid{Path: 2, Type: lib9p.QTFILE}, - Mode: lib9p.FileMode(0400), + Mode: lib9p.FileMode(0600), Name: "b", Uid: "ken", Gid: "ken",