lib9p

Go 9P library.
Log | Files | Refs | LICENSE

client.go (10157B)


      1 package client
      2 
      3 import (
      4 	"context"
      5 	"errors"
      6 	"fmt"
      7 	"io"
      8 	"log"
      9 	"sync"
     10 
     11 	"git.mtkn.jp/lib9p"
     12 )
     13 
     14 // Client is a client side of the 9P conversation.
     15 type Client struct {
     16 	// Msize is the maximum message size in length
     17 	msize uint32
     18 	// MSizeLock is the mutex used when msize is to be changed.
     19 	mSizeLock *sync.Mutex
     20 
     21 	// Uname is used to communicate with a server.
     22 	uname string
     23 
     24 	// RPool is the set of all outstanding requests.
     25 	rPool *reqPool
     26 
     27 	// Txc is used to send a reqest to the multiplexer goroutine
     28 	txc chan<- *req
     29 
     30 	// Wg is the WaitGroup of all goroutines evoked by this client and its
     31 	// descendants.
     32 	wg *sync.WaitGroup
     33 
     34 	// Done is closed when the context passed to NewClient is canceled.
     35 	// This is used to notify transact() that the client is already
     36 	// canceled and the transaction should also be canceled.
     37 	done <-chan struct{}
     38 }
     39 
     40 // NewClient creates a Client and prepare to transact with a server via r and w.
     41 // It runs several goroutines to handle requests.
     42 // And the returned client should be stopped afterwords by cancelling ctx
     43 // and then closing r and w.
     44 func NewClient(ctx context.Context, mSize uint32, uname string, r io.Reader, w io.Writer) *Client {
     45 	c := &Client{
     46 		msize:     mSize,
     47 		mSizeLock: new(sync.Mutex),
     48 		uname:     uname,
     49 		rPool:     newReqPool(),
     50 		wg:        new(sync.WaitGroup),
     51 		done:      ctx.Done(),
     52 	}
     53 	c.txc = c.runMultiplexer(ctx, r, w)
     54 	return c
     55 }
     56 
     57 // mSize returns the maximum message size of the Client.
     58 func (c *Client) mSize() uint32 {
     59 	c.mSizeLock.Lock()
     60 	defer c.mSizeLock.Unlock()
     61 	return c.msize
     62 }
     63 
     64 // setMSize changes the maximum message size of the Client.
     65 func (c *Client) setMSize(mSize uint32) {
     66 	c.mSizeLock.Lock()
     67 	defer c.mSizeLock.Unlock()
     68 	c.msize = mSize
     69 }
     70 
     71 // RunMultiplexer runs two goroutines,
     72 // one for recieving Rmsg and another for sending Tmsg.
     73 // The goroutine for Tmsg recieves *req from the returned channel,
     74 // and sends the lib9p.Msg to w.
     75 // The goroutine for Rmsg reads lib9p.Msg from r and sends it to rxc of the
     76 // corresponding request.
     77 func (c *Client) runMultiplexer(ctx context.Context, r io.Reader, w io.Writer) chan<- *req {
     78 	c.wg.Add(2)
     79 	txc := make(chan *req)
     80 	// Rmsg
     81 	go func() {
     82 		defer c.wg.Done()
     83 		for {
     84 			select {
     85 			case <-ctx.Done():
     86 				return
     87 			default:
     88 			}
     89 			msg, err := lib9p.RecvMsg(r)
     90 			if err != nil {
     91 				c.rPool.cancelAll(fmt.Errorf("recv: %v", err))
     92 				continue // TODO: should return?
     93 			}
     94 			rq, ok := c.rPool.lookup(msg.GetTag())
     95 			if !ok {
     96 				log.Printf("mux: unknown tag for msg: %v", msg)
     97 				continue // TODO: how to recover?
     98 			}
     99 			c.rPool.delete(msg.GetTag())
    100 			if tflush, ok := rq.tmsg.(*lib9p.TFlush); ok {
    101 				if _, ok := msg.(*lib9p.RFlush); !ok {
    102 					rq.errc <- fmt.Errorf("mux: response to Tflush is not Rflush")
    103 				}
    104 				if oldreq, ok := c.rPool.lookup(tflush.Oldtag); ok {
    105 					oldreq.errc <- errors.New("request flushed")
    106 					c.rPool.delete(tflush.Oldtag)
    107 				}
    108 			}
    109 			rq.rmsg = msg
    110 			select {
    111 			case <-ctx.Done():
    112 			case rq.rxc <- rq:
    113 			}
    114 		}
    115 	}()
    116 	// Tmsg
    117 	go func() {
    118 		defer c.wg.Done()
    119 		for {
    120 			select {
    121 			case <-ctx.Done():
    122 				return
    123 			case r, ok := <-txc:
    124 				if !ok {
    125 					return
    126 				}
    127 				if _, ok := c.rPool.lookup(r.tag); ok {
    128 					r.errc <- fmt.Errorf("mux: %w: %d", lib9p.ErrDupTag, r.tag)
    129 					continue
    130 				}
    131 				c.rPool.add(r)
    132 				if err := lib9p.SendMsg(r.tmsg, w); err != nil {
    133 					r.errc <- fmt.Errorf("send: %v", err)
    134 				}
    135 			}
    136 		}
    137 	}()
    138 	return txc
    139 }
    140 
    141 // Transact sends 9P tmsg to the multiplexer and recieves
    142 // the reply.
    143 func (c *Client) transact(tmsg lib9p.Msg) (lib9p.Msg, error) {
    144 	r := newReq(tmsg)
    145 	select {
    146 	case c.txc <- r:
    147 	case <-c.done:
    148 		return nil, errors.New("client stopped")
    149 	}
    150 	select {
    151 	case r := <-r.rxc:
    152 		return r.rmsg, r.err
    153 	case err := <-r.errc: // Client side error.
    154 		return nil, err
    155 	case <-c.done:
    156 		return nil, errors.New("client stopped")
    157 	}
    158 }
    159 
    160 // Version sends Tversion message to the server and returns the resulting
    161 // data of Rversion or non nil error if any.
    162 // This function and other Tmessage functions don't have a context.Contex
    163 // as their argument.
    164 // The caller can call *Client.Flush to cancel a pending request if the
    165 // connection to the server is helthy.
    166 // And even if the connection has some problem sending/recieving, there
    167 // is no way to cancel blocking reads/writes. In this case, the caller
    168 // can close the connection.
    169 func (c *Client) Version(tag uint16, msize uint32, version string) (uint32, string, error) {
    170 	tmsg := &lib9p.TVersion{Tag: tag, Msize: msize, Version: version}
    171 	rmsg, err := c.transact(tmsg)
    172 	if err != nil {
    173 		return 0, "", fmt.Errorf("transact: %w", err)
    174 	}
    175 	switch rmsg := rmsg.(type) {
    176 	case *lib9p.RVersion:
    177 		return rmsg.Msize, rmsg.Version, nil
    178 	case *lib9p.RError:
    179 		return 0, "", rmsg.Ename
    180 	default:
    181 		return 0, "", fmt.Errorf("invalid reply: %v", rmsg)
    182 	}
    183 }
    184 
    185 func (c *Client) Auth(tag uint16, afid uint32, uname, aname string) (lib9p.Qid, error) {
    186 	tmsg := &lib9p.TAuth{Tag: tag, Afid: afid, Uname: uname}
    187 	rmsg, err := c.transact(tmsg)
    188 	if err != nil {
    189 		return lib9p.Qid{}, fmt.Errorf("transact: %w", err)
    190 	}
    191 	switch rmsg := rmsg.(type) {
    192 	case *lib9p.RAuth:
    193 		return rmsg.Aqid, nil
    194 	case *lib9p.RError:
    195 		return lib9p.Qid{}, rmsg.Ename
    196 	default:
    197 		return lib9p.Qid{}, fmt.Errorf("invalid reply: %v", rmsg)
    198 	}
    199 }
    200 
    201 func (c *Client) Attach(tag uint16, fid, afid uint32, uname, aname string) (lib9p.Qid, error) {
    202 	tmsg := &lib9p.TAttach{Tag: tag, Fid: fid, Afid: afid, Uname: uname, Aname: aname}
    203 	rmsg, err := c.transact(tmsg)
    204 	if err != nil {
    205 		return lib9p.Qid{}, fmt.Errorf("transact: %w", err)
    206 	}
    207 	switch rmsg := rmsg.(type) {
    208 	case *lib9p.RAttach:
    209 		return rmsg.Qid, nil
    210 	case *lib9p.RError:
    211 		return lib9p.Qid{}, rmsg.Ename
    212 	default:
    213 		return lib9p.Qid{}, fmt.Errorf("invalid reply: %v", rmsg)
    214 	}
    215 }
    216 
    217 func (c *Client) Flush(tag, oldtag uint16) error {
    218 	tmsg := &lib9p.TFlush{Tag: tag, Oldtag: oldtag}
    219 	rmsg, err := c.transact(tmsg)
    220 	if err != nil {
    221 		return fmt.Errorf("transact: %w", err)
    222 	}
    223 	switch rmsg := rmsg.(type) {
    224 	case *lib9p.RFlush:
    225 		return nil
    226 	case *lib9p.RError:
    227 		return rmsg.Ename
    228 	default:
    229 		return fmt.Errorf("invalid reply: %v", rmsg)
    230 	}
    231 }
    232 func (c *Client) Walk(tag uint16, fid, newfid uint32, wnames []string) (wqid []lib9p.Qid, err error) {
    233 	tmsg := &lib9p.TWalk{Tag: tag, Fid: fid, Newfid: newfid, Wnames: wnames}
    234 	rmsg, err := c.transact(tmsg)
    235 	if err != nil {
    236 		return nil, fmt.Errorf("transact: %w", err)
    237 	}
    238 	switch rmsg := rmsg.(type) {
    239 	case *lib9p.RWalk:
    240 		return rmsg.Qids, nil
    241 	case *lib9p.RError:
    242 		return nil, rmsg.Ename
    243 	default:
    244 		return nil, fmt.Errorf("invalid reply: %v", rmsg)
    245 	}
    246 }
    247 func (c *Client) Open(tag uint16, fid uint32, mode lib9p.OpenMode) (qid lib9p.Qid, iounit uint32, err error) {
    248 	tmsg := &lib9p.TOpen{Tag: tag, Fid: fid, Mode: mode}
    249 	rmsg, err := c.transact(tmsg)
    250 	if err != nil {
    251 		return lib9p.Qid{}, 0, fmt.Errorf("transact: %w", err)
    252 	}
    253 	switch rmsg := rmsg.(type) {
    254 	case *lib9p.ROpen:
    255 		return rmsg.Qid, rmsg.Iounit, nil
    256 	case *lib9p.RError:
    257 		return lib9p.Qid{}, 0, rmsg.Ename
    258 	default:
    259 		return lib9p.Qid{}, 0, fmt.Errorf("invalid reply: %v", rmsg)
    260 	}
    261 }
    262 func (c *Client) Create(tag uint16, fid uint32, name string, perm lib9p.FileMode, mode lib9p.OpenMode) (lib9p.Qid, uint32, error) {
    263 	tmsg := &lib9p.TCreate{Tag: tag, Fid: fid, Name: name, Perm: perm, Mode: mode}
    264 	rmsg, err := c.transact(tmsg)
    265 	if err != nil {
    266 		return lib9p.Qid{}, 0, fmt.Errorf("transact: %w", err)
    267 	}
    268 	switch rmsg := rmsg.(type) {
    269 	case *lib9p.RCreate:
    270 		return rmsg.Qid, rmsg.Iounit, nil
    271 	case *lib9p.RError:
    272 		return lib9p.Qid{}, 0, rmsg.Ename
    273 	default:
    274 		return lib9p.Qid{}, 0, fmt.Errorf("invalid reply: %v", rmsg)
    275 	}
    276 }
    277 
    278 // Read doesn't return io.EOF at the end of file,
    279 // but returns empty data and nil error.
    280 func (c *Client) Read(tag uint16, fid uint32, offset uint64, count uint32) (data []byte, err error) {
    281 	tmsg := &lib9p.TRead{Tag: tag, Fid: fid, Offset: offset, Count: count}
    282 	rmsg, err := c.transact(tmsg)
    283 	if err != nil {
    284 		return nil, fmt.Errorf("transact: %w", err)
    285 	}
    286 	switch rmsg := rmsg.(type) {
    287 	case *lib9p.RRead:
    288 		return rmsg.Data, nil
    289 	case *lib9p.RError:
    290 		return nil, rmsg.Ename
    291 	default:
    292 		return nil, fmt.Errorf("invalid reply: %v", rmsg)
    293 	}
    294 }
    295 func (c *Client) Write(tag uint16, fid uint32, offset uint64, count uint32, data []byte) (uint32, error) {
    296 	tmsg := &lib9p.TWrite{Tag: tag, Fid: fid, Offset: offset, Count: count, Data: data}
    297 	rmsg, err := c.transact(tmsg)
    298 	if err != nil {
    299 		return 0, fmt.Errorf("transact: %w", err)
    300 	}
    301 	switch rmsg := rmsg.(type) {
    302 	case *lib9p.RWrite:
    303 		return rmsg.Count, nil
    304 	case *lib9p.RError:
    305 		return 0, rmsg.Ename
    306 	default:
    307 		return 0, fmt.Errorf("invalid reply: %v", rmsg)
    308 	}
    309 }
    310 func (c *Client) Clunk(tag uint16, fid uint32) error {
    311 	tmsg := &lib9p.TClunk{Tag: tag, Fid: fid}
    312 	rmsg, err := c.transact(tmsg)
    313 	if err != nil {
    314 		return fmt.Errorf("transact: %w", err)
    315 	}
    316 	switch rmsg := rmsg.(type) {
    317 	case *lib9p.RClunk:
    318 		return nil
    319 	case *lib9p.RError:
    320 		return rmsg.Ename
    321 	default:
    322 		return fmt.Errorf("invalid reply: %v", rmsg)
    323 	}
    324 }
    325 func (c *Client) Remove(tag uint16, fid uint32) error {
    326 	tmsg := &lib9p.TRemove{Tag: tag, Fid: fid}
    327 	rmsg, err := c.transact(tmsg)
    328 	if err != nil {
    329 		return fmt.Errorf("transact: %w", err)
    330 	}
    331 	switch rmsg := rmsg.(type) {
    332 	case *lib9p.RRemove:
    333 		return nil
    334 	case *lib9p.RError:
    335 		return rmsg.Ename
    336 	default:
    337 		return fmt.Errorf("invalid reply: %v", rmsg)
    338 	}
    339 }
    340 func (c *Client) Stat(tag uint16, fid uint32) (*lib9p.Stat, error) {
    341 	tmsg := &lib9p.TStat{Tag: tag, Fid: fid}
    342 	rmsg, err := c.transact(tmsg)
    343 	if err != nil {
    344 		return nil, fmt.Errorf("transact: %w", err)
    345 	}
    346 	switch rmsg := rmsg.(type) {
    347 	case *lib9p.RStat:
    348 		return rmsg.Stat, nil
    349 	case *lib9p.RError:
    350 		return nil, rmsg.Ename
    351 	default:
    352 		return nil, fmt.Errorf("invalid reply: %v", rmsg)
    353 	}
    354 }
    355 func (c *Client) Wstat(tag uint16, fid uint32, stat *lib9p.Stat) error {
    356 	tmsg := &lib9p.TWstat{Tag: tag, Fid: fid, Stat: stat}
    357 	rmsg, err := c.transact(tmsg)
    358 	if err != nil {
    359 		return fmt.Errorf("transact: %w", err)
    360 	}
    361 	switch rmsg := rmsg.(type) {
    362 	case *lib9p.RWstat:
    363 		return nil
    364 	case *lib9p.RError:
    365 		return rmsg.Ename
    366 	default:
    367 		return fmt.Errorf("invalid reply: %v", rmsg)
    368 	}
    369 }