commit 7f859b6f3974be1a81eb24098f1d6dfc00658690
parent 0d0f72add8ba7cde8e315b016bcb0c7840f7ebec
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Fri, 12 Jan 2024 15:31:16 +0900
add fid.qidpath
Diffstat:
3 files changed, 65 insertions(+), 30 deletions(-)
diff --git a/fid.go b/fid.go
@@ -71,6 +71,7 @@ type fid struct {
 	fid       uint32
 	omode     OpenMode      /* -1 = not open */
 	path      string        // The path from the root of the FS.
+	qidpath   uint64
 	fs        FS            // The associated FS.
 	file      File          // The associated File.
 	uid       string        // The user id derived from the attach message.
diff --git a/server.go b/server.go
@@ -280,6 +280,10 @@ func sAuth(ctx context.Context, c *conn, rc <-chan *request) {
 				setError(r, r.err)
 			} else {
 				c.s.Auth(ctx, r)
+				// TODO: move this code to c.s.Auth
+				if rauth, ok := r.ofcall.(*RAuth); ok {
+					r.afid.qidpath = rauth.Aqid.Path
+				}
 			}
 			select {
 			case c.respChan <- r:
@@ -347,6 +351,7 @@ func sAttach(ctx context.Context, c *conn, rc <-chan *request) {
 				r.err = fmt.Errorf("stat root: %v", err)
 				goto resp
 			}
+			r.fid.qidpath = fi.Sys().(*Stat).Qid.Path
 			r.ofcall = &RAttach{
 				Qid: fi.Sys().(*Stat).Qid,
 			}
@@ -404,7 +409,6 @@ func sWalk(ctx context.Context, c *conn, rc <-chan *request) {
 				newFid *fid
 				wqids  []Qid
 				cwdp   string
-				n      int
 			)
 			ifcall := r.ifcall.(*TWalk)
 			oldFid, ok := c.fPool.lookup(ifcall.Fid)
@@ -434,9 +438,9 @@ func sWalk(ctx context.Context, c *conn, rc <-chan *request) {
 					goto resp
 				}
 			}
-			wqids = make([]Qid, len(ifcall.Wnames))
+			wqids = make([]Qid, 0, len(ifcall.Wnames))
 			cwdp = oldFid.path
-			for i, name := range ifcall.Wnames {
+			for _, name := range ifcall.Wnames {
 				cwdp = path.Clean(path.Join(cwdp, name))
 				if cwdp == ".." {
 					cwdp = "." // parent of the root is itself.
@@ -445,14 +449,18 @@ func sWalk(ctx context.Context, c *conn, rc <-chan *request) {
 				if err != nil {
 					break
 				}
-				wqids[i] = stat.Sys().(*Stat).Qid
-				n++
+				wqids = append(wqids, stat.Sys().(*Stat).Qid)
+			}
+			if len(wqids) == 0 {
+				newFid.qidpath = oldFid.qidpath
+			} else {
+				newFid.qidpath = wqids[len(wqids)-1].Path
 			}
 			newFid.path = cwdp
 			newFid.uid = oldFid.uid
 			newFid.fs = oldFid.fs
 			r.ofcall = &RWalk{
-				Qids: wqids[:n],
+				Qids: wqids,
 			}
 		resp:
 			if r.ofcall == nil {
@@ -573,34 +581,32 @@ func sOpen(ctx context.Context, c *conn, rc <-chan *request) {
 		resp:
 			if r.err != nil {
 				setError(r, r.err)
-				select {
-				case c.respChan <- r:
-				case <-ctx.Done():
-					return
-				}
-				continue
+				goto send
 			}
 			if _, ok := r.fid.file.(*AuthFile); ok {
 				r.fid.omode = ifcall.Mode
-				select {
-				case c.respChan <- r:
-				case <-ctx.Done():
-					return
-				}
-				continue
+				goto send
 			}
 			r.fid.file, err = r.fid.fs.OpenFile(r.fid.path, ModeToFlag(ifcall.Mode))
 			if err != nil {
 				setError(r, err)
-				select {
-				case c.respChan <- r:
-				case <-ctx.Done():
-					return
-				}
-				continue
+				goto send
+			}
+			fi, err = r.fid.file.Stat()
+			if err != nil {
+				r.fid.file.Close()
+				setError(r, err)
+				goto send
+			}
+			if fi.Sys().(*Stat).Qid.Path != r.fid.qidpath {
+				log.Println("open:", fi.Sys().(*Stat).Qid.Path, r.fid.qidpath)
+				r.fid.file.Close()
+				setError(r, fmt.Errorf("qid path mismatch"))
+				goto send
 			}
 			// omode should be set after successfully opening it.
 			r.fid.omode = ifcall.Mode
+		send:
 			select {
 			case c.respChan <- r:
 			case <-ctx.Done():
@@ -676,6 +682,7 @@ func sCreate(ctx context.Context, c *conn, rc <-chan *request) {
 				r.err = fmt.Errorf("stat: %v", err)
 				goto resp
 			}
+			r.fid.qidpath = fi.Sys().(*Stat).Qid.Path
 			r.ofcall = &RCreate{
 				Qid:    fi.Sys().(*Stat).Qid,
 				Iounit: c.mSize() - IOHDRSZ,
@@ -954,7 +961,8 @@ func sRemove(ctx context.Context, c *conn, rc <-chan *request) {
 			}
 			var (
 				parentPath string
-				pstat      fs.FileInfo
+				pfi        fs.FileInfo
+				fi         fs.FileInfo
 				err        error
 				rfs        RemoverFS
 			)
@@ -969,21 +977,33 @@ func sRemove(ctx context.Context, c *conn, rc <-chan *request) {
 				r.fid.file.Close()
 			}
 			parentPath = path.Dir(r.fid.path)
-			pstat, err = fs.Stat(ExportFS{r.fid.fs}, parentPath)
+			pfi, err = fs.Stat(ExportFS{r.fid.fs}, parentPath)
 			if err != nil {
 				r.err = fmt.Errorf("stat parent: %v", err)
 				goto resp
 			}
-			if !hasPerm(r.fid.fs, pstat, r.fid.uid, AWRITE) {
+			if !hasPerm(r.fid.fs, pfi, r.fid.uid, AWRITE) {
 				r.err = ErrPerm
 				goto resp
 			}
+			// BUG: race. Remove call below uses r.fid.path, so I need to
+			// check whether the underlying qid is the same.
+			// But other connection can move the same file and then create
+			// new one with the same name.
+			fi, err = fs.Stat(ExportFS{FS: r.fid.fs}, r.fid.path)
+			if err != nil {
+				r.err = err
+				goto resp
+			}
+			if r.fid.qidpath != fi.Sys().(*Stat).Qid.Path {
+				r.err = fmt.Errorf("qid path mismatch")
+				goto resp
+			}
 			rfs, ok = r.fid.fs.(RemoverFS)
 			if !ok {
 				r.err = ErrOperation
 				goto resp
 			}
-			// TODO: this assumes files can be identified by its path.
 			// I think the argument of RemoverFS.Remove should be Qid.Path.
 			if err = rfs.Remove(r.fid.path); err != nil {
 				r.err = fmt.Errorf("remove: %v", err)
@@ -1072,7 +1092,6 @@ func sWStat(ctx context.Context, c *conn, rc <-chan *request) {
 					r.err = fmt.Errorf("open: %v", err)
 					goto resp
 				}
-				defer r.fid.file.Close()
 			}
 			wsfile, ok = r.fid.file.(WriterStatFile)
 			if !ok {
@@ -1086,6 +1105,10 @@ func sWStat(ctx context.Context, c *conn, rc <-chan *request) {
 				goto resp
 			}
 			newStat = fi.Sys().(*Stat)
+			if r.fid.qidpath != newStat.Qid.Path {
+				r.err = fmt.Errorf("qid mismatch")
+				goto resp
+			}
 			if wstat.Type != ^uint16(0) && wstat.Type != newStat.Type ||
 				wstat.Dev != ^uint32(0) && wstat.Dev != newStat.Dev ||
 				wstat.Qid.Type != QidType(^uint8(0)) && wstat.Qid.Type != newStat.Qid.Type ||
@@ -1178,9 +1201,14 @@ func sWStat(ctx context.Context, c *conn, rc <-chan *request) {
 				r.err = fmt.Errorf("wstat: %v", err)
 				goto resp
 			}
+			if path.Base(r.fid.path) != newStat.Name {
+				r.fid.path = path.Join(path.Dir(r.fid.path), newStat.Name)
+			}
 			r.ofcall = &RWstat{}
-			// TODO: update r.fid.path
 		resp:
+			if r.fid.omode == -1 && r.fid.file != nil {
+				r.fid.file.Close()
+			}
 			if r.err != nil {
 				setError(r, r.err)
 			}
diff --git a/server_test.go b/server_test.go
@@ -432,6 +432,7 @@ func TestSOpen(t *testing.T) {
 			continue
 		}
 		fid.file = f
+		fid.qidpath = f.stat.Qid.Path
 		tc <- &request{ifcall: test.input}
 		ofcall := (<-rc).ofcall
 		switch test.wantMsg.(type) {
@@ -791,6 +792,11 @@ func TestSRemove(t *testing.T) {
 			}
 			defer f.Close()
 			fid.file = f
+			fi, err := f.Stat()
+			if err != nil {
+				t.Fatal(i, err)
+			}
+			fid.qidpath = fi.Sys().(*Stat).Qid.Path
 			fid.uid = test.ruid
 			fid.fs = testfs
 			tc <- &request{ifcall: test.ifcall}