lib9p

Go 9P library.
Log | Files | Refs

commit 27938facf38b31090b0292c7495e811e3d9cde1a
parent 11053d7793b1f0299e8fbc104ef1d02d176bc0eb
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Tue,  3 Oct 2023 10:04:44 +0900

add transact()

Diffstat:
Mclient.go | 114+++++++++++++++++++++++++++++++++++++++----------------------------------------
Mreq.go | 70++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 126 insertions(+), 58 deletions(-)

diff --git a/client.go b/client.go @@ -8,25 +8,27 @@ import ( ) type Client struct { - msize uint32 + msize uint32 mSizeLock *sync.Mutex - uname string - fPool *FidPool - rPool *ReqPool - reader io.Reader - writer io.Writer + uname string + fPool *FidPool + rPool *clientReqPool + tmsgc chan<- Msg + terrc <-chan error + rerrc <-chan error } func NewClient(mSize uint32, uname string, r io.Reader, w io.Writer) *Client { - return &Client{ - msize: mSize, + c := &Client{ + msize: mSize, mSizeLock: new(sync.Mutex), - uname: uname, - fPool: allocFidPool(), - rPool: allocReqPool(), - reader: r, - writer: w, + uname: uname, + fPool: allocFidPool(), + rPool: newClientReqPool(), } + c.tmsgc, c.terrc = c.runSpeaker(context.TODO(), w) + c.rerrc = c.runListener(context.TODO(), r) + return c } func (c *Client) mSize() uint32 { @@ -41,23 +43,38 @@ func (c *Client) setMSize(mSize uint32) { c.msize = mSize } -func (c *Client) runListnener(ctx context.Context, r io.Reader) (<-chan Msg, <-chan error) { +func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan error { // TODO: terminate with ctx.Done() - mc := make(chan Msg) ec := make(chan error) go func() { - defer close(mc) defer close(ec) for { - msg, err := recv(r) - if err != nil { - ec <- err - continue + select { + case <-ctx.Done(): + // TODO: should return error via ec?? + return + default: + msg, err := recv(r) + if err != nil { + ec <- err + continue + } + req, ok := c.rPool.lookup(msg.Tag()) + if !ok { + ec <- fmt.Errorf("unknown tag: %d", msg.Tag()) + } + go func() { + defer close(req.rmsgc) + select { + case req.rmsgc <- msg: + c.rPool.delete(msg.Tag()) + case <-ctx.Done(): + } + }() } - mc <- msg } }() - return mc, ec + return ec } func (c *Client) runSpeaker(ctx context.Context, w io.Writer) (chan<- Msg, <-chan error) { @@ -70,7 +87,7 @@ func (c *Client) runSpeaker(ctx context.Context, w io.Writer) (chan<- Msg, <-cha select { case <-ctx.Done(): return - case msg := <- mc: + case msg := <-mc: if err := send(msg, w); err != nil { ec <- err } @@ -80,50 +97,30 @@ func (c *Client) runSpeaker(ctx context.Context, w io.Writer) (chan<- Msg, <-cha return mc, ec } -func transact(ctx context.Context, w io.Writer, r io.Reader, tmsg Msg) (<-chan Msg, <-chan error) { - err := send(tmsg, w) - rmsgc := make(chan Msg, 1) - errc := make(chan error, 1) - go func() { - defer close(rmsgc) +// tag of tmsg is managed by this function. +func (c *Client) transact(ctx context.Context, tmsg Msg) (<-chan Msg, <-chan error) { + req, err := c.rPool.add(tmsg) + if err != nil { + errc := make(chan error, 1) defer close(errc) - if err != nil { - errc <- err - return - } - - rmsgc1 := make(chan Msg, 1) - errc1 := make(chan error, 1) - // TODO: cancel recv() with ctx.Done() - go func() { - defer close(rmsgc1) - defer close(errc1) - rmsg, err := recv(r) - if err != nil { - errc1 <- err - return - } - rmsgc1 <- rmsg - }() + errc <- fmt.Errorf("add clientReq: %v", err) + return nil, errc + } + go func() { select { - case rmsg := <-rmsgc1: - rmsgc <- rmsg - case err := <-errc1: - errc <- err + case c.tmsgc <- tmsg: case <-ctx.Done(): - errc <- fmt.Errorf("wait rmsg: %w", ctx.Err()) } }() - return rmsgc, errc + return req.rmsgc, req.errc } func (c *Client) Version(ctx context.Context, mSize uint32, version string) (<-chan *RVersion, <-chan error) { tmsg := &TVersion{ - tag: ^uint16(0), mSize: mSize, version: version, } - rmsgc, errc := transact(ctx, c.writer, c.reader, tmsg) + rmsgc, errc := c.transact(ctx, tmsg) rmsgc1 := make(chan *RVersion, 1) errc1 := make(chan error, 1) go func() { @@ -147,8 +144,10 @@ func (c *Client) Version(ctx context.Context, mSize uint32, version string) (<-c }() return rmsgc1, errc1 } - +/* func (c *Client) Auth(ctx context.Context, afid uint32, uname, aname string) (<-chan *RAuth, <-chan error) { // tmsg := &TAuth{afid: afid, uname: uname} return nil, nil -} -\ No newline at end of file +} + +*/ diff --git a/req.go b/req.go @@ -1,6 +1,7 @@ package lib9p import ( + "fmt" "sync" ) @@ -59,3 +60,72 @@ func (rp *ReqPool) delete(tag uint16) { defer rp.lock.Unlock() delete(rp.m, tag) } + +type clientReq struct { + tag uint16 + pool *clientReqPool + rmsgc chan Msg + errc chan error +} + +type clientReqPool struct { + m map[uint16]*clientReq + lock *sync.Mutex +} + +func newClientReqPool() *clientReqPool { + return &clientReqPool{ + m: make(map[uint16]*clientReq), + lock: new(sync.Mutex), + } +} + +func (rp *clientReqPool) nextTag() (uint16, error) { + // TODO: optimize + rp.lock.Lock() + defer rp.lock.Unlock() + for i := uint16(0); i < i+1; i++ { + if _, ok := rp.m[i]; !ok { + return i, nil + } + } + return 0, fmt.Errorf("run out of tag") +} + +func (rp *clientReqPool) add(msg Msg) (*clientReq, error) { + var tag uint16 + if _, ok := msg.(*TVersion); ok { + tag = ^uint16(0) + } else { + var err error + tag, err = rp.nextTag() + if err != nil { + return nil, fmt.Errorf("nextTag: %v", err) + } + } + rp.lock.Lock() + defer rp.lock.Unlock() + + msg.SetTag(tag) + req := &clientReq{ + tag: tag, + pool: rp, + rmsgc: make(chan Msg), + errc: make(chan error), + } + rp.m[tag] = req + return req, nil +} + +func (rp *clientReqPool) lookup(tag uint16) (*clientReq, bool) { + rp.lock.Lock() + defer rp.lock.Unlock() + r, ok := rp.m[tag] + return r, ok +} + +func (rp *clientReqPool) delete(tag uint16) { + rp.lock.Lock() + defer rp.lock.Unlock() + delete(rp.m, tag) +}