// Copyright 2014 beego Author. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ // // Usage: // // import( // "log" // "net/http" // "os" // // "github.com/astaxie/beego/grace" // ) // // func handler(w http.ResponseWriter, r *http.Request) { // w.Write([]byte("WORLD!")) // } // // func main() { // mux := http.NewServeMux() // mux.HandleFunc("/hello", handler) // // err := grace.ListenAndServe("localhost:8080", mux1) // if err != nil { // log.Println(err) // } // log.Println("Server on 8080 stopped") // os.Exit(0) // } package grace import ( "crypto/tls" "flag" "fmt" "log" "net" "net/http" "os" "os/exec" "os/signal" "strings" "sync" "syscall" "time" ) const ( PRE_SIGNAL = iota POST_SIGNAL STATE_INIT STATE_RUNNING STATE_SHUTTING_DOWN STATE_TERMINATE ) var ( regLock *sync.Mutex runningServers map[string]*graceServer runningServersOrder []string socketPtrOffsetMap map[string]uint runningServersForked bool DefaultReadTimeOut time.Duration DefaultWriteTimeOut time.Duration DefaultMaxHeaderBytes int DefaultTimeout time.Duration isChild bool socketOrder string ) func init() { regLock = &sync.Mutex{} flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") runningServers = make(map[string]*graceServer) runningServersOrder = []string{} socketPtrOffsetMap = make(map[string]uint) DefaultMaxHeaderBytes = 0 // after a restart the parent will finish ongoing requests before // shutting down. set to a negative value to disable DefaultTimeout = 60 * time.Second } type graceServer struct { http.Server GraceListener net.Listener SignalHooks map[int]map[os.Signal][]func() tlsInnerListener *graceListener wg sync.WaitGroup sigChan chan os.Signal isChild bool state uint8 Network string } // NewServer returns an intialized graceServer. Calling Serve on it will // actually "start" the server. func NewServer(addr string, handler http.Handler) (srv *graceServer) { regLock.Lock() defer regLock.Unlock() if !flag.Parsed() { flag.Parse() } if len(socketOrder) > 0 { for i, addr := range strings.Split(socketOrder, ",") { socketPtrOffsetMap[addr] = uint(i) } } else { socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) } srv = &graceServer{ wg: sync.WaitGroup{}, sigChan: make(chan os.Signal), isChild: isChild, SignalHooks: map[int]map[os.Signal][]func(){ PRE_SIGNAL: map[os.Signal][]func(){ syscall.SIGHUP: []func(){}, syscall.SIGUSR1: []func(){}, syscall.SIGUSR2: []func(){}, syscall.SIGINT: []func(){}, syscall.SIGTERM: []func(){}, syscall.SIGTSTP: []func(){}, }, POST_SIGNAL: map[os.Signal][]func(){ syscall.SIGHUP: []func(){}, syscall.SIGUSR1: []func(){}, syscall.SIGUSR2: []func(){}, syscall.SIGINT: []func(){}, syscall.SIGTERM: []func(){}, syscall.SIGTSTP: []func(){}, }, }, state: STATE_INIT, Network: "tcp", } srv.Server.Addr = addr srv.Server.ReadTimeout = DefaultReadTimeOut srv.Server.WriteTimeout = DefaultWriteTimeOut srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes srv.Server.Handler = handler runningServersOrder = append(runningServersOrder, addr) runningServers[addr] = srv return } // ListenAndServe listens on the TCP network address addr // and then calls Serve to handle requests on incoming connections. func ListenAndServe(addr string, handler http.Handler) error { server := NewServer(addr, handler) return server.ListenAndServe() } // ListenAndServeTLS listens on the TCP network address addr and then calls // Serve to handle requests on incoming TLS connections. // // Filenames containing a certificate and matching private key for the server must be provided. // If the certificate is signed by a certificate authority, // the certFile should be the concatenation of the server's certificate followed by the CA's certificate. func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { server := NewServer(addr, handler) return server.ListenAndServeTLS(certFile, keyFile) } // Serve accepts incoming connections on the Listener l, // creating a new service goroutine for each. // The service goroutines read requests and then call srv.Handler to reply to them. func (srv *graceServer) Serve() (err error) { srv.state = STATE_RUNNING err = srv.Server.Serve(srv.GraceListener) log.Println(syscall.Getpid(), "Waiting for connections to finish...") srv.wg.Wait() srv.state = STATE_TERMINATE return } // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve // to handle requests on incoming connections. If srv.Addr is blank, ":http" is // used. func (srv *graceServer) ListenAndServe() (err error) { addr := srv.Addr if addr == "" { addr = ":http" } go srv.handleSignals() l, err := srv.getListener(addr) if err != nil { log.Println(err) return } srv.GraceListener = newGraceListener(l, srv) if srv.isChild { syscall.Kill(syscall.Getppid(), syscall.SIGTERM) } log.Println(syscall.Getpid(), srv.Addr) return srv.Serve() } // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls // Serve to handle requests on incoming TLS connections. // // Filenames containing a certificate and matching private key for the server must // be provided. If the certificate is signed by a certificate authority, the // certFile should be the concatenation of the server's certificate followed by the // CA's certificate. // // If srv.Addr is blank, ":https" is used. func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) { addr := srv.Addr if addr == "" { addr = ":https" } config := &tls.Config{} if srv.TLSConfig != nil { *config = *srv.TLSConfig } if config.NextProtos == nil { config.NextProtos = []string{"http/1.1"} } config.Certificates = make([]tls.Certificate, 1) config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return } go srv.handleSignals() l, err := srv.getListener(addr) if err != nil { log.Println(err) return } srv.tlsInnerListener = newGraceListener(l, srv) srv.GraceListener = tls.NewListener(srv.tlsInnerListener, config) if srv.isChild { syscall.Kill(syscall.Getppid(), syscall.SIGTERM) } log.Println(syscall.Getpid(), srv.Addr) return srv.Serve() } // getListener either opens a new socket to listen on, or takes the acceptor socket // it got passed when restarted. func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) { if srv.isChild { var ptrOffset uint = 0 if len(socketPtrOffsetMap) > 0 { ptrOffset = socketPtrOffsetMap[laddr] log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) } f := os.NewFile(uintptr(3+ptrOffset), "") l, err = net.FileListener(f) if err != nil { err = fmt.Errorf("net.FileListener error: %v", err) return } } else { l, err = net.Listen(srv.Network, laddr) if err != nil { err = fmt.Errorf("net.Listen error: %v", err) return } } return } // handleSignals listens for os Signals and calls any hooked in function that the // user had registered with the signal. func (srv *graceServer) handleSignals() { var sig os.Signal signal.Notify( srv.sigChan, syscall.SIGHUP, syscall.SIGUSR1, syscall.SIGUSR2, syscall.SIGINT, syscall.SIGTERM, syscall.SIGTSTP, ) pid := syscall.Getpid() for { sig = <-srv.sigChan srv.signalHooks(PRE_SIGNAL, sig) switch sig { case syscall.SIGHUP: log.Println(pid, "Received SIGHUP. forking.") err := srv.fork() if err != nil { log.Println("Fork err:", err) } case syscall.SIGUSR1: log.Println(pid, "Received SIGUSR1.") case syscall.SIGUSR2: log.Println(pid, "Received SIGUSR2.") srv.serverTimeout(0 * time.Second) case syscall.SIGINT: log.Println(pid, "Received SIGINT.") srv.shutdown() case syscall.SIGTERM: log.Println(pid, "Received SIGTERM.") srv.shutdown() case syscall.SIGTSTP: log.Println(pid, "Received SIGTSTP.") default: log.Printf("Received %v: nothing i care about...\n", sig) } srv.signalHooks(POST_SIGNAL, sig) } } func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) { if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { return } for _, f := range srv.SignalHooks[ppFlag][sig] { f() } return } // shutdown closes the listener so that no new connections are accepted. it also // starts a goroutine that will hammer (stop all running requests) the server // after DefaultTimeout. func (srv *graceServer) shutdown() { if srv.state != STATE_RUNNING { return } srv.state = STATE_SHUTTING_DOWN if DefaultTimeout >= 0 { go srv.serverTimeout(DefaultTimeout) } err := srv.GraceListener.Close() if err != nil { log.Println(syscall.Getpid(), "Listener.Close() error:", err) } else { log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.") } } // hammerTime forces the server to shutdown in a given timeout - whether it // finished outstanding requests or not. if Read/WriteTimeout are not set or the // max header size is very big a connection could hang... // // srv.Serve() will not return until all connections are served. this will // unblock the srv.wg.Wait() in Serve() thus causing ListenAndServe(TLS) to // return. func (srv *graceServer) serverTimeout(d time.Duration) { defer func() { // we are calling srv.wg.Done() until it panics which means we called // Done() when the counter was already at 0 and we're done. // (and thus Serve() will return and the parent will exit) if r := recover(); r != nil { log.Println("WaitGroup at 0", r) } }() if srv.state != STATE_SHUTTING_DOWN { return } time.Sleep(d) log.Println("[STOP - Hammer Time] Forcefully shutting down parent") for { if srv.state == STATE_TERMINATE { break } srv.wg.Done() } } func (srv *graceServer) fork() (err error) { // only one server isntance should fork! regLock.Lock() defer regLock.Unlock() if runningServersForked { return } runningServersForked = true var files = make([]*os.File, len(runningServers)) var orderArgs = make([]string, len(runningServers)) // get the accessor socket fds for _all_ server instances for _, srvPtr := range runningServers { // introspect.PrintTypeDump(srvPtr.EndlessListener) switch srvPtr.GraceListener.(type) { case *graceListener: // normal listener files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File() default: // tls listener files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File() } orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr } log.Println(files) path := os.Args[0] var args []string if len(os.Args) > 1 { for _, arg := range os.Args[1:] { if arg == "-graceful" { break } args = append(args, arg) } } args = append(args, "-graceful") if len(runningServers) > 1 { args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) log.Println(args) } cmd := exec.Command(path, args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.ExtraFiles = files err = cmd.Start() if err != nil { log.Fatalf("Restart: Failed to launch, error: %v", err) } return } type graceListener struct { net.Listener stop chan error stopped bool server *graceServer } func (gl *graceListener) Accept() (c net.Conn, err error) { tc, err := gl.Listener.(*net.TCPListener).AcceptTCP() if err != nil { return } tc.SetKeepAlive(true) // see http.tcpKeepAliveListener tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener c = graceConn{ Conn: tc, server: gl.server, } gl.server.wg.Add(1) return } func newGraceListener(l net.Listener, srv *graceServer) (el *graceListener) { el = &graceListener{ Listener: l, stop: make(chan error), server: srv, } // Starting the listener for the stop signal here because Accept blocks on // el.Listener.(*net.TCPListener).AcceptTCP() // The goroutine will unblock it by closing the listeners fd go func() { _ = <-el.stop el.stopped = true el.stop <- el.Listener.Close() }() return } func (el *graceListener) Close() error { if el.stopped { return syscall.EINVAL } el.stop <- nil return <-el.stop } func (el *graceListener) File() *os.File { // returns a dup(2) - FD_CLOEXEC flag *not* set tl := el.Listener.(*net.TCPListener) fl, _ := tl.File() return fl } type graceConn struct { net.Conn server *graceServer } func (c graceConn) Close() error { c.server.wg.Done() return c.Conn.Close() }