commit 0a9e489cbcf9f090d653799afa97794b91ae6de8
parent ec0e71d99f8cf1b00224e07131dc2cbdcbbaa60c
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Sun, 21 Jan 2024 14:32:29 +0900
add client TestMux
Diffstat:
2 files changed, 68 insertions(+), 16 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -24,6 +24,9 @@ type Client struct {
 	// FPool is the fidPool which hold the list of open fids.
 	fPool *fidPool
 
+	// RPool is the set of all outstanding requests.
+	rPool *reqPool
+
 	// Txc is used to send a reqest to the multiplexer goroutine
 	txc chan<- *req
 
@@ -42,14 +45,15 @@ type Client struct {
 
 // NewClient creates a Client and prepare to transact with a server via r and w.
 // It runs several goroutines to handle requests.
-// And the returned client should be stopped by cancelling ctx and closing
-// r and w afterwords.
+// And the returned client should be stopped afterwords by cancelling ctx
+// and then closing r and w.
 func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
 	c := &Client{
 		msize:     mSize,
 		mSizeLock: new(sync.Mutex),
 		uname:     uname,
 		fPool:     allocClientFidPool(),
+		rPool:     newReqPool(),
 		wg:        new(sync.WaitGroup),
 		done:      ctx.Done(),
 	}
@@ -74,16 +78,12 @@ func (c *Client) setMSize(mSize uint32) {
 // RunMultiplexer runs two goroutines,
 // one for recieving Rmsg and another for sending Tmsg.
 // The goroutine for Tmsg recieves *req from the returned channel,
-// and send the lib9p.Msg to the speaker goroutine via tmsgc.
-// The goroutine for Rmsg recieves *req from the Tmsg goroutine and waits for
-// the reply to the corresponding message from the listener goroutine via rmsgc.
-// After recieving the reply, it sets the *req.rmsg and sends it to the
-// *req.rxc.
-// It reports any errors to the client's errc channel.
+// and sends the lib9p.Msg to w.
+// The goroutine for Rmsg reads lib9p.Msg from r and sends it to rxc of the
+// corresponding request.
 func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) chan<- *req {
 	c.wg.Add(2)
 	txc := make(chan *req)
-	rPool := newReqPool()
 	// Rmsg
 	go func() {
 		defer c.wg.Done()
@@ -95,22 +95,22 @@ func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) c
 			}
 			msg, err := lib9p.RecvMsg(r)
 			if err != nil {
-				rPool.cancelAll(fmt.Errorf("recv: %v", err))
+				c.rPool.cancelAll(fmt.Errorf("recv: %v", err))
 				continue // TODO: should return?
 			}
-			rq, ok := rPool.lookup(msg.GetTag())
+			rq, ok := c.rPool.lookup(msg.GetTag())
 			if !ok {
 				log.Printf("mux: unknown tag for msg: %v", msg)
 				continue // TODO: how to recover?
 			}
-			rPool.delete(msg.GetTag())
+			c.rPool.delete(msg.GetTag())
 			if tflush, ok := rq.tmsg.(*lib9p.TFlush); ok {
 				if _, ok := msg.(*lib9p.RFlush); !ok {
 					rq.errc <- fmt.Errorf("mux: response to Tflush is not Rflush")
 				}
-				if oldreq, ok := rPool.lookup(tflush.Oldtag); ok {
+				if oldreq, ok := c.rPool.lookup(tflush.Oldtag); ok {
 					oldreq.errc <- errors.New("request flushed")
-					rPool.delete(tflush.Oldtag)
+					c.rPool.delete(tflush.Oldtag)
 				}
 			}
 			rq.rmsg = msg
@@ -131,11 +131,11 @@ func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) c
 				if !ok {
 					return
 				}
-				if _, ok := rPool.lookup(r.tag); ok {
+				if _, ok := c.rPool.lookup(r.tag); ok {
 					r.errc <- fmt.Errorf("mux: %w: %d", lib9p.ErrDupTag, r.tag)
 					continue
 				}
-				rPool.add(r)
+				c.rPool.add(r)
 				if err := lib9p.SendMsg(r.tmsg, w); err != nil {
 					r.errc <- fmt.Errorf("send: %v", err)
 				}
diff --git a/client/client_test.go b/client/client_test.go
@@ -59,6 +59,7 @@ func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client,
 		mSizeLock: new(sync.Mutex),
 		uname:     uname,
 		fPool:     allocClientFidPool(),
+		rPool:     newReqPool(),
 		wg:        new(sync.WaitGroup),
 	}
 	cr, sw := io.Pipe()
@@ -67,6 +68,57 @@ func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client,
 	return c, sr, sw
 }
 
+func TestMux(t *testing.T) {
+	tests := []struct{
+		tmsg, rmsg lib9p.Msg
+	}{
+		{&lib9p.TVersion{}, &lib9p.RVersion{}},
+		{&lib9p.TCreate{}, &lib9p.RError{Ename:errors.New("e")}},
+	}
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	c, r, w := newClientForTest(ctx, 1024, "glenda")
+	for i, test := range tests{
+		test.tmsg.SetTag(uint16(i))
+		test.rmsg.SetTag(uint16(i))
+		rq := newReq(test.tmsg)
+		c.txc <- rq
+		gottmsg, err := lib9p.RecvMsg(r)
+		if err != nil {
+			t.Error(err)
+			continue
+		}
+		if !reflect.DeepEqual(test.tmsg, gottmsg) {
+			t.Errorf("%d: mux modified tmsg:\n\twant: %v\n\tgot:  %v",
+				i, test.tmsg, gottmsg)
+		}
+		rq0, ok := c.rPool.lookup(uint16(i))
+		if !ok {
+			t.Errorf("%d: req not registered to the pool.", i)
+			continue
+		}
+		if rq != rq0 {
+			t.Errorf("%d: wrong req registered.", i)
+			continue
+		}
+		if err = lib9p.SendMsg(test.rmsg, w); err != nil {
+			t.Error(err)
+			continue
+		}
+		rq = <- rq.rxc
+		if !reflect.DeepEqual(test.rmsg, rq.rmsg) {
+			t.Errorf("%d: mux modified tmsg:\n\twant: %v\n\tgot:  %v",
+				i, test.tmsg, gottmsg)
+		}
+		_, ok = c.rPool.lookup(uint16(i))
+		if ok {
+			t.Errorf("req not deleted from the pool.")
+			continue
+		}
+	}
+}
+
 func TestVersion(t *testing.T) {
 	tests := []struct {
 		name string