lib9p

Go 9P library.
Log | Files | Refs | LICENSE

client.go (11236B)


      1 package lib9p
      2 
      3 import (
      4 	"context"
      5 	"fmt"
      6 	"io"
      7 	"sync"
      8 )
      9 
     10 type Client struct {
     11 	msize     uint32
     12 	mSizeLock *sync.Mutex
     13 	uname     string
     14 	fPool     *clientFidPool
     15 	txc       chan<- *clientReq
     16 	errc      chan error
     17 	cancel    context.CancelFunc
     18 	rootFid   *clientFid
     19 	wg        *sync.WaitGroup
     20 }
     21 
     22 func NewClient(mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
     23 	ctx, cancel := context.WithCancel(context.Background())
     24 	c := &Client{
     25 		msize:     mSize,
     26 		mSizeLock: new(sync.Mutex),
     27 		uname:     uname,
     28 		fPool:     allocClientFidPool(),
     29 		errc:      make(chan error),
     30 		cancel:    cancel,
     31 		wg:        new(sync.WaitGroup),
     32 	}
     33 	tmsgc := c.runSpeaker(ctx, w)
     34 	rmsgc := c.runListener(ctx, r)
     35 	c.txc = c.runMultiplexer(ctx, tmsgc, rmsgc)
     36 	return c
     37 }
     38 
     39 func (c *Client) Stop() {
     40 	c.cancel()
     41 	c.wg.Wait()
     42 	close(c.errc)
     43 }
     44 
     45 func (c *Client) mSize() uint32 {
     46 	c.mSizeLock.Lock()
     47 	defer c.mSizeLock.Unlock()
     48 	return c.msize
     49 }
     50 
     51 func (c *Client) setMSize(mSize uint32) {
     52 	c.mSizeLock.Lock()
     53 	defer c.mSizeLock.Unlock()
     54 	c.msize = mSize
     55 }
     56 
     57 // RunListener runs listener goroutine.
     58 // Listener reads byte array of 9P messages from r and make each of them into
     59 // corresponding struct that implements Msg, and sends it to the returned channel.
     60 // Listener goroutine returns when ctx is canceled.
     61 // Listener goroutine reports errors to the client's errc channel.
     62 func (c *Client) runListener(ctx context.Context, r io.Reader) <-chan Msg {
     63 	c.wg.Add(1)
     64 	// TODO: terminate with ctx.Done()
     65 	rmsgc := make(chan Msg, 3)
     66 	go func() {
     67 		wg := new(sync.WaitGroup)
     68 		defer func() {
     69 			wg.Wait()
     70 			close(rmsgc)
     71 			c.wg.Done()
     72 		}()
     73 		for {
     74 			select {
     75 			case <-ctx.Done():
     76 				// TODO: should return error via ec??
     77 				// TODO: should close r?
     78 				return
     79 			default:
     80 				done := make(chan struct{})
     81 				var (
     82 					msg Msg
     83 					err error
     84 				)
     85 				go func() {
     86 					defer close(done)
     87 					msg, err = recv(r)
     88 				}()
     89 				select {
     90 				case <-done:
     91 				case <-ctx.Done():
     92 				}
     93 				if err != nil {
     94 					c.errc <- fmt.Errorf("recv: %v", err)
     95 					continue
     96 				}
     97 				wg.Add(1)
     98 				go func() {
     99 					defer wg.Done()
    100 					select {
    101 					case rmsgc <- msg:
    102 					case <-ctx.Done():
    103 					}
    104 				}()
    105 			}
    106 		}
    107 	}()
    108 	return rmsgc
    109 }
    110 
    111 // RunSpeaker runs speaker goroutine.
    112 // Speaker goroutine recieves 9P Msgs from the returned channel, marshal them
    113 // into byte arrays and sends them to w.
    114 // It reports any errors to the clients errc channel.
    115 // It returnes when ctx is canceled.
    116 func (c *Client) runSpeaker(ctx context.Context, w io.Writer) chan<- Msg {
    117 	c.wg.Add(1)
    118 	tmsgc := make(chan Msg, 3)
    119 	go func() {
    120 		defer c.wg.Done()
    121 		for {
    122 			select {
    123 			case <-ctx.Done():
    124 				return
    125 			case msg := <-tmsgc:
    126 				if msg == nil {
    127 					// tmsgc is closed, which means ctx.Done() is also closed.
    128 					// but this code breaks semantics?
    129 					return
    130 				}
    131 				if err := send(msg, w); err != nil {
    132 					c.errc <- fmt.Errorf("send: %v", err)
    133 				}
    134 			}
    135 		}
    136 	}()
    137 	return tmsgc
    138 }
    139 
    140 // RunMultiplexer runs multiplexer goroutine.
    141 // Multiplexer goroutines, one for recieving Rmsg and another for sending Tmsg.
    142 // The goroutine for Tmsg recieves *clientReq from the returned channel,
    143 // and send the 9P Msg to the speaker goroutine via tmsgc.
    144 // The goroutine for Rmsg recieves *clientReq from the Tmsg goroutine and waits for
    145 // the reply to the corresponding message from the listener goroutine via rmsgc.
    146 // After recieving the reply, it sets the *clientReq.rmsg and sends it t the
    147 // *clientReq.rxc.
    148 // It reports any errors to the client's errc channel.
    149 func (c *Client) runMultiplexer(ctx context.Context, tmsgc chan<- Msg, rmsgc <-chan Msg) chan<- *clientReq {
    150 	c.wg.Add(2)
    151 	txc := make(chan *clientReq)
    152 	reqc := make(chan *clientReq)
    153 	// Rmsg
    154 	go func(reqc <-chan *clientReq) {
    155 		wg := new(sync.WaitGroup)
    156 		defer func() {
    157 			wg.Wait()
    158 			c.wg.Done()
    159 		}()
    160 		rPool := make(map[uint16]*clientReq)
    161 		for {
    162 			select {
    163 			case <-ctx.Done():
    164 				return
    165 			case req := <-reqc:
    166 				if req == nil {
    167 					// ctx is canceled.
    168 					continue
    169 				}
    170 				if _, ok := rPool[req.tag]; ok {
    171 					c.errc <- fmt.Errorf("mux: duplicate tag: %d", req.tag)
    172 					continue
    173 				}
    174 				rPool[req.tag] = req // TODO: wait for req.ctxDone channel.
    175 				wg.Add(1)
    176 				go func() {
    177 					defer wg.Done()
    178 					<-req.ctxDone
    179 				}()
    180 			case msg := <-rmsgc:
    181 				if msg == nil {
    182 					// ctx is canceled.
    183 					continue
    184 				}
    185 				req, ok := rPool[msg.Tag()]
    186 				if !ok {
    187 					c.errc <- fmt.Errorf("mux: unknown tag for msg: %v", msg)
    188 					continue
    189 				}
    190 				delete(rPool, msg.Tag())
    191 				req.rmsg = msg
    192 				go func() {
    193 					defer close(req.rxc)
    194 					select {
    195 					case <-req.ctxDone:
    196 					case req.rxc <- req:
    197 					}
    198 				}()
    199 			}
    200 		}
    201 	}(reqc)
    202 	// Tmsg
    203 	go func(reqc chan<- *clientReq) {
    204 		wg := new(sync.WaitGroup)
    205 		defer func() {
    206 			wg.Wait()
    207 			close(reqc)
    208 			close(tmsgc)
    209 			c.wg.Done()
    210 		}()
    211 		for {
    212 			select {
    213 			case <-ctx.Done():
    214 				return
    215 			case req := <-txc:
    216 				select {
    217 				case reqc <- req:
    218 				case <-ctx.Done():
    219 					return
    220 				}
    221 				wg.Add(1)
    222 				go func() {
    223 					defer wg.Done()
    224 					tmsgc <- req.tmsg
    225 				}()
    226 			}
    227 		}
    228 	}(reqc)
    229 	return txc
    230 }
    231 
    232 // Transact send 9P Msg of req to the multiplexer goroutines and recieves
    233 // the reply.
    234 func (c *Client) transact(ctx context.Context, tmsg Msg) (Msg, error) {
    235 	ctx1, cancel1 := context.WithCancel(ctx)
    236 	req := newClientReq(ctx1, tmsg)
    237 	select {
    238 	case <-ctx.Done():
    239 		return nil, ctx.Err()
    240 	case c.txc <- req:
    241 	}
    242 	select {
    243 	case req = <-req.rxc: // TODO: this assignment is not required.
    244 		cancel1()
    245 		return req.rmsg, req.err
    246 	case <-ctx.Done():
    247 		return nil, ctx.Err()
    248 	}
    249 }
    250 
    251 func (c *Client) Version(ctx context.Context, tag uint16, mSize uint32, version string) (uint32, string, error) {
    252 	tmsg := &TVersion{tag: tag, mSize: mSize, version: version}
    253 	rmsg, err := c.transact(ctx, tmsg)
    254 	if err != nil {
    255 		return 0, "", fmt.Errorf("transact: %v", err)
    256 	}
    257 	switch rmsg := rmsg.(type) {
    258 	case *RVersion:
    259 		return rmsg.mSize, rmsg.version, nil
    260 	case *RError:
    261 		return 0, "", rmsg.ename
    262 	default:
    263 		return 0, "", fmt.Errorf("invalid reply: %v", rmsg)
    264 	}
    265 }
    266 
    267 func (c *Client) Auth(ctx context.Context, tag uint16, afid uint32, uname, aname string) (Qid, error) {
    268 	tmsg := &TAuth{tag: tag, afid: afid, uname: uname}
    269 	rmsg, err := c.transact(ctx, tmsg)
    270 	if err != nil {
    271 		return Qid{}, fmt.Errorf("transact: %v", err)
    272 	}
    273 	switch rmsg := rmsg.(type) {
    274 	case *RAuth:
    275 		return rmsg.aqid, nil
    276 	case *RError:
    277 		return Qid{}, rmsg.ename
    278 	default:
    279 		return Qid{}, fmt.Errorf("invalid reply: %v", rmsg)
    280 	}
    281 }
    282 
    283 func (c *Client) Attach(ctx context.Context, tag uint16, fid, afid uint32, uname, aname string) (Qid, error) {
    284 	tmsg := &TAttach{tag: tag, fid: fid, afid: afid, uname: uname, aname: aname}
    285 	rmsg, err := c.transact(ctx, tmsg)
    286 	if err != nil {
    287 		return Qid{}, fmt.Errorf("transact: %v", err)
    288 	}
    289 	switch rmsg := rmsg.(type) {
    290 	case *RAttach:
    291 		return rmsg.qid, nil
    292 	case *RError:
    293 		return Qid{}, rmsg.ename
    294 	default:
    295 		return Qid{}, fmt.Errorf("invalid reply: %v", rmsg)
    296 	}
    297 }
    298 
    299 func (c *Client) Flush(ctx context.Context, tag, oldtag uint16) error {
    300 	tmsg := &TFlush{tag: tag, oldtag: oldtag}
    301 	rmsg, err := c.transact(ctx, tmsg)
    302 	if err != nil {
    303 		return fmt.Errorf("transact: %v", err)
    304 	}
    305 	switch rmsg := rmsg.(type) {
    306 	case *RFlush:
    307 		return nil
    308 	case *RError:
    309 		return rmsg.ename
    310 	default:
    311 		return fmt.Errorf("invalid reply: %v", rmsg)
    312 	}
    313 }
    314 func (c *Client) Walk(ctx context.Context, tag uint16, fid, newFid uint32, wname []string) (wqid []Qid, err error) {
    315 	tmsg := &TWalk{tag: tag, fid: fid, newFid: newFid, wname: wname}
    316 	rmsg, err := c.transact(ctx, tmsg)
    317 	if err != nil {
    318 		return nil, fmt.Errorf("transact: %v", err)
    319 	}
    320 	switch rmsg := rmsg.(type) {
    321 	case *RWalk:
    322 		return rmsg.qid, nil
    323 	case *RError:
    324 		return nil, rmsg.ename
    325 	default:
    326 		return nil, fmt.Errorf("invalid reply: %v", rmsg)
    327 	}
    328 }
    329 func (c *Client) Open(ctx context.Context, tag uint16, fid uint32, mode OpenMode) (qid Qid, iounit uint32, err error) {
    330 	tmsg := &TOpen{tag: tag, fid: fid, mode: mode}
    331 	rmsg, err := c.transact(ctx, tmsg)
    332 	if err != nil {
    333 		return Qid{}, 0, fmt.Errorf("transact: %v", err)
    334 	}
    335 	switch rmsg := rmsg.(type) {
    336 	case *ROpen:
    337 		return rmsg.qid, rmsg.iounit, nil
    338 	case *RError:
    339 		return Qid{}, 0, rmsg.ename
    340 	default:
    341 		return Qid{}, 0, fmt.Errorf("invalid reply: %v", rmsg)
    342 	}
    343 }
    344 func (c *Client) Create(ctx context.Context, tag uint16, fid uint32, name string, perm FileMode, mode OpenMode) (Qid, uint32, error) {
    345 	tmsg := &TCreate{tag: tag, fid: fid, name: name, perm: perm, mode: mode}
    346 	rmsg, err := c.transact(ctx, tmsg)
    347 	if err != nil {
    348 		return Qid{}, 0, fmt.Errorf("transact: %v", err)
    349 	}
    350 	switch rmsg := rmsg.(type) {
    351 	case *RCreate:
    352 		return rmsg.qid, rmsg.iounit, nil
    353 	case *RError:
    354 		return Qid{}, 0, rmsg.ename
    355 	default:
    356 		return Qid{}, 0, fmt.Errorf("invalid reply: %v", rmsg)
    357 	}
    358 }
    359 func (c *Client) Read(ctx context.Context, tag uint16, fid uint32, offset uint64, count uint32) (data []byte, err error) {
    360 	tmsg := &TRead{tag: tag, fid: fid, offset: offset, count: count}
    361 	rmsg, err := c.transact(ctx, tmsg)
    362 	if err != nil {
    363 		return nil, fmt.Errorf("transact: %v", err)
    364 	}
    365 	switch rmsg := rmsg.(type) {
    366 	case *RRead:
    367 		return rmsg.data, nil
    368 	case *RError:
    369 		return nil, rmsg.ename
    370 	default:
    371 		return nil, fmt.Errorf("invalid reply: %v", rmsg)
    372 	}
    373 }
    374 func (c *Client) Write(ctx context.Context, tag uint16, fid uint32, offset uint64, count uint32, data []byte) (uint32, error) {
    375 	tmsg := &TWrite{tag: tag, fid: fid, offset: offset, count: count, data: data}
    376 	rmsg, err := c.transact(ctx, tmsg)
    377 	if err != nil {
    378 		return 0, fmt.Errorf("transact: %v", err)
    379 	}
    380 	switch rmsg := rmsg.(type) {
    381 	case *RWrite:
    382 		return rmsg.count, nil
    383 	case *RError:
    384 		return 0, rmsg.ename
    385 	default:
    386 		return 0, fmt.Errorf("invalid reply: %v", rmsg)
    387 	}
    388 }
    389 func (c *Client) Clunk(ctx context.Context, tag uint16, fid uint32) error {
    390 	tmsg := &TClunk{tag: tag, fid: fid}
    391 	rmsg, err := c.transact(ctx, tmsg)
    392 	if err != nil {
    393 		return fmt.Errorf("transact: %v", err)
    394 	}
    395 	switch rmsg := rmsg.(type) {
    396 	case *RClunk:
    397 		return nil
    398 	case *RError:
    399 		return rmsg.ename
    400 	default:
    401 		return fmt.Errorf("invalid reply: %v", rmsg)
    402 	}
    403 }
    404 func (c *Client) Remove(ctx context.Context, tag uint16, fid uint32) error {
    405 	tmsg := &TRemove{tag: tag, fid: fid}
    406 	rmsg, err := c.transact(ctx, tmsg)
    407 	if err != nil {
    408 		return fmt.Errorf("transact: %v", err)
    409 	}
    410 	switch rmsg := rmsg.(type) {
    411 	case *RRemove:
    412 		return nil
    413 	case *RError:
    414 		return rmsg.ename
    415 	default:
    416 		return fmt.Errorf("invalid reply: %v", rmsg)
    417 	}
    418 }
    419 func (c *Client) Stat(ctx context.Context, tag uint16, fid uint32) (*Stat, error) {
    420 	tmsg := &TStat{tag: tag, fid: fid}
    421 	rmsg, err := c.transact(ctx, tmsg)
    422 	if err != nil {
    423 		return nil, fmt.Errorf("transact: %v", err)
    424 	}
    425 	switch rmsg := rmsg.(type) {
    426 	case *RStat:
    427 		return rmsg.stat, nil
    428 	case *RError:
    429 		return nil, rmsg.ename
    430 	default:
    431 		return nil, fmt.Errorf("invalid reply: %v", rmsg)
    432 	}
    433 }
    434 func (c *Client) Wstat(ctx context.Context, tag uint16, fid uint32, stat *Stat) error {
    435 	tmsg := &TWStat{tag: tag, fid: fid, stat: stat}
    436 	rmsg, err := c.transact(ctx, tmsg)
    437 	if err != nil {
    438 		return fmt.Errorf("transact: %v", err)
    439 	}
    440 	switch rmsg := rmsg.(type) {
    441 	case *RWStat:
    442 		return nil
    443 	case *RError:
    444 		return rmsg.ename
    445 	default:
    446 		return fmt.Errorf("invalid reply: %v", rmsg)
    447 	}
    448 }