commit 12d5d20995251f28c26fb73f0b1042b88be1b652
parent fdd86832f13940ed0a179d5c83d13ec59fa54f94
Author: Matsuda Kenji <info@mtkn.jp>
Date: Sun, 17 Sep 2023 09:55:49 +0900
modify Req
Diffstat:
| M | fcall.go | | | 86 | ++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------- |
| M | req.go | | | 3 | +++ |
| M | server.go | | | 64 | +++++++++++++++++++++++++++++++++++++++++----------------------- |
3 files changed, 101 insertions(+), 52 deletions(-)
diff --git a/fcall.go b/fcall.go
@@ -47,6 +47,7 @@ type Msg interface {
Type() MsgType
// Tag returns the tag of message.
Tag() uint16
+ SetTag(uint16)
// Marshal convert Msg to byte array to be transmitted.
marshal() []byte
@@ -59,6 +60,7 @@ type bufMsg []byte
func (msg bufMsg) Size() uint32 { return gbit32(msg[0:4]) }
func (msg bufMsg) Type() MsgType { return MsgType(msg[4]) }
func (msg bufMsg) Tag() uint16 { return gbit16(msg[5:7]) }
+func (msg bufMsg) SetTag(t uint16) { pbit16(msg[5:7], t) }
func (msg bufMsg) marshal() []byte { return []byte(msg)[:msg.Size()] }
func (msg bufMsg) String() string { return fmt.Sprint([]byte(msg[:msg.Size()])) }
@@ -82,6 +84,7 @@ func newTVersion(buf []byte) *TVersion {
func (msg *TVersion) Size() uint32 { return msg.size }
func (msg *TVersion) Type() MsgType { return Tversion }
func (msg *TVersion) Tag() uint16 { return msg.tag }
+func (msg *TVersion) SetTag(t uint16) { msg.tag = t }
func (msg *TVersion) MSize() uint32 { return msg.mSize }
func (msg *TVersion) Version() string { return msg.version }
func (msg *TVersion) marshal() []byte {
@@ -182,13 +185,14 @@ func newTAuth(buf []byte) *TAuth {
}
return msg
}
-func (msg TAuth) Size() uint32 { return msg.size }
-func (msg TAuth) Type() MsgType { return Tauth }
-func (msg TAuth) Tag() uint16 { return msg.tag }
-func (msg TAuth) AFid() uint32 { return msg.afid }
-func (msg TAuth) UName() string { return msg.uname }
-func (msg TAuth) AName() string { return msg.aname }
-func (msg TAuth) marshal() []byte {
+func (msg *TAuth) Size() uint32 { return msg.size }
+func (msg *TAuth) Type() MsgType { return Tauth }
+func (msg *TAuth) Tag() uint16 { return msg.tag }
+func (msg *TAuth) SetTag(t uint16) { msg.tag = t }
+func (msg *TAuth) AFid() uint32 { return msg.afid }
+func (msg *TAuth) UName() string { return msg.uname }
+func (msg *TAuth) AName() string { return msg.aname }
+func (msg *TAuth) marshal() []byte {
cur := 0
buf := make([]byte, msg.Size())
pbit32(buf[cur:cur+4], msg.Size())
@@ -218,7 +222,7 @@ func (msg TAuth) marshal() []byte {
}
return buf
}
-func (msg TAuth) String() string {
+func (msg *TAuth) String() string {
afid := int64(msg.AFid())
if afid == int64(NOFID) {
afid = -1
@@ -238,11 +242,12 @@ func newRAuth(buf []byte) *RAuth {
msg.aqid = unmarshalQid(buf[7:])
return msg
}
-func (msg RAuth) Size() uint32 { return 4 + 1 + 2 + 13 }
-func (msg RAuth) Type() MsgType { return Rauth }
-func (msg RAuth) Tag() uint16 { return msg.tag }
-func (msg RAuth) AQid() Qid { return msg.aqid }
-func (msg RAuth) marshal() []byte {
+func (msg *RAuth) Size() uint32 { return 4 + 1 + 2 + 13 }
+func (msg *RAuth) Type() MsgType { return Rauth }
+func (msg *RAuth) Tag() uint16 { return msg.tag }
+func (msg *RAuth) SetTag(t uint16) { msg.tag = t }
+func (msg *RAuth) AQid() Qid { return msg.aqid }
+func (msg *RAuth) marshal() []byte {
cur := 0
buf := make([]byte, msg.Size())
pbit32(buf[cur:cur+4], msg.Size())
@@ -261,7 +266,7 @@ func (msg RAuth) marshal() []byte {
}
return buf
}
-func (msg RAuth) String() string {
+func (msg *RAuth) String() string {
return fmt.Sprintf("Tauth tag %d aqid %v",
msg.Tag(), msg.AQid())
}
@@ -300,14 +305,15 @@ func newTAttach(buf []byte) *TAttach {
}
return msg
}
-func (msg TAttach) Size() uint32 { return msg.size }
-func (msg TAttach) Type() MsgType { return Tattach }
-func (msg TAttach) Tag() uint16 { return msg.tag }
-func (msg TAttach) Fid() uint32 { return msg.fid }
-func (msg TAttach) AFid() uint32 { return msg.afid }
-func (msg TAttach) UName() string { return msg.uname }
-func (msg TAttach) AName() string { return msg.aname }
-func (msg TAttach) marshal() []byte {
+func (msg *TAttach) Size() uint32 { return msg.size }
+func (msg *TAttach) Type() MsgType { return Tattach }
+func (msg *TAttach) Tag() uint16 { return msg.tag }
+func (msg *TAttach) SetTag(t uint16) { msg.tag = t }
+func (msg *TAttach) Fid() uint32 { return msg.fid }
+func (msg *TAttach) AFid() uint32 { return msg.afid }
+func (msg *TAttach) UName() string { return msg.uname }
+func (msg *TAttach) AName() string { return msg.aname }
+func (msg *TAttach) marshal() []byte {
cur := 0
buf := make([]byte, msg.Size())
pbit32(buf[cur:cur+4], msg.Size())
@@ -339,7 +345,7 @@ func (msg TAttach) marshal() []byte {
}
return buf
}
-func (msg TAttach) String() string {
+func (msg *TAttach) String() string {
fid := int64(msg.Fid())
if fid == int64(NOFID) {
fid = -1
@@ -363,11 +369,12 @@ func newRAttach(buf []byte) *RAttach {
msg.qid = unmarshalQid(buf[7:])
return msg
}
-func (msg RAttach) Size() uint32 { return 4 + 1 + 2 + 13 }
-func (msg RAttach) Type() MsgType { return Rattach }
-func (msg RAttach) Tag() uint16 { return msg.tag }
-func (msg RAttach) Qid() Qid { return msg.qid }
-func (msg RAttach) marshal() []byte {
+func (msg *RAttach) Size() uint32 { return 4 + 1 + 2 + 13 }
+func (msg *RAttach) Type() MsgType { return Rattach }
+func (msg *RAttach) Tag() uint16 { return msg.tag }
+func (msg *RAttach) SetTag(t uint16) { msg.tag = t }
+func (msg *RAttach) Qid() Qid { return msg.qid }
+func (msg *RAttach) marshal() []byte {
cur := 0
buf := make([]byte, msg.Size())
pbit32(buf[cur:cur+4], msg.Size())
@@ -386,7 +393,7 @@ func (msg RAttach) marshal() []byte {
}
return buf
}
-func (msg RAttach) String() string {
+func (msg *RAttach) String() string {
return fmt.Sprintf("Rattach tag %d qid %v",
msg.Tag(), msg.Qid())
}
@@ -406,6 +413,7 @@ func newRError(buf []byte) *RError {
func (msg *RError) Size() uint32 { return uint32(4 + 1 + 2 + 2 + len(msg.EName())) }
func (msg *RError) Type() MsgType { return Rerror }
func (msg *RError) Tag() uint16 { return msg.tag }
+func (msg *RError) SetTag(t uint16) { msg.tag = t }
func (msg *RError) EName() string { return msg.ename.Error() }
func (msg *RError) marshal() []byte {
cur := 0
@@ -448,6 +456,7 @@ func newTFlush(buf []byte) *TFlush {
func (msg *TFlush) Size() uint32 { return 9 }
func (msg *TFlush) Type() MsgType { return Tflush }
func (msg *TFlush) Tag() uint16 { return msg.tag }
+func (msg *TFlush) SetTag(t uint16) { msg.tag = t }
func (msg *TFlush) OldTag() uint16 { return msg.oldtag }
func (msg *TFlush) marshal() []byte {
buf := make([]byte, msg.Size())
@@ -474,6 +483,7 @@ func newRFlush(buf []byte) *RFlush {
func (msg *RFlush) Size() uint32 { return 7 }
func (msg *RFlush) Type() MsgType { return Rflush }
func (msg *RFlush) Tag() uint16 { return msg.tag }
+func (msg *RFlush) SetTag(t uint16) { msg.tag = t }
func (msg *RFlush) marshal() []byte {
buf := make([]byte, msg.Size())
pbit32(buf[0:4], msg.Size())
@@ -523,6 +533,7 @@ func newTWalk(buf []byte) *TWalk {
func (msg *TWalk) Size() uint32 { return msg.size }
func (msg *TWalk) Type() MsgType { return Twalk }
func (msg *TWalk) Tag() uint16 { return msg.tag }
+func (msg *TWalk) SetTag(t uint16) { msg.tag = t }
func (msg *TWalk) Fid() uint32 { return msg.fid }
func (msg *TWalk) NewFid() uint32 { return msg.newFid }
func (msg *TWalk) NWName() uint16 { return uint16(len(msg.wname)) }
@@ -588,6 +599,7 @@ func (msg *RWalk) Size() uint32 {
}
func (msg *RWalk) Type() MsgType { return Rwalk }
func (msg *RWalk) Tag() uint16 { return msg.tag }
+func (msg *RWalk) SetTag(t uint16) { msg.tag = t }
func (msg *RWalk) Qid() []Qid { return msg.qid }
func (msg *RWalk) marshal() []byte {
cur := 0
@@ -646,6 +658,7 @@ func newTOpen(buf []byte) *TOpen {
func (msg *TOpen) Size() uint32 { return msg.size }
func (msg *TOpen) Type() MsgType { return Topen }
func (msg *TOpen) Tag() uint16 { return msg.tag }
+func (msg *TOpen) SetTag(t uint16) { msg.tag = t }
func (msg *TOpen) Fid() uint32 { return msg.fid }
func (msg *TOpen) Mode() OpenMode { return msg.mode }
func (msg *TOpen) marshal() []byte {
@@ -690,6 +703,7 @@ func (msg *ROpen) Size() uint32 {
}
func (msg *ROpen) Type() MsgType { return Ropen }
func (msg *ROpen) Tag() uint16 { return msg.tag }
+func (msg *ROpen) SetTag(t uint16) { msg.tag = t }
func (msg *ROpen) Qid() Qid { return msg.qid }
func (msg *ROpen) IoUnit() uint32 { return msg.iounit }
func (msg *ROpen) marshal() []byte {
@@ -752,6 +766,7 @@ func (msg *TCreate) Size() uint32 {
}
func (msg *TCreate) Type() MsgType { return Tcreate }
func (msg *TCreate) Tag() uint16 { return msg.tag }
+func (msg *TCreate) SetTag(t uint16) { msg.tag = t }
func (msg *TCreate) Fid() uint32 { return msg.fid }
func (msg *TCreate) Name() string { return msg.name }
func (msg *TCreate) Perm() FileMode { return msg.perm }
@@ -808,6 +823,7 @@ func (msg *RCreate) Size() uint32 {
}
func (msg *RCreate) Type() MsgType { return Rcreate }
func (msg *RCreate) Tag() uint16 { return msg.tag }
+func (msg *RCreate) SetTag(t uint16) { msg.tag = t }
func (msg *RCreate) Qid() Qid { return msg.qid }
func (msg *RCreate) IoUnit() uint32 { return msg.iounit }
func (msg *RCreate) marshal() []byte {
@@ -866,6 +882,7 @@ func newTRead(buf []byte) *TRead {
func (msg *TRead) Size() uint32 { return msg.size }
func (msg *TRead) Type() MsgType { return Tread }
func (msg *TRead) Tag() uint16 { return msg.tag }
+func (msg *TRead) SetTag(t uint16) { msg.tag = t }
func (msg *TRead) Fid() uint32 { return msg.fid }
func (msg *TRead) Offset() uint64 { return msg.offset }
func (msg *TRead) Count() uint32 { return msg.count }
@@ -915,6 +932,7 @@ func (msg *RRead) Size() uint32 {
}
func (msg *RRead) Type() MsgType { return Rread }
func (msg *RRead) Tag() uint16 { return msg.tag }
+func (msg *RRead) SetTag(t uint16) { msg.tag = t }
func (msg *RRead) Count() uint32 { return msg.count }
func (msg *RRead) Data() []byte { return msg.data }
func (msg *RRead) marshal() []byte {
@@ -995,6 +1013,7 @@ func (msg *TWrite) Size() uint32 {
}
func (msg *TWrite) Type() MsgType { return Twrite }
func (msg *TWrite) Tag() uint16 { return msg.tag }
+func (msg *TWrite) SetTag(t uint16) { msg.tag = t }
func (msg *TWrite) Fid() uint32 { return msg.fid }
func (msg *TWrite) Offset() uint64 { return msg.offset }
func (msg *TWrite) Count() uint32 { return msg.count }
@@ -1044,6 +1063,7 @@ func (msg *RWrite) Size() uint32 {
}
func (msg *RWrite) Type() MsgType { return Rwrite }
func (msg *RWrite) Tag() uint16 { return msg.tag }
+func (msg *RWrite) SetTag(t uint16) { msg.tag = t }
func (msg *RWrite) Count() uint32 { return msg.count }
func (msg *RWrite) marshal() []byte {
cur := 0
@@ -1083,6 +1103,7 @@ func newTClunk(buf []byte) *TClunk {
func (msg *TClunk) Size() uint32 { return msg.size }
func (msg *TClunk) Type() MsgType { return Tclunk }
func (msg *TClunk) Tag() uint16 { return msg.tag }
+func (msg *TClunk) SetTag(t uint16) { msg.tag = t }
func (msg *TClunk) Fid() uint32 { return msg.fid }
func (msg *TClunk) marshal() []byte {
m := make([]byte, msg.Size())
@@ -1109,6 +1130,7 @@ func newRClunk(buf []byte) *RClunk {
func (msg *RClunk) Size() uint32 { return 7 }
func (msg *RClunk) Type() MsgType { return Rclunk }
func (msg *RClunk) Tag() uint16 { return msg.tag }
+func (msg *RClunk) SetTag(t uint16) { msg.tag = t }
func (msg *RClunk) marshal() []byte {
m := make([]byte, msg.Size())
pbit32(m[0:4], msg.Size())
@@ -1134,6 +1156,7 @@ func newTRemove(buf []byte) *TRemove {
func (msg *TRemove) Size() uint32 { return 4 + 1 + 2 + 4 }
func (msg *TRemove) Type() MsgType { return Tremove }
func (msg *TRemove) Tag() uint16 { return msg.tag }
+func (msg *TRemove) SetTag(t uint16) { msg.tag = t }
func (msg *TRemove) Fid() uint32 { return msg.fid }
func (msg *TRemove) marshal() []byte {
m := make([]byte, msg.Size())
@@ -1159,6 +1182,7 @@ func newRRemove(buf []byte) *RRemove {
func (msg *RRemove) Size() uint32 { return 4 + 1 + 2 }
func (msg *RRemove) Type() MsgType { return Rremove }
func (msg *RRemove) Tag() uint16 { return msg.tag }
+func (msg *RRemove) SetTag(t uint16) { msg.tag = t }
func (msg *RRemove) marshal() []byte {
m := make([]byte, msg.Size())
pbit32(m[0:4], msg.Size())
@@ -1186,6 +1210,7 @@ func newTStat(buf []byte) *TStat {
func (msg *TStat) Size() uint32 { return msg.size }
func (msg *TStat) Type() MsgType { return Tstat }
func (msg *TStat) Tag() uint16 { return msg.tag }
+func (msg *TStat) SetTag(t uint16) { msg.tag = t }
func (msg *TStat) Fid() uint32 { return msg.fid }
func (msg *TStat) marshal() []byte {
m := make([]byte, msg.Size())
@@ -1216,6 +1241,7 @@ func (msg *RStat) Size() uint32 {
}
func (msg *RStat) Type() MsgType { return Rstat }
func (msg *RStat) Tag() uint16 { return msg.tag }
+func (msg *RStat) SetTag(t uint16) { msg.tag = t }
func (msg *RStat) marshal() []byte {
buf := make([]byte, msg.Size())
pbit32(buf[0:4], msg.Size())
@@ -1252,6 +1278,7 @@ func (msg *TWStat) Size() uint32 {
}
func (msg *TWStat) Type() MsgType { return Twstat }
func (msg *TWStat) Tag() uint16 { return msg.tag }
+func (msg *TWStat) SetTag(t uint16) { msg.tag = t }
func (msg *TWStat) Fid() uint32 { return msg.fid }
func (msg *TWStat) Stat() *Stat { return msg.stat }
func (msg *TWStat) marshal() []byte {
@@ -1286,6 +1313,7 @@ func newRWStat(buf []byte) *RWStat {
func (msg *RWStat) Size() uint32 { return 4 + 1 + 2 }
func (msg *RWStat) Type() MsgType { return Rwstat }
func (msg *RWStat) Tag() uint16 { return msg.tag }
+func (msg *RWStat) SetTag(t uint16) { msg.tag = t }
func (msg *RWStat) marshal() []byte {
m := make([]byte, msg.Size())
pbit32(m[0:4], msg.Size())
diff --git a/req.go b/req.go
@@ -5,9 +5,12 @@ import (
)
type Req struct {
+ tag uint16
srv *Server
ifcall Msg
ofcall Msg
+ fid *Fid
+ afid *Fid
pool *ReqPool
}
diff --git a/server.go b/server.go
@@ -28,7 +28,6 @@ var (
func setError(r *Req, err error) {
r.ofcall = &RError{
- tag: r.ifcall.Tag(),
ename: err,
}
}
@@ -42,6 +41,8 @@ type Server struct {
rlock *sync.Mutex
writer io.Writer
wlock *sync.Mutex
+
+ auth func(*Req, *Fid)
}
func NewServer(fsys FS, mSize uint32, r io.Reader, w io.Writer) *Server {
@@ -97,6 +98,7 @@ func (s *Server) getReq() (*Req, error) {
}
r.srv = s
+ r.tag = bufMsg(buf).Tag()
r.ifcall, err = unmarshal(buf)
if err != nil {
log.Printf("unmarshal: %v", err)
@@ -129,7 +131,6 @@ func sVersion(s *Server, r *Req) {
}
r.ofcall = &RVersion{
- tag: ifcall.Tag(),
mSize: msize,
version: version,
}
@@ -145,10 +146,25 @@ func rVersion(r *Req, err error) {
}
func sAuth(s *Server, r *Req) {
- respond(r, fmt.Errorf("authentication not implemented"))
+ ifcall := r.ifcall.(*TAuth)
+ afid, err := s.fPool.alloc(ifcall.AFid())
+ if err != nil {
+ respond(r, ErrDupFid)
+ }
+
+ if s.auth != nil {
+ s.auth(r, afid)
+ } else {
+ respond(r, fmt.Errorf("authentication not required"))
+ return
+ }
}
-func rAuth(r *Req, err error) {}
+func rAuth(r *Req, err error) {
+ if err != nil {
+ r.srv.fPool.delete(r.ifcall.(*TAuth).AFid())
+ }
+}
func sAttach(s *Server, r *Req) {
ifcall := r.ifcall.(*TAttach)
@@ -158,7 +174,16 @@ func sAttach(s *Server, r *Req) {
return
}
- f, err := s.fs.Open(".") // TODO: use aname? // TODO: open mode?
+ // TODO: implement afid
+ if ifcall.AFid() != NOFID {
+ _, ok := s.fPool.lookup(ifcall.AFid())
+ if !ok {
+ respond(r, ErrUnknownFid)
+ return
+ }
+ }
+
+ f, err := s.fs.Open(".") // TODO: open mode?
if err != nil {
log.Printf("open fs: %v", err)
respond(r, fmt.Errorf("unable to open file tree"))
@@ -172,14 +197,12 @@ func sAttach(s *Server, r *Req) {
info, err := fid.File.Stat()
if err != nil {
log.Printf("Stat %s, %v", fid.File, err)
- s.fPool.delete(ifcall.Fid())
respond(r, fmt.Errorf("internal error"))
return
}
fid.Qid = info.Qid()
r.ofcall = &RAttach{
- tag: ifcall.Tag(),
qid: fid.Qid,
}
respond(r, nil)
@@ -250,7 +273,6 @@ func sWalk(s *Server, r *Req) {
}
r.ofcall = &RWalk{
- tag: r.ifcall.Tag(),
qid: wqids[:n],
}
respond(r, nil)
@@ -313,7 +335,7 @@ func sOpen(s *Server, r *Req) {
}
if !hasPerm(fid.File, fid.Uid, p) {
- respond(r, fmt.Errorf("permission denied"))
+ respond(r, ErrPerm)
return
}
@@ -329,13 +351,16 @@ func sOpen(s *Server, r *Req) {
fid.OMode = ifcall.Mode()
r.ofcall = &ROpen{
- tag: ifcall.Tag(),
qid: fid.Qid,
iounit: s.mSize - 23,
}
respond(r, nil)
}
-func rOpen(r *Req, err error) {}
+func rOpen(r *Req, err error) {
+ if err != nil {
+ return
+ }
+}
func sCreate(s *Server, r *Req) {
ifcall := r.ifcall.(*TCreate)
@@ -393,7 +418,6 @@ func sCreate(s *Server, r *Req) {
fid.OMode = ifcall.Mode()
r.ofcall = &RCreate{
- tag: ifcall.Tag(),
qid: fi.Qid(),
iounit: s.mSize - 23,
}
@@ -467,7 +491,6 @@ func sRead(s *Server, r *Req) {
}
}
ofcall := &RRead{
- tag: ifcall.Tag(),
count: uint32(n),
data: data[:n],
}
@@ -491,7 +514,6 @@ func sWrite(s *Server, r *Req) {
}
ofcall := new(RWrite)
- ofcall.tag = ifcall.Tag()
switch file := file.(type) {
case io.WriterAt:
n, err := file.WriteAt(ifcall.data, int64(ifcall.offset))
@@ -525,9 +547,7 @@ func sClunk(s *Server, r *Req) {
return
}
s.fPool.delete(ifcall.Fid())
- r.ofcall = &RClunk{
- tag: ifcall.Tag(),
- }
+ r.ofcall = &RClunk{}
respond(r, nil)
}
@@ -565,7 +585,7 @@ func sRemove(s *Server, r *Req) {
return
}
- r.ofcall = &RRemove{tag: ifcall.Tag()}
+ r.ofcall = &RRemove{}
respond(r, nil)
}
func rRemove(r *Req, err error) {}
@@ -585,7 +605,6 @@ func sStat(s *Server, r *Req) {
}
r.ofcall = &RStat{
- tag: ifcall.Tag(),
stat: fileInfo.Sys().(*Stat),
}
respond(r, nil)
@@ -693,9 +712,7 @@ func sWStat(s *Server, r *Req) {
return
}
- r.ofcall = &RWStat{
- tag: ifcall.Tag(),
- }
+ r.ofcall = &RWStat{}
respond(r, nil)
}
@@ -796,11 +813,12 @@ func respond(r *Req, err error) {
}
// free tag.
+ r.ofcall.SetTag(r.tag)
if r.pool == nil && err != ErrDupTag {
panic("ReqPool is nil but err is not EDupTag")
}
if r.pool != nil {
- r.pool.delete(r.ifcall.Tag())
+ r.pool.delete(r.tag)
}
r.srv.wlock.Lock()