lib9p

Go 9P library.
Log | Files | Refs | LICENSE

commit ddfea05b5376d2c0f35f6fba766f608df2e3b8d5
parent 98d89fa079b3fe3c485c938a7fbd1e16d17b0345
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Thu, 28 Dec 2023 09:03:47 +0900

delete context from req and replace it by a done channel

Diffstat:
Mreq.go | 11+++++------
Mserver.go | 83++++++++++++++++++++++++++++++++++++-------------------------------------------
Mserver2_test.go | 10+---------
3 files changed, 44 insertions(+), 60 deletions(-)

diff --git a/req.go b/req.go @@ -1,7 +1,6 @@ package lib9p import ( - "context" "sync" ) @@ -15,13 +14,13 @@ type Req struct { Ofcall Msg Fid *Fid Afid *Fid + // Oldreq is set with Tflush message. Oldreq *Req + // Pool is the pool this request belongs to. pool *ReqPool - // Cancel is used to flush the request. - Cancel context.CancelFunc // Done is used by time consuming goroutines to check whether the request - // is canceled. - Done <-chan struct{} + // is flushed. + done chan struct{} // listenErr is any error encountered while waiting for new 9P message. listenErr error // speakErrChan is used to report any error encountered while sending @@ -36,7 +35,7 @@ type Req struct { // It also delete the request from its pool. func (r *Req) flush() { // TODO: need mutex? - r.Cancel() + close(r.done) r.pool.delete(r.Tag) } diff --git a/server.go b/server.go @@ -119,7 +119,7 @@ func (s *Server) runListener(ctx context.Context) { defer close(rc) for { select { - case rc <- s.getReq(ctx): + case rc <- s.getReq(): case <-ctx.Done(): return } @@ -158,8 +158,7 @@ func (s *Server) runSpeaker(ctx context.Context) { // Any error it encountered is embedded into the Req struct. // This function is called only by the server's listener goroutine, // and does not need to lock s.r. -// The argument ctx is used to set the request's context. -func (s *Server) getReq(ctx context.Context) *Req { +func (s *Server) getReq() *Req { ifcall, err := RecvMsg(s.r) if err != nil { if err == io.EOF { @@ -167,7 +166,6 @@ func (s *Server) getReq(ctx context.Context) *Req { } return &Req{listenErr: fmt.Errorf("readMsg: %v", err)} } - ctx, cancel := context.WithCancel(ctx) req, err := s.rPool.add(ifcall.GetTag()) if err != nil { // duplicate tag: cons up a fake Req @@ -175,8 +173,7 @@ func (s *Server) getReq(ctx context.Context) *Req { req.Srv = s req.Ifcall = ifcall req.listenErr = ErrDupTag - req.Cancel = cancel - req.Done = ctx.Done() + req.done = make(chan struct{}) if s.chatty9P { fmt.Fprintf(os.Stderr, "<-- %v\n", req.Ifcall) } @@ -185,8 +182,7 @@ func (s *Server) getReq(ctx context.Context) *Req { req.Srv = s req.Tag = ifcall.GetTag() req.Ifcall = ifcall - req.Cancel = cancel - req.Done = ctx.Done() + req.done = make(chan struct{}) if ifcall, ok := req.Ifcall.(*TFlush); ok { req.Oldreq, _ = s.rPool.lookup(ifcall.Oldtag) } @@ -280,7 +276,7 @@ func sAttach(ctx context.Context, s *Server, c <-chan *Req) { return case r, ok := <-c: var ( - st fs.FileInfo + st fs.FileInfo err error ) if !ok { @@ -728,7 +724,7 @@ func sRead(ctx context.Context, s *Server, c <-chan *Req) { r.err = err goto resp } - case <-r.Done: + case <-r.done: continue } r.Ofcall = &RRead{ @@ -803,7 +799,7 @@ func sWrite(ctx context.Context, s *Server, c <-chan *Req) { r.err = err goto resp } - case <-r.Done: + case <-r.done: continue } r.Ofcall = ofcall @@ -1147,39 +1143,37 @@ L: } continue L } - go func() { // TODO: is this worth a goroutine? - switch r.Ifcall.(type) { - default: - setError(r, fmt.Errorf("unknown message type: %d", r.Ifcall.Type())) - s.respChan <- r - case *TVersion: - versionChan <- r - case *TAuth: - authChan <- r - case *TAttach: - attachChan <- r - case *TFlush: - flushChan <- r - case *TWalk: - walkChan <- r - case *TOpen: - openChan <- r - case *TCreate: - createChan <- r - case *TRead: - readChan <- r - case *TWrite: - writeChan <- r - case *TClunk: - clunkChan <- r - case *TRemove: - removeChan <- r - case *TStat: - statChan <- r - case *TWStat: - wstatChan <- r - } - }() + switch r.Ifcall.(type) { + default: + setError(r, fmt.Errorf("unknown message type: %d", r.Ifcall.Type())) + s.respChan <- r + case *TVersion: + versionChan <- r + case *TAuth: + authChan <- r + case *TAttach: + attachChan <- r + case *TFlush: + flushChan <- r + case *TWalk: + walkChan <- r + case *TOpen: + openChan <- r + case *TCreate: + createChan <- r + case *TRead: + readChan <- r + case *TWrite: + writeChan <- r + case *TClunk: + clunkChan <- r + case *TRemove: + removeChan <- r + case *TStat: + statChan <- r + case *TWStat: + wstatChan <- r + } } } } @@ -1215,7 +1209,6 @@ func respond(ctx context.Context, s *Server) { if err != nil { log.Printf("speak: %v", err) } - r.Cancel() case <-ctx.Done(): return } diff --git a/server2_test.go b/server2_test.go @@ -105,9 +105,7 @@ func TestGetReq(t *testing.T) { defer tFile2.Close() s := &Server{r: tFile, rPool: newReqPool()} for { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() // in case error. - got := s.getReq(ctx) + got := s.getReq() if got.listenErr == io.EOF { break } else if got.listenErr != nil { @@ -134,12 +132,6 @@ func TestGetReq(t *testing.T) { t.Errorf("wrong message in pool:\n\twant: %p,\n\tgot: %p", got, got2) } s.rPool.delete(wantMsg.GetTag()) - cancel() - select { - case <-got.Done: - default: - t.Errorf("r.Done not closed") - } } }