commit ddfea05b5376d2c0f35f6fba766f608df2e3b8d5
parent 98d89fa079b3fe3c485c938a7fbd1e16d17b0345
Author: Matsuda Kenji <info@mtkn.jp>
Date: Thu, 28 Dec 2023 09:03:47 +0900
delete context from req and replace it by a done channel
Diffstat:
| M | req.go | | | 11 | +++++------ |
| M | server.go | | | 83 | ++++++++++++++++++++++++++++++++++++------------------------------------------- |
| M | server2_test.go | | | 10 | +--------- |
3 files changed, 44 insertions(+), 60 deletions(-)
diff --git a/req.go b/req.go
@@ -1,7 +1,6 @@
package lib9p
import (
- "context"
"sync"
)
@@ -15,13 +14,13 @@ type Req struct {
Ofcall Msg
Fid *Fid
Afid *Fid
+ // Oldreq is set with Tflush message.
Oldreq *Req
+ // Pool is the pool this request belongs to.
pool *ReqPool
- // Cancel is used to flush the request.
- Cancel context.CancelFunc
// Done is used by time consuming goroutines to check whether the request
- // is canceled.
- Done <-chan struct{}
+ // is flushed.
+ done chan struct{}
// listenErr is any error encountered while waiting for new 9P message.
listenErr error
// speakErrChan is used to report any error encountered while sending
@@ -36,7 +35,7 @@ type Req struct {
// It also delete the request from its pool.
func (r *Req) flush() {
// TODO: need mutex?
- r.Cancel()
+ close(r.done)
r.pool.delete(r.Tag)
}
diff --git a/server.go b/server.go
@@ -119,7 +119,7 @@ func (s *Server) runListener(ctx context.Context) {
defer close(rc)
for {
select {
- case rc <- s.getReq(ctx):
+ case rc <- s.getReq():
case <-ctx.Done():
return
}
@@ -158,8 +158,7 @@ func (s *Server) runSpeaker(ctx context.Context) {
// Any error it encountered is embedded into the Req struct.
// This function is called only by the server's listener goroutine,
// and does not need to lock s.r.
-// The argument ctx is used to set the request's context.
-func (s *Server) getReq(ctx context.Context) *Req {
+func (s *Server) getReq() *Req {
ifcall, err := RecvMsg(s.r)
if err != nil {
if err == io.EOF {
@@ -167,7 +166,6 @@ func (s *Server) getReq(ctx context.Context) *Req {
}
return &Req{listenErr: fmt.Errorf("readMsg: %v", err)}
}
- ctx, cancel := context.WithCancel(ctx)
req, err := s.rPool.add(ifcall.GetTag())
if err != nil {
// duplicate tag: cons up a fake Req
@@ -175,8 +173,7 @@ func (s *Server) getReq(ctx context.Context) *Req {
req.Srv = s
req.Ifcall = ifcall
req.listenErr = ErrDupTag
- req.Cancel = cancel
- req.Done = ctx.Done()
+ req.done = make(chan struct{})
if s.chatty9P {
fmt.Fprintf(os.Stderr, "<-- %v\n", req.Ifcall)
}
@@ -185,8 +182,7 @@ func (s *Server) getReq(ctx context.Context) *Req {
req.Srv = s
req.Tag = ifcall.GetTag()
req.Ifcall = ifcall
- req.Cancel = cancel
- req.Done = ctx.Done()
+ req.done = make(chan struct{})
if ifcall, ok := req.Ifcall.(*TFlush); ok {
req.Oldreq, _ = s.rPool.lookup(ifcall.Oldtag)
}
@@ -280,7 +276,7 @@ func sAttach(ctx context.Context, s *Server, c <-chan *Req) {
return
case r, ok := <-c:
var (
- st fs.FileInfo
+ st fs.FileInfo
err error
)
if !ok {
@@ -728,7 +724,7 @@ func sRead(ctx context.Context, s *Server, c <-chan *Req) {
r.err = err
goto resp
}
- case <-r.Done:
+ case <-r.done:
continue
}
r.Ofcall = &RRead{
@@ -803,7 +799,7 @@ func sWrite(ctx context.Context, s *Server, c <-chan *Req) {
r.err = err
goto resp
}
- case <-r.Done:
+ case <-r.done:
continue
}
r.Ofcall = ofcall
@@ -1147,39 +1143,37 @@ L:
}
continue L
}
- go func() { // TODO: is this worth a goroutine?
- switch r.Ifcall.(type) {
- default:
- setError(r, fmt.Errorf("unknown message type: %d", r.Ifcall.Type()))
- s.respChan <- r
- case *TVersion:
- versionChan <- r
- case *TAuth:
- authChan <- r
- case *TAttach:
- attachChan <- r
- case *TFlush:
- flushChan <- r
- case *TWalk:
- walkChan <- r
- case *TOpen:
- openChan <- r
- case *TCreate:
- createChan <- r
- case *TRead:
- readChan <- r
- case *TWrite:
- writeChan <- r
- case *TClunk:
- clunkChan <- r
- case *TRemove:
- removeChan <- r
- case *TStat:
- statChan <- r
- case *TWStat:
- wstatChan <- r
- }
- }()
+ switch r.Ifcall.(type) {
+ default:
+ setError(r, fmt.Errorf("unknown message type: %d", r.Ifcall.Type()))
+ s.respChan <- r
+ case *TVersion:
+ versionChan <- r
+ case *TAuth:
+ authChan <- r
+ case *TAttach:
+ attachChan <- r
+ case *TFlush:
+ flushChan <- r
+ case *TWalk:
+ walkChan <- r
+ case *TOpen:
+ openChan <- r
+ case *TCreate:
+ createChan <- r
+ case *TRead:
+ readChan <- r
+ case *TWrite:
+ writeChan <- r
+ case *TClunk:
+ clunkChan <- r
+ case *TRemove:
+ removeChan <- r
+ case *TStat:
+ statChan <- r
+ case *TWStat:
+ wstatChan <- r
+ }
}
}
}
@@ -1215,7 +1209,6 @@ func respond(ctx context.Context, s *Server) {
if err != nil {
log.Printf("speak: %v", err)
}
- r.Cancel()
case <-ctx.Done():
return
}
diff --git a/server2_test.go b/server2_test.go
@@ -105,9 +105,7 @@ func TestGetReq(t *testing.T) {
defer tFile2.Close()
s := &Server{r: tFile, rPool: newReqPool()}
for {
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel() // in case error.
- got := s.getReq(ctx)
+ got := s.getReq()
if got.listenErr == io.EOF {
break
} else if got.listenErr != nil {
@@ -134,12 +132,6 @@ func TestGetReq(t *testing.T) {
t.Errorf("wrong message in pool:\n\twant: %p,\n\tgot: %p", got, got2)
}
s.rPool.delete(wantMsg.GetTag())
- cancel()
- select {
- case <-got.Done:
- default:
- t.Errorf("r.Done not closed")
- }
}
}