commit 05a26fbcb320ed68aa4e56c373b5534260aef9cf
parent 7bd2d64bb275555671213e0870dcf54d0ca1280b
Author: Matsuda Kenji <info@mtkn.jp>
Date:   Thu, 28 Dec 2023 09:34:27 +0900
add select
Diffstat:
| M | server.go | | | 120 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------- | 
1 file changed, 99 insertions(+), 21 deletions(-)
diff --git a/server.go b/server.go
@@ -225,7 +225,11 @@ func sVersion(ctx context.Context, s *Server, c <-chan *Req) {
 				Version: version,
 			}
 			s.setMSize(r.Ofcall.(*RVersion).Msize)
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -248,20 +252,32 @@ func sAuth(ctx context.Context, s *Server, c <-chan *Req) {
 			}
 			if authc == nil {
 				setError(r, fmt.Errorf("authentication not required"))
-				s.respChan <- r
+				select {
+				case s.respChan <- r:
+				case <-ctx.Done():
+					return
+				}
 				continue
 			}
 			ifcall := r.Ifcall.(*TAuth)
 			var err error
 			if ifcall.Afid == NOFID {
 				setError(r, fmt.Errorf("NOFID can't be used for afid")) // TODO: really?
-				s.respChan <- r
+				select {
+				case s.respChan <- r:
+				case <-ctx.Done():
+					return
+				}
 				continue
 			}
 			r.Afid, err = s.fPool.add(ifcall.Afid)
 			if err != nil {
 				setError(r, ErrDupFid)
-				s.respChan <- r
+				select {
+				case s.respChan <- r:
+				case <-ctx.Done():
+					return
+				}
 				continue
 			}
 			authc <- r // TODO: I think Req.listenErr should be exported.
@@ -330,7 +346,11 @@ func sAttach(ctx context.Context, s *Server, c <-chan *Req) {
 				}
 				setError(r, r.err)
 			}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -348,7 +368,11 @@ func sFlush(ctx context.Context, s *Server, c <-chan *Req) {
 				r.Oldreq.flush()
 			}
 			r.Ofcall = &RFlush{}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -437,7 +461,11 @@ func sWalk(ctx context.Context, s *Server, c <-chan *Req) {
 					}
 				}
 			}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -534,22 +562,38 @@ func sOpen(ctx context.Context, s *Server, c <-chan *Req) {
 		resp:
 			if r.err != nil {
 				setError(r, r.err)
-				s.respChan <- r
+				select {
+				case s.respChan <- r:
+				case <-ctx.Done():
+					return
+				}
 				continue
 			}
 			r.Fid.OMode = r.Ifcall.(*TOpen).Mode
 			if _, ok := r.Fid.File.(*AuthFile); ok {
-				s.respChan <- r
+				select {
+				case s.respChan <- r:
+				case <-ctx.Done():
+					return
+				}
 				continue
 			}
 			f, err := r.Srv.fs.OpenFile(r.Fid.path, r.Fid.OMode)
 			if err != nil {
 				setError(r, err)
-				s.respChan <- r
+				select {
+				case s.respChan <- r:
+				case <-ctx.Done():
+					return
+				}
 				continue
 			}
 			r.Fid.File = f
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -626,7 +670,11 @@ func sCreate(ctx context.Context, s *Server, c <-chan *Req) {
 			if r.err != nil {
 				setError(r, r.err)
 			}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -726,6 +774,8 @@ func sRead(ctx context.Context, s *Server, c <-chan *Req) {
 				}
 			case <-r.done:
 				continue
+			case <-ctx.Done():
+				return
 			}
 			r.Ofcall = &RRead{
 				Count: uint32(n),
@@ -735,7 +785,11 @@ func sRead(ctx context.Context, s *Server, c <-chan *Req) {
 			if r.err != nil {
 				setError(r, r.err)
 			}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -801,13 +855,19 @@ func sWrite(ctx context.Context, s *Server, c <-chan *Req) {
 				}
 			case <-r.done:
 				continue
+			case <-ctx.Done():
+				return
 			}
 			r.Ofcall = ofcall
 		resp:
 			if r.err != nil {
 				setError(r, r.err)
 			}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -840,7 +900,11 @@ func sClunk(ctx context.Context, s *Server, c <-chan *Req) {
 				log.Printf("clunk: %v", r.err)
 				r.Ofcall = &RClunk{}
 			}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -894,7 +958,11 @@ func sRemove(ctx context.Context, s *Server, c <-chan *Req) {
 			if r.err != nil {
 				setError(r, r.err)
 			}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -930,7 +998,11 @@ func sStat(ctx context.Context, s *Server, c <-chan *Req) {
 			if r.err != nil {
 				setError(r, r.err)
 			}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
@@ -1072,15 +1144,17 @@ func sWStat(ctx context.Context, s *Server, c <-chan *Req) {
 			if r.err != nil {
 				setError(r, r.err)
 			}
-			s.respChan <- r
+			select {
+			case s.respChan <- r:
+			case <-ctx.Done():
+				return
+			}
 		}
 	}
 }
 
 // Serve serves 9P conversation.
 func (s *Server) Serve(ctx context.Context) {
-	ctx, cancel := context.WithCancel(ctx)
-	defer cancel()
 	s.runListener(ctx)
 	s.runSpeaker(ctx)
 	var (
@@ -1146,7 +1220,11 @@ L:
 			switch r.Ifcall.(type) {
 			default:
 				setError(r, fmt.Errorf("unknown message type: %d", r.Ifcall.Type()))
-				s.respChan <- r
+				select {
+				case s.respChan <- r:
+				case <-ctx.Done():
+					return
+				}
 			case *TVersion:
 				versionChan <- r
 			case *TAuth: