commit ca647525894af888776171b3f9f44d195af1ac65
parent 2c13f07672c329595634ef4ee1c96002a215978f
Author: Matsuda Kenji <info@mtkn.jp>
Date: Fri, 20 Oct 2023 08:06:52 +0900
add context to Req
Diffstat:
5 files changed, 226 insertions(+), 220 deletions(-)
diff --git a/cmd/diskfs/main.go b/cmd/diskfs/main.go
@@ -2,6 +2,7 @@
package main
import (
+ "context"
"flag"
"fmt"
"log"
@@ -46,5 +47,5 @@ func handle(conn net.Conn, disk *diskfs.FS) {
if *dFlag {
srv.Chatty()
}
- srv.Serve()
+ srv.Serve(context.Background())
}
diff --git a/cmd/iofs/main.go b/cmd/iofs/main.go
@@ -2,6 +2,7 @@
package main
import (
+ "context"
"flag"
"fmt"
"log"
@@ -44,5 +45,5 @@ func handle(conn net.Conn, disk *iofs.FS) {
if *dFlag {
srv.Chatty()
}
- srv.Serve()
+ srv.Serve(context.Background())
}
diff --git a/cmd/numfs/main.go b/cmd/numfs/main.go
@@ -2,6 +2,7 @@ package main
import (
"bytes"
+ "context"
"flag"
"fmt"
"io"
@@ -204,5 +205,5 @@ func handle(conn net.Conn, fs *numFS) {
if *dFlag {
srv.Chatty()
}
- srv.Serve()
+ srv.Serve(context.Background())
}
diff --git a/req.go b/req.go
@@ -15,10 +15,12 @@ type Req struct {
afid *Fid
oldReq *Req
pool *ReqPool
+ cancel context.CancelFunc
}
func (r *Req) flush() {
-
+ // TODO: need mutex?
+ r.cancel()
}
type ReqPool struct {
diff --git a/server.go b/server.go
@@ -1,6 +1,7 @@
package lib9p
import (
+ "context"
"fmt"
"io"
"io/fs"
@@ -39,7 +40,7 @@ type Server struct {
speakChan chan<- *Req
speakErrChan <-chan error
- Auth func(*Req)
+ Auth func(context.Context, *Req)
}
func NewServer(fsys FS, mSize uint32, r io.Reader, w io.Writer) *Server {
@@ -115,7 +116,6 @@ func getReq(r io.Reader, s *Server) (*Req, error) {
}
return nil, fmt.Errorf("recv: %v", err)
}
-
req, err := s.rPool.add(bufMsg(buf).Tag())
if err != nil {
// duplicate tag: cons up a fake Req
@@ -131,7 +131,6 @@ func getReq(r io.Reader, s *Server) (*Req, error) {
}
return req, ErrDupTag
}
-
req.srv = s
req.tag = bufMsg(buf).Tag()
req.ifcall, err = unmarshal(buf)
@@ -143,7 +142,6 @@ func getReq(r io.Reader, s *Server) (*Req, error) {
req.ifcall = bufMsg(buf)
return req, err
}
-
if s.chatty9P {
fmt.Fprintf(os.Stderr, "<-- %v\n", req.ifcall)
}
@@ -151,7 +149,7 @@ func getReq(r io.Reader, s *Server) (*Req, error) {
}
// TODO: abort all outstanding I/O on the same connection.
-func sVersion(s *Server, r *Req) {
+func sVersion(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TVersion)
version := ifcall.Version()
if ok := strings.HasPrefix(version, "9P2000"); !ok {
@@ -159,18 +157,15 @@ func sVersion(s *Server, r *Req) {
} else {
version = "9P2000"
}
-
msize := ifcall.MSize()
if msize > s.mSize() {
msize = s.mSize()
}
-
r.ofcall = &RVersion{
mSize: msize,
version: version,
}
-
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rVersion(r *Req, err error) {
@@ -180,18 +175,18 @@ func rVersion(r *Req, err error) {
r.srv.setMSize(r.ofcall.(*RVersion).MSize())
}
-func sAuth(s *Server, r *Req) {
+func sAuth(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TAuth)
var err error
+ // TODO: this call can block. need ctx?
r.afid, err = s.fPool.add(ifcall.AFid())
if err != nil {
- respond(r, ErrDupFid)
+ respond(ctx, r, ErrDupFid)
}
-
if s.Auth != nil {
- s.Auth(r)
+ s.Auth(ctx, r)
} else {
- respond(r, fmt.Errorf("authentication not required"))
+ respond(ctx, r, fmt.Errorf("authentication not required"))
return
}
}
@@ -203,56 +198,53 @@ func rAuth(r *Req, err error) {
}
}
-func sAttach(s *Server, r *Req) {
+func sAttach(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TAttach)
fid, err := s.fPool.add(ifcall.Fid())
if err != nil {
- respond(r, ErrDupFid)
+ respond(ctx, r, ErrDupFid)
return
}
-
- if s.Auth == nil && ifcall.AFid() != NOFID {
- respond(r, ErrBotch)
+ switch {
+ case s.Auth == nil && ifcall.AFid() == NOFID:
+ case s.Auth == nil && ifcall.AFid() != NOFID:
+ respond(ctx, r, ErrBotch)
return
- }
- if s.Auth != nil && ifcall.AFid() == NOFID {
- respond(r, fmt.Errorf("authentication required"))
+ case s.Auth != nil && ifcall.AFid() == NOFID:
+ respond(ctx, r, fmt.Errorf("authentication required"))
return
- }
- if s.Auth != nil && ifcall.AFid() != NOFID {
+ case s.Auth != nil && ifcall.AFid() != NOFID:
afid, ok := s.fPool.lookup(ifcall.AFid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
af, ok := afid.File.(*AuthFile)
if !ok {
- respond(r, fmt.Errorf("not auth file"))
+ respond(ctx, r, fmt.Errorf("not auth file"))
return
}
if !af.AuthOK {
- respond(r, fmt.Errorf("not authenticated"))
+ respond(ctx, r, fmt.Errorf("not authenticated"))
return
}
}
-
fid.File, err = s.fs.OpenFile(".", OREAD, 0)
if err != nil {
- respond(r, fmt.Errorf("open root: %v", err))
+ respond(ctx, r, fmt.Errorf("open root: %v", err))
return
}
fid.Uid = ifcall.UName()
fid.OMode = -1 // TODO: right?
st, err := fid.File.Stat()
if err != nil {
- respond(r, fmt.Errorf("stat root: %v", err))
+ respond(ctx, r, fmt.Errorf("stat root: %v", err))
return
}
-
r.ofcall = &RAttach{
qid: st.Sys().(*Stat).Qid,
}
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rAttach(r *Req, err error) {
@@ -261,10 +253,10 @@ func rAttach(r *Req, err error) {
}
}
-func sFlush(s *Server, r *Req) {
+func sFlush(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TFlush)
r.oldReq, _ = s.rPool.lookup(ifcall.OldTag())
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rFlush(r *Req, err error) {
@@ -277,27 +269,26 @@ func rFlush(r *Req, err error) {
r.ofcall = &RFlush{}
}
-func sWalk(s *Server, r *Req) {
+func sWalk(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TWalk)
oldFid, ok := s.fPool.lookup(ifcall.Fid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
if oldFid.OMode != -1 {
- respond(r, fmt.Errorf("cannot clone open fid"))
+ respond(ctx, r, fmt.Errorf("cannot clone open fid"))
return
}
oldst, err := oldFid.File.Stat()
if err != nil {
- respond(r, fmt.Errorf("stat: %v", err))
+ respond(ctx, r, fmt.Errorf("stat: %v", err))
return
}
if ifcall.NWName() > 0 && oldst.Sys().(*Stat).Qid.Type&QTDIR == 0 {
- respond(r, fmt.Errorf("walk on non-dir"))
+ respond(ctx, r, fmt.Errorf("walk on non-dir"))
return
}
-
var newFid *Fid
if ifcall.Fid() == ifcall.NewFid() {
newFid = oldFid
@@ -306,11 +297,10 @@ func sWalk(s *Server, r *Req) {
newFid, err = s.fPool.add(ifcall.NewFid())
if err != nil {
log.Printf("alloc: %v", err)
- respond(r, fmt.Errorf("internal error"))
+ respond(ctx, r, fmt.Errorf("internal error"))
return
}
}
-
wqids := make([]Qid, ifcall.NWName())
cwdp := oldFid.path
cwdf := oldFid.File
@@ -334,11 +324,10 @@ func sWalk(s *Server, r *Req) {
newFid.File = cwdf
newFid.Uid = oldFid.Uid
newFid.path = cwdp
-
r.ofcall = &RWalk{
qid: wqids[:n],
}
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rWalk(r *Req, err error) {
@@ -357,16 +346,16 @@ func rWalk(r *Req, err error) {
}
}
-func sOpen(s *Server, r *Req) {
+func sOpen(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TOpen)
var ok bool
r.fid, ok = s.fPool.lookup(ifcall.Fid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
if r.fid.OMode != -1 {
- respond(r, ErrBotch)
+ respond(ctx, r, ErrBotch)
return
}
// Write attempt to a directory is prohibitted by the protocol.
@@ -375,12 +364,12 @@ func sOpen(s *Server, r *Req) {
// but ORCLOSE is also prohibitted by the protocol...
st, err := r.fid.File.Stat()
if err != nil {
- respond(r, fmt.Errorf("stat: %v", err))
+ respond(ctx, r, fmt.Errorf("stat: %v", err))
return
}
qid := st.Sys().(*Stat).Qid
if qid.Type == QTDIR && ifcall.Mode() != OREAD {
- respond(r, fmt.Errorf("is a directory"))
+ respond(ctx, r, fmt.Errorf("is a directory"))
return
}
var p fs.FileMode
@@ -400,35 +389,33 @@ func sOpen(s *Server, r *Req) {
p |= AWRITE
}
if qid.Type&QTDIR != 0 && p != AREAD {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
-
if !hasPerm(r.fid.File, r.fid.Uid, p) {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
-
if ifcall.Mode()&ORCLOSE != 0 {
parentPath := path.Dir(r.fid.path)
parent, err := s.fs.OpenFile(parentPath, OREAD, 0)
defer parent.Close()
if err != nil {
- respond(r, fmt.Errorf("open parent"))
+ respond(ctx, r, fmt.Errorf("open parent"))
return
}
if !hasPerm(parent, r.fid.Uid, AWRITE) {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
}
-
r.ofcall = &ROpen{
qid: qid,
iounit: s.mSize() - IOHDRSZ,
}
- respond(r, nil)
+ respond(ctx, r, nil)
}
+
func rOpen(r *Req, err error) {
if err != nil {
setError(r, err)
@@ -443,36 +430,33 @@ func rOpen(r *Req, err error) {
r.fid.File = f
}
-func sCreate(s *Server, r *Req) {
+func sCreate(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TCreate)
-
var ok bool
r.fid, ok = s.fPool.lookup(ifcall.Fid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
dir := r.fid.File
dirstat, err := dir.Stat()
if err != nil {
- respond(r, fmt.Errorf("stat: %v", err))
+ respond(ctx, r, fmt.Errorf("stat: %v", err))
return
}
if !dirstat.IsDir() {
- respond(r, fmt.Errorf("create in non-dir"))
+ respond(ctx, r, fmt.Errorf("create in non-dir"))
return
}
if !hasPerm(dir, r.fid.Uid, AWRITE) {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
-
cfdir, ok := r.fid.File.(CreaterFile)
if !ok {
- respond(r, ErrOperation)
+ respond(ctx, r, ErrOperation)
return
}
-
perm := ifcall.Perm()
dirperm := dirstat.Mode()
if perm&fs.ModeDir == 0 {
@@ -480,31 +464,27 @@ func sCreate(s *Server, r *Req) {
} else {
perm &= ^FileMode(0777) | (dirperm & FileMode(0777))
}
-
file, err := cfdir.Create(ifcall.Name(), r.fid.Uid, ifcall.Mode(), perm)
if err != nil {
- respond(r, fmt.Errorf("create: %v", err))
+ respond(ctx, r, fmt.Errorf("create: %v", err))
return
}
if err := r.fid.File.Close(); err != nil {
- respond(r, fmt.Errorf("close: %v", err))
+ respond(ctx, r, fmt.Errorf("close: %v", err))
return
}
-
r.fid.File = file
r.fid.path = path.Join(r.fid.path, ifcall.Name())
-
st, err := r.fid.File.Stat()
if err != nil {
- respond(r, fmt.Errorf("stat: %v", err))
+ respond(ctx, r, fmt.Errorf("stat: %v", err))
return
}
-
r.ofcall = &RCreate{
qid: st.Sys().(*Stat).Qid,
iounit: s.mSize() - IOHDRSZ,
}
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rCreate(r *Req, err error) {
if err != nil {
@@ -521,86 +501,100 @@ func rCreate(r *Req, err error) {
r.fid.File = f
}
-func sRead(s *Server, r *Req) {
+func sRead(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TRead)
var ok bool
r.fid, ok = s.fPool.lookup(ifcall.Fid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
if r.fid.OMode == -1 {
- respond(r, fmt.Errorf("not open"))
+ respond(ctx, r, fmt.Errorf("not open"))
return
}
-
if r.fid.OMode != OREAD && r.fid.OMode != ORDWR && r.fid.OMode != OEXEC {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
-
data := make([]byte, ifcall.Count())
- var n int
- var err error
+ done := make(chan struct{})
+ var (
+ n int
+ err error
+ wg sync.WaitGroup
+ )
fi, err := r.fid.File.Stat()
if err != nil {
log.Printf("Stat: %v", err)
- respond(r, fmt.Errorf("internal error"))
+ respond(ctx, r, fmt.Errorf("internal error"))
return
}
- if fi.IsDir() {
- children, err := getChildren(s.fs, r.fid.path)
- if err != nil {
- log.Printf("get children: %v", err)
- }
-
- if ifcall.Offset() != 0 && ifcall.Offset() != r.fid.dirOffset {
- respond(r, fmt.Errorf("invalid dir offset"))
- return
- }
- if ifcall.Offset() == 0 {
- r.fid.dirIndex = 0
- r.fid.dirOffset = 0
- }
- k := r.fid.dirIndex
- for ; k < len(children); k++ {
- if children[k] == nil {
- continue
- }
- fi, err := children[k].Stat()
+ wg.Add(1)
+ go func() {
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+ defer wg.Done()
+ if fi.IsDir() {
+ children, err := getChildren(s.fs, r.fid.path)
if err != nil {
- log.Printf("stat: %v", err)
- continue
+ log.Printf("get children: %v", err)
}
- st := fi.Sys().(*Stat)
- buf := st.marshal()
- if n+len(buf) > len(data) {
- break
+
+ if ifcall.Offset() != 0 && ifcall.Offset() != r.fid.dirOffset {
+ respond(ctx, r, fmt.Errorf("invalid dir offset"))
+ return
}
- for i := 0; i < len(buf); i++ {
- data[n+i] = buf[i]
+ if ifcall.Offset() == 0 {
+ r.fid.dirIndex = 0
+ r.fid.dirOffset = 0
}
- n += len(buf)
- }
- r.fid.dirOffset += uint64(n)
- r.fid.dirIndex = k
- } else {
- if reader, ok := r.fid.File.(io.ReaderAt); ok {
- n, err = reader.ReadAt(data, int64(ifcall.Offset()))
+ k := r.fid.dirIndex
+ for ; k < len(children); k++ {
+ if children[k] == nil {
+ continue
+ }
+ fi, err := children[k].Stat()
+ if err != nil {
+ log.Printf("stat: %v", err)
+ continue
+ }
+ st := fi.Sys().(*Stat)
+ buf := st.marshal()
+ if n+len(buf) > len(data) {
+ break
+ }
+ for i := 0; i < len(buf); i++ {
+ data[n+i] = buf[i]
+ }
+ n += len(buf)
+ }
+ r.fid.dirOffset += uint64(n)
+ r.fid.dirIndex = k
} else {
- n, err = r.fid.File.Read(data)
- }
- if err != io.EOF && err != nil {
- log.Printf("sRead: %v\n", err)
- respond(r, err)
- return
+ if reader, ok := r.fid.File.(io.ReaderAt); ok {
+ n, err = reader.ReadAt(data, int64(ifcall.Offset()))
+ } else {
+ n, err = r.fid.File.Read(data)
+ }
+ if err != io.EOF && err != nil {
+ log.Printf("sRead: %v\n", err)
+ respond(ctx, r, err)
+ return
+ }
}
+ }()
+ select {
+ case <-done:
+ case <-ctx.Done():
}
r.ofcall = &RRead{
count: uint32(n),
data: data[:n],
}
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rRead(r *Req, err error) {
if err != nil {
@@ -608,12 +602,12 @@ func rRead(r *Req, err error) {
}
}
-func sWrite(s *Server, r *Req) {
+func sWrite(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TWrite)
var ok bool
r.fid, ok = s.fPool.lookup(ifcall.Fid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
// TODO: should I use exported function instead of directly
@@ -622,37 +616,50 @@ func sWrite(s *Server, r *Req) {
ifcall.count = s.mSize() - IOHDRSZ
}
if !hasPerm(r.fid.File, r.fid.Uid, AWRITE) {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
omode := r.fid.OMode & 3
if omode != OWRITE && omode != ORDWR {
- respond(r, fmt.Errorf("write on fid with open mode 0x%x", r.fid.OMode))
+ respond(ctx, r, fmt.Errorf("write on fid with open mode 0x%x", r.fid.OMode))
return
}
-
ofcall := new(RWrite)
- switch file := r.fid.File.(type) {
- case io.WriterAt:
- n, err := file.WriteAt(ifcall.data, int64(ifcall.offset))
- if err != nil {
- respond(r, fmt.Errorf("write: %v", err))
- return
- }
- ofcall.count = uint32(n)
- case io.Writer:
- n, err := file.Write(ifcall.data)
- if err != nil {
- respond(r, fmt.Errorf("write: %v", err))
+ var wg sync.WaitGroup
+ done := make(chan struct{})
+ wg.Add(1)
+ go func() {
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+ switch file := r.fid.File.(type) {
+ case io.WriterAt:
+ n, err := file.WriteAt(ifcall.data, int64(ifcall.offset))
+ if err != nil {
+ respond(ctx, r, fmt.Errorf("write: %v", err))
+ return
+ }
+ ofcall.count = uint32(n)
+ case io.Writer:
+ n, err := file.Write(ifcall.data)
+ if err != nil {
+ respond(ctx, r, fmt.Errorf("write: %v", err))
+ return
+ }
+ ofcall.count = uint32(n)
+ default:
+ respond(ctx, r, ErrOperation)
return
}
- ofcall.count = uint32(n)
- default:
- respond(r, ErrOperation)
- return
+ wg.Done()
+ }()
+ select {
+ case <-done:
+ case <-ctx.Done():
}
r.ofcall = ofcall
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rWrite(r *Req, err error) {
@@ -663,52 +670,51 @@ func rWrite(r *Req, err error) {
// TODO: Increment Qid.Vers
}
-func sClunk(s *Server, r *Req) {
+func sClunk(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TClunk)
_, ok := s.fPool.lookup(ifcall.Fid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
s.fPool.delete(ifcall.Fid())
r.ofcall = &RClunk{}
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rClunk(r *Req, err error) {}
-func sRemove(s *Server, r *Req) {
+func sRemove(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TRemove)
var ok bool
r.fid, ok = s.fPool.lookup(ifcall.Fid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
defer s.fPool.delete(ifcall.Fid())
-
parentPath := path.Dir(r.fid.path)
parent, err := s.fs.OpenFile(parentPath, OREAD, 0)
if err != nil {
- respond(r, fmt.Errorf("open parent: %v", err))
+ respond(ctx, r, fmt.Errorf("open parent: %v", err))
return
}
defer parent.Close()
if !hasPerm(parent, r.fid.Uid, AWRITE) {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
rfile, ok := r.fid.File.(RemoverFile)
if !ok {
- respond(r, ErrOperation)
+ respond(ctx, r, ErrOperation)
return
}
if err := rfile.Remove(); err != nil {
- respond(r, fmt.Errorf("remove: %v", err))
+ respond(ctx, r, fmt.Errorf("remove: %v", err))
return
}
r.ofcall = &RRemove{}
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rRemove(r *Req, err error) {
if err != nil {
@@ -716,24 +722,24 @@ func rRemove(r *Req, err error) {
}
}
-func sStat(s *Server, r *Req) {
+func sStat(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TStat)
var ok bool
r.fid, ok = s.fPool.lookup(ifcall.Fid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
fileInfo, err := r.fid.File.Stat()
if err != nil {
log.Printf("stat %v: %v", r.fid.File, err)
- respond(r, fmt.Errorf("internal error"))
+ respond(ctx, r, fmt.Errorf("internal error"))
return
}
r.ofcall = &RStat{
stat: fileInfo.Sys().(*Stat),
}
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rStat(r *Req, err error) {
@@ -742,110 +748,101 @@ func rStat(r *Req, err error) {
}
}
-func sWStat(s *Server, r *Req) {
+func sWStat(ctx context.Context, s *Server, r *Req) {
ifcall := r.ifcall.(*TWStat)
var ok bool
r.fid, ok = s.fPool.lookup(ifcall.Fid())
if !ok {
- respond(r, ErrUnknownFid)
+ respond(ctx, r, ErrUnknownFid)
return
}
-
wsfile, ok := r.fid.File.(WriterStatFile)
if !ok {
- respond(r, ErrOperation)
+ respond(ctx, r, ErrOperation)
return
}
-
wstat := ifcall.Stat()
fi, err := r.fid.File.Stat()
if err != nil {
- respond(r, fmt.Errorf("stat: %v", err))
+ respond(ctx, r, fmt.Errorf("stat: %v", err))
return
}
newStat := fi.Sys().(*Stat)
-
if wstat.Type != ^uint16(0) || wstat.Dev != ^uint32(0) ||
wstat.Qid.Type != QidType(^uint8(0)) || wstat.Qid.Vers != ^uint32(0) ||
wstat.Qid.Path != ^uint64(0) || wstat.Atime != ^uint32(0) ||
wstat.Uid != "" || wstat.Muid != "" {
- respond(r, fmt.Errorf("operation not permitted"))
+ respond(ctx, r, fmt.Errorf("operation not permitted"))
return
}
-
if wstat.Name != "" {
parentPath := path.Dir(r.fid.path)
parent, err := s.fs.OpenFile(parentPath, OREAD, 0)
if err != nil {
- respond(r, fmt.Errorf("get parent: %v", err))
+ respond(ctx, r, fmt.Errorf("get parent: %v", err))
return
}
if !hasPerm(parent, r.fid.Uid, AWRITE) {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
children, err := getChildren(s.fs, parentPath)
if err != nil {
- respond(r, fmt.Errorf("get children: %v", err))
+ respond(ctx, r, fmt.Errorf("get children: %v", err))
return
}
for _, f := range children {
s, err := f.Stat()
if err != nil {
- respond(r, fmt.Errorf("stat: %v", err))
+ respond(ctx, r, fmt.Errorf("stat: %v", err))
return
}
if s.Name() == wstat.Name {
- respond(r, fmt.Errorf("file already exists"))
+ respond(ctx, r, fmt.Errorf("file already exists"))
return
}
}
newStat.Name = wstat.Name
}
-
if wstat.Length != ^int64(0) {
if fi.IsDir() || !hasPerm(r.fid.File, r.fid.Uid, AWRITE) {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
newStat.Length = wstat.Length
}
-
if wstat.Mode != FileMode(^uint32(0)) {
// the owner of the file or the group leader of the file's group.
if r.fid.Uid != newStat.Uid && r.fid.Uid != newStat.Gid {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
if wstat.Mode&fs.ModeDir != newStat.Mode&fs.ModeDir {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
newStat.Mode = wstat.Mode
}
-
if wstat.Mtime != ^uint32(0) {
// the owner of the file or the group leader of the file's group.
if r.fid.Uid != newStat.Uid && r.fid.Uid != newStat.Gid {
- respond(r, ErrPerm)
+ respond(ctx, r, ErrPerm)
return
}
newStat.Mtime = wstat.Mtime
}
if wstat.Gid != "" {
// TODO implement
- respond(r, fmt.Errorf("not implemented"))
+ respond(ctx, r, fmt.Errorf("not implemented"))
return
}
-
err = wsfile.WStat(newStat)
if err != nil {
- respond(r, fmt.Errorf("wstat: %v", err))
+ respond(ctx, r, fmt.Errorf("wstat: %v", err))
return
}
-
r.ofcall = &RWStat{}
- respond(r, nil)
+ respond(ctx, r, nil)
}
func rWStat(r *Req, err error) {
@@ -854,47 +851,56 @@ func rWStat(r *Req, err error) {
}
}
-func (s *Server) Serve() {
+func (s *Server) Serve(ctx context.Context) {
L:
for {
select {
+ case err := <-s.speakErrChan:
+ // TODO: handle error
+ log.Printf("speak: %v", err)
+ continue L
case err := <-s.listenErrChan:
- log.Printf("getReq: %v\n", err)
+ log.Printf("getReq: %v", err)
if err == io.EOF {
break L
}
continue L
case r := <-s.listenChan:
+ ctx1, cancel := context.WithCancel(ctx)
+ r.cancel = cancel
go func() {
switch r.ifcall.(type) {
default:
- respond(r, fmt.Errorf("unknown message type: %d", r.ifcall.Type()))
+ respond(ctx, r, fmt.Errorf("unknown message type: %d", r.ifcall.Type()))
case *TVersion:
- sVersion(s, r)
+ sVersion(ctx1, s, r)
case *TAuth:
- sAuth(s, r)
+ sAuth(ctx1, s, r)
case *TAttach:
- sAttach(s, r)
+ sAttach(ctx1, s, r)
case *TWalk:
- sWalk(s, r)
+ sWalk(ctx1, s, r)
case *TOpen:
- sOpen(s, r)
+ sOpen(ctx1, s, r)
case *TCreate:
- sCreate(s, r)
+ sCreate(ctx1, s, r)
case *TRead:
- sRead(s, r)
+ sRead(ctx1, s, r)
case *TWrite:
- sWrite(s, r)
+ sWrite(ctx1, s, r)
case *TClunk:
- sClunk(s, r)
+ sClunk(ctx1, s, r)
case *TRemove:
- sRemove(s, r)
+ sRemove(ctx1, s, r)
case *TStat:
- sStat(s, r)
+ sStat(ctx1, s, r)
case *TWStat:
- sWStat(s, r)
+ sWStat(ctx1, s, r)
}
}()
+ case <-ctx.Done():
+ log.Println(ctx.Err())
+ break L
}
}
}
@@ -902,7 +908,7 @@ L:
// Respond responds to the request r with the message r.ofcall if err is nil,
// or if err is not nil, with the Rerror with the error message.
// If r is nil, or both r.ofcall and err are nil it panics.
-func respond(r *Req, err error) {
+func respond(ctx context.Context, r *Req, err error) {
switch r.ifcall.(type) {
default:
panic("bug")
@@ -931,9 +937,7 @@ func respond(r *Req, err error) {
case *TWStat:
rWStat(r, err)
}
-
r.ofcall.SetTag(r.tag)
-
// free tag.
if r.pool == nil && err != ErrDupTag {
panic("ReqPool is nil but err is not EDupTag")
@@ -941,15 +945,12 @@ func respond(r *Req, err error) {
if r.pool != nil {
r.pool.delete(r.tag)
}
-
- r.srv.speakChan <- r
- // TODO: handle error gently
- go func() {
- err := <-r.srv.speakErrChan
- log.Fatalf("speak: %v", err)
- }()
-
- if r.srv.chatty9P {
- fmt.Fprintf(os.Stderr, "--> %s\n", r.ofcall)
+ select {
+ case r.srv.speakChan <- r:
+ if r.srv.chatty9P {
+ fmt.Fprintf(os.Stderr, "--> %s\n", r.ofcall)
+ }
+ case <-ctx.Done():
+ log.Printf("req flush: %v", r.ifcall)
}
}