commit d92515436c042cbc39ea8d0566ccd8c0c6ab5c24
parent 2f9b6106085ab5766b4838d5e94a5410c69a6d5f
Author: Matsuda Kenji <info@mtkn.jp>
Date: Sat, 6 Jan 2024 12:52:03 +0900
change client API.
Diffstat:
7 files changed, 99 insertions(+), 121 deletions(-)
diff --git a/client/client.go b/client/client.go
@@ -2,6 +2,7 @@ package client
import (
"context"
+ "errors"
"fmt"
"io"
"log"
@@ -44,15 +45,13 @@ type Client struct {
// NewClient creates a Client.
// It also runs several goroutines to handle requests.
// And the returned client should be stopped by calling *Client.Stop afterwards.
-func NewClient(mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
- ctx, cancel := context.WithCancel(context.Background())
+func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
c := &Client{
msize: mSize,
mSizeLock: new(sync.Mutex),
uname: uname,
fPool: allocClientFidPool(),
errc: make(chan error),
- cancel: cancel,
wg: new(sync.WaitGroup),
}
tmsgc := c.runSpeaker(ctx, w)
@@ -62,13 +61,6 @@ func NewClient(mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
return c
}
-// Stop stops the Client.
-func (c *Client) Stop() {
- c.cancel()
- c.wg.Wait()
- close(c.errc)
-}
-
// mSize returns the maximum message size of the Client.
func (c *Client) mSize() uint32 {
c.mSizeLock.Lock()
@@ -89,10 +81,14 @@ func (c *Client) runErrorReporter(ctx context.Context) {
go func() {
for {
select {
- case err := <-c.errc:
- if err == nil {
+ case err, ok := <-c.errc:
+ if !ok {
return
}
+ switch {
+ case errors.Is(err, io.EOF):
+ default:
+ }
log.Println("client err:", err)
case <-ctx.Done():
return
@@ -108,12 +104,9 @@ func (c *Client) runErrorReporter(ctx context.Context) {
// Listener goroutine reports errors to the client's errc channel.
func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan lib9p.Msg {
c.wg.Add(1)
- // TODO: terminate with ctx.Done()
rmsgc := make(chan lib9p.Msg, 3)
go func() {
- wg := new(sync.WaitGroup)
defer func() {
- wg.Wait()
close(rmsgc)
c.wg.Done()
}()
@@ -137,17 +130,17 @@ func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan lib9p.Msg
return
}
if err != nil {
- c.errc <- fmt.Errorf("recv: %v", err)
+ if err == io.EOF {
+ c.errc <- err
+ } else {
+ c.errc <- fmt.Errorf("recv: %v", err)
+ }
continue
}
- wg.Add(1)
- go func() {
- defer wg.Done()
- select {
- case rmsgc <- msg:
- case <-ctx.Done():
- }
- }()
+ select {
+ case rmsgc <- msg:
+ case <-ctx.Done():
+ }
}
}
}()
@@ -196,35 +189,29 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
reqc := make(chan *req)
// Rmsg
go func(reqc <-chan *req) {
- wg := new(sync.WaitGroup)
+ rPool := make(map[uint16]*req)
defer func() {
- wg.Wait()
+ for _, r := range rPool {
+ <-r.ctxDone
+ }
c.wg.Done()
}()
- rPool := make(map[uint16]*req)
for {
select {
case <-ctx.Done():
return
- case r := <-reqc:
- if r == nil {
- // ctx is canceled.
- continue
+ case r, ok := <-reqc:
+ if !ok {
+ return
}
if _, ok := rPool[r.tag]; ok {
r.errc <- fmt.Errorf("mux: duplicate tag: %d", r.tag)
continue
}
- rPool[r.tag] = r // TODO: wait for r.ctxDone channel.
- wg.Add(1)
- go func() {
- defer wg.Done()
- <-r.ctxDone
- }()
- case msg := <-rmsgc:
- if msg == nil {
- // ctx is canceled.
- continue
+ rPool[r.tag] = r
+ case msg, ok := <-rmsgc:
+ if !ok {
+ return
}
r, ok := rPool[msg.GetTag()]
if !ok {
@@ -236,23 +223,23 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
if _, ok := msg.(*lib9p.RFlush); !ok {
r.errc <- fmt.Errorf("mux: response to Tflush is not Rflush")
}
+ oldreq, ok := rPool[tflush.Oldtag]
+ if ok {
+ oldreq.Close()
+ }
delete(rPool, tflush.Oldtag)
}
r.rmsg = msg
- go func() {
- select {
- case <-r.ctxDone:
- case r.rxc <- r:
- }
- }()
+ select {
+ case <-r.ctxDone:
+ case r.rxc <- r:
+ }
}
}
}(reqc)
// Tmsg
go func(reqc chan<- *req) {
- wg := new(sync.WaitGroup)
defer func() {
- wg.Wait()
close(reqc)
close(tmsgc)
c.wg.Done()
@@ -267,11 +254,11 @@ func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- lib9p.Msg, rms
case <-ctx.Done():
return
}
- wg.Add(1)
- go func() {
- defer wg.Done()
- tmsgc <- r.tmsg
- }()
+ select {
+ case tmsgc <- r.tmsg:
+ case <-ctx.Done():
+ return
+ }
}
}
}(reqc)
@@ -284,16 +271,21 @@ func (c *Client) transact(ctx context.Context, tmsg lib9p.Msg) (lib9p.Msg, error
ctx1, cancel1 := context.WithCancel(ctx)
defer cancel1()
r := newReq(ctx1, tmsg)
- defer r.Close()
select {
case <-ctx.Done():
return nil, ctx.Err()
case c.txc <- r:
}
select {
- case r = <-r.rxc: // This assignment is not required.
+ case r, ok := <-r.rxc:
+ if !ok {
+ return nil, errors.New("reqest canceled")
+ }
return r.rmsg, r.err
- case err := <-r.errc: // Client side error.
+ case err, ok := <-r.errc: // Client side error.
+ if !ok {
+ return nil, errors.New("reqest canceled")
+ }
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
diff --git a/client/client_test.go b/client/client_test.go
@@ -18,7 +18,7 @@ func setupClientAndServer(fs lib9p.FS) (*Client, context.CancelFunc) {
s := lib9p.NewServer(fs)
ctx, cancel := context.WithCancel(context.Background())
go s.Serve(ctx, sr, sw)
- c := NewClient(8*1024, "glenda", cr, cw)
+ c := NewClient(ctx, 8*1024, "glenda", cr, cw)
return c, cancel
}
@@ -59,15 +59,13 @@ func TestDupTag(t *testing.T) {
}
}
-func newClientForTest(msize uint32, uname string) (*Client, <-chan lib9p.Msg, chan<- lib9p.Msg) {
- ctx, cancel := context.WithCancel(context.Background())
+func newClientForTest(ctx context.Context, msize uint32, uname string) (*Client, <-chan lib9p.Msg, chan<- lib9p.Msg) {
c := &Client{
msize: msize,
mSizeLock: new(sync.Mutex),
uname: uname,
fPool: allocClientFidPool(),
errc: make(chan error),
- cancel: cancel,
wg: new(sync.WaitGroup),
}
tmsgc, rmsgc := make(chan lib9p.Msg), make(chan lib9p.Msg)
@@ -90,9 +88,8 @@ func TestVersion(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotmsize uint32
@@ -156,9 +153,8 @@ func TestAuth(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotaqid lib9p.Qid
@@ -220,9 +216,8 @@ func TestAttach(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotqid lib9p.Qid
@@ -281,9 +276,8 @@ func TestFlush(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
goterr error
@@ -340,9 +334,8 @@ func TestWalk(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotqids []lib9p.Qid
@@ -404,9 +397,8 @@ func TestOpen(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotqid lib9p.Qid
@@ -469,9 +461,8 @@ func TestCreate(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotqid lib9p.Qid
@@ -534,9 +525,8 @@ func TestRead(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotdata []byte
@@ -598,9 +588,8 @@ func TestWrite(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotcount uint32
@@ -662,9 +651,8 @@ func TestClunk(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
goterr error
@@ -721,9 +709,8 @@ func TestRemove(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
goterr error
@@ -777,9 +764,8 @@ func TestStat(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
gotstat *lib9p.Stat
@@ -841,9 +827,8 @@ func TestWstat(t *testing.T) {
}
for _, test := range tests {
func() {
- c, tmsgc, rmsgc := newClientForTest(1024, "glenda")
- defer c.Stop()
ctx, cancel := context.WithCancel(context.Background())
+ c, tmsgc, rmsgc := newClientForTest(ctx, 1024, "glenda")
defer cancel()
var (
goterr error
diff --git a/client/file_test.go b/client/file_test.go
@@ -14,41 +14,41 @@ import (
)
// Mount runs a server with fs and mounts it as *FS.
-// cancel is to stop the server.
-func mount(fs lib9p.FS) (cfs *FS, cancel context.CancelFunc, err error) {
+func mount(ctx context.Context, fs lib9p.FS) (cfs *FS, err error) {
cr, sw := io.Pipe()
sr, cw := io.Pipe()
s := lib9p.NewServer(fs)
- ctx, cancel := context.WithCancel(context.Background())
go s.Serve(ctx, sr, sw)
- cfs, err = Mount(cr, cw, "glenda", "")
+ cfs, err = Mount(ctx, cr, cw, "glenda", "")
if err != nil {
- return nil, nil, fmt.Errorf("mount: %v", err)
+ return nil, fmt.Errorf("mount: %v", err)
}
- return cfs, cancel, nil
+ return cfs, nil
}
// TestFileStat tests whether Stat returns the same lib9p.Stat as testfs defines.
// TODO: work in progress.
func TestFileStat(t *testing.T) {
- cfs, cancel, err := mount(testfs)
+ ctx, cancel := context.WithCancel(context.Background())
+ cfs, err := mount(ctx, testfs)
if err != nil {
t.Fatal(err)
}
defer cancel()
- defer cfs.Unmount()
var testFileStat = func(path string, d fs.DirEntry, err error) error {
- t.Log("walk:", path)
- fi, err := fs.Stat(lib9p.ExportFS{FS: testfs}, path)
+ cfi, err := fs.Stat(lib9p.ExportFS{cfs}, path)
if err != nil {
+ if strings.Contains(err.Error(), "permission") {
+ return nil
+ }
t.Error(err)
}
- cfi, err := d.Info()
+ tfi, err := fs.Stat(lib9p.ExportFS{testfs}, path)
if err != nil {
t.Error(err)
}
- if !reflect.DeepEqual(cfi, fi) {
- t.Errorf("fi not match:\n\twant: %v\n\tgot: %v", fi, cfi)
+ if !reflect.DeepEqual(cfi, tfi) {
+ t.Errorf("fi not match:\n\twant: %v\n\tgot: %v", tfi, cfi)
}
return nil
}
@@ -60,12 +60,12 @@ func TestFileStat(t *testing.T) {
// TestClose checks it *File.Close calls are transmitted to the server and
// *testFile.Close is called for each *File.Close.
func TestClose(t *testing.T) {
- cfs, cancel, err := mount(testfs)
+ ctx, cancel := context.WithCancel(context.Background())
+ cfs, err := mount(ctx, testfs)
if err != nil {
t.Fatal(err)
}
defer cancel()
- defer cfs.Unmount()
var walk = func(path string, d fs.DirEntry, err error) error {
if err != nil {
if err == io.EOF {
@@ -98,12 +98,12 @@ func TestClose(t *testing.T) {
// TestFileRead checks *File.Read reads the same []byte as *testFile.content.
func TestFileRead(t *testing.T) {
- cfs, cancel, err := mount(testfs)
+ ctx, cancel := context.WithCancel(context.Background())
+ cfs, err := mount(ctx, testfs)
if err != nil {
t.Fatal(err)
}
defer cancel()
- defer cfs.Unmount()
var walk = func(path string, d fs.DirEntry, err error) error {
if err != nil {
if err == io.EOF {
@@ -147,12 +147,12 @@ func TestFileRead(t *testing.T) {
// TestReadDir tests whether ReadDir returns the same dir entries as testfs
// by comparing their *lib9p.Stat.
func TestReadDir(t *testing.T) {
- cfs, cancel, err := mount(testfs)
+ ctx, cancel := context.WithCancel(context.Background())
+ cfs, err := mount(ctx, testfs)
if err != nil {
t.Fatal(err)
}
defer cancel()
- defer cfs.Unmount()
var walk = func(path string, d fs.DirEntry, err error) error {
t.Log(path)
if err != nil {
diff --git a/client/fs.go b/client/fs.go
@@ -91,22 +91,26 @@ func (fsys *FS) walkFile(name string) (*File, error) {
// Mount initiates a 9P session and returns the resulting file system.
// The 9P session is established by writing to w and reading from r.
-func Mount(r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) {
+// When fs is not needed anymore, ctx should be canceled to stop the
+// underlying client's goroutines.
+// If non-nil error is returned, underlying client is stopped by this function,
+// and there is no need to cancel ctx.
+func Mount(ctx context.Context, r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) {
var (
mSize uint32 = 8192
version = "9P2000"
- ctx = context.TODO()
)
+ ctx0, cancel0 := context.WithCancel(ctx)
cfs := &FS{
- c: NewClient(mSize, uname, r, w),
+ c: NewClient(ctx0, mSize, uname, r, w),
tPool: newTagPool(),
}
defer func() {
if err != nil {
- cfs.c.Stop()
+ cancel0()
}
}()
- rmSize, rver, err := cfs.c.Version(ctx, lib9p.NOTAG, mSize, version)
+ rmSize, rver, err := cfs.c.Version(ctx0, lib9p.NOTAG, mSize, version)
if err != nil {
return nil, fmt.Errorf("version: %v", err)
}
@@ -125,7 +129,7 @@ func Mount(r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) {
if err != nil {
return nil, err
}
- _, err = cfs.c.Attach(ctx, tag, fid.fid, lib9p.NOFID, uname, aname)
+ _, err = cfs.c.Attach(ctx0, tag, fid.fid, lib9p.NOFID, uname, aname)
cfs.tPool.delete(tag)
if err != nil {
return nil, fmt.Errorf("attach: %v", err)
@@ -133,5 +137,3 @@ func Mount(r io.Reader, w io.Writer, uname, aname string) (fs *FS, err error) {
cfs.c.rootFid = fid
return cfs, nil
}
-
-func (fsys *FS) Unmount() { fsys.c.Stop() }
diff --git a/client/fs2_test.go b/client/fs2_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"io/fs"
"reflect"
"strings"
@@ -13,7 +14,8 @@ import (
// other than permission error and
// every opened file is the same as that of testfs by comparing their lib9p.Stat.
func TestOpenFile(t *testing.T) {
- cfs, cancel, err := mount(testfs)
+ ctx, cancel := context.WithCancel(context.Background())
+ cfs, err := mount(ctx, testfs)
if err != nil {
t.Fatal(err)
}
diff --git a/diskfs/file_test.go b/diskfs/file_test.go
@@ -25,8 +25,7 @@ func BenchmarkRead(b *testing.B) {
}
s := lib9p.NewServer(disk)
go s.Serve(ctx, sr, sw)
- c := client.NewClient(8*1024, "kenji", cr, cw)
- defer c.Stop()
+ c := client.NewClient(ctx, 8*1024, "kenji", cr, cw)
_, err = c.Attach(ctx, ^uint16(0), 0, lib9p.NOFID, "kenji", "")
if err != nil {
b.Fatalf("attach: %v", err)
diff --git a/diskfs/stat_unix_test.go b/diskfs/stat_unix_test.go
@@ -28,8 +28,7 @@ func BenchmarkGID(b *testing.B) {
}
s := lib9p.NewServer(disk)
go s.Serve(ctx, sr, sw)
- c := client.NewClient(8*1024, "kenji", cr, cw)
- defer c.Stop()
+ c := client.NewClient(ctx, 8*1024, "kenji", cr, cw)
_, err = c.Attach(ctx, ^uint16(0), 0, lib9p.NOFID, "kenji", "")
if err != nil {
b.Fatalf("attach: %v", err)
@@ -57,8 +56,7 @@ func TestChgrp(t *testing.T) {
}
s := lib9p.NewServer(disk)
go s.Serve(ctx, sr, sw)
- c := client.NewClient(8*1024, "kenji", cr, cw)
- defer c.Stop()
+ c := client.NewClient(ctx, 8*1024, "kenji", cr, cw)
_, err = c.Attach(ctx, ^uint16(0), 0, lib9p.NOFID, "kenji", "")
if err != nil {
t.Fatalf("attach: %v", err)