commit 4c24389b9733f6966e66fb863329f5c9e4e4c2bf
parent 9f326e9115d885e8fcd7d651ef6773ccf0351a45
Author: Matsuda Kenji <info@mtkn.jp>
Date: Sun, 7 Jan 2024 09:52:35 +0900
refactor ctx
Diffstat:
4 files changed, 59 insertions(+), 41 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -13,35 +13,41 @@ import (
// Client is a client side of the 9P conversation.
type Client struct {
- // msize is the maximum message size in length
+ // Msize is the maximum message size in length
msize uint32
- // mSizeLock is the mutex used when msize is to be changed.
+ // MSizeLock is the mutex used when msize is to be changed.
mSizeLock *sync.Mutex
- // uname is used to communicate with a server.
+ // Uname is used to communicate with a server.
uname string
- // fPool is the fidPool which hold the list of opend fids.
+ // FPool is the fidPool which hold the list of opend fids.
fPool *fidPool
+ // RPool is the list of outstanding requests.
rPool *reqPool
- // txc is used to send a reqest to the multiplexer goroutine
+ // Txc is used to send a reqest to the multiplexer goroutine
txc chan<- *req
- // errc is used to report any error which is not relevant to
+ // Errc is used to report any error which is not relevant to
// a specific request
errc chan error
- // cancel is the CancelFunc to stop the goroutines evoked by this client.
+ // Cancel is the CancelFunc to stop the goroutines evoked by this client.
cancel context.CancelFunc
- // rootFid is the fid of the root of the file system.
+ // RootFid is the fid of the root of the file system.
rootFid *fid
- // wg is the WaitGroup of all goroutines evoked by this client and its
+ // Wg is the WaitGroup of all goroutines evoked by this client and its
// descendants.
wg *sync.WaitGroup
+
+ // Done is closed when the context passed to NewClient is canceled.
+ // This is used to notify transact() that the client is already
+ // canceled and the transaction should also be canceled.
+ done <-chan struct{}
}
// NewClient creates a Client.
@@ -56,6 +62,7 @@ func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w i
rPool: newReqPool(),
errc: make(chan error),
wg: new(sync.WaitGroup),
+ done: ctx.Done(),
}
tmsgc := c.runSpeaker(ctx, w)
rmsgc := c.runListener(ctx, r)
@@ -192,7 +199,6 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
// Rmsg
go func() {
defer func() {
- c.rPool.cancelAll()
c.wg.Done()
}()
for {
@@ -214,7 +220,7 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
r.errc <- fmt.Errorf("mux: response to Tflush is not Rflush")
}
if oldreq, ok := c.rPool.lookup(tflush.Oldtag); ok {
- oldreq.cancel()
+ oldreq.errc <- errors.New("request flushed")
c.rPool.delete(tflush.Oldtag)
}
}
@@ -229,6 +235,16 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
// Tmsg
go func() {
defer func() {
+ L:
+ for {
+ select {
+ case r := <-txc:
+ r.errc <- errors.New("client stopped")
+ default:
+ break L
+ }
+ }
+ c.rPool.cancelAll(errors.New("client stopped"))
close(tmsgc)
c.wg.Done()
}()
@@ -238,8 +254,6 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
return
case r := <-txc:
if _, ok := c.rPool.lookup(r.tag); ok {
- // r.cancel() is not needed because transaction unblocks by
- // sending error via r.errc.
r.errc <- fmt.Errorf("mux: duplicate tag: %d", r.tag)
continue
}
@@ -247,8 +261,6 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
select {
case tmsgc <- r.tmsg:
case <-ctx.Done():
- // after registergin to rPool of Rmsg goroutine,
- // cancelling is done by that goroutine.
return
}
}
@@ -265,20 +277,18 @@ func (c *Client) transact(ctx context.Context, tmsg lib9p.Msg) (lib9p.Msg, error
case <-ctx.Done():
return nil, ctx.Err()
case c.txc <- r:
+ case <-c.done:
+ return nil, errors.New("client stopped")
}
select {
- case r, ok := <-r.rxc:
- if !ok {
- return nil, errors.New("reqest canceled")
- }
+ case r := <-r.rxc:
return r.rmsg, r.err
- case err, ok := <-r.errc: // Client side error.
- if !ok {
- return nil, errors.New("reqest canceled")
- }
+ case err := <-r.errc: // Client side error.
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
+ case <-c.done:
+ return nil, <-r.errc
}
}
diff --git a/client/client_test.go b/client/client_test.go
@@ -12,7 +12,7 @@ import (
"git.mtkn.jp/lib9p"
)
-// TestClientCancel checks if the client goroutine cancel outstanding transactions
+// TestClientCancel checks if the client goroutines cancel outstanding transactions
// propperly when the Client is canceled.
func TestClientCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
@@ -25,21 +25,24 @@ func TestClientCancel(t *testing.T) {
for i := 0; i < 10; i++ {
go func(i int) {
defer wg.Done()
- _, _, err := c.Version(context.Background(), uint16(i), 1024, "9P2000")
- t.Logf("%d: %v", i, err)
+ c.Version(context.Background(), uint16(i), 1024, "9P2000")
}(i)
}
- for i := 0; i < 6; i++ {
+ for i := 0; i < 5; i++ {
_, err := lib9p.RecvMsg(sr)
if err != nil {
t.Fatal(err)
}
}
- t.Logf("cancel client.")
cancel()
wg.Wait()
+ if len(c.rPool.m) != 0 {
+ t.Errorf("req pool clogged: %v", c.rPool)
+ }
}
+// TestReqCancel checks if the function transact cancels outstanding transactions
+// propperly when the requests are canceled.
func TestReqCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -53,19 +56,20 @@ func TestReqCancel(t *testing.T) {
for i := 0; i < 10; i++ {
go func(i int) {
defer wg.Done()
- _, _, err := c.Version(ctx0, uint16(i), 1024, "9P2000")
- t.Logf("%d: %v", i, err)
+ c.Version(ctx0, uint16(i), 1024, "9P2000")
}(i)
}
- for i := 0; i < 6; i++ {
+ for i := 0; i < 5; i++ {
_, err := lib9p.RecvMsg(sr)
if err != nil {
t.Fatal(err)
}
}
- t.Logf("cancel requests.")
cancel0()
wg.Wait()
+ if len(c.rPool.m) != 0 {
+ t.Errorf("req pool clogged: %v", c.rPool)
+ }
}
diff --git a/client/file_test.go b/client/file_test.go
@@ -36,14 +36,14 @@ func TestFileStat(t *testing.T) {
}
defer cancel()
var testFileStat = func(path string, d fs.DirEntry, err error) error {
- cfi, err := fs.Stat(lib9p.ExportFS{cfs}, path)
+ cfi, err := fs.Stat(lib9p.ExportFS{FS: cfs}, path)
if err != nil {
if strings.Contains(err.Error(), "permission") {
return nil
}
t.Error(err)
}
- tfi, err := fs.Stat(lib9p.ExportFS{testfs}, path)
+ tfi, err := fs.Stat(lib9p.ExportFS{FS: testfs}, path)
if err != nil {
t.Error(err)
}
diff --git a/client/req.go b/client/req.go
@@ -15,6 +15,8 @@ type req struct {
err error
errc chan error // To report any client side error to transact().
rxc chan *req
+ // used to notify client goroutines that the request is canceled.
+ done chan struct{}
}
// newReq allocates a req with msg.
@@ -24,14 +26,10 @@ func newReq(msg lib9p.Msg) *req {
tmsg: msg,
rxc: make(chan *req),
errc: make(chan error),
+ done: make(chan struct{}),
}
}
-func (r *req) cancel() {
- close(r.rxc)
- close(r.errc)
-}
-
// reqPool is the pool of Reqs the server is dealing with.
type reqPool struct {
m map[uint16]*req
@@ -75,15 +73,21 @@ func (rp *reqPool) delete(tag uint16) {
delete(rp.m, tag)
}
-func (rp *reqPool) cancelAll() {
+func (rp *reqPool) cancelAll(err error) {
rp.Lock()
defer rp.Unlock()
for tag, r := range rp.m {
- r.cancel()
+ r.errc <- err
delete(rp.m, tag)
}
}
+func (rp *reqPool) String() string {
+ rp.Lock()
+ defer rp.Unlock()
+ return fmt.Sprint(rp.m)
+}
+
// tagPool is a pool of tags being used by a client.
type tagPool struct {
m map[uint16]bool