commit 1579f3bf347e9b217b1f310a885bc7443491122a
parent 813e7119de7a4a2864c487c87b59eec939c366da
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Sun, 21 Jan 2024 08:25:02 +0900
delete client's listener goroutine
Diffstat:
| M | client/client.go | | | 109 | +++++++++++++++++++++---------------------------------------------------------- | 
| M | client/client_test.go | | | 61 | +++++++++++++++++++++++++++++++------------------------------ | 
2 files changed, 59 insertions(+), 111 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -61,8 +61,7 @@ func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w i
 		done:      ctx.Done(),
 	}
 	tmsgc := c.runSpeaker(ctx, w)
-	rmsgc := c.runListener(ctx, r)
-	c.txc = c.runMultiplexer(ctx, tmsgc, rmsgc)
+	c.txc = c.runMultiplexer(ctx, tmsgc, r)
 	c.runErrorReporter(ctx)
 	return c
 }
@@ -103,57 +102,6 @@ func (c *Client) runErrorReporter(ctx context.Context) {
 	}()
 }
 
-// RunListener runs listener goroutine.
-// Listener reads byte array of 9P messages from r and make each of them into
-// corresponding struct that implements lib9p.Msg,
-// and sends it to the returned channel.
-// Listener goroutine returns when ctx is canceled.
-// Listener goroutine reports errors to the client's errc channel.
-func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan lib9p.Msg {
-	c.wg.Add(1)
-	rmsgc := make(chan lib9p.Msg, 3)
-	go func() {
-		defer func() {
-			close(rmsgc)
-			c.wg.Done()
-		}()
-		for {
-			select {
-			case <-ctx.Done():
-				return
-			default:
-				done := make(chan struct{})
-				var (
-					msg lib9p.Msg
-					err error
-				)
-				go func() {
-					defer close(done)
-					msg, err = lib9p.RecvMsg(r)
-				}()
-				select {
-				case <-done:
-				case <-ctx.Done():
-					return
-				}
-				if err != nil {
-					if err == io.EOF {
-						c.errc <- err
-					} else {
-						c.errc <- fmt.Errorf("recv: %v", err)
-					}
-					continue
-				}
-				select {
-				case rmsgc <- msg:
-				case <-ctx.Done():
-				}
-			}
-		}
-	}()
-	return rmsgc
-}
-
 // RunSpeaker runs speaker goroutine.
 // Speaker goroutine recieves 9P lib9p.Msgs from the returned channel,
 // marshal them into byte arrays and sends them to w.
@@ -190,44 +138,43 @@ func (c *Client) runSpeaker(ctx context.Context, w io.Writer) chan<- lib9p.Msg {
 // After recieving the reply, it sets the *req.rmsg and sends it to the
 // *req.rxc.
 // It reports any errors to the client's errc channel.
-func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rmsgc <-chan lib9p.Msg) chan<- *req {
+func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, r io.Reader) chan<- *req {
 	c.wg.Add(2)
 	txc := make(chan *req)
 	rPool := newReqPool()
 	// Rmsg
 	go func() {
-		defer func() {
-			c.wg.Done()
-		}()
+		defer c.wg.Done()
 		for {
-			select {
-			case <-ctx.Done():
-				return
-			case msg, ok := <-rmsgc:
-				if !ok {
-					return
+			msg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				if err == io.EOF {
+					c.errc <- err
+				} else {
+					c.errc <- fmt.Errorf("recv: %v", err)
 				}
-				r, ok := rPool.lookup(msg.GetTag())
-				if !ok {
-					c.errc <- fmt.Errorf("mux: unknown tag for msg: %v", msg)
-					continue
-				}
-				rPool.delete(msg.GetTag())
-				if tflush, ok := r.tmsg.(*lib9p.TFlush); ok {
-					if _, ok := msg.(*lib9p.RFlush); !ok {
-						r.errc <- fmt.Errorf("mux: response to Tflush is not Rflush")
-					}
-					if oldreq, ok := rPool.lookup(tflush.Oldtag); ok {
-						oldreq.errc <- errors.New("request flushed")
-						rPool.delete(tflush.Oldtag)
-					}
+				continue
+			}
+			rq, ok := rPool.lookup(msg.GetTag())
+			if !ok {
+				c.errc <- fmt.Errorf("mux: unknown tag for msg: %v", msg)
+				continue
+			}
+			rPool.delete(msg.GetTag())
+			if tflush, ok := rq.tmsg.(*lib9p.TFlush); ok {
+				if _, ok := msg.(*lib9p.RFlush); !ok {
+					rq.errc <- fmt.Errorf("mux: response to Tflush is not Rflush")
 				}
-				r.rmsg = msg
-				select {
-				case <-ctx.Done():
-				case r.rxc <- r:
+				if oldreq, ok := rPool.lookup(tflush.Oldtag); ok {
+					oldreq.errc <- errors.New("request flushed")
+					rPool.delete(tflush.Oldtag)
 				}
 			}
+			rq.rmsg = msg
+			select {
+			case <-ctx.Done():
+			case rq.rxc <- rq:
+			}
 		}
 	}()
 	// Tmsg
diff --git a/client/client_test.go b/client/client_test.go
@@ -50,7 +50,7 @@ func TestDupTag(t *testing.T) {
 	}
 }
 
-func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client, <-chan lib9p.Msg, chan<- lib9p.Msg) {
+func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client, <-chan lib9p.Msg, io.Writer) {
 	c := &Client{
 		msize:     msize,
 		mSizeLock: new(sync.Mutex),
@@ -59,9 +59,10 @@ func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client,
 		errc:      make(chan error),
 		wg:        new(sync.WaitGroup),
 	}
-	tmsgc, rmsgc := make(chan lib9p.Msg), make(chan lib9p.Msg)
-	c.txc = c.runMultiplexer(ctx, tmsgc, rmsgc)
-	return c, tmsgc, rmsgc
+	tmsgc := make(chan lib9p.Msg)
+	r, w := io.Pipe()
+	c.txc = c.runMultiplexer(ctx, tmsgc, r)
+	return c, tmsgc, w
 }
 
 func TestVersion(t *testing.T) {
@@ -80,7 +81,7 @@ func TestVersion(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotmsize   uint32
@@ -100,7 +101,7 @@ func TestVersion(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -145,7 +146,7 @@ func TestAuth(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotaqid lib9p.Qid
@@ -164,7 +165,7 @@ func TestAuth(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -208,7 +209,7 @@ func TestAttach(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotqid lib9p.Qid
@@ -227,7 +228,7 @@ func TestAttach(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -268,7 +269,7 @@ func TestFlush(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				goterr error
@@ -286,7 +287,7 @@ func TestFlush(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -326,7 +327,7 @@ func TestWalk(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotqids []lib9p.Qid
@@ -345,7 +346,7 @@ func TestWalk(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -389,7 +390,7 @@ func TestOpen(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotqid    lib9p.Qid
@@ -409,7 +410,7 @@ func TestOpen(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -453,7 +454,7 @@ func TestCreate(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotqid    lib9p.Qid
@@ -473,7 +474,7 @@ func TestCreate(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -517,7 +518,7 @@ func TestRead(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotdata []byte
@@ -536,7 +537,7 @@ func TestRead(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -580,7 +581,7 @@ func TestWrite(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotcount uint32
@@ -599,7 +600,7 @@ func TestWrite(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -643,7 +644,7 @@ func TestClunk(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				goterr error
@@ -661,7 +662,7 @@ func TestClunk(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -701,7 +702,7 @@ func TestRemove(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				goterr error
@@ -719,7 +720,7 @@ func TestRemove(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -756,7 +757,7 @@ func TestStat(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotstat *lib9p.Stat
@@ -775,7 +776,7 @@ func TestStat(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)
@@ -819,7 +820,7 @@ func TestWstat(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
+			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				goterr error
@@ -837,7 +838,7 @@ func TestWstat(t *testing.T) {
 					test.name, test.tmsg, gottmsg)
 				return
 			}
-			rmsgc <- test.rmsg
+			lib9p.SendMsg(test.rmsg, w)
 			select {
 			case err := <-c.errc:
 				t.Errorf("client error: %v", err)