commit a0f21856338a38df393820641f9416bb0ed556ba
parent 1f3935ea80fa09e85f9a8d5deadbd8c94d93e67b
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Wed, 10 Jan 2024 09:59:36 +0900
flush outstanding requests when a request with NOTAG arived
Diffstat:
6 files changed, 71 insertions(+), 45 deletions(-)
diff --git a/client/fs2_test.go b/client/fs2_test.go
@@ -11,7 +11,7 @@ import (
 )
 
 func TestCleanPath(t *testing.T) {
-	tests := []struct{
+	tests := []struct {
 		name string
 		want string
 	}{
diff --git a/diskfs/fs.go b/diskfs/fs.go
@@ -47,9 +47,9 @@ func Open(name string) (*FS, error) {
 func (fsys *FS) OpenFile(name string, omode lib9p.OpenMode) (lib9p.File, error) {
 	if !fs.ValidPath(name) {
 		return nil, &fs.PathError{
-			Op: "openfile",
+			Op:   "openfile",
 			Path: name,
-			Err: fs.ErrInvalid,
+			Err:  fs.ErrInvalid,
 		}
 	}
 	fp := filepath.Join(fsys.rootPath, name)
@@ -67,17 +67,17 @@ func (fsys *FS) OpenFile(name string, omode lib9p.OpenMode) (lib9p.File, error) 
 	}
 	if omode&lib9p.ORCLOSE != 0 {
 		return nil, &fs.PathError{
-			Op: "openfile",
+			Op:   "openfile",
 			Path: name,
-			Err: fmt.Errorf("orclose not implemented"),
+			Err:  fmt.Errorf("orclose not implemented"),
 		}
 	}
 	osf, err := os.OpenFile(fp, m, 0)
 	if err != nil {
 		return nil, &fs.PathError{
-			Op: "openfile",
+			Op:   "openfile",
 			Path: name,
-			Err: err,
+			Err:  err,
 		}
 	}
 	return &File{fs: fsys, path: name, file: osf}, nil
@@ -88,9 +88,9 @@ func (fsys *FS) OpenFile(name string, omode lib9p.OpenMode) (lib9p.File, error) 
 func (fsys *FS) Create(name string, uid string, omode lib9p.OpenMode, perm lib9p.FileMode) (lib9p.File, error) {
 	if !fs.ValidPath(name) {
 		return nil, &fs.PathError{
-			Op: "create",
+			Op:   "create",
 			Path: name,
-			Err: fs.ErrInvalid,
+			Err:  fs.ErrInvalid,
 		}
 	}
 	paths := append([]string{fsys.rootPath}, strings.Split(name, "/")...)
@@ -109,9 +109,9 @@ func (fsys *FS) Create(name string, uid string, omode lib9p.OpenMode, perm lib9p
 	}
 	if omode&lib9p.ORCLOSE != 0 {
 		return nil, &fs.PathError{
-			Op: "create",
+			Op:   "create",
 			Path: name,
-			Err: fmt.Errorf("orclose not implemented"),
+			Err:  fmt.Errorf("orclose not implemented"),
 		}
 	}
 	var (
@@ -120,18 +120,18 @@ func (fsys *FS) Create(name string, uid string, omode lib9p.OpenMode, perm lib9p
 	)
 	if perm&os.ModeDir != 0 {
 		if err := os.Mkdir(ospath, perm); err != nil {
-			return nil,  &fs.PathError{
-				Op: "create",
+			return nil, &fs.PathError{
+				Op:   "create",
 				Path: name,
-				Err: fmt.Errorf("mkdir: %v", err),
+				Err:  fmt.Errorf("mkdir: %v", err),
 			}
 		}
 		osfile, err = os.OpenFile(ospath, flag, 0)
 		if err != nil {
 			return nil, &fs.PathError{
-				Op: "create",
+				Op:   "create",
 				Path: name,
-				Err: err,
+				Err:  err,
 			}
 		}
 	} else {
@@ -139,9 +139,9 @@ func (fsys *FS) Create(name string, uid string, omode lib9p.OpenMode, perm lib9p
 		osfile, err = os.OpenFile(ospath, flag, perm)
 		if err != nil {
 			return nil, &fs.PathError{
-				Op: "create",
+				Op:   "create",
 				Path: name,
-				Err: err,
+				Err:  err,
 			}
 		}
 	}
diff --git a/fid.go b/fid.go
@@ -30,13 +30,13 @@ const (
 // and does not change between 9P sessions.
 type fid struct {
 	fid       uint32
-	omode     OpenMode /* -1 = not open */
-	path      string   // The path from the root of the FS.
-	fs        FS       // The associated FS.
-	file      File     // The associated File.
-	uid       string   // The user id derived from the attach message.
-	dirOffset uint64   // Used when reading directory.
-	dirIndex  int      // Used when reading directory.
+	omode     OpenMode      /* -1 = not open */
+	path      string        // The path from the root of the FS.
+	fs        FS            // The associated FS.
+	file      File          // The associated File.
+	uid       string        // The user id derived from the attach message.
+	dirOffset uint64        // Used when reading directory.
+	dirIndex  int           // Used when reading directory.
 	dirEnts   []fs.DirEntry // DirEntry cache.
 }
 
diff --git a/req.go b/req.go
@@ -40,32 +40,41 @@ func (r *request) flush() {
 
 // reqPool is the pool of Reqs the server is dealing with.
 type reqPool struct {
-	m    map[uint16]*request
-	lock *sync.Mutex
+	m map[uint16]*request
+	*sync.Mutex
 }
 
 // newReqPool allocats a reqPool.
 func newReqPool() *reqPool {
 	return &reqPool{
-		m:    make(map[uint16]*request),
-		lock: new(sync.Mutex),
+		make(map[uint16]*request),
+		new(sync.Mutex),
 	}
 }
 
 // Add allocates a request with the specified tag in reqPool rp.
-// It returns (nil, ErrDupTag) if there is already a request with the specified tag.
+// It returns (nil, ErrDupTag) if there is already a request with the specified tag,
+// exept that if the tag is NOTAG, it deletes all the request in the map and
+// allocats an req with that tag.
 func (rp *reqPool) add(tag uint16) (*request, error) {
 	return reqPoolAdd(rp, tag)
 }
 
 var reqPoolAdd = func(rp *reqPool, tag uint16) (*request, error) {
-	rp.lock.Lock()
-	defer rp.lock.Unlock()
+	rp.Lock()
+	defer rp.Unlock()
+	if tag == NOTAG {
+		for _, r := range rp.m {
+			close(r.done)
+		}
+		rp.m = make(map[uint16]*request)
+	}
 	if _, ok := rp.m[tag]; ok {
 		return nil, ErrDupTag
 	}
 	req := &request{
 		pool:         rp,
+		done:         make(chan struct{}),
 		speakErrChan: make(chan error),
 	}
 	rp.m[tag] = req
@@ -76,15 +85,15 @@ var reqPoolAdd = func(rp *reqPool, tag uint16) (*request, error) {
 // If found, it returns the found request and true, otherwise
 // it returns nil and false.
 func (rp *reqPool) lookup(tag uint16) (*request, bool) {
-	rp.lock.Lock()
-	defer rp.lock.Unlock()
+	rp.Lock()
+	defer rp.Unlock()
 	r, ok := rp.m[tag]
 	return r, ok
 }
 
 // delete delets the request with tag from the pool.
 func (rp *reqPool) delete(tag uint16) {
-	rp.lock.Lock()
-	defer rp.lock.Unlock()
+	rp.Lock()
+	defer rp.Unlock()
 	delete(rp.m, tag)
 }
diff --git a/req_test.go b/req_test.go
@@ -61,3 +61,22 @@ func TestFlush(t *testing.T) {
 		t.Error("req not flushed.")
 	}
 }
+
+func TestReqPool(t *testing.T) {
+	rp := newReqPool()
+	for i := 0; i < 10; i++ {
+		if _, err := rp.add(uint16(i)); err != nil {
+			t.Fatal(err)
+		}
+	}
+	r, err := rp.add(NOTAG)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(rp.m) != 1 {
+		t.Fatalf("requests not deleted: %v", rp.m)
+	}
+	if rp.m[NOTAG] != r {
+		t.Fatalf("requests mismatch: %v", rp.m)
+	}
+}
diff --git a/server.go b/server.go
@@ -120,7 +120,7 @@ func (c *conn) setMSize(mSize uint32) {
 
 // runListener runs the listener goroutine.
 // Listener goroutine reads 9P messages from s.r by calling getReq
-// and allocats request for each of them, and sends it to the server's listenChan.
+// and sends them to c.listenChan.
 func (c *conn) runListener(ctx context.Context, rp *reqPool) {
 	rc := make(chan *request)
 	c.listenChan = rc
@@ -137,7 +137,7 @@ func (c *conn) runListener(ctx context.Context, rp *reqPool) {
 }
 
 // runResponder runs the responder goroutine.
-// Responder goroutine wait for reply Requests from the returned channel,
+// Responder goroutine wait for reply Requests from c.respChan,
 // and marshalls each of them into 9P messages and writes it to c.w.
 func (c *conn) runResponder(ctx context.Context, rp *reqPool) {
 	rc := make(chan *request)
@@ -168,11 +168,11 @@ func (c *conn) runResponder(ctx context.Context, rp *reqPool) {
 	}()
 }
 
-// GetReq reads 9P message from r, allocates request, adds it to s.rPool,
+// GetReq reads 9P message from r, allocates request, adds it to rp,
 // and returns it.
 // Any error it encountered is embedded into the request struct.
 // This function is called only by the server's listener goroutine,
-// and does not need to lock s.r.
+// and does not need to lock r.
 func getReq(r io.Reader, rp *reqPool, chatty bool) *request {
 	ifcall, err := RecvMsg(r)
 	if err != nil {
@@ -187,7 +187,6 @@ func getReq(r io.Reader, rp *reqPool, chatty bool) *request {
 		req := new(request)
 		req.ifcall = ifcall
 		req.listenErr = ErrDupTag
-		req.done = make(chan struct{})
 		if chatty {
 			fmt.Fprintf(os.Stderr, "<-- %v\n", req.ifcall)
 		}
@@ -195,7 +194,6 @@ func getReq(r io.Reader, rp *reqPool, chatty bool) *request {
 	}
 	req.tag = ifcall.GetTag()
 	req.ifcall = ifcall
-	req.done = make(chan struct{})
 	if ifcall, ok := req.ifcall.(*TFlush); ok {
 		req.oldreq, _ = rp.lookup(ifcall.Oldtag)
 	}
@@ -302,8 +300,8 @@ func sAttach(ctx context.Context, c *conn, rc <-chan *request) {
 		case r, ok := <-rc:
 			var (
 				fsys FS
-				st  fs.FileInfo
-				err error
+				st   fs.FileInfo
+				err  error
 			)
 			if !ok {
 				return
@@ -1251,10 +1249,10 @@ L:
 				break L
 			}
 			if r.listenErr != nil {
-				log.Printf("listen: %v", r.listenErr)
 				if r.listenErr == io.EOF {
 					break L
 				}
+				log.Printf("listen: %v", r.listenErr)
 				continue L
 			}
 			switch r.ifcall.(type) {