commit 99140d0231703ffba9f8e08b33029f2fa6922bcb
parent 1579f3bf347e9b217b1f310a885bc7443491122a
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Sun, 21 Jan 2024 09:17:42 +0900
delete client' speaker goroutine
Diffstat:
2 files changed, 87 insertions(+), 74 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -49,7 +49,8 @@ type Client struct {
 
 // NewClient creates a Client and prepare to transact with a server via r and w.
 // It runs several goroutines to handle requests.
-// And the returned client should be stopped by cancelling ctx afterwards.
+// And the returned client should be stopped by cancelling ctx and closing
+// r and w afterwords.
 func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
 	c := &Client{
 		msize:     mSize,
@@ -60,8 +61,7 @@ func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w i
 		wg:        new(sync.WaitGroup),
 		done:      ctx.Done(),
 	}
-	tmsgc := c.runSpeaker(ctx, w)
-	c.txc = c.runMultiplexer(ctx, tmsgc, r)
+	c.txc = c.runMultiplexer(ctx, r, w)
 	c.runErrorReporter(ctx)
 	return c
 }
@@ -102,33 +102,6 @@ func (c *Client) runErrorReporter(ctx context.Context) {
 	}()
 }
 
-// RunSpeaker runs speaker goroutine.
-// Speaker goroutine recieves 9P lib9p.Msgs from the returned channel,
-// marshal them into byte arrays and sends them to w.
-// It reports any errors to the clients errc channel.
-// It returnes when ctx is canceled.
-func (c *Client) runSpeaker(ctx context.Context, w io.Writer) chan<- lib9p.Msg {
-	c.wg.Add(1)
-	tmsgc := make(chan lib9p.Msg, 3)
-	go func() {
-		defer c.wg.Done()
-		for {
-			select {
-			case <-ctx.Done():
-				return
-			case msg := <-tmsgc:
-				if msg == nil {
-					return
-				}
-				if err := lib9p.SendMsg(msg, w); err != nil {
-					c.errc <- fmt.Errorf("send: %v", err)
-				}
-			}
-		}
-	}()
-	return tmsgc
-}
-
 // RunMultiplexer runs two goroutines,
 // one for recieving Rmsg and another for sending Tmsg.
 // The goroutine for Tmsg recieves *req from the returned channel,
@@ -138,7 +111,7 @@ func (c *Client) runSpeaker(ctx context.Context, w io.Writer) chan<- lib9p.Msg {
 // After recieving the reply, it sets the *req.rmsg and sends it to the
 // *req.rxc.
 // It reports any errors to the client's errc channel.
-func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, r io.Reader) chan<- *req {
+func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) chan<- *req {
 	c.wg.Add(2)
 	txc := make(chan *req)
 	rPool := newReqPool()
@@ -179,24 +152,22 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, r i
 	}()
 	// Tmsg
 	go func() {
-		defer func() {
-			close(tmsgc)
-			c.wg.Done()
-		}()
+		defer c.wg.Done()
 		for {
 			select {
 			case <-ctx.Done():
 				return
-			case r := <-txc:
+			case r, ok := <-txc:
+				if !ok {
+					return
+				}
 				if _, ok := rPool.lookup(r.tag); ok {
 					r.errc <- fmt.Errorf("mux: %w: %d", lib9p.ErrDupTag, r.tag)
 					continue
 				}
 				rPool.add(r)
-				select {
-				case tmsgc <- r.tmsg:
-				case <-ctx.Done():
-					return
+				if err := lib9p.SendMsg(r.tmsg, w); err != nil {
+					c.errc <- fmt.Errorf("send: %v", err)
 				}
 			}
 		}
diff --git a/client/client_test.go b/client/client_test.go
@@ -40,17 +40,20 @@ func TestClientCancel(t *testing.T) {
 
 func TestDupTag(t *testing.T) {
 	ctx, cancel := context.WithCancel(context.Background())
-	c, tmsgc, _ := newClientForTest(ctx, 1024, "glenda")
+	c, r, _ := newClientForTest(ctx, 1024, "glenda")
 	defer cancel()
 	go c.Version(0, 1024, "9P2000")
-	<-tmsgc
-	_, _, err := c.Version(0, 1024, "9P2000")
+	_, err := lib9p.RecvMsg(r)
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, _, err = c.Version(0, 1024, "9P2000")
 	if !errors.Is(err, lib9p.ErrDupTag) {
 		t.Error("duplicate tag error not reported: err:", err)
 	}
 }
 
-func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client, <-chan lib9p.Msg, io.Writer) {
+func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client, io.Reader, io.Writer) {
 	c := &Client{
 		msize:     msize,
 		mSizeLock: new(sync.Mutex),
@@ -59,10 +62,10 @@ func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client,
 		errc:      make(chan error),
 		wg:        new(sync.WaitGroup),
 	}
-	tmsgc := make(chan lib9p.Msg)
-	r, w := io.Pipe()
-	c.txc = c.runMultiplexer(ctx, tmsgc, r)
-	return c, tmsgc, w
+	cr, sw := io.Pipe()
+	sr, cw := io.Pipe()
+	c.txc = c.runMultiplexer(ctx, cr, cw)
+	return c, sr, sw
 }
 
 func TestVersion(t *testing.T) {
@@ -81,7 +84,7 @@ func TestVersion(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotmsize   uint32
@@ -95,7 +98,10 @@ func TestVersion(t *testing.T) {
 					c.Version(ifcall.Tag, ifcall.Msize, ifcall.Version)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -146,7 +152,7 @@ func TestAuth(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotaqid lib9p.Qid
@@ -159,7 +165,10 @@ func TestAuth(t *testing.T) {
 					c.Auth(ifcall.Tag, ifcall.Afid, ifcall.Uname, ifcall.Aname)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -209,7 +218,7 @@ func TestAttach(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotqid lib9p.Qid
@@ -222,7 +231,10 @@ func TestAttach(t *testing.T) {
 					c.Attach(ifcall.Tag, ifcall.Fid, ifcall.Afid, ifcall.Uname, ifcall.Aname)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -269,7 +281,7 @@ func TestFlush(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				goterr error
@@ -281,7 +293,10 @@ func TestFlush(t *testing.T) {
 					c.Flush(ifcall.Tag, ifcall.Oldtag)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -327,7 +342,7 @@ func TestWalk(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotqids []lib9p.Qid
@@ -340,7 +355,10 @@ func TestWalk(t *testing.T) {
 					c.Walk(ifcall.Tag, ifcall.Fid, ifcall.Newfid, ifcall.Wnames)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -390,7 +408,7 @@ func TestOpen(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotqid    lib9p.Qid
@@ -404,7 +422,10 @@ func TestOpen(t *testing.T) {
 					c.Open(ifcall.Tag, ifcall.Fid, ifcall.Mode)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -454,7 +475,7 @@ func TestCreate(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotqid    lib9p.Qid
@@ -468,7 +489,10 @@ func TestCreate(t *testing.T) {
 					c.Create(ifcall.Tag, ifcall.Fid, ifcall.Name, ifcall.Perm, ifcall.Mode)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -518,7 +542,7 @@ func TestRead(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotdata []byte
@@ -531,7 +555,10 @@ func TestRead(t *testing.T) {
 					c.Read(ifcall.Tag, ifcall.Fid, ifcall.Offset, ifcall.Count)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -581,7 +608,7 @@ func TestWrite(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotcount uint32
@@ -594,7 +621,10 @@ func TestWrite(t *testing.T) {
 					c.Write(ifcall.Tag, ifcall.Fid, ifcall.Offset, ifcall.Count, ifcall.Data)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -644,7 +674,7 @@ func TestClunk(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				goterr error
@@ -656,7 +686,10 @@ func TestClunk(t *testing.T) {
 					c.Clunk(ifcall.Tag, ifcall.Fid)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -702,7 +735,7 @@ func TestRemove(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				goterr error
@@ -714,7 +747,10 @@ func TestRemove(t *testing.T) {
 					c.Remove(ifcall.Tag, ifcall.Fid)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -757,7 +793,7 @@ func TestStat(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				gotstat *lib9p.Stat
@@ -770,7 +806,10 @@ func TestStat(t *testing.T) {
 					c.Stat(ifcall.Tag, ifcall.Fid)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)
@@ -820,7 +859,7 @@ func TestWstat(t *testing.T) {
 	for _, test := range tests {
 		func() {
 			ctx, cancel := context.WithCancel(context.Background())
-			c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+			c, r, w := newClientForTest(ctx, 1024, "glenda")
 			defer cancel()
 			var (
 				goterr error
@@ -832,7 +871,10 @@ func TestWstat(t *testing.T) {
 					c.Wstat(ifcall.Tag, ifcall.Fid, ifcall.Stat)
 				close(done)
 			}()
-			gottmsg := <-tmsgc
+			gottmsg, err := lib9p.RecvMsg(r)
+			if err != nil {
+				t.Error(err)
+			}
 			if !reflect.DeepEqual(test.tmsg, gottmsg) {
 				t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot:  %v",
 					test.name, test.tmsg, gottmsg)