commit 58f68fc6fd687b6efa2d857d20927b832dbc8adc
parent 9195bcd6cff0707bdabf5dba81787d8429f6f91b
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Sun, 21 Jan 2024 11:42:45 +0900
move error handler into client's multiplexer
Diffstat:
2 files changed, 17 insertions(+), 120 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -27,10 +27,6 @@ type Client struct {
 	// Txc is used to send a reqest to the multiplexer goroutine
 	txc chan<- *req
 
-	// Errc is used to report any error which is not relevant to
-	// a specific request
-	errc chan error
-
 	// Cancel is the CancelFunc to stop the goroutines evoked by this client.
 	cancel context.CancelFunc
 
@@ -57,12 +53,10 @@ func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w i
 		mSizeLock: new(sync.Mutex),
 		uname:     uname,
 		fPool:     allocClientFidPool(),
-		errc:      make(chan error),
 		wg:        new(sync.WaitGroup),
 		done:      ctx.Done(),
 	}
 	c.txc = c.runMultiplexer(ctx, r, w)
-	c.runErrorReporter(ctx)
 	return c
 }
 
@@ -80,33 +74,6 @@ func (c *Client) setMSize(mSize uint32) {
 	c.msize = mSize
 }
 
-// TODO: handle errors properly.
-// By just printing log message, transact function can't return.
-func (c *Client) runErrorReporter(ctx context.Context) {
-	go func() {
-		for {
-			select {
-			case err, ok := <-c.errc:
-				if !ok {
-					return
-				}
-				switch {
-				case errors.Is(err, io.EOF):
-					return
-				// case receive error:
-				// case send error:
-				// I want to flush all blocking request calls by
-				// returning the error.
-				default:
-				}
-				log.Println("client err:", err)
-			case <-ctx.Done():
-				return
-			}
-		}
-	}()
-}
-
 // RunMultiplexer runs two goroutines,
 // one for recieving Rmsg and another for sending Tmsg.
 // The goroutine for Tmsg recieves *req from the returned channel,
@@ -131,16 +98,12 @@ func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) c
 			}
 			msg, err := lib9p.RecvMsg(r)
 			if err != nil {
-				if err == io.EOF {
-					c.errc <- err
-				} else {
-					c.errc <- fmt.Errorf("recv: %v", err)
-				}
-				continue
+				rPool.cancelAll(fmt.Errorf("recv: %v", err))
+				continue // TODO: should return?
 			}
 			rq, ok := rPool.lookup(msg.GetTag())
 			if !ok {
-				c.errc <- fmt.Errorf("mux: unknown tag for msg: %v", msg)
+				log.Printf("mux: unknown tag for msg: %v", msg)
 				continue // TODO: how to recover?
 			}
 			rPool.delete(msg.GetTag())
@@ -177,7 +140,7 @@ func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) c
 				}
 				rPool.add(r)
 				if err := lib9p.SendMsg(r.tmsg, w); err != nil {
-					c.errc <- fmt.Errorf("send: %v", err)
+					r.errc <- fmt.Errorf("send: %v", err)
 				}
 			}
 		}
diff --git a/client/client_test.go b/client/client_test.go
@@ -59,7 +59,6 @@ func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client,
 		mSizeLock: new(sync.Mutex),
 		uname:     uname,
 		fPool:     allocClientFidPool(),
-		errc:      make(chan error),
 		wg:        new(sync.WaitGroup),
 	}
 	cr, sw := io.Pipe()
@@ -108,12 +107,7 @@ func TestVersion(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RVersion:
 				if goterr != nil {
@@ -175,12 +169,7 @@ func TestAuth(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RAuth:
 				if goterr != nil {
@@ -241,12 +230,7 @@ func TestAttach(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RAttach:
 				if goterr != nil {
@@ -303,12 +287,7 @@ func TestFlush(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RFlush:
 				if goterr != nil {
@@ -365,12 +344,7 @@ func TestWalk(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RWalk:
 				if goterr != nil {
@@ -432,12 +406,7 @@ func TestOpen(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.ROpen:
 				if goterr != nil {
@@ -499,12 +468,7 @@ func TestCreate(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RCreate:
 				if goterr != nil {
@@ -565,12 +529,7 @@ func TestRead(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RRead:
 				if goterr != nil {
@@ -631,12 +590,7 @@ func TestWrite(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RWrite:
 				if goterr != nil {
@@ -696,12 +650,7 @@ func TestClunk(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RClunk:
 				if goterr != nil {
@@ -757,12 +706,7 @@ func TestRemove(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RRemove:
 				if goterr != nil {
@@ -816,12 +760,7 @@ func TestStat(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RStat:
 				if goterr != nil {
@@ -881,12 +820,7 @@ func TestWstat(t *testing.T) {
 				return
 			}
 			lib9p.SendMsg(test.rmsg, w)
-			select {
-			case err := <-c.errc:
-				t.Errorf("client error: %v", err)
-				return
-			case <-done:
-			}
+			<-done
 			switch ofcall := test.rmsg.(type) {
 			case *lib9p.RWstat:
 				if goterr != nil {