diff --git a/grace/conn.go b/grace/conn.go deleted file mode 100644 index 32623650..00000000 --- a/grace/conn.go +++ /dev/null @@ -1,38 +0,0 @@ -package grace - -import ( - "errors" - "net" - "sync" -) - -type graceConn struct { - net.Conn - server *Server - m sync.Mutex - closed bool -} - -func (c *graceConn) Close() (err error) { - defer func() { - if r := recover(); r != nil { - switch x := r.(type) { - case string: - err = errors.New(x) - case error: - err = x - default: - err = errors.New("Unknown panic") - } - } - }() - - c.m.Lock() - defer c.m.Unlock() - if c.closed { - return - } - c.server.wg.Done() - c.closed = true - return c.Conn.Close() -} diff --git a/grace/grace.go b/grace/grace.go index 6ebf8455..5a8bc3b8 100644 --- a/grace/grace.go +++ b/grace/grace.go @@ -122,7 +122,6 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { } srv = &Server{ - wg: sync.WaitGroup{}, sigChan: make(chan os.Signal), isChild: isChild, SignalHooks: map[int]map[os.Signal][]func(){ @@ -137,20 +136,21 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { syscall.SIGTERM: {}, }, }, - state: StateInit, - Network: "tcp", + state: StateInit, + Network: "tcp", + terminalChan: make(chan error), //no cache channel + } + srv.Server = &http.Server{ + Addr: addr, + ReadTimeout: DefaultReadTimeOut, + WriteTimeout: DefaultWriteTimeOut, + MaxHeaderBytes: DefaultMaxHeaderBytes, + Handler: handler, } - srv.Server = &http.Server{} - srv.Server.Addr = addr - srv.Server.ReadTimeout = DefaultReadTimeOut - srv.Server.WriteTimeout = DefaultWriteTimeOut - srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes - srv.Server.Handler = handler runningServersOrder = append(runningServersOrder, addr) runningServers[addr] = srv - - return + return srv } // ListenAndServe refer http.ListenAndServe diff --git a/grace/listener.go b/grace/listener.go deleted file mode 100644 index 7ede63a3..00000000 --- a/grace/listener.go +++ /dev/null @@ -1,62 +0,0 @@ -package grace - -import ( - "net" - "os" - "syscall" - "time" -) - -type graceListener struct { - net.Listener - stop chan error - stopped bool - server *Server -} - -func newGraceListener(l net.Listener, srv *Server) (el *graceListener) { - el = &graceListener{ - Listener: l, - stop: make(chan error), - server: srv, - } - go func() { - <-el.stop - el.stopped = true - el.stop <- el.Listener.Close() - }() - return -} - -func (gl *graceListener) Accept() (c net.Conn, err error) { - tc, err := gl.Listener.(*net.TCPListener).AcceptTCP() - if err != nil { - return - } - - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - - c = &graceConn{ - Conn: tc, - server: gl.server, - } - - gl.server.wg.Add(1) - return -} - -func (gl *graceListener) Close() error { - if gl.stopped { - return syscall.EINVAL - } - gl.stop <- nil - return <-gl.stop -} - -func (gl *graceListener) File() *os.File { - // returns a dup(2) - FD_CLOEXEC flag *not* set - tl := gl.Listener.(*net.TCPListener) - fl, _ := tl.File() - return fl -} diff --git a/grace/server.go b/grace/server.go index ef5cbe7e..1ce8bc78 100644 --- a/grace/server.go +++ b/grace/server.go @@ -1,6 +1,7 @@ package grace import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -12,7 +13,6 @@ import ( "os/exec" "os/signal" "strings" - "sync" "syscall" "time" ) @@ -20,31 +20,33 @@ import ( // Server embedded http.Server type Server struct { *http.Server - GraceListener net.Listener - SignalHooks map[int]map[os.Signal][]func() - tlsInnerListener *graceListener - wg sync.WaitGroup - sigChan chan os.Signal - isChild bool - state uint8 - Network string + ln net.Listener + SignalHooks map[int]map[os.Signal][]func() + sigChan chan os.Signal + isChild bool + state uint8 + Network string + terminalChan chan error } // Serve accepts incoming connections on the Listener l, // creating a new service goroutine for each. // The service goroutines read requests and then call srv.Handler to reply to them. func (srv *Server) Serve() (err error) { - defer func() { - if r := recover(); r != nil { - log.Println("wait group counter is negative", r) - } - }() srv.state = StateRunning - err = srv.Server.Serve(srv.GraceListener) - log.Println(syscall.Getpid(), "Waiting for connections to finish...") - srv.wg.Wait() - srv.state = StateTerminate - return + defer func() { srv.state = StateTerminate }() + + // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS + // immediately return ErrServerClosed. Make sure the program doesn't exit + // and waits instead for Shutdown to return. + if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { + log.Println(syscall.Getpid(), "Server.Serve() error:", err) + return err + } + + log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") + // wait for Shutdown to return + return <-srv.terminalChan } // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve @@ -58,14 +60,12 @@ func (srv *Server) ListenAndServe() (err error) { go srv.handleSignals() - l, err := srv.getListener(addr) + srv.ln, err = srv.getListener(addr) if err != nil { log.Println(err) return err } - srv.GraceListener = newGraceListener(l, srv) - if srv.isChild { process, err := os.FindProcess(os.Getppid()) if err != nil { @@ -112,14 +112,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { go srv.handleSignals() - l, err := srv.getListener(addr) + ln, err := srv.getListener(addr) if err != nil { log.Println(err) return err } - - srv.tlsInnerListener = newGraceListener(l, srv) - srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig) + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) if srv.isChild { process, err := os.FindProcess(os.Getppid()) @@ -132,6 +130,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { return err } } + log.Println(os.Getpid(), srv.Addr) return srv.Serve() } @@ -168,14 +167,12 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) log.Println("Mutual HTTPS") go srv.handleSignals() - l, err := srv.getListener(addr) + ln, err := srv.getListener(addr) if err != nil { log.Println(err) return err } - - srv.tlsInnerListener = newGraceListener(l, srv) - srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig) + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) if srv.isChild { process, err := os.FindProcess(os.Getppid()) @@ -188,6 +185,7 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) return err } } + log.Println(os.Getpid(), srv.Addr) return srv.Serve() } @@ -218,6 +216,20 @@ func (srv *Server) getListener(laddr string) (l net.Listener, err error) { return } +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + // handleSignals listens for os Signals and calls any hooked in function that the // user had registered with the signal. func (srv *Server) handleSignals() { @@ -270,37 +282,14 @@ func (srv *Server) shutdown() { } srv.state = StateShuttingDown + log.Println(syscall.Getpid(), "Waiting for connections to finish...") + ctx := context.Background() if DefaultTimeout >= 0 { - go srv.serverTimeout(DefaultTimeout) - } - err := srv.GraceListener.Close() - if err != nil { - log.Println(syscall.Getpid(), "Listener.Close() error:", err) - } else { - log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.") - } -} - -// serverTimeout forces the server to shutdown in a given timeout - whether it -// finished outstanding requests or not. if Read/WriteTimeout are not set or the -// max header size is very big a connection could hang -func (srv *Server) serverTimeout(d time.Duration) { - defer func() { - if r := recover(); r != nil { - log.Println("WaitGroup at 0", r) - } - }() - if srv.state != StateShuttingDown { - return - } - time.Sleep(d) - log.Println("[STOP - Hammer Time] Forcefully shutting down parent") - for { - if srv.state == StateTerminate { - break - } - srv.wg.Done() + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) + defer cancel() } + srv.terminalChan <- srv.Server.Shutdown(ctx) } func (srv *Server) fork() (err error) { @@ -314,12 +303,8 @@ func (srv *Server) fork() (err error) { var files = make([]*os.File, len(runningServers)) var orderArgs = make([]string, len(runningServers)) for _, srvPtr := range runningServers { - switch srvPtr.GraceListener.(type) { - case *graceListener: - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File() - default: - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File() - } + f, _ := srvPtr.ln.(*net.TCPListener).File() + files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr }