commit 27938facf38b31090b0292c7495e811e3d9cde1a
parent 11053d7793b1f0299e8fbc104ef1d02d176bc0eb
Author: Matsuda Kenji <info@mtkn.jp>
Date: Tue, 3 Oct 2023 10:04:44 +0900
add transact()
Diffstat:
| M | client.go | | | 114 | +++++++++++++++++++++++++++++++++++++++---------------------------------------- |
| M | req.go | | | 70 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
2 files changed, 126 insertions(+), 58 deletions(-)
diff --git a/client.go b/client.go
@@ -8,25 +8,27 @@ import (
)
type Client struct {
- msize uint32
+ msize uint32
mSizeLock *sync.Mutex
- uname string
- fPool *FidPool
- rPool *ReqPool
- reader io.Reader
- writer io.Writer
+ uname string
+ fPool *FidPool
+ rPool *clientReqPool
+ tmsgc chan<- Msg
+ terrc <-chan error
+ rerrc <-chan error
}
func NewClient(mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
- return &Client{
- msize: mSize,
+ c := &Client{
+ msize: mSize,
mSizeLock: new(sync.Mutex),
- uname: uname,
- fPool: allocFidPool(),
- rPool: allocReqPool(),
- reader: r,
- writer: w,
+ uname: uname,
+ fPool: allocFidPool(),
+ rPool: newClientReqPool(),
}
+ c.tmsgc, c.terrc = c.runSpeaker(context.TODO(), w)
+ c.rerrc = c.runListener(context.TODO(), r)
+ return c
}
func (c *Client) mSize() uint32 {
@@ -41,23 +43,38 @@ func (c *Client) setMSize(mSize uint32) {
c.msize = mSize
}
-func (c *Client) runListnener(ctx context.Context, r io.Reader) (<-chan Msg, <-chan error) {
+func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan error {
// TODO: terminate with ctx.Done()
- mc := make(chan Msg)
ec := make(chan error)
go func() {
- defer close(mc)
defer close(ec)
for {
- msg, err := recv(r)
- if err != nil {
- ec <- err
- continue
+ select {
+ case <-ctx.Done():
+ // TODO: should return error via ec??
+ return
+ default:
+ msg, err := recv(r)
+ if err != nil {
+ ec <- err
+ continue
+ }
+ req, ok := c.rPool.lookup(msg.Tag())
+ if !ok {
+ ec <- fmt.Errorf("unknown tag: %d", msg.Tag())
+ }
+ go func() {
+ defer close(req.rmsgc)
+ select {
+ case req.rmsgc <- msg:
+ c.rPool.delete(msg.Tag())
+ case <-ctx.Done():
+ }
+ }()
}
- mc <- msg
}
}()
- return mc, ec
+ return ec
}
func (c *Client) runSpeaker(ctx context.Context, w io.Writer) (chan<- Msg, <-chan error) {
@@ -70,7 +87,7 @@ func (c *Client) runSpeaker(ctx context.Context, w io.Writer) (chan<- Msg, <-cha
select {
case <-ctx.Done():
return
- case msg := <- mc:
+ case msg := <-mc:
if err := send(msg, w); err != nil {
ec <- err
}
@@ -80,50 +97,30 @@ func (c *Client) runSpeaker(ctx context.Context, w io.Writer) (chan<- Msg, <-cha
return mc, ec
}
-func transact(ctx context.Context, w io.Writer, r io.Reader, tmsg Msg) (<-chan Msg, <-chan error) {
- err := send(tmsg, w)
- rmsgc := make(chan Msg, 1)
- errc := make(chan error, 1)
- go func() {
- defer close(rmsgc)
+// tag of tmsg is managed by this function.
+func (c *Client) transact(ctx context.Context, tmsg Msg) (<-chan Msg, <-chan error) {
+ req, err := c.rPool.add(tmsg)
+ if err != nil {
+ errc := make(chan error, 1)
defer close(errc)
- if err != nil {
- errc <- err
- return
- }
-
- rmsgc1 := make(chan Msg, 1)
- errc1 := make(chan error, 1)
- // TODO: cancel recv() with ctx.Done()
- go func() {
- defer close(rmsgc1)
- defer close(errc1)
- rmsg, err := recv(r)
- if err != nil {
- errc1 <- err
- return
- }
- rmsgc1 <- rmsg
- }()
+ errc <- fmt.Errorf("add clientReq: %v", err)
+ return nil, errc
+ }
+ go func() {
select {
- case rmsg := <-rmsgc1:
- rmsgc <- rmsg
- case err := <-errc1:
- errc <- err
+ case c.tmsgc <- tmsg:
case <-ctx.Done():
- errc <- fmt.Errorf("wait rmsg: %w", ctx.Err())
}
}()
- return rmsgc, errc
+ return req.rmsgc, req.errc
}
func (c *Client) Version(ctx context.Context, mSize uint32, version string) (<-chan *RVersion, <-chan error) {
tmsg := &TVersion{
- tag: ^uint16(0),
mSize: mSize,
version: version,
}
- rmsgc, errc := transact(ctx, c.writer, c.reader, tmsg)
+ rmsgc, errc := c.transact(ctx, tmsg)
rmsgc1 := make(chan *RVersion, 1)
errc1 := make(chan error, 1)
go func() {
@@ -147,8 +144,10 @@ func (c *Client) Version(ctx context.Context, mSize uint32, version string) (<-c
}()
return rmsgc1, errc1
}
-
+/*
func (c *Client) Auth(ctx context.Context, afid uint32, uname, aname string) (<-chan *RAuth, <-chan error) {
// tmsg := &TAuth{afid: afid, uname: uname}
return nil, nil
-}
-\ No newline at end of file
+}
+
+*/
diff --git a/req.go b/req.go
@@ -1,6 +1,7 @@
package lib9p
import (
+ "fmt"
"sync"
)
@@ -59,3 +60,72 @@ func (rp *ReqPool) delete(tag uint16) {
defer rp.lock.Unlock()
delete(rp.m, tag)
}
+
+type clientReq struct {
+ tag uint16
+ pool *clientReqPool
+ rmsgc chan Msg
+ errc chan error
+}
+
+type clientReqPool struct {
+ m map[uint16]*clientReq
+ lock *sync.Mutex
+}
+
+func newClientReqPool() *clientReqPool {
+ return &clientReqPool{
+ m: make(map[uint16]*clientReq),
+ lock: new(sync.Mutex),
+ }
+}
+
+func (rp *clientReqPool) nextTag() (uint16, error) {
+ // TODO: optimize
+ rp.lock.Lock()
+ defer rp.lock.Unlock()
+ for i := uint16(0); i < i+1; i++ {
+ if _, ok := rp.m[i]; !ok {
+ return i, nil
+ }
+ }
+ return 0, fmt.Errorf("run out of tag")
+}
+
+func (rp *clientReqPool) add(msg Msg) (*clientReq, error) {
+ var tag uint16
+ if _, ok := msg.(*TVersion); ok {
+ tag = ^uint16(0)
+ } else {
+ var err error
+ tag, err = rp.nextTag()
+ if err != nil {
+ return nil, fmt.Errorf("nextTag: %v", err)
+ }
+ }
+ rp.lock.Lock()
+ defer rp.lock.Unlock()
+
+ msg.SetTag(tag)
+ req := &clientReq{
+ tag: tag,
+ pool: rp,
+ rmsgc: make(chan Msg),
+ errc: make(chan error),
+ }
+ rp.m[tag] = req
+ return req, nil
+}
+
+func (rp *clientReqPool) lookup(tag uint16) (*clientReq, bool) {
+ rp.lock.Lock()
+ defer rp.lock.Unlock()
+ r, ok := rp.m[tag]
+ return r, ok
+}
+
+func (rp *clientReqPool) delete(tag uint16) {
+ rp.lock.Lock()
+ defer rp.lock.Unlock()
+ delete(rp.m, tag)
+}