client.go (11236B)
1 package lib9p 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "sync" 8 ) 9 10 type Client struct { 11 msize uint32 12 mSizeLock *sync.Mutex 13 uname string 14 fPool *clientFidPool 15 txc chan<- *clientReq 16 errc chan error 17 cancel context.CancelFunc 18 rootFid *clientFid 19 wg *sync.WaitGroup 20 } 21 22 func NewClient(mSize uint32, uname string, r io.Reader, w io.Writer) *Client { 23 ctx, cancel := context.WithCancel(context.Background()) 24 c := &Client{ 25 msize: mSize, 26 mSizeLock: new(sync.Mutex), 27 uname: uname, 28 fPool: allocClientFidPool(), 29 errc: make(chan error), 30 cancel: cancel, 31 wg: new(sync.WaitGroup), 32 } 33 tmsgc := c.runSpeaker(ctx, w) 34 rmsgc := c.runListener(ctx, r) 35 c.txc = c.runMultiplexer(ctx, tmsgc, rmsgc) 36 return c 37 } 38 39 func (c *Client) Stop() { 40 c.cancel() 41 c.wg.Wait() 42 close(c.errc) 43 } 44 45 func (c *Client) mSize() uint32 { 46 c.mSizeLock.Lock() 47 defer c.mSizeLock.Unlock() 48 return c.msize 49 } 50 51 func (c *Client) setMSize(mSize uint32) { 52 c.mSizeLock.Lock() 53 defer c.mSizeLock.Unlock() 54 c.msize = mSize 55 } 56 57 // RunListener runs listener goroutine. 58 // Listener reads byte array of 9P messages from r and make each of them into 59 // corresponding struct that implements Msg, and sends it to the returned channel. 60 // Listener goroutine returns when ctx is canceled. 61 // Listener goroutine reports errors to the client's errc channel. 62 func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan Msg { 63 c.wg.Add(1) 64 // TODO: terminate with ctx.Done() 65 rmsgc := make(chan Msg, 3) 66 go func() { 67 wg := new(sync.WaitGroup) 68 defer func() { 69 wg.Wait() 70 close(rmsgc) 71 c.wg.Done() 72 }() 73 for { 74 select { 75 case <-ctx.Done(): 76 // TODO: should return error via ec?? 77 // TODO: should close r? 78 return 79 default: 80 done := make(chan struct{}) 81 var ( 82 msg Msg 83 err error 84 ) 85 go func() { 86 defer close(done) 87 msg, err = recv(r) 88 }() 89 select { 90 case <-done: 91 case <-ctx.Done(): 92 } 93 if err != nil { 94 c.errc <- fmt.Errorf("recv: %v", err) 95 continue 96 } 97 wg.Add(1) 98 go func() { 99 defer wg.Done() 100 select { 101 case rmsgc <- msg: 102 case <-ctx.Done(): 103 } 104 }() 105 } 106 } 107 }() 108 return rmsgc 109 } 110 111 // RunSpeaker runs speaker goroutine. 112 // Speaker goroutine recieves 9P Msgs from the returned channel, marshal them 113 // into byte arrays and sends them to w. 114 // It reports any errors to the clients errc channel. 115 // It returnes when ctx is canceled. 116 func (c *Client) runSpeaker(ctx context.Context, w io.Writer) chan<- Msg { 117 c.wg.Add(1) 118 tmsgc := make(chan Msg, 3) 119 go func() { 120 defer c.wg.Done() 121 for { 122 select { 123 case <-ctx.Done(): 124 return 125 case msg := <-tmsgc: 126 if msg == nil { 127 // tmsgc is closed, which means ctx.Done() is also closed. 128 // but this code breaks semantics? 129 return 130 } 131 if err := send(msg, w); err != nil { 132 c.errc <- fmt.Errorf("send: %v", err) 133 } 134 } 135 } 136 }() 137 return tmsgc 138 } 139 140 // RunMultiplexer runs multiplexer goroutine. 141 // Multiplexer goroutines, one for recieving Rmsg and another for sending Tmsg. 142 // The goroutine for Tmsg recieves *clientReq from the returned channel, 143 // and send the 9P Msg to the speaker goroutine via tmsgc. 144 // The goroutine for Rmsg recieves *clientReq from the Tmsg goroutine and waits for 145 // the reply to the corresponding message from the listener goroutine via rmsgc. 146 // After recieving the reply, it sets the *clientReq.rmsg and sends it t the 147 // *clientReq.rxc. 148 // It reports any errors to the client's errc channel. 149 func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- Msg, rmsgc <-chan Msg) chan<- *clientReq { 150 c.wg.Add(2) 151 txc := make(chan *clientReq) 152 reqc := make(chan *clientReq) 153 // Rmsg 154 go func(reqc <-chan *clientReq) { 155 wg := new(sync.WaitGroup) 156 defer func() { 157 wg.Wait() 158 c.wg.Done() 159 }() 160 rPool := make(map[uint16]*clientReq) 161 for { 162 select { 163 case <-ctx.Done(): 164 return 165 case req := <-reqc: 166 if req == nil { 167 // ctx is canceled. 168 continue 169 } 170 if _, ok := rPool[req.tag]; ok { 171 c.errc <- fmt.Errorf("mux: duplicate tag: %d", req.tag) 172 continue 173 } 174 rPool[req.tag] = req // TODO: wait for req.ctxDone channel. 175 wg.Add(1) 176 go func() { 177 defer wg.Done() 178 <-req.ctxDone 179 }() 180 case msg := <-rmsgc: 181 if msg == nil { 182 // ctx is canceled. 183 continue 184 } 185 req, ok := rPool[msg.Tag()] 186 if !ok { 187 c.errc <- fmt.Errorf("mux: unknown tag for msg: %v", msg) 188 continue 189 } 190 delete(rPool, msg.Tag()) 191 req.rmsg = msg 192 go func() { 193 defer close(req.rxc) 194 select { 195 case <-req.ctxDone: 196 case req.rxc <- req: 197 } 198 }() 199 } 200 } 201 }(reqc) 202 // Tmsg 203 go func(reqc chan<- *clientReq) { 204 wg := new(sync.WaitGroup) 205 defer func() { 206 wg.Wait() 207 close(reqc) 208 close(tmsgc) 209 c.wg.Done() 210 }() 211 for { 212 select { 213 case <-ctx.Done(): 214 return 215 case req := <-txc: 216 select { 217 case reqc <- req: 218 case <-ctx.Done(): 219 return 220 } 221 wg.Add(1) 222 go func() { 223 defer wg.Done() 224 tmsgc <- req.tmsg 225 }() 226 } 227 } 228 }(reqc) 229 return txc 230 } 231 232 // Transact send 9P Msg of req to the multiplexer goroutines and recieves 233 // the reply. 234 func (c *Client) transact(ctx context.Context, tmsg Msg) (Msg, error) { 235 ctx1, cancel1 := context.WithCancel(ctx) 236 req := newClientReq(ctx1, tmsg) 237 select { 238 case <-ctx.Done(): 239 return nil, ctx.Err() 240 case c.txc <- req: 241 } 242 select { 243 case req = <-req.rxc: // TODO: this assignment is not required. 244 cancel1() 245 return req.rmsg, req.err 246 case <-ctx.Done(): 247 return nil, ctx.Err() 248 } 249 } 250 251 func (c *Client) Version(ctx context.Context, tag uint16, mSize uint32, version string) (uint32, string, error) { 252 tmsg := &TVersion{tag: tag, mSize: mSize, version: version} 253 rmsg, err := c.transact(ctx, tmsg) 254 if err != nil { 255 return 0, "", fmt.Errorf("transact: %v", err) 256 } 257 switch rmsg := rmsg.(type) { 258 case *RVersion: 259 return rmsg.mSize, rmsg.version, nil 260 case *RError: 261 return 0, "", rmsg.ename 262 default: 263 return 0, "", fmt.Errorf("invalid reply: %v", rmsg) 264 } 265 } 266 267 func (c *Client) Auth(ctx context.Context, tag uint16, afid uint32, uname, aname string) (Qid, error) { 268 tmsg := &TAuth{tag: tag, afid: afid, uname: uname} 269 rmsg, err := c.transact(ctx, tmsg) 270 if err != nil { 271 return Qid{}, fmt.Errorf("transact: %v", err) 272 } 273 switch rmsg := rmsg.(type) { 274 case *RAuth: 275 return rmsg.aqid, nil 276 case *RError: 277 return Qid{}, rmsg.ename 278 default: 279 return Qid{}, fmt.Errorf("invalid reply: %v", rmsg) 280 } 281 } 282 283 func (c *Client) Attach(ctx context.Context, tag uint16, fid, afid uint32, uname, aname string) (Qid, error) { 284 tmsg := &TAttach{tag: tag, fid: fid, afid: afid, uname: uname, aname: aname} 285 rmsg, err := c.transact(ctx, tmsg) 286 if err != nil { 287 return Qid{}, fmt.Errorf("transact: %v", err) 288 } 289 switch rmsg := rmsg.(type) { 290 case *RAttach: 291 return rmsg.qid, nil 292 case *RError: 293 return Qid{}, rmsg.ename 294 default: 295 return Qid{}, fmt.Errorf("invalid reply: %v", rmsg) 296 } 297 } 298 299 func (c *Client) Flush(ctx context.Context, tag, oldtag uint16) error { 300 tmsg := &TFlush{tag: tag, oldtag: oldtag} 301 rmsg, err := c.transact(ctx, tmsg) 302 if err != nil { 303 return fmt.Errorf("transact: %v", err) 304 } 305 switch rmsg := rmsg.(type) { 306 case *RFlush: 307 return nil 308 case *RError: 309 return rmsg.ename 310 default: 311 return fmt.Errorf("invalid reply: %v", rmsg) 312 } 313 } 314 func (c *Client) Walk(ctx context.Context, tag uint16, fid, newFid uint32, wname []string) (wqid []Qid, err error) { 315 tmsg := &TWalk{tag: tag, fid: fid, newFid: newFid, wname: wname} 316 rmsg, err := c.transact(ctx, tmsg) 317 if err != nil { 318 return nil, fmt.Errorf("transact: %v", err) 319 } 320 switch rmsg := rmsg.(type) { 321 case *RWalk: 322 return rmsg.qid, nil 323 case *RError: 324 return nil, rmsg.ename 325 default: 326 return nil, fmt.Errorf("invalid reply: %v", rmsg) 327 } 328 } 329 func (c *Client) Open(ctx context.Context, tag uint16, fid uint32, mode OpenMode) (qid Qid, iounit uint32, err error) { 330 tmsg := &TOpen{tag: tag, fid: fid, mode: mode} 331 rmsg, err := c.transact(ctx, tmsg) 332 if err != nil { 333 return Qid{}, 0, fmt.Errorf("transact: %v", err) 334 } 335 switch rmsg := rmsg.(type) { 336 case *ROpen: 337 return rmsg.qid, rmsg.iounit, nil 338 case *RError: 339 return Qid{}, 0, rmsg.ename 340 default: 341 return Qid{}, 0, fmt.Errorf("invalid reply: %v", rmsg) 342 } 343 } 344 func (c *Client) Create(ctx context.Context, tag uint16, fid uint32, name string, perm FileMode, mode OpenMode) (Qid, uint32, error) { 345 tmsg := &TCreate{tag: tag, fid: fid, name: name, perm: perm, mode: mode} 346 rmsg, err := c.transact(ctx, tmsg) 347 if err != nil { 348 return Qid{}, 0, fmt.Errorf("transact: %v", err) 349 } 350 switch rmsg := rmsg.(type) { 351 case *RCreate: 352 return rmsg.qid, rmsg.iounit, nil 353 case *RError: 354 return Qid{}, 0, rmsg.ename 355 default: 356 return Qid{}, 0, fmt.Errorf("invalid reply: %v", rmsg) 357 } 358 } 359 func (c *Client) Read(ctx context.Context, tag uint16, fid uint32, offset uint64, count uint32) (data []byte, err error) { 360 tmsg := &TRead{tag: tag, fid: fid, offset: offset, count: count} 361 rmsg, err := c.transact(ctx, tmsg) 362 if err != nil { 363 return nil, fmt.Errorf("transact: %v", err) 364 } 365 switch rmsg := rmsg.(type) { 366 case *RRead: 367 return rmsg.data, nil 368 case *RError: 369 return nil, rmsg.ename 370 default: 371 return nil, fmt.Errorf("invalid reply: %v", rmsg) 372 } 373 } 374 func (c *Client) Write(ctx context.Context, tag uint16, fid uint32, offset uint64, count uint32, data []byte) (uint32, error) { 375 tmsg := &TWrite{tag: tag, fid: fid, offset: offset, count: count, data: data} 376 rmsg, err := c.transact(ctx, tmsg) 377 if err != nil { 378 return 0, fmt.Errorf("transact: %v", err) 379 } 380 switch rmsg := rmsg.(type) { 381 case *RWrite: 382 return rmsg.count, nil 383 case *RError: 384 return 0, rmsg.ename 385 default: 386 return 0, fmt.Errorf("invalid reply: %v", rmsg) 387 } 388 } 389 func (c *Client) Clunk(ctx context.Context, tag uint16, fid uint32) error { 390 tmsg := &TClunk{tag: tag, fid: fid} 391 rmsg, err := c.transact(ctx, tmsg) 392 if err != nil { 393 return fmt.Errorf("transact: %v", err) 394 } 395 switch rmsg := rmsg.(type) { 396 case *RClunk: 397 return nil 398 case *RError: 399 return rmsg.ename 400 default: 401 return fmt.Errorf("invalid reply: %v", rmsg) 402 } 403 } 404 func (c *Client) Remove(ctx context.Context, tag uint16, fid uint32) error { 405 tmsg := &TRemove{tag: tag, fid: fid} 406 rmsg, err := c.transact(ctx, tmsg) 407 if err != nil { 408 return fmt.Errorf("transact: %v", err) 409 } 410 switch rmsg := rmsg.(type) { 411 case *RRemove: 412 return nil 413 case *RError: 414 return rmsg.ename 415 default: 416 return fmt.Errorf("invalid reply: %v", rmsg) 417 } 418 } 419 func (c *Client) Stat(ctx context.Context, tag uint16, fid uint32) (*Stat, error) { 420 tmsg := &TStat{tag: tag, fid: fid} 421 rmsg, err := c.transact(ctx, tmsg) 422 if err != nil { 423 return nil, fmt.Errorf("transact: %v", err) 424 } 425 switch rmsg := rmsg.(type) { 426 case *RStat: 427 return rmsg.stat, nil 428 case *RError: 429 return nil, rmsg.ename 430 default: 431 return nil, fmt.Errorf("invalid reply: %v", rmsg) 432 } 433 } 434 func (c *Client) Wstat(ctx context.Context, tag uint16, fid uint32, stat *Stat) error { 435 tmsg := &TWStat{tag: tag, fid: fid, stat: stat} 436 rmsg, err := c.transact(ctx, tmsg) 437 if err != nil { 438 return fmt.Errorf("transact: %v", err) 439 } 440 switch rmsg := rmsg.(type) { 441 case *RWStat: 442 return nil 443 case *RError: 444 return rmsg.ename 445 default: 446 return fmt.Errorf("invalid reply: %v", rmsg) 447 } 448 }