commit 0a9e489cbcf9f090d653799afa97794b91ae6de8
parent ec0e71d99f8cf1b00224e07131dc2cbdcbbaa60c
Author: Matsuda Kenji <info@mtkn.jp>
Date: Sun, 21 Jan 2024 14:32:29 +0900
add client TestMux
Diffstat:
2 files changed, 68 insertions(+), 16 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -24,6 +24,9 @@ type Client struct {
// FPool is the fidPool which hold the list of open fids.
fPool *fidPool
+ // RPool is the set of all outstanding requests.
+ rPool *reqPool
+
// Txc is used to send a reqest to the multiplexer goroutine
txc chan<- *req
@@ -42,14 +45,15 @@ type Client struct {
// NewClient creates a Client and prepare to transact with a server via r and w.
// It runs several goroutines to handle requests.
-// And the returned client should be stopped by cancelling ctx and closing
-// r and w afterwords.
+// And the returned client should be stopped afterwords by cancelling ctx
+// and then closing r and w.
func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
c := &Client{
msize: mSize,
mSizeLock: new(sync.Mutex),
uname: uname,
fPool: allocClientFidPool(),
+ rPool: newReqPool(),
wg: new(sync.WaitGroup),
done: ctx.Done(),
}
@@ -74,16 +78,12 @@ func (c *Client) setMSize(mSize uint32) {
// RunMultiplexer runs two goroutines,
// one for recieving Rmsg and another for sending Tmsg.
// The goroutine for Tmsg recieves *req from the returned channel,
-// and send the lib9p.Msg to the speaker goroutine via tmsgc.
-// The goroutine for Rmsg recieves *req from the Tmsg goroutine and waits for
-// the reply to the corresponding message from the listener goroutine via rmsgc.
-// After recieving the reply, it sets the *req.rmsg and sends it to the
-// *req.rxc.
-// It reports any errors to the client's errc channel.
+// and sends the lib9p.Msg to w.
+// The goroutine for Rmsg reads lib9p.Msg from r and sends it to rxc of the
+// corresponding request.
func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) chan<- *req {
c.wg.Add(2)
txc := make(chan *req)
- rPool := newReqPool()
// Rmsg
go func() {
defer c.wg.Done()
@@ -95,22 +95,22 @@ func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) c
}
msg, err := lib9p.RecvMsg(r)
if err != nil {
- rPool.cancelAll(fmt.Errorf("recv: %v", err))
+ c.rPool.cancelAll(fmt.Errorf("recv: %v", err))
continue // TODO: should return?
}
- rq, ok := rPool.lookup(msg.GetTag())
+ rq, ok := c.rPool.lookup(msg.GetTag())
if !ok {
log.Printf("mux: unknown tag for msg: %v", msg)
continue // TODO: how to recover?
}
- rPool.delete(msg.GetTag())
+ c.rPool.delete(msg.GetTag())
if tflush, ok := rq.tmsg.(*lib9p.TFlush); ok {
if _, ok := msg.(*lib9p.RFlush); !ok {
rq.errc <- fmt.Errorf("mux: response to Tflush is not Rflush")
}
- if oldreq, ok := rPool.lookup(tflush.Oldtag); ok {
+ if oldreq, ok := c.rPool.lookup(tflush.Oldtag); ok {
oldreq.errc <- errors.New("request flushed")
- rPool.delete(tflush.Oldtag)
+ c.rPool.delete(tflush.Oldtag)
}
}
rq.rmsg = msg
@@ -131,11 +131,11 @@ func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) c
if !ok {
return
}
- if _, ok := rPool.lookup(r.tag); ok {
+ if _, ok := c.rPool.lookup(r.tag); ok {
r.errc <- fmt.Errorf("mux: %w: %d", lib9p.ErrDupTag, r.tag)
continue
}
- rPool.add(r)
+ c.rPool.add(r)
if err := lib9p.SendMsg(r.tmsg, w); err != nil {
r.errc <- fmt.Errorf("send: %v", err)
}
diff --git a/client/client_test.go b/client/client_test.go
@@ -59,6 +59,7 @@ func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client,
mSizeLock: new(sync.Mutex),
uname: uname,
fPool: allocClientFidPool(),
+ rPool: newReqPool(),
wg: new(sync.WaitGroup),
}
cr, sw := io.Pipe()
@@ -67,6 +68,57 @@ func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client,
return c, sr, sw
}
+func TestMux(t *testing.T) {
+ tests := []struct{
+ tmsg, rmsg lib9p.Msg
+ }{
+ {&lib9p.TVersion{}, &lib9p.RVersion{}},
+ {&lib9p.TCreate{}, &lib9p.RError{Ename:errors.New("e")}},
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
+ for i, test := range tests{
+ test.tmsg.SetTag(uint16(i))
+ test.rmsg.SetTag(uint16(i))
+ rq := newReq(test.tmsg)
+ c.txc <- rq
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if !reflect.DeepEqual(test.tmsg, gottmsg) {
+ t.Errorf("%d: mux modified tmsg:\n\twant: %v\n\tgot: %v",
+ i, test.tmsg, gottmsg)
+ }
+ rq0, ok := c.rPool.lookup(uint16(i))
+ if !ok {
+ t.Errorf("%d: req not registered to the pool.", i)
+ continue
+ }
+ if rq != rq0 {
+ t.Errorf("%d: wrong req registered.", i)
+ continue
+ }
+ if err = lib9p.SendMsg(test.rmsg, w); err != nil {
+ t.Error(err)
+ continue
+ }
+ rq = <- rq.rxc
+ if !reflect.DeepEqual(test.rmsg, rq.rmsg) {
+ t.Errorf("%d: mux modified tmsg:\n\twant: %v\n\tgot: %v",
+ i, test.tmsg, gottmsg)
+ }
+ _, ok = c.rPool.lookup(uint16(i))
+ if ok {
+ t.Errorf("req not deleted from the pool.")
+ continue
+ }
+ }
+}
+
func TestVersion(t *testing.T) {
tests := []struct {
name string