commit 9f326e9115d885e8fcd7d651ef6773ccf0351a45
parent 5aaab3fe165c5a0d44708c4c091354bd63eecfee
Author: Matsuda Kenji <info@mtkn.jp>
Date: Sun, 7 Jan 2024 08:09:34 +0900
add client.reqPool and use it in mux
Diffstat:
3 files changed, 73 insertions(+), 32 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -24,6 +24,8 @@ type Client struct {
// fPool is the fidPool which hold the list of opend fids.
fPool *fidPool
+ rPool *reqPool
+
// txc is used to send a reqest to the multiplexer goroutine
txc chan<- *req
@@ -51,6 +53,7 @@ func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w i
mSizeLock: new(sync.Mutex),
uname: uname,
fPool: allocClientFidPool(),
+ rPool: newReqPool(),
errc: make(chan error),
wg: new(sync.WaitGroup),
}
@@ -180,55 +183,40 @@ func (c *Client) runSpeaker(ctx context.Context, w io.Writer) chan<- lib9p.Msg {
// and send the 9P lib9p.Msg to the speaker goroutine via tmsgc.
// The goroutine for Rmsg recieves *req from the Tmsg goroutine and waits for
// the reply to the corresponding message from the listener goroutine via rmsgc.
-// After recieving the reply, it sets the *req.rmsg and sends it t the
+// After recieving the reply, it sets the *req.rmsg and sends it to the
// *req.rxc.
// It reports any errors to the client's errc channel.
func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rmsgc <-chan lib9p.Msg) chan<- *req {
c.wg.Add(2)
txc := make(chan *req)
- reqc := make(chan *req)
// Rmsg
- go func(reqc <-chan *req) {
- rPool := make(map[uint16]*req)
+ go func() {
defer func() {
- for _, r := range rPool {
- r.cancel()
- }
+ c.rPool.cancelAll()
c.wg.Done()
}()
for {
select {
case <-ctx.Done():
return
- case r, ok := <-reqc:
- if !ok {
- return
- }
- if _, ok := rPool[r.tag]; ok {
- // r.cancel() is not needed because transaction unblocks by
- // sending error via r.errc.
- r.errc <- fmt.Errorf("mux: duplicate tag: %d", r.tag)
- continue
- }
- rPool[r.tag] = r
case msg, ok := <-rmsgc:
if !ok {
return
}
- r, ok := rPool[msg.GetTag()]
+ r, ok := c.rPool.lookup(msg.GetTag())
if !ok {
c.errc <- fmt.Errorf("mux: unknown tag for msg: %v", msg)
continue
}
- delete(rPool, msg.GetTag())
+ c.rPool.delete(msg.GetTag())
if tflush, ok := r.tmsg.(*lib9p.TFlush); ok {
if _, ok := msg.(*lib9p.RFlush); !ok {
r.errc <- fmt.Errorf("mux: response to Tflush is not Rflush")
}
- if oldreq, ok := rPool[tflush.Oldtag]; ok {
+ if oldreq, ok := c.rPool.lookup(tflush.Oldtag); ok {
oldreq.cancel()
+ c.rPool.delete(tflush.Oldtag)
}
- delete(rPool, tflush.Oldtag)
}
r.rmsg = msg
select {
@@ -237,11 +225,10 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
}
}
}
- }(reqc)
+ }()
// Tmsg
- go func(reqc chan<- *req) {
+ go func() {
defer func() {
- close(reqc)
close(tmsgc)
c.wg.Done()
}()
@@ -250,12 +237,13 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
case <-ctx.Done():
return
case r := <-txc:
- select {
- case reqc <- r:
- case <-ctx.Done():
- r.cancel() // before registering to rPool of Rmsg goroutine.
- return
+ if _, ok := c.rPool.lookup(r.tag); ok {
+ // r.cancel() is not needed because transaction unblocks by
+ // sending error via r.errc.
+ r.errc <- fmt.Errorf("mux: duplicate tag: %d", r.tag)
+ continue
}
+ c.rPool.add(r)
select {
case tmsgc <- r.tmsg:
case <-ctx.Done():
@@ -265,7 +253,7 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
}
}
}
- }(reqc)
+ }()
return txc
}
diff --git a/client/client_test.go b/client/client_test.go
@@ -29,7 +29,7 @@ func TestClientCancel(t *testing.T) {
t.Logf("%d: %v", i, err)
}(i)
}
- for i := 0; i < 7; i++ {
+ for i := 0; i < 6; i++ {
_, err := lib9p.RecvMsg(sr)
if err != nil {
t.Fatal(err)
@@ -122,6 +122,7 @@ func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client,
mSizeLock: new(sync.Mutex),
uname: uname,
fPool: allocClientFidPool(),
+ rPool: newReqPool(),
errc: make(chan error),
wg: new(sync.WaitGroup),
}
diff --git a/client/req.go b/client/req.go
@@ -32,6 +32,58 @@ func (r *req) cancel() {
close(r.errc)
}
+// reqPool is the pool of Reqs the server is dealing with.
+type reqPool struct {
+ m map[uint16]*req
+ *sync.Mutex
+}
+
+// newReqPool allocats a reqPool.
+func newReqPool() *reqPool {
+ return &reqPool{
+ make(map[uint16]*req),
+ new(sync.Mutex),
+ }
+}
+
+// Add allocates a request with the specified tag in reqPool rp.
+// It returns (nil, ErrDupTag) if there is already a request with the specified tag.
+func (rp *reqPool) add(r *req) error {
+ rp.Lock()
+ defer rp.Unlock()
+ if _, ok := rp.m[r.tag]; ok {
+ return lib9p.ErrDupTag
+ }
+ rp.m[r.tag] = r
+ return nil
+}
+
+// lookup looks for the request in the pool with tag.
+// If found, it returns the found request and true, otherwise
+// it returns nil and false.
+func (rp *reqPool) lookup(tag uint16) (*req, bool) {
+ rp.Lock()
+ defer rp.Unlock()
+ r, ok := rp.m[tag]
+ return r, ok
+}
+
+// delete delets the request with tag from the pool.
+func (rp *reqPool) delete(tag uint16) {
+ rp.Lock()
+ defer rp.Unlock()
+ delete(rp.m, tag)
+}
+
+func (rp *reqPool) cancelAll() {
+ rp.Lock()
+ defer rp.Unlock()
+ for tag, r := range rp.m {
+ r.cancel()
+ delete(rp.m, tag)
+ }
+}
+
// tagPool is a pool of tags being used by a client.
type tagPool struct {
m map[uint16]bool