commit 99140d0231703ffba9f8e08b33029f2fa6922bcb
parent 1579f3bf347e9b217b1f310a885bc7443491122a
Author: Matsuda Kenji <info@mtkn.jp>
Date: Sun, 21 Jan 2024 09:17:42 +0900
delete client' speaker goroutine
Diffstat:
2 files changed, 87 insertions(+), 74 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -49,7 +49,8 @@ type Client struct {
// NewClient creates a Client and prepare to transact with a server via r and w.
// It runs several goroutines to handle requests.
-// And the returned client should be stopped by cancelling ctx afterwards.
+// And the returned client should be stopped by cancelling ctx and closing
+// r and w afterwords.
func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
c := &Client{
msize: mSize,
@@ -60,8 +61,7 @@ func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w i
wg: new(sync.WaitGroup),
done: ctx.Done(),
}
- tmsgc := c.runSpeaker(ctx, w)
- c.txc = c.runMultiplexer(ctx, tmsgc, r)
+ c.txc = c.runMultiplexer(ctx, r, w)
c.runErrorReporter(ctx)
return c
}
@@ -102,33 +102,6 @@ func (c *Client) runErrorReporter(ctx context.Context) {
}()
}
-// RunSpeaker runs speaker goroutine.
-// Speaker goroutine recieves 9P lib9p.Msgs from the returned channel,
-// marshal them into byte arrays and sends them to w.
-// It reports any errors to the clients errc channel.
-// It returnes when ctx is canceled.
-func (c *Client) runSpeaker(ctx context.Context, w io.Writer) chan<- lib9p.Msg {
- c.wg.Add(1)
- tmsgc := make(chan lib9p.Msg, 3)
- go func() {
- defer c.wg.Done()
- for {
- select {
- case <-ctx.Done():
- return
- case msg := <-tmsgc:
- if msg == nil {
- return
- }
- if err := lib9p.SendMsg(msg, w); err != nil {
- c.errc <- fmt.Errorf("send: %v", err)
- }
- }
- }
- }()
- return tmsgc
-}
-
// RunMultiplexer runs two goroutines,
// one for recieving Rmsg and another for sending Tmsg.
// The goroutine for Tmsg recieves *req from the returned channel,
@@ -138,7 +111,7 @@ func (c *Client) runSpeaker(ctx context.Context, w io.Writer) chan<- lib9p.Msg {
// After recieving the reply, it sets the *req.rmsg and sends it to the
// *req.rxc.
// It reports any errors to the client's errc channel.
-func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, r io.Reader) chan<- *req {
+func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) chan<- *req {
c.wg.Add(2)
txc := make(chan *req)
rPool := newReqPool()
@@ -179,24 +152,22 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, r i
}()
// Tmsg
go func() {
- defer func() {
- close(tmsgc)
- c.wg.Done()
- }()
+ defer c.wg.Done()
for {
select {
case <-ctx.Done():
return
- case r := <-txc:
+ case r, ok := <-txc:
+ if !ok {
+ return
+ }
if _, ok := rPool.lookup(r.tag); ok {
r.errc <- fmt.Errorf("mux: %w: %d", lib9p.ErrDupTag, r.tag)
continue
}
rPool.add(r)
- select {
- case tmsgc <- r.tmsg:
- case <-ctx.Done():
- return
+ if err := lib9p.SendMsg(r.tmsg, w); err != nil {
+ c.errc <- fmt.Errorf("send: %v", err)
}
}
}
diff --git a/client/client_test.go b/client/client_test.go
@@ -40,17 +40,20 @@ func TestClientCancel(t *testing.T) {
func TestDupTag(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, _ := newClientForTest(ctx, 1024, "glenda")
+ c, r, _ := newClientForTest(ctx, 1024, "glenda")
defer cancel()
go c.Version(0, 1024, "9P2000")
- <-tmsgc
- _, _, err := c.Version(0, 1024, "9P2000")
+ _, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, err = c.Version(0, 1024, "9P2000")
if !errors.Is(err, lib9p.ErrDupTag) {
t.Error("duplicate tag error not reported: err:", err)
}
}
-func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client, <-chan lib9p.Msg, io.Writer) {
+func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client, io.Reader, io.Writer) {
c := &Client{
msize: msize,
mSizeLock: new(sync.Mutex),
@@ -59,10 +62,10 @@ func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client,
errc: make(chan error),
wg: new(sync.WaitGroup),
}
- tmsgc := make(chan lib9p.Msg)
- r, w := io.Pipe()
- c.txc = c.runMultiplexer(ctx, tmsgc, r)
- return c, tmsgc, w
+ cr, sw := io.Pipe()
+ sr, cw := io.Pipe()
+ c.txc = c.runMultiplexer(ctx, cr, cw)
+ return c, sr, sw
}
func TestVersion(t *testing.T) {
@@ -81,7 +84,7 @@ func TestVersion(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotmsize uint32
@@ -95,7 +98,10 @@ func TestVersion(t *testing.T) {
c.Version(ifcall.Tag, ifcall.Msize, ifcall.Version)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -146,7 +152,7 @@ func TestAuth(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotaqid lib9p.Qid
@@ -159,7 +165,10 @@ func TestAuth(t *testing.T) {
c.Auth(ifcall.Tag, ifcall.Afid, ifcall.Uname, ifcall.Aname)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -209,7 +218,7 @@ func TestAttach(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotqid lib9p.Qid
@@ -222,7 +231,10 @@ func TestAttach(t *testing.T) {
c.Attach(ifcall.Tag, ifcall.Fid, ifcall.Afid, ifcall.Uname, ifcall.Aname)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -269,7 +281,7 @@ func TestFlush(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
goterr error
@@ -281,7 +293,10 @@ func TestFlush(t *testing.T) {
c.Flush(ifcall.Tag, ifcall.Oldtag)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -327,7 +342,7 @@ func TestWalk(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotqids []lib9p.Qid
@@ -340,7 +355,10 @@ func TestWalk(t *testing.T) {
c.Walk(ifcall.Tag, ifcall.Fid, ifcall.Newfid, ifcall.Wnames)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -390,7 +408,7 @@ func TestOpen(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotqid lib9p.Qid
@@ -404,7 +422,10 @@ func TestOpen(t *testing.T) {
c.Open(ifcall.Tag, ifcall.Fid, ifcall.Mode)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -454,7 +475,7 @@ func TestCreate(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotqid lib9p.Qid
@@ -468,7 +489,10 @@ func TestCreate(t *testing.T) {
c.Create(ifcall.Tag, ifcall.Fid, ifcall.Name, ifcall.Perm, ifcall.Mode)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -518,7 +542,7 @@ func TestRead(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotdata []byte
@@ -531,7 +555,10 @@ func TestRead(t *testing.T) {
c.Read(ifcall.Tag, ifcall.Fid, ifcall.Offset, ifcall.Count)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -581,7 +608,7 @@ func TestWrite(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotcount uint32
@@ -594,7 +621,10 @@ func TestWrite(t *testing.T) {
c.Write(ifcall.Tag, ifcall.Fid, ifcall.Offset, ifcall.Count, ifcall.Data)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -644,7 +674,7 @@ func TestClunk(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
goterr error
@@ -656,7 +686,10 @@ func TestClunk(t *testing.T) {
c.Clunk(ifcall.Tag, ifcall.Fid)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -702,7 +735,7 @@ func TestRemove(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
goterr error
@@ -714,7 +747,10 @@ func TestRemove(t *testing.T) {
c.Remove(ifcall.Tag, ifcall.Fid)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -757,7 +793,7 @@ func TestStat(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotstat *lib9p.Stat
@@ -770,7 +806,10 @@ func TestStat(t *testing.T) {
c.Stat(ifcall.Tag, ifcall.Fid)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)
@@ -820,7 +859,7 @@ func TestWstat(t *testing.T) {
for _, test := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
- c, tmsgc, w := newClientForTest(ctx, 1024, "glenda")
+ c, r, w := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
goterr error
@@ -832,7 +871,10 @@ func TestWstat(t *testing.T) {
c.Wstat(ifcall.Tag, ifcall.Fid, ifcall.Stat)
close(done)
}()
- gottmsg := <-tmsgc
+ gottmsg, err := lib9p.RecvMsg(r)
+ if err != nil {
+ t.Error(err)
+ }
if !reflect.DeepEqual(test.tmsg, gottmsg) {
t.Errorf("%s: tmsg modified:\n\twant: %v\n\tgot: %v",
test.name, test.tmsg, gottmsg)