lib9p

Go 9P library.
Log | Files | Refs

commit ca647525894af888776171b3f9f44d195af1ac65
parent 2c13f07672c329595634ef4ee1c96002a215978f
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Fri, 20 Oct 2023 08:06:52 +0900

add context to Req

Diffstat:
Mcmd/diskfs/main.go | 3++-
Mcmd/iofs/main.go | 3++-
Mcmd/numfs/main.go | 3++-
Mreq.go | 4+++-
Mserver.go | 433++++++++++++++++++++++++++++++++++++++++---------------------------------------
5 files changed, 226 insertions(+), 220 deletions(-)

diff --git a/cmd/diskfs/main.go b/cmd/diskfs/main.go @@ -2,6 +2,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -46,5 +47,5 @@ func handle(conn net.Conn, disk *diskfs.FS) { if *dFlag { srv.Chatty() } - srv.Serve() + srv.Serve(context.Background()) } diff --git a/cmd/iofs/main.go b/cmd/iofs/main.go @@ -2,6 +2,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -44,5 +45,5 @@ func handle(conn net.Conn, disk *iofs.FS) { if *dFlag { srv.Chatty() } - srv.Serve() + srv.Serve(context.Background()) } diff --git a/cmd/numfs/main.go b/cmd/numfs/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "flag" "fmt" "io" @@ -204,5 +205,5 @@ func handle(conn net.Conn, fs *numFS) { if *dFlag { srv.Chatty() } - srv.Serve() + srv.Serve(context.Background()) } diff --git a/req.go b/req.go @@ -15,10 +15,12 @@ type Req struct { afid *Fid oldReq *Req pool *ReqPool + cancel context.CancelFunc } func (r *Req) flush() { - + // TODO: need mutex? + r.cancel() } type ReqPool struct { diff --git a/server.go b/server.go @@ -1,6 +1,7 @@ package lib9p import ( + "context" "fmt" "io" "io/fs" @@ -39,7 +40,7 @@ type Server struct { speakChan chan<- *Req speakErrChan <-chan error - Auth func(*Req) + Auth func(context.Context, *Req) } func NewServer(fsys FS, mSize uint32, r io.Reader, w io.Writer) *Server { @@ -115,7 +116,6 @@ func getReq(r io.Reader, s *Server) (*Req, error) { } return nil, fmt.Errorf("recv: %v", err) } - req, err := s.rPool.add(bufMsg(buf).Tag()) if err != nil { // duplicate tag: cons up a fake Req @@ -131,7 +131,6 @@ func getReq(r io.Reader, s *Server) (*Req, error) { } return req, ErrDupTag } - req.srv = s req.tag = bufMsg(buf).Tag() req.ifcall, err = unmarshal(buf) @@ -143,7 +142,6 @@ func getReq(r io.Reader, s *Server) (*Req, error) { req.ifcall = bufMsg(buf) return req, err } - if s.chatty9P { fmt.Fprintf(os.Stderr, "<-- %v\n", req.ifcall) } @@ -151,7 +149,7 @@ func getReq(r io.Reader, s *Server) (*Req, error) { } // TODO: abort all outstanding I/O on the same connection. -func sVersion(s *Server, r *Req) { +func sVersion(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TVersion) version := ifcall.Version() if ok := strings.HasPrefix(version, "9P2000"); !ok { @@ -159,18 +157,15 @@ func sVersion(s *Server, r *Req) { } else { version = "9P2000" } - msize := ifcall.MSize() if msize > s.mSize() { msize = s.mSize() } - r.ofcall = &RVersion{ mSize: msize, version: version, } - - respond(r, nil) + respond(ctx, r, nil) } func rVersion(r *Req, err error) { @@ -180,18 +175,18 @@ func rVersion(r *Req, err error) { r.srv.setMSize(r.ofcall.(*RVersion).MSize()) } -func sAuth(s *Server, r *Req) { +func sAuth(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TAuth) var err error + // TODO: this call can block. need ctx? r.afid, err = s.fPool.add(ifcall.AFid()) if err != nil { - respond(r, ErrDupFid) + respond(ctx, r, ErrDupFid) } - if s.Auth != nil { - s.Auth(r) + s.Auth(ctx, r) } else { - respond(r, fmt.Errorf("authentication not required")) + respond(ctx, r, fmt.Errorf("authentication not required")) return } } @@ -203,56 +198,53 @@ func rAuth(r *Req, err error) { } } -func sAttach(s *Server, r *Req) { +func sAttach(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TAttach) fid, err := s.fPool.add(ifcall.Fid()) if err != nil { - respond(r, ErrDupFid) + respond(ctx, r, ErrDupFid) return } - - if s.Auth == nil && ifcall.AFid() != NOFID { - respond(r, ErrBotch) + switch { + case s.Auth == nil && ifcall.AFid() == NOFID: + case s.Auth == nil && ifcall.AFid() != NOFID: + respond(ctx, r, ErrBotch) return - } - if s.Auth != nil && ifcall.AFid() == NOFID { - respond(r, fmt.Errorf("authentication required")) + case s.Auth != nil && ifcall.AFid() == NOFID: + respond(ctx, r, fmt.Errorf("authentication required")) return - } - if s.Auth != nil && ifcall.AFid() != NOFID { + case s.Auth != nil && ifcall.AFid() != NOFID: afid, ok := s.fPool.lookup(ifcall.AFid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } af, ok := afid.File.(*AuthFile) if !ok { - respond(r, fmt.Errorf("not auth file")) + respond(ctx, r, fmt.Errorf("not auth file")) return } if !af.AuthOK { - respond(r, fmt.Errorf("not authenticated")) + respond(ctx, r, fmt.Errorf("not authenticated")) return } } - fid.File, err = s.fs.OpenFile(".", OREAD, 0) if err != nil { - respond(r, fmt.Errorf("open root: %v", err)) + respond(ctx, r, fmt.Errorf("open root: %v", err)) return } fid.Uid = ifcall.UName() fid.OMode = -1 // TODO: right? st, err := fid.File.Stat() if err != nil { - respond(r, fmt.Errorf("stat root: %v", err)) + respond(ctx, r, fmt.Errorf("stat root: %v", err)) return } - r.ofcall = &RAttach{ qid: st.Sys().(*Stat).Qid, } - respond(r, nil) + respond(ctx, r, nil) } func rAttach(r *Req, err error) { @@ -261,10 +253,10 @@ func rAttach(r *Req, err error) { } } -func sFlush(s *Server, r *Req) { +func sFlush(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TFlush) r.oldReq, _ = s.rPool.lookup(ifcall.OldTag()) - respond(r, nil) + respond(ctx, r, nil) } func rFlush(r *Req, err error) { @@ -277,27 +269,26 @@ func rFlush(r *Req, err error) { r.ofcall = &RFlush{} } -func sWalk(s *Server, r *Req) { +func sWalk(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TWalk) oldFid, ok := s.fPool.lookup(ifcall.Fid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } if oldFid.OMode != -1 { - respond(r, fmt.Errorf("cannot clone open fid")) + respond(ctx, r, fmt.Errorf("cannot clone open fid")) return } oldst, err := oldFid.File.Stat() if err != nil { - respond(r, fmt.Errorf("stat: %v", err)) + respond(ctx, r, fmt.Errorf("stat: %v", err)) return } if ifcall.NWName() > 0 && oldst.Sys().(*Stat).Qid.Type&QTDIR == 0 { - respond(r, fmt.Errorf("walk on non-dir")) + respond(ctx, r, fmt.Errorf("walk on non-dir")) return } - var newFid *Fid if ifcall.Fid() == ifcall.NewFid() { newFid = oldFid @@ -306,11 +297,10 @@ func sWalk(s *Server, r *Req) { newFid, err = s.fPool.add(ifcall.NewFid()) if err != nil { log.Printf("alloc: %v", err) - respond(r, fmt.Errorf("internal error")) + respond(ctx, r, fmt.Errorf("internal error")) return } } - wqids := make([]Qid, ifcall.NWName()) cwdp := oldFid.path cwdf := oldFid.File @@ -334,11 +324,10 @@ func sWalk(s *Server, r *Req) { newFid.File = cwdf newFid.Uid = oldFid.Uid newFid.path = cwdp - r.ofcall = &RWalk{ qid: wqids[:n], } - respond(r, nil) + respond(ctx, r, nil) } func rWalk(r *Req, err error) { @@ -357,16 +346,16 @@ func rWalk(r *Req, err error) { } } -func sOpen(s *Server, r *Req) { +func sOpen(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TOpen) var ok bool r.fid, ok = s.fPool.lookup(ifcall.Fid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } if r.fid.OMode != -1 { - respond(r, ErrBotch) + respond(ctx, r, ErrBotch) return } // Write attempt to a directory is prohibitted by the protocol. @@ -375,12 +364,12 @@ func sOpen(s *Server, r *Req) { // but ORCLOSE is also prohibitted by the protocol... st, err := r.fid.File.Stat() if err != nil { - respond(r, fmt.Errorf("stat: %v", err)) + respond(ctx, r, fmt.Errorf("stat: %v", err)) return } qid := st.Sys().(*Stat).Qid if qid.Type == QTDIR && ifcall.Mode() != OREAD { - respond(r, fmt.Errorf("is a directory")) + respond(ctx, r, fmt.Errorf("is a directory")) return } var p fs.FileMode @@ -400,35 +389,33 @@ func sOpen(s *Server, r *Req) { p |= AWRITE } if qid.Type&QTDIR != 0 && p != AREAD { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } - if !hasPerm(r.fid.File, r.fid.Uid, p) { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } - if ifcall.Mode()&ORCLOSE != 0 { parentPath := path.Dir(r.fid.path) parent, err := s.fs.OpenFile(parentPath, OREAD, 0) defer parent.Close() if err != nil { - respond(r, fmt.Errorf("open parent")) + respond(ctx, r, fmt.Errorf("open parent")) return } if !hasPerm(parent, r.fid.Uid, AWRITE) { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } } - r.ofcall = &ROpen{ qid: qid, iounit: s.mSize() - IOHDRSZ, } - respond(r, nil) + respond(ctx, r, nil) } + func rOpen(r *Req, err error) { if err != nil { setError(r, err) @@ -443,36 +430,33 @@ func rOpen(r *Req, err error) { r.fid.File = f } -func sCreate(s *Server, r *Req) { +func sCreate(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TCreate) - var ok bool r.fid, ok = s.fPool.lookup(ifcall.Fid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } dir := r.fid.File dirstat, err := dir.Stat() if err != nil { - respond(r, fmt.Errorf("stat: %v", err)) + respond(ctx, r, fmt.Errorf("stat: %v", err)) return } if !dirstat.IsDir() { - respond(r, fmt.Errorf("create in non-dir")) + respond(ctx, r, fmt.Errorf("create in non-dir")) return } if !hasPerm(dir, r.fid.Uid, AWRITE) { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } - cfdir, ok := r.fid.File.(CreaterFile) if !ok { - respond(r, ErrOperation) + respond(ctx, r, ErrOperation) return } - perm := ifcall.Perm() dirperm := dirstat.Mode() if perm&fs.ModeDir == 0 { @@ -480,31 +464,27 @@ func sCreate(s *Server, r *Req) { } else { perm &= ^FileMode(0777) | (dirperm & FileMode(0777)) } - file, err := cfdir.Create(ifcall.Name(), r.fid.Uid, ifcall.Mode(), perm) if err != nil { - respond(r, fmt.Errorf("create: %v", err)) + respond(ctx, r, fmt.Errorf("create: %v", err)) return } if err := r.fid.File.Close(); err != nil { - respond(r, fmt.Errorf("close: %v", err)) + respond(ctx, r, fmt.Errorf("close: %v", err)) return } - r.fid.File = file r.fid.path = path.Join(r.fid.path, ifcall.Name()) - st, err := r.fid.File.Stat() if err != nil { - respond(r, fmt.Errorf("stat: %v", err)) + respond(ctx, r, fmt.Errorf("stat: %v", err)) return } - r.ofcall = &RCreate{ qid: st.Sys().(*Stat).Qid, iounit: s.mSize() - IOHDRSZ, } - respond(r, nil) + respond(ctx, r, nil) } func rCreate(r *Req, err error) { if err != nil { @@ -521,86 +501,100 @@ func rCreate(r *Req, err error) { r.fid.File = f } -func sRead(s *Server, r *Req) { +func sRead(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TRead) var ok bool r.fid, ok = s.fPool.lookup(ifcall.Fid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } if r.fid.OMode == -1 { - respond(r, fmt.Errorf("not open")) + respond(ctx, r, fmt.Errorf("not open")) return } - if r.fid.OMode != OREAD && r.fid.OMode != ORDWR && r.fid.OMode != OEXEC { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } - data := make([]byte, ifcall.Count()) - var n int - var err error + done := make(chan struct{}) + var ( + n int + err error + wg sync.WaitGroup + ) fi, err := r.fid.File.Stat() if err != nil { log.Printf("Stat: %v", err) - respond(r, fmt.Errorf("internal error")) + respond(ctx, r, fmt.Errorf("internal error")) return } - if fi.IsDir() { - children, err := getChildren(s.fs, r.fid.path) - if err != nil { - log.Printf("get children: %v", err) - } - - if ifcall.Offset() != 0 && ifcall.Offset() != r.fid.dirOffset { - respond(r, fmt.Errorf("invalid dir offset")) - return - } - if ifcall.Offset() == 0 { - r.fid.dirIndex = 0 - r.fid.dirOffset = 0 - } - k := r.fid.dirIndex - for ; k < len(children); k++ { - if children[k] == nil { - continue - } - fi, err := children[k].Stat() + wg.Add(1) + go func() { + go func() { + wg.Wait() + close(done) + }() + defer wg.Done() + if fi.IsDir() { + children, err := getChildren(s.fs, r.fid.path) if err != nil { - log.Printf("stat: %v", err) - continue + log.Printf("get children: %v", err) } - st := fi.Sys().(*Stat) - buf := st.marshal() - if n+len(buf) > len(data) { - break + + if ifcall.Offset() != 0 && ifcall.Offset() != r.fid.dirOffset { + respond(ctx, r, fmt.Errorf("invalid dir offset")) + return } - for i := 0; i < len(buf); i++ { - data[n+i] = buf[i] + if ifcall.Offset() == 0 { + r.fid.dirIndex = 0 + r.fid.dirOffset = 0 } - n += len(buf) - } - r.fid.dirOffset += uint64(n) - r.fid.dirIndex = k - } else { - if reader, ok := r.fid.File.(io.ReaderAt); ok { - n, err = reader.ReadAt(data, int64(ifcall.Offset())) + k := r.fid.dirIndex + for ; k < len(children); k++ { + if children[k] == nil { + continue + } + fi, err := children[k].Stat() + if err != nil { + log.Printf("stat: %v", err) + continue + } + st := fi.Sys().(*Stat) + buf := st.marshal() + if n+len(buf) > len(data) { + break + } + for i := 0; i < len(buf); i++ { + data[n+i] = buf[i] + } + n += len(buf) + } + r.fid.dirOffset += uint64(n) + r.fid.dirIndex = k } else { - n, err = r.fid.File.Read(data) - } - if err != io.EOF && err != nil { - log.Printf("sRead: %v\n", err) - respond(r, err) - return + if reader, ok := r.fid.File.(io.ReaderAt); ok { + n, err = reader.ReadAt(data, int64(ifcall.Offset())) + } else { + n, err = r.fid.File.Read(data) + } + if err != io.EOF && err != nil { + log.Printf("sRead: %v\n", err) + respond(ctx, r, err) + return + } } + }() + select { + case <-done: + case <-ctx.Done(): } r.ofcall = &RRead{ count: uint32(n), data: data[:n], } - respond(r, nil) + respond(ctx, r, nil) } func rRead(r *Req, err error) { if err != nil { @@ -608,12 +602,12 @@ func rRead(r *Req, err error) { } } -func sWrite(s *Server, r *Req) { +func sWrite(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TWrite) var ok bool r.fid, ok = s.fPool.lookup(ifcall.Fid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } // TODO: should I use exported function instead of directly @@ -622,37 +616,50 @@ func sWrite(s *Server, r *Req) { ifcall.count = s.mSize() - IOHDRSZ } if !hasPerm(r.fid.File, r.fid.Uid, AWRITE) { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } omode := r.fid.OMode & 3 if omode != OWRITE && omode != ORDWR { - respond(r, fmt.Errorf("write on fid with open mode 0x%x", r.fid.OMode)) + respond(ctx, r, fmt.Errorf("write on fid with open mode 0x%x", r.fid.OMode)) return } - ofcall := new(RWrite) - switch file := r.fid.File.(type) { - case io.WriterAt: - n, err := file.WriteAt(ifcall.data, int64(ifcall.offset)) - if err != nil { - respond(r, fmt.Errorf("write: %v", err)) - return - } - ofcall.count = uint32(n) - case io.Writer: - n, err := file.Write(ifcall.data) - if err != nil { - respond(r, fmt.Errorf("write: %v", err)) + var wg sync.WaitGroup + done := make(chan struct{}) + wg.Add(1) + go func() { + go func() { + wg.Wait() + close(done) + }() + switch file := r.fid.File.(type) { + case io.WriterAt: + n, err := file.WriteAt(ifcall.data, int64(ifcall.offset)) + if err != nil { + respond(ctx, r, fmt.Errorf("write: %v", err)) + return + } + ofcall.count = uint32(n) + case io.Writer: + n, err := file.Write(ifcall.data) + if err != nil { + respond(ctx, r, fmt.Errorf("write: %v", err)) + return + } + ofcall.count = uint32(n) + default: + respond(ctx, r, ErrOperation) return } - ofcall.count = uint32(n) - default: - respond(r, ErrOperation) - return + wg.Done() + }() + select { + case <-done: + case <-ctx.Done(): } r.ofcall = ofcall - respond(r, nil) + respond(ctx, r, nil) } func rWrite(r *Req, err error) { @@ -663,52 +670,51 @@ func rWrite(r *Req, err error) { // TODO: Increment Qid.Vers } -func sClunk(s *Server, r *Req) { +func sClunk(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TClunk) _, ok := s.fPool.lookup(ifcall.Fid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } s.fPool.delete(ifcall.Fid()) r.ofcall = &RClunk{} - respond(r, nil) + respond(ctx, r, nil) } func rClunk(r *Req, err error) {} -func sRemove(s *Server, r *Req) { +func sRemove(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TRemove) var ok bool r.fid, ok = s.fPool.lookup(ifcall.Fid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } defer s.fPool.delete(ifcall.Fid()) - parentPath := path.Dir(r.fid.path) parent, err := s.fs.OpenFile(parentPath, OREAD, 0) if err != nil { - respond(r, fmt.Errorf("open parent: %v", err)) + respond(ctx, r, fmt.Errorf("open parent: %v", err)) return } defer parent.Close() if !hasPerm(parent, r.fid.Uid, AWRITE) { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } rfile, ok := r.fid.File.(RemoverFile) if !ok { - respond(r, ErrOperation) + respond(ctx, r, ErrOperation) return } if err := rfile.Remove(); err != nil { - respond(r, fmt.Errorf("remove: %v", err)) + respond(ctx, r, fmt.Errorf("remove: %v", err)) return } r.ofcall = &RRemove{} - respond(r, nil) + respond(ctx, r, nil) } func rRemove(r *Req, err error) { if err != nil { @@ -716,24 +722,24 @@ func rRemove(r *Req, err error) { } } -func sStat(s *Server, r *Req) { +func sStat(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TStat) var ok bool r.fid, ok = s.fPool.lookup(ifcall.Fid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } fileInfo, err := r.fid.File.Stat() if err != nil { log.Printf("stat %v: %v", r.fid.File, err) - respond(r, fmt.Errorf("internal error")) + respond(ctx, r, fmt.Errorf("internal error")) return } r.ofcall = &RStat{ stat: fileInfo.Sys().(*Stat), } - respond(r, nil) + respond(ctx, r, nil) } func rStat(r *Req, err error) { @@ -742,110 +748,101 @@ func rStat(r *Req, err error) { } } -func sWStat(s *Server, r *Req) { +func sWStat(ctx context.Context, s *Server, r *Req) { ifcall := r.ifcall.(*TWStat) var ok bool r.fid, ok = s.fPool.lookup(ifcall.Fid()) if !ok { - respond(r, ErrUnknownFid) + respond(ctx, r, ErrUnknownFid) return } - wsfile, ok := r.fid.File.(WriterStatFile) if !ok { - respond(r, ErrOperation) + respond(ctx, r, ErrOperation) return } - wstat := ifcall.Stat() fi, err := r.fid.File.Stat() if err != nil { - respond(r, fmt.Errorf("stat: %v", err)) + respond(ctx, r, fmt.Errorf("stat: %v", err)) return } newStat := fi.Sys().(*Stat) - if wstat.Type != ^uint16(0) || wstat.Dev != ^uint32(0) || wstat.Qid.Type != QidType(^uint8(0)) || wstat.Qid.Vers != ^uint32(0) || wstat.Qid.Path != ^uint64(0) || wstat.Atime != ^uint32(0) || wstat.Uid != "" || wstat.Muid != "" { - respond(r, fmt.Errorf("operation not permitted")) + respond(ctx, r, fmt.Errorf("operation not permitted")) return } - if wstat.Name != "" { parentPath := path.Dir(r.fid.path) parent, err := s.fs.OpenFile(parentPath, OREAD, 0) if err != nil { - respond(r, fmt.Errorf("get parent: %v", err)) + respond(ctx, r, fmt.Errorf("get parent: %v", err)) return } if !hasPerm(parent, r.fid.Uid, AWRITE) { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } children, err := getChildren(s.fs, parentPath) if err != nil { - respond(r, fmt.Errorf("get children: %v", err)) + respond(ctx, r, fmt.Errorf("get children: %v", err)) return } for _, f := range children { s, err := f.Stat() if err != nil { - respond(r, fmt.Errorf("stat: %v", err)) + respond(ctx, r, fmt.Errorf("stat: %v", err)) return } if s.Name() == wstat.Name { - respond(r, fmt.Errorf("file already exists")) + respond(ctx, r, fmt.Errorf("file already exists")) return } } newStat.Name = wstat.Name } - if wstat.Length != ^int64(0) { if fi.IsDir() || !hasPerm(r.fid.File, r.fid.Uid, AWRITE) { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } newStat.Length = wstat.Length } - if wstat.Mode != FileMode(^uint32(0)) { // the owner of the file or the group leader of the file's group. if r.fid.Uid != newStat.Uid && r.fid.Uid != newStat.Gid { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } if wstat.Mode&fs.ModeDir != newStat.Mode&fs.ModeDir { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } newStat.Mode = wstat.Mode } - if wstat.Mtime != ^uint32(0) { // the owner of the file or the group leader of the file's group. if r.fid.Uid != newStat.Uid && r.fid.Uid != newStat.Gid { - respond(r, ErrPerm) + respond(ctx, r, ErrPerm) return } newStat.Mtime = wstat.Mtime } if wstat.Gid != "" { // TODO implement - respond(r, fmt.Errorf("not implemented")) + respond(ctx, r, fmt.Errorf("not implemented")) return } - err = wsfile.WStat(newStat) if err != nil { - respond(r, fmt.Errorf("wstat: %v", err)) + respond(ctx, r, fmt.Errorf("wstat: %v", err)) return } - r.ofcall = &RWStat{} - respond(r, nil) + respond(ctx, r, nil) } func rWStat(r *Req, err error) { @@ -854,47 +851,56 @@ func rWStat(r *Req, err error) { } } -func (s *Server) Serve() { +func (s *Server) Serve(ctx context.Context) { L: for { select { + case err := <-s.speakErrChan: + // TODO: handle error + log.Printf("speak: %v", err) + continue L case err := <-s.listenErrChan: - log.Printf("getReq: %v\n", err) + log.Printf("getReq: %v", err) if err == io.EOF { break L } continue L case r := <-s.listenChan: + ctx1, cancel := context.WithCancel(ctx) + r.cancel = cancel go func() { switch r.ifcall.(type) { default: - respond(r, fmt.Errorf("unknown message type: %d", r.ifcall.Type())) + respond(ctx, r, fmt.Errorf("unknown message type: %d", r.ifcall.Type())) case *TVersion: - sVersion(s, r) + sVersion(ctx1, s, r) case *TAuth: - sAuth(s, r) + sAuth(ctx1, s, r) case *TAttach: - sAttach(s, r) + sAttach(ctx1, s, r) case *TWalk: - sWalk(s, r) + sWalk(ctx1, s, r) case *TOpen: - sOpen(s, r) + sOpen(ctx1, s, r) case *TCreate: - sCreate(s, r) + sCreate(ctx1, s, r) case *TRead: - sRead(s, r) + sRead(ctx1, s, r) case *TWrite: - sWrite(s, r) + sWrite(ctx1, s, r) case *TClunk: - sClunk(s, r) + sClunk(ctx1, s, r) case *TRemove: - sRemove(s, r) + sRemove(ctx1, s, r) case *TStat: - sStat(s, r) + sStat(ctx1, s, r) case *TWStat: - sWStat(s, r) + sWStat(ctx1, s, r) } }() + case <-ctx.Done(): + log.Println(ctx.Err()) + break L } } } @@ -902,7 +908,7 @@ L: // Respond responds to the request r with the message r.ofcall if err is nil, // or if err is not nil, with the Rerror with the error message. // If r is nil, or both r.ofcall and err are nil it panics. -func respond(r *Req, err error) { +func respond(ctx context.Context, r *Req, err error) { switch r.ifcall.(type) { default: panic("bug") @@ -931,9 +937,7 @@ func respond(r *Req, err error) { case *TWStat: rWStat(r, err) } - r.ofcall.SetTag(r.tag) - // free tag. if r.pool == nil && err != ErrDupTag { panic("ReqPool is nil but err is not EDupTag") @@ -941,15 +945,12 @@ func respond(r *Req, err error) { if r.pool != nil { r.pool.delete(r.tag) } - - r.srv.speakChan <- r - // TODO: handle error gently - go func() { - err := <-r.srv.speakErrChan - log.Fatalf("speak: %v", err) - }() - - if r.srv.chatty9P { - fmt.Fprintf(os.Stderr, "--> %s\n", r.ofcall) + select { + case r.srv.speakChan <- r: + if r.srv.chatty9P { + fmt.Fprintf(os.Stderr, "--> %s\n", r.ofcall) + } + case <-ctx.Done(): + log.Printf("req flush: %v", r.ifcall) } }