diff --git a/admin.go b/admin.go index 7eb07cc0..52388c76 100644 --- a/admin.go +++ b/admin.go @@ -19,9 +19,11 @@ import ( "encoding/json" "fmt" "net/http" + "os" "text/template" "time" + "github.com/astaxie/beego/grace" "github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/utils" ) @@ -458,8 +460,14 @@ func (admin *adminApp) Run() { http.Handle(p, f) } BeeLogger.Info("Admin server Running on %s", addr) - err := http.ListenAndServe(addr, nil) + + var err error + if Graceful { + err = grace.ListenAndServe(addr, nil) + } else { + err = http.ListenAndServe(addr, nil) + } if err != nil { - BeeLogger.Critical("Admin ListenAndServe: ", err) + BeeLogger.Critical("Admin ListenAndServe: ", err, fmt.Sprint(os.Getpid())) } } diff --git a/app.go b/app.go index 35040f33..4740e7e5 100644 --- a/app.go +++ b/app.go @@ -22,6 +22,7 @@ import ( "os" "time" + "github.com/astaxie/beego/grace" "github.com/astaxie/beego/utils" ) @@ -76,57 +77,93 @@ func (app *App) Run() { err = fcgi.Serve(l, app.Handlers) } } else { - app.Server.Addr = addr - app.Server.Handler = app.Handlers - app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second - app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second + if Graceful { + if EnableHttpTLS { + go func() { + time.Sleep(20 * time.Microsecond) + if HttpsPort != 0 { + addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort) + } + server := grace.NewServer(addr, app.Handlers) + server.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second + server.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second + err := server.ListenAndServeTLS(HttpCertFile, HttpKeyFile) + if err != nil { + BeeLogger.Critical("ListenAndServeTLS: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + } + if EnableHttpListen { + go func() { + server := grace.NewServer(addr, app.Handlers) + server.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second + server.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second + if ListenTCP4 && HttpAddr == "" { + server.Network = "tcp4" + } + err := server.ListenAndServe() + if err != nil { + BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprint(os.Getpid())) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + } + } else { + app.Server.Addr = addr + app.Server.Handler = app.Handlers + app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second + app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second - if EnableHttpTLS { - go func() { - time.Sleep(20 * time.Microsecond) - if HttpsPort != 0 { - app.Server.Addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort) - } - BeeLogger.Info("https server Running on %s", app.Server.Addr) - err := app.Server.ListenAndServeTLS(HttpCertFile, HttpKeyFile) - if err != nil { - BeeLogger.Critical("ListenAndServeTLS: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - }() + if EnableHttpTLS { + go func() { + time.Sleep(20 * time.Microsecond) + if HttpsPort != 0 { + app.Server.Addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort) + } + BeeLogger.Info("https server Running on %s", app.Server.Addr) + err := app.Server.ListenAndServeTLS(HttpCertFile, HttpKeyFile) + if err != nil { + BeeLogger.Critical("ListenAndServeTLS: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + } + + if EnableHttpListen { + go func() { + app.Server.Addr = addr + BeeLogger.Info("http server Running on %s", app.Server.Addr) + if ListenTCP4 && HttpAddr == "" { + ln, err := net.Listen("tcp4", app.Server.Addr) + if err != nil { + BeeLogger.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + err = app.Server.Serve(ln) + if err != nil { + BeeLogger.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + } else { + err := app.Server.ListenAndServe() + if err != nil { + BeeLogger.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + } + }() + } } - if EnableHttpListen { - go func() { - app.Server.Addr = addr - BeeLogger.Info("http server Running on %s", app.Server.Addr) - if ListenTCP4 && HttpAddr == "" { - ln, err := net.Listen("tcp4", app.Server.Addr) - if err != nil { - BeeLogger.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - err = app.Server.Serve(ln) - if err != nil { - BeeLogger.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - } else { - err := app.Server.ListenAndServe() - if err != nil { - BeeLogger.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - } - }() - } } - <-endRunning } diff --git a/config.go b/config.go index b22db163..09d0df24 100644 --- a/config.go +++ b/config.go @@ -82,6 +82,7 @@ var ( EnableDocs bool // enable generate docs & server docs API Swagger RouterCaseSensitive bool // router case sensitive default is true AccessLogs bool // print access logs, default is false + Graceful bool // use graceful start the server ) type beegoAppConfig struct { @@ -509,5 +510,8 @@ func ParseConfig() (err error) { if casesensitive, err := AppConfig.Bool("RouterCaseSensitive"); err == nil { RouterCaseSensitive = casesensitive } + if graceful, err := AppConfig.Bool("Graceful"); err == nil { + Graceful = graceful + } return nil } diff --git a/grace/grace.go b/grace/grace.go new file mode 100644 index 00000000..e77112d0 --- /dev/null +++ b/grace/grace.go @@ -0,0 +1,522 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ +// +// Usage: +// +// import( +// "log" +// "net/http" +// "os" +// +// "github.com/astaxie/beego/grace" +// ) +// +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } +// +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) +// +// err := grace.ListenAndServe("localhost:8080", mux1) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } +package grace + +import ( + "crypto/tls" + "flag" + "fmt" + "log" + "net" + "net/http" + "os" + "os/exec" + "os/signal" + "strings" + "sync" + "syscall" + "time" +) + +const ( + PRE_SIGNAL = iota + POST_SIGNAL + + STATE_INIT + STATE_RUNNING + STATE_SHUTTING_DOWN + STATE_TERMINATE +) + +var ( + regLock *sync.Mutex + runningServers map[string]*graceServer + runningServersOrder []string + socketPtrOffsetMap map[string]uint + runningServersForked bool + + DefaultReadTimeOut time.Duration + DefaultWriteTimeOut time.Duration + DefaultMaxHeaderBytes int + DefaultTimeout time.Duration + + isChild bool + socketOrder string +) + +func init() { + regLock = &sync.Mutex{} + flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") + flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") + runningServers = make(map[string]*graceServer) + runningServersOrder = []string{} + socketPtrOffsetMap = make(map[string]uint) + + DefaultMaxHeaderBytes = 0 + + // after a restart the parent will finish ongoing requests before + // shutting down. set to a negative value to disable + DefaultTimeout = 60 * time.Second +} + +type graceServer struct { + http.Server + GraceListener net.Listener + SignalHooks map[int]map[os.Signal][]func() + tlsInnerListener *graceListener + wg sync.WaitGroup + sigChan chan os.Signal + isChild bool + state uint8 + Network string +} + +// NewServer returns an intialized graceServer. Calling Serve on it will +// actually "start" the server. +func NewServer(addr string, handler http.Handler) (srv *graceServer) { + regLock.Lock() + defer regLock.Unlock() + if !flag.Parsed() { + flag.Parse() + } + if len(socketOrder) > 0 { + for i, addr := range strings.Split(socketOrder, ",") { + socketPtrOffsetMap[addr] = uint(i) + } + } else { + socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) + } + + srv = &graceServer{ + wg: sync.WaitGroup{}, + sigChan: make(chan os.Signal), + isChild: isChild, + SignalHooks: map[int]map[os.Signal][]func(){ + PRE_SIGNAL: map[os.Signal][]func(){ + syscall.SIGHUP: []func(){}, + syscall.SIGUSR1: []func(){}, + syscall.SIGUSR2: []func(){}, + syscall.SIGINT: []func(){}, + syscall.SIGTERM: []func(){}, + syscall.SIGTSTP: []func(){}, + }, + POST_SIGNAL: map[os.Signal][]func(){ + syscall.SIGHUP: []func(){}, + syscall.SIGUSR1: []func(){}, + syscall.SIGUSR2: []func(){}, + syscall.SIGINT: []func(){}, + syscall.SIGTERM: []func(){}, + syscall.SIGTSTP: []func(){}, + }, + }, + state: STATE_INIT, + Network: "tcp", + } + + srv.Server.Addr = addr + srv.Server.ReadTimeout = DefaultReadTimeOut + srv.Server.WriteTimeout = DefaultWriteTimeOut + srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes + srv.Server.Handler = handler + + runningServersOrder = append(runningServersOrder, addr) + runningServers[addr] = srv + + return +} + +// ListenAndServe listens on the TCP network address addr +// and then calls Serve to handle requests on incoming connections. +func ListenAndServe(addr string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServe() +} + +// ListenAndServeTLS listens on the TCP network address addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must be provided. +// If the certificate is signed by a certificate authority, +// the certFile should be the concatenation of the server's certificate followed by the CA's certificate. +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} + +// Serve accepts incoming connections on the Listener l, +// creating a new service goroutine for each. +// The service goroutines read requests and then call srv.Handler to reply to them. +func (srv *graceServer) Serve() (err error) { + srv.state = STATE_RUNNING + err = srv.Server.Serve(srv.GraceListener) + log.Println(syscall.Getpid(), "Waiting for connections to finish...") + srv.wg.Wait() + srv.state = STATE_TERMINATE + return +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +// to handle requests on incoming connections. If srv.Addr is blank, ":http" is +// used. +func (srv *graceServer) ListenAndServe() (err error) { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + + go srv.handleSignals() + + l, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return + } + + srv.GraceListener = newGraceListener(l, srv) + + if srv.isChild { + syscall.Kill(syscall.Getppid(), syscall.SIGTERM) + } + + log.Println(syscall.Getpid(), srv.Addr) + return srv.Serve() +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must +// be provided. If the certificate is signed by a certificate authority, the +// certFile should be the concatenation of the server's certificate followed by the +// CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + config := &tls.Config{} + if srv.TLSConfig != nil { + *config = *srv.TLSConfig + } + if config.NextProtos == nil { + config.NextProtos = []string{"http/1.1"} + } + + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + + go srv.handleSignals() + + l, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return + } + + srv.tlsInnerListener = newGraceListener(l, srv) + srv.GraceListener = tls.NewListener(srv.tlsInnerListener, config) + + if srv.isChild { + syscall.Kill(syscall.Getppid(), syscall.SIGTERM) + } + + log.Println(syscall.Getpid(), srv.Addr) + return srv.Serve() +} + +// getListener either opens a new socket to listen on, or takes the acceptor socket +// it got passed when restarted. +func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) { + if srv.isChild { + var ptrOffset uint = 0 + if len(socketPtrOffsetMap) > 0 { + ptrOffset = socketPtrOffsetMap[laddr] + log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) + } + + f := os.NewFile(uintptr(3+ptrOffset), "") + l, err = net.FileListener(f) + if err != nil { + err = fmt.Errorf("net.FileListener error: %v", err) + return + } + } else { + l, err = net.Listen(srv.Network, laddr) + if err != nil { + err = fmt.Errorf("net.Listen error: %v", err) + return + } + } + return +} + +// handleSignals listens for os Signals and calls any hooked in function that the +// user had registered with the signal. +func (srv *graceServer) handleSignals() { + var sig os.Signal + + signal.Notify( + srv.sigChan, + syscall.SIGHUP, + syscall.SIGUSR1, + syscall.SIGUSR2, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGTSTP, + ) + + pid := syscall.Getpid() + for { + sig = <-srv.sigChan + srv.signalHooks(PRE_SIGNAL, sig) + switch sig { + case syscall.SIGHUP: + log.Println(pid, "Received SIGHUP. forking.") + err := srv.fork() + if err != nil { + log.Println("Fork err:", err) + } + case syscall.SIGUSR1: + log.Println(pid, "Received SIGUSR1.") + case syscall.SIGUSR2: + log.Println(pid, "Received SIGUSR2.") + srv.serverTimeout(0 * time.Second) + case syscall.SIGINT: + log.Println(pid, "Received SIGINT.") + srv.shutdown() + case syscall.SIGTERM: + log.Println(pid, "Received SIGTERM.") + srv.shutdown() + case syscall.SIGTSTP: + log.Println(pid, "Received SIGTSTP.") + default: + log.Printf("Received %v: nothing i care about...\n", sig) + } + srv.signalHooks(POST_SIGNAL, sig) + } +} + +func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) { + if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { + return + } + for _, f := range srv.SignalHooks[ppFlag][sig] { + f() + } + return +} + +// shutdown closes the listener so that no new connections are accepted. it also +// starts a goroutine that will hammer (stop all running requests) the server +// after DefaultTimeout. +func (srv *graceServer) shutdown() { + if srv.state != STATE_RUNNING { + return + } + + srv.state = STATE_SHUTTING_DOWN + if DefaultTimeout >= 0 { + go srv.serverTimeout(DefaultTimeout) + } + err := srv.GraceListener.Close() + if err != nil { + log.Println(syscall.Getpid(), "Listener.Close() error:", err) + } else { + log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.") + } +} + +// hammerTime forces the server to shutdown in a given timeout - whether it +// finished outstanding requests or not. if Read/WriteTimeout are not set or the +// max header size is very big a connection could hang... +// +// srv.Serve() will not return until all connections are served. this will +// unblock the srv.wg.Wait() in Serve() thus causing ListenAndServe(TLS) to +// return. +func (srv *graceServer) serverTimeout(d time.Duration) { + defer func() { + // we are calling srv.wg.Done() until it panics which means we called + // Done() when the counter was already at 0 and we're done. + // (and thus Serve() will return and the parent will exit) + if r := recover(); r != nil { + log.Println("WaitGroup at 0", r) + } + }() + if srv.state != STATE_SHUTTING_DOWN { + return + } + time.Sleep(d) + log.Println("[STOP - Hammer Time] Forcefully shutting down parent") + for { + if srv.state == STATE_TERMINATE { + break + } + srv.wg.Done() + } +} + +func (srv *graceServer) fork() (err error) { + // only one server isntance should fork! + regLock.Lock() + defer regLock.Unlock() + if runningServersForked { + return + } + runningServersForked = true + + var files = make([]*os.File, len(runningServers)) + var orderArgs = make([]string, len(runningServers)) + // get the accessor socket fds for _all_ server instances + for _, srvPtr := range runningServers { + // introspect.PrintTypeDump(srvPtr.EndlessListener) + switch srvPtr.GraceListener.(type) { + case *graceListener: + // normal listener + files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File() + default: + // tls listener + files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File() + } + orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr + } + + log.Println(files) + path := os.Args[0] + var args []string + if len(os.Args) > 1 { + for _, arg := range os.Args[1:] { + if arg == "-graceful" { + break + } + args = append(args, arg) + } + } + args = append(args, "-graceful") + if len(runningServers) > 1 { + args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) + log.Println(args) + } + cmd := exec.Command(path, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.ExtraFiles = files + err = cmd.Start() + if err != nil { + log.Fatalf("Restart: Failed to launch, error: %v", err) + } + + return +} + +type graceListener struct { + net.Listener + stop chan error + stopped bool + server *graceServer +} + +func (gl *graceListener) Accept() (c net.Conn, err error) { + tc, err := gl.Listener.(*net.TCPListener).AcceptTCP() + if err != nil { + return + } + + tc.SetKeepAlive(true) // see http.tcpKeepAliveListener + tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener + + c = graceConn{ + Conn: tc, + server: gl.server, + } + + gl.server.wg.Add(1) + return +} + +func newGraceListener(l net.Listener, srv *graceServer) (el *graceListener) { + el = &graceListener{ + Listener: l, + stop: make(chan error), + server: srv, + } + + // Starting the listener for the stop signal here because Accept blocks on + // el.Listener.(*net.TCPListener).AcceptTCP() + // The goroutine will unblock it by closing the listeners fd + go func() { + _ = <-el.stop + el.stopped = true + el.stop <- el.Listener.Close() + }() + return +} + +func (el *graceListener) Close() error { + if el.stopped { + return syscall.EINVAL + } + el.stop <- nil + return <-el.stop +} + +func (el *graceListener) File() *os.File { + // returns a dup(2) - FD_CLOEXEC flag *not* set + tl := el.Listener.(*net.TCPListener) + fl, _ := tl.File() + return fl +} + +type graceConn struct { + net.Conn + server *graceServer +} + +func (c graceConn) Close() error { + c.server.wg.Done() + return c.Conn.Close() +}