lib9p

Go 9P library.
Log | Files | Refs | LICENSE

commit d92515436c042cbc39ea8d0566ccd8c0c6ab5c24
parent 2f9b6106085ab5766b4838d5e94a5410c69a6d5f
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Sat,  6 Jan 2024 12:52:03 +0900

change client API.

Diffstat:
Mclient/client.go | 106+++++++++++++++++++++++++++++++++++++------------------------------------------
Mclient/client_test.go | 45+++++++++++++++------------------------------
Mclient/file_test.go | 38+++++++++++++++++++-------------------
Mclient/fs.go | 18++++++++++--------
Mclient/fs2_test.go | 4+++-
Mdiskfs/file_test.go | 3+--
Mdiskfs/stat_unix_test.go | 6++----
7 files changed, 99 insertions(+), 121 deletions(-)

diff --git a/client/client.go b/client/client.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "fmt" "io" "log" @@ -44,15 +45,13 @@ type Client struct { // NewClient creates a Client. // It also runs several goroutines to handle requests. // And the returned client should be stopped by calling *Client.Stop afterwards. -func NewClient(mSize uint32, uname string, r io.Reader, w io.Writer) *Client { - ctx, cancel := context.WithCancel(context.Background()) +func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w io.Writer) *Client { c := &Client{ msize: mSize, mSizeLock: new(sync.Mutex), uname: uname, fPool: allocClientFidPool(), errc: make(chan error), - cancel: cancel, wg: new(sync.WaitGroup), } tmsgc := c.runSpeaker(ctx, w) @@ -62,13 +61,6 @@ func NewClient(mSize uint32, uname string, r io.Reader, w io.Writer) *Client { return c } -// Stop stops the Client. -func (c *Client) Stop() { - c.cancel() - c.wg.Wait() - close(c.errc) -} - // mSize returns the maximum message size of the Client. func (c *Client) mSize() uint32 { c.mSizeLock.Lock() @@ -89,10 +81,14 @@ func (c *Client) runErrorReporter(ctx context.Context) { go func() { for { select { - case err := <-c.errc: - if err == nil { + case err, ok := <-c.errc: + if !ok { return } + switch { + case errors.Is(err, io.EOF): + default: + } log.Println("client err:", err) case <-ctx.Done(): return @@ -108,12 +104,9 @@ func (c *Client) runErrorReporter(ctx context.Context) { // Listener goroutine reports errors to the client's errc channel. func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan lib9p.Msg { c.wg.Add(1) - // TODO: terminate with ctx.Done() rmsgc := make(chan lib9p.Msg, 3) go func() { - wg := new(sync.WaitGroup) defer func() { - wg.Wait() close(rmsgc) c.wg.Done() }() @@ -137,17 +130,17 @@ func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan lib9p.Msg return } if err != nil { - c.errc <- fmt.Errorf("recv: %v", err) + if err == io.EOF { + c.errc <- err + } else { + c.errc <- fmt.Errorf("recv: %v", err) + } continue } - wg.Add(1) - go func() { - defer wg.Done() - select { - case rmsgc <- msg: - case <-ctx.Done(): - } - }() + select { + case rmsgc <- msg: + case <-ctx.Done(): + } } } }() @@ -196,35 +189,29 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms reqc := make(chan *req) // Rmsg go func(reqc <-chan *req) { - wg := new(sync.WaitGroup) + rPool := make(map[uint16]*req) defer func() { - wg.Wait() + for _, r := range rPool { + <-r.ctxDone + } c.wg.Done() }() - rPool := make(map[uint16]*req) for { select { case <-ctx.Done(): return - case r := <-reqc: - if r == nil { - // ctx is canceled. - continue + case r, ok := <-reqc: + if !ok { + return } if _, ok := rPool[r.tag]; ok { r.errc <- fmt.Errorf("mux: duplicate tag: %d", r.tag) continue } - rPool[r.tag] = r // TODO: wait for r.ctxDone channel. - wg.Add(1) - go func() { - defer wg.Done() - <-r.ctxDone - }() - case msg := <-rmsgc: - if msg == nil { - // ctx is canceled. - continue + rPool[r.tag] = r + case msg, ok := <-rmsgc: + if !ok { + return } r, ok := rPool[msg.GetTag()] if !ok { @@ -236,23 +223,23 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms if _, ok := msg.(*lib9p.RFlush); !ok { r.errc <- fmt.Errorf("mux: response to Tflush is not Rflush") } + oldreq, ok := rPool[tflush.Oldtag] + if ok { + oldreq.Close() + } delete(rPool, tflush.Oldtag) } r.rmsg = msg - go func() { - select { - case <-r.ctxDone: - case r.rxc <- r: - } - }() + select { + case <-r.ctxDone: + case r.rxc <- r: + } } } }(reqc) // Tmsg go func(reqc chan<- *req) { - wg := new(sync.WaitGroup) defer func() { - wg.Wait() close(reqc) close(tmsgc) c.wg.Done() @@ -267,11 +254,11 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms case <-ctx.Done(): return } - wg.Add(1) - go func() { - defer wg.Done() - tmsgc <- r.tmsg - }() + select { + case tmsgc <- r.tmsg: + case <-ctx.Done(): + return + } } } }(reqc) @@ -284,16 +271,21 @@ func (c *Client) transact(ctx context.Context, tmsg lib9p.Msg) (lib9p.Msg, error ctx1, cancel1 := context.WithCancel(ctx) defer cancel1() r := newReq(ctx1, tmsg) - defer r.Close() select { case <-ctx.Done(): return nil, ctx.Err() case c.txc <- r: } select { - case r = <-r.rxc: // This assignment is not required. + case r, ok := <-r.rxc: + if !ok { + return nil, errors.New("reqest canceled") + } return r.rmsg, r.err - case err := <-r.errc: // Client side error. + case err, ok := <-r.errc: // Client side error. + if !ok { + return nil, errors.New("reqest canceled") + } return nil, err case <-ctx.Done(): return nil, ctx.Err() diff --git a/client/client_test.go b/client/client_test.go @@ -18,7 +18,7 @@ func setupClientAndServer(fs lib9p.FS) (*Client, context.CancelFunc) { s := lib9p.NewServer(fs) ctx, cancel := context.WithCancel(context.Background()) go s.Serve(ctx, sr, sw) - c := NewClient(8*1024, "glenda", cr, cw) + c := NewClient(ctx, 8*1024, "glenda", cr, cw) return c, cancel } @@ -59,15 +59,13 @@ func TestDupTag(t *testing.T) { } } -func newClientForTest(msize uint32, uname string) (*Client, <-chan lib9p.Msg, chan<- lib9p.Msg) { - ctx, cancel := context.WithCancel(context.Background()) +func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client, <-chan lib9p.Msg, chan<- lib9p.Msg) { c := &Client{ msize: msize, mSizeLock: new(sync.Mutex), uname: uname, fPool: allocClientFidPool(), errc: make(chan error), - cancel: cancel, wg: new(sync.WaitGroup), } tmsgc, rmsgc := make(chan lib9p.Msg), make(chan lib9p.Msg) @@ -90,9 +88,8 @@ func TestVersion(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( gotmsize uint32 @@ -156,9 +153,8 @@ func TestAuth(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( gotaqid lib9p.Qid @@ -220,9 +216,8 @@ func TestAttach(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( gotqid lib9p.Qid @@ -281,9 +276,8 @@ func TestFlush(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( goterr error @@ -340,9 +334,8 @@ func TestWalk(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( gotqids []lib9p.Qid @@ -404,9 +397,8 @@ func TestOpen(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( gotqid lib9p.Qid @@ -469,9 +461,8 @@ func TestCreate(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( gotqid lib9p.Qid @@ -534,9 +525,8 @@ func TestRead(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( gotdata []byte @@ -598,9 +588,8 @@ func TestWrite(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( gotcount uint32 @@ -662,9 +651,8 @@ func TestClunk(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( goterr error @@ -721,9 +709,8 @@ func TestRemove(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( goterr error @@ -777,9 +764,8 @@ func TestStat(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( gotstat *lib9p.Stat @@ -841,9 +827,8 @@ func TestWstat(t *testing.T) { } for _, test := range tests { func() { - c, tmsgc, rmsgc := newClientForTest(1024, "glenda") - defer c.Stop() ctx, cancel := context.WithCancel(context.Background()) + c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda") defer cancel() var ( goterr error diff --git a/client/file_test.go b/client/file_test.go @@ -14,41 +14,41 @@ import ( ) // Mount runs a server with fs and mounts it as *FS. -// cancel is to stop the server. -func mount(fs lib9p.FS) (cfs *FS, cancel context.CancelFunc, err error) { +func mount(ctx context.Context, fs lib9p.FS) (cfs *FS, err error) { cr, sw := io.Pipe() sr, cw := io.Pipe() s := lib9p.NewServer(fs) - ctx, cancel := context.WithCancel(context.Background()) go s.Serve(ctx, sr, sw) - cfs, err = Mount(cr, cw, "glenda", "") + cfs, err = Mount(ctx, cr, cw, "glenda", "") if err != nil { - return nil, nil, fmt.Errorf("mount: %v", err) + return nil, fmt.Errorf("mount: %v", err) } - return cfs, cancel, nil + return cfs, nil } // TestFileStat tests whether Stat returns the same lib9p.Stat as testfs defines. // TODO: work in progress. func TestFileStat(t *testing.T) { - cfs, cancel, err := mount(testfs) + ctx, cancel := context.WithCancel(context.Background()) + cfs, err := mount(ctx, testfs) if err != nil { t.Fatal(err) } defer cancel() - defer cfs.Unmount() var testFileStat = func(path string, d fs.DirEntry, err error) error { - t.Log("walk:", path) - fi, err := fs.Stat(lib9p.ExportFS{FS: testfs}, path) + cfi, err := fs.Stat(lib9p.ExportFS{cfs}, path) if err != nil { + if strings.Contains(err.Error(), "permission") { + return nil + } t.Error(err) } - cfi, err := d.Info() + tfi, err := fs.Stat(lib9p.ExportFS{testfs}, path) if err != nil { t.Error(err) } - if !reflect.DeepEqual(cfi, fi) { - t.Errorf("fi not match:\n\twant: %v\n\tgot: %v", fi, cfi) + if !reflect.DeepEqual(cfi, tfi) { + t.Errorf("fi not match:\n\twant: %v\n\tgot: %v", tfi, cfi) } return nil } @@ -60,12 +60,12 @@ func TestFileStat(t *testing.T) { // TestClose checks it *File.Close calls are transmitted to the server and // *testFile.Close is called for each *File.Close. func TestClose(t *testing.T) { - cfs, cancel, err := mount(testfs) + ctx, cancel := context.WithCancel(context.Background()) + cfs, err := mount(ctx, testfs) if err != nil { t.Fatal(err) } defer cancel() - defer cfs.Unmount() var walk = func(path string, d fs.DirEntry, err error) error { if err != nil { if err == io.EOF { @@ -98,12 +98,12 @@ func TestClose(t *testing.T) { // TestFileRead checks *File.Read reads the same []byte as *testFile.content. func TestFileRead(t *testing.T) { - cfs, cancel, err := mount(testfs) + ctx, cancel := context.WithCancel(context.Background()) + cfs, err := mount(ctx, testfs) if err != nil { t.Fatal(err) } defer cancel() - defer cfs.Unmount() var walk = func(path string, d fs.DirEntry, err error) error { if err != nil { if err == io.EOF { @@ -147,12 +147,12 @@ func TestFileRead(t *testing.T) { // TestReadDir tests whether ReadDir returns the same dir entries as testfs // by comparing their *lib9p.Stat. func TestReadDir(t *testing.T) { - cfs, cancel, err := mount(testfs) + ctx, cancel := context.WithCancel(context.Background()) + cfs, err := mount(ctx, testfs) if err != nil { t.Fatal(err) } defer cancel() - defer cfs.Unmount() var walk = func(path string, d fs.DirEntry, err error) error { t.Log(path) if err != nil { diff --git a/client/fs.go b/client/fs.go @@ -91,22 +91,26 @@ func (fsys *FS) walkFile(name string) (*File, error) { // Mount initiates a 9P session and returns the resulting file system. // The 9P session is established by writing to w and reading from r. -func Mount(r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) { +// When fs is not needed anymore, ctx should be canceled to stop the +// underlying client's goroutines. +// If non-nil error is returned, underlying client is stopped by this function, +// and there is no need to cancel ctx. +func Mount(ctx context.Context, r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) { var ( mSize uint32 = 8192 version = "9P2000" - ctx = context.TODO() ) + ctx0, cancel0 := context.WithCancel(ctx) cfs := &FS{ - c: NewClient(mSize, uname, r, w), + c: NewClient(ctx0, mSize, uname, r, w), tPool: newTagPool(), } defer func() { if err != nil { - cfs.c.Stop() + cancel0() } }() - rmSize, rver, err := cfs.c.Version(ctx, lib9p.NOTAG, mSize, version) + rmSize, rver, err := cfs.c.Version(ctx0, lib9p.NOTAG, mSize, version) if err != nil { return nil, fmt.Errorf("version: %v", err) } @@ -125,7 +129,7 @@ func Mount(r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) { if err != nil { return nil, err } - _, err = cfs.c.Attach(ctx, tag, fid.fid, lib9p.NOFID, uname, aname) + _, err = cfs.c.Attach(ctx0, tag, fid.fid, lib9p.NOFID, uname, aname) cfs.tPool.delete(tag) if err != nil { return nil, fmt.Errorf("attach: %v", err) @@ -133,5 +137,3 @@ func Mount(r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) { cfs.c.rootFid = fid return cfs, nil } - -func (fsys *FS) Unmount() { fsys.c.Stop() } diff --git a/client/fs2_test.go b/client/fs2_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "io/fs" "reflect" "strings" @@ -13,7 +14,8 @@ import ( // other than permission error and // every opened file is the same as that of testfs by comparing their lib9p.Stat. func TestOpenFile(t *testing.T) { - cfs, cancel, err := mount(testfs) + ctx, cancel := context.WithCancel(context.Background()) + cfs, err := mount(ctx, testfs) if err != nil { t.Fatal(err) } diff --git a/diskfs/file_test.go b/diskfs/file_test.go @@ -25,8 +25,7 @@ func BenchmarkRead(b *testing.B) { } s := lib9p.NewServer(disk) go s.Serve(ctx, sr, sw) - c := client.NewClient(8*1024, "kenji", cr, cw) - defer c.Stop() + c := client.NewClient(ctx, 8*1024, "kenji", cr, cw) _, err = c.Attach(ctx, ^uint16(0), 0, lib9p.NOFID, "kenji", "") if err != nil { b.Fatalf("attach: %v", err) diff --git a/diskfs/stat_unix_test.go b/diskfs/stat_unix_test.go @@ -28,8 +28,7 @@ func BenchmarkGID(b *testing.B) { } s := lib9p.NewServer(disk) go s.Serve(ctx, sr, sw) - c := client.NewClient(8*1024, "kenji", cr, cw) - defer c.Stop() + c := client.NewClient(ctx, 8*1024, "kenji", cr, cw) _, err = c.Attach(ctx, ^uint16(0), 0, lib9p.NOFID, "kenji", "") if err != nil { b.Fatalf("attach: %v", err) @@ -57,8 +56,7 @@ func TestChgrp(t *testing.T) { } s := lib9p.NewServer(disk) go s.Serve(ctx, sr, sw) - c := client.NewClient(8*1024, "kenji", cr, cw) - defer c.Stop() + c := client.NewClient(ctx, 8*1024, "kenji", cr, cw) _, err = c.Attach(ctx, ^uint16(0), 0, lib9p.NOFID, "kenji", "") if err != nil { t.Fatalf("attach: %v", err)