commit 5aaab3fe165c5a0d44708c4c091354bd63eecfee
parent d92515436c042cbc39ea8d0566ccd8c0c6ab5c24
Author: Matsuda Kenji <info@mtkn.jp>
Date: Sat, 6 Jan 2024 15:51:14 +0900
refactor ctx of client
Diffstat:
3 files changed, 72 insertions(+), 19 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -44,7 +44,7 @@ type Client struct {
// NewClient creates a Client.
// It also runs several goroutines to handle requests.
-// And the returned client should be stopped by calling *Client.Stop afterwards.
+// And the returned client should be stopped by cancelling ctx afterwards.
func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
c := &Client{
msize: mSize,
@@ -174,8 +174,8 @@ func (c *Client) runSpeaker(ctx context.Context, w io.Writer) chan<- lib9p.Msg {
return tmsgc
}
-// RunMultiplexer runs multiplexer goroutine.
-// Multiplexer goroutines, one for recieving Rmsg and another for sending Tmsg.
+// RunMultiplexer runs multiplexer goroutines,
+// one for recieving Rmsg and another for sending Tmsg.
// The goroutine for Tmsg recieves *req from the returned channel,
// and send the 9P lib9p.Msg to the speaker goroutine via tmsgc.
// The goroutine for Rmsg recieves *req from the Tmsg goroutine and waits for
@@ -192,7 +192,7 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
rPool := make(map[uint16]*req)
defer func() {
for _, r := range rPool {
- <-r.ctxDone
+ r.cancel()
}
c.wg.Done()
}()
@@ -205,6 +205,8 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
return
}
if _, ok := rPool[r.tag]; ok {
+ // r.cancel() is not needed because transaction unblocks by
+ // sending error via r.errc.
r.errc <- fmt.Errorf("mux: duplicate tag: %d", r.tag)
continue
}
@@ -223,15 +225,14 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
if _, ok := msg.(*lib9p.RFlush); !ok {
r.errc <- fmt.Errorf("mux: response to Tflush is not Rflush")
}
- oldreq, ok := rPool[tflush.Oldtag]
- if ok {
- oldreq.Close()
+ if oldreq, ok := rPool[tflush.Oldtag]; ok {
+ oldreq.cancel()
}
delete(rPool, tflush.Oldtag)
}
r.rmsg = msg
select {
- case <-r.ctxDone:
+ case <-ctx.Done():
case r.rxc <- r:
}
}
@@ -252,11 +253,14 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
select {
case reqc <- r:
case <-ctx.Done():
+ r.cancel() // before registering to rPool of Rmsg goroutine.
return
}
select {
case tmsgc <- r.tmsg:
case <-ctx.Done():
+ // after registergin to rPool of Rmsg goroutine,
+ // cancelling is done by that goroutine.
return
}
}
@@ -268,9 +272,7 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
// Transact send 9P lib9p.Msg of r to the multiplexer goroutines and recieves
// the reply.
func (c *Client) transact(ctx context.Context, tmsg lib9p.Msg) (lib9p.Msg, error) {
- ctx1, cancel1 := context.WithCancel(ctx)
- defer cancel1()
- r := newReq(ctx1, tmsg)
+ r := newReq(tmsg)
select {
case <-ctx.Done():
return nil, ctx.Err()
diff --git a/client/client_test.go b/client/client_test.go
@@ -12,6 +12,63 @@ import (
"git.mtkn.jp/lib9p"
)
+// TestClientCancel checks if the client goroutine cancel outstanding transactions
+// propperly when the Client is canceled.
+func TestClientCancel(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ cr, _ := io.Pipe()
+ sr, cw := io.Pipe()
+ defer func() { cr.Close(); sr.Close() }()
+ c := NewClient(ctx, 1024, "", cr, cw)
+ wg := new(sync.WaitGroup)
+ wg.Add(10)
+ for i := 0; i < 10; i++ {
+ go func(i int) {
+ defer wg.Done()
+ _, _, err := c.Version(context.Background(), uint16(i), 1024, "9P2000")
+ t.Logf("%d: %v", i, err)
+ }(i)
+ }
+ for i := 0; i < 7; i++ {
+ _, err := lib9p.RecvMsg(sr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ t.Logf("cancel client.")
+ cancel()
+ wg.Wait()
+}
+
+func TestReqCancel(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ ctx0, cancel0 := context.WithCancel(context.Background())
+ cr, _ := io.Pipe()
+ sr, cw := io.Pipe()
+ defer func() { cr.Close(); sr.Close() }()
+ c := NewClient(ctx, 1024, "", cr, cw)
+ wg := new(sync.WaitGroup)
+ wg.Add(10)
+ for i := 0; i < 10; i++ {
+ go func(i int) {
+ defer wg.Done()
+ _, _, err := c.Version(ctx0, uint16(i), 1024, "9P2000")
+ t.Logf("%d: %v", i, err)
+ }(i)
+ }
+ for i := 0; i < 6; i++ {
+ _, err := lib9p.RecvMsg(sr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ t.Logf("cancel requests.")
+ cancel0()
+ wg.Wait()
+}
+
+
func setupClientAndServer(fs lib9p.FS) (*Client, context.CancelFunc) {
cr, sw := io.Pipe()
sr, cw := io.Pipe()
diff --git a/client/req.go b/client/req.go
@@ -1,7 +1,6 @@
package client
import (
- "context"
"fmt"
"sync"
@@ -16,24 +15,19 @@ type req struct {
err error
errc chan error // To report any client side error to transact().
rxc chan *req
- ctxDone <-chan struct{}
}
// newReq allocates a req with msg.
-// It also sets the ctxDone channel to ctx.Done().
-// TODO: passing ctx is confusing?
-// it only needs the done channel.
-func newReq(ctx context.Context, msg lib9p.Msg) *req {
+func newReq(msg lib9p.Msg) *req {
return &req{
tag: msg.GetTag(),
tmsg: msg,
rxc: make(chan *req),
- ctxDone: ctx.Done(),
errc: make(chan error),
}
}
-func (r *req) Close() {
+func (r *req) cancel() {
close(r.rxc)
close(r.errc)
}