Merge pull request #3494 from nuczzz/develop

simplify beego grace with http.Shutdown
This commit is contained in:
astaxie 2019-02-26 16:31:40 +08:00 committed by GitHub
commit 422e8285b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 176 deletions

View File

@ -1,38 +0,0 @@
package grace
import (
"errors"
"net"
"sync"
)
type graceConn struct {
net.Conn
server *Server
m sync.Mutex
closed bool
}
func (c *graceConn) Close() (err error) {
defer func() {
if r := recover(); r != nil {
switch x := r.(type) {
case string:
err = errors.New(x)
case error:
err = x
default:
err = errors.New("Unknown panic")
}
}
}()
c.m.Lock()
defer c.m.Unlock()
if c.closed {
return
}
c.server.wg.Done()
c.closed = true
return c.Conn.Close()
}

View File

@ -122,7 +122,6 @@ func NewServer(addr string, handler http.Handler) (srv *Server) {
} }
srv = &Server{ srv = &Server{
wg: sync.WaitGroup{},
sigChan: make(chan os.Signal), sigChan: make(chan os.Signal),
isChild: isChild, isChild: isChild,
SignalHooks: map[int]map[os.Signal][]func(){ SignalHooks: map[int]map[os.Signal][]func(){
@ -137,20 +136,21 @@ func NewServer(addr string, handler http.Handler) (srv *Server) {
syscall.SIGTERM: {}, syscall.SIGTERM: {},
}, },
}, },
state: StateInit, state: StateInit,
Network: "tcp", Network: "tcp",
terminalChan: make(chan error), //no cache channel
}
srv.Server = &http.Server{
Addr: addr,
ReadTimeout: DefaultReadTimeOut,
WriteTimeout: DefaultWriteTimeOut,
MaxHeaderBytes: DefaultMaxHeaderBytes,
Handler: handler,
} }
srv.Server = &http.Server{}
srv.Server.Addr = addr
srv.Server.ReadTimeout = DefaultReadTimeOut
srv.Server.WriteTimeout = DefaultWriteTimeOut
srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
srv.Server.Handler = handler
runningServersOrder = append(runningServersOrder, addr) runningServersOrder = append(runningServersOrder, addr)
runningServers[addr] = srv runningServers[addr] = srv
return srv
return
} }
// ListenAndServe refer http.ListenAndServe // ListenAndServe refer http.ListenAndServe

View File

@ -1,62 +0,0 @@
package grace
import (
"net"
"os"
"syscall"
"time"
)
type graceListener struct {
net.Listener
stop chan error
stopped bool
server *Server
}
func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
el = &graceListener{
Listener: l,
stop: make(chan error),
server: srv,
}
go func() {
<-el.stop
el.stopped = true
el.stop <- el.Listener.Close()
}()
return
}
func (gl *graceListener) Accept() (c net.Conn, err error) {
tc, err := gl.Listener.(*net.TCPListener).AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
c = &graceConn{
Conn: tc,
server: gl.server,
}
gl.server.wg.Add(1)
return
}
func (gl *graceListener) Close() error {
if gl.stopped {
return syscall.EINVAL
}
gl.stop <- nil
return <-gl.stop
}
func (gl *graceListener) File() *os.File {
// returns a dup(2) - FD_CLOEXEC flag *not* set
tl := gl.Listener.(*net.TCPListener)
fl, _ := tl.File()
return fl
}

View File

@ -1,6 +1,7 @@
package grace package grace
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
@ -12,7 +13,6 @@ import (
"os/exec" "os/exec"
"os/signal" "os/signal"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
) )
@ -20,31 +20,33 @@ import (
// Server embedded http.Server // Server embedded http.Server
type Server struct { type Server struct {
*http.Server *http.Server
GraceListener net.Listener ln net.Listener
SignalHooks map[int]map[os.Signal][]func() SignalHooks map[int]map[os.Signal][]func()
tlsInnerListener *graceListener sigChan chan os.Signal
wg sync.WaitGroup isChild bool
sigChan chan os.Signal state uint8
isChild bool Network string
state uint8 terminalChan chan error
Network string
} }
// Serve accepts incoming connections on the Listener l, // Serve accepts incoming connections on the Listener l,
// creating a new service goroutine for each. // creating a new service goroutine for each.
// The service goroutines read requests and then call srv.Handler to reply to them. // The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *Server) Serve() (err error) { func (srv *Server) Serve() (err error) {
defer func() {
if r := recover(); r != nil {
log.Println("wait group counter is negative", r)
}
}()
srv.state = StateRunning srv.state = StateRunning
err = srv.Server.Serve(srv.GraceListener) defer func() { srv.state = StateTerminate }()
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
srv.wg.Wait() // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
srv.state = StateTerminate // immediately return ErrServerClosed. Make sure the program doesn't exit
return // and waits instead for Shutdown to return.
if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed {
log.Println(syscall.Getpid(), "Server.Serve() error:", err)
return err
}
log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
// wait for Shutdown to return
return <-srv.terminalChan
} }
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
@ -58,14 +60,12 @@ func (srv *Server) ListenAndServe() (err error) {
go srv.handleSignals() go srv.handleSignals()
l, err := srv.getListener(addr) srv.ln, err = srv.getListener(addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err return err
} }
srv.GraceListener = newGraceListener(l, srv)
if srv.isChild { if srv.isChild {
process, err := os.FindProcess(os.Getppid()) process, err := os.FindProcess(os.Getppid())
if err != nil { if err != nil {
@ -112,14 +112,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
go srv.handleSignals() go srv.handleSignals()
l, err := srv.getListener(addr) ln, err := srv.getListener(addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err return err
} }
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
if srv.isChild { if srv.isChild {
process, err := os.FindProcess(os.Getppid()) process, err := os.FindProcess(os.Getppid())
@ -132,6 +130,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
return err return err
} }
} }
log.Println(os.Getpid(), srv.Addr) log.Println(os.Getpid(), srv.Addr)
return srv.Serve() return srv.Serve()
} }
@ -168,14 +167,12 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
log.Println("Mutual HTTPS") log.Println("Mutual HTTPS")
go srv.handleSignals() go srv.handleSignals()
l, err := srv.getListener(addr) ln, err := srv.getListener(addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err return err
} }
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
if srv.isChild { if srv.isChild {
process, err := os.FindProcess(os.Getppid()) process, err := os.FindProcess(os.Getppid())
@ -188,6 +185,7 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
return err return err
} }
} }
log.Println(os.Getpid(), srv.Addr) log.Println(os.Getpid(), srv.Addr)
return srv.Serve() return srv.Serve()
} }
@ -218,6 +216,20 @@ func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
return return
} }
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
// handleSignals listens for os Signals and calls any hooked in function that the // handleSignals listens for os Signals and calls any hooked in function that the
// user had registered with the signal. // user had registered with the signal.
func (srv *Server) handleSignals() { func (srv *Server) handleSignals() {
@ -270,37 +282,14 @@ func (srv *Server) shutdown() {
} }
srv.state = StateShuttingDown srv.state = StateShuttingDown
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
ctx := context.Background()
if DefaultTimeout >= 0 { if DefaultTimeout >= 0 {
go srv.serverTimeout(DefaultTimeout) var cancel context.CancelFunc
} ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
err := srv.GraceListener.Close() defer cancel()
if err != nil {
log.Println(syscall.Getpid(), "Listener.Close() error:", err)
} else {
log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
}
}
// serverTimeout forces the server to shutdown in a given timeout - whether it
// finished outstanding requests or not. if Read/WriteTimeout are not set or the
// max header size is very big a connection could hang
func (srv *Server) serverTimeout(d time.Duration) {
defer func() {
if r := recover(); r != nil {
log.Println("WaitGroup at 0", r)
}
}()
if srv.state != StateShuttingDown {
return
}
time.Sleep(d)
log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
for {
if srv.state == StateTerminate {
break
}
srv.wg.Done()
} }
srv.terminalChan <- srv.Server.Shutdown(ctx)
} }
func (srv *Server) fork() (err error) { func (srv *Server) fork() (err error) {
@ -314,12 +303,8 @@ func (srv *Server) fork() (err error) {
var files = make([]*os.File, len(runningServers)) var files = make([]*os.File, len(runningServers))
var orderArgs = make([]string, len(runningServers)) var orderArgs = make([]string, len(runningServers))
for _, srvPtr := range runningServers { for _, srvPtr := range runningServers {
switch srvPtr.GraceListener.(type) { f, _ := srvPtr.ln.(*net.TCPListener).File()
case *graceListener: files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
default:
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
}
orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
} }