package grace import ( "crypto/tls" "crypto/x509" "fmt" "io/ioutil" "log" "net" "net/http" "os" "os/exec" "os/signal" "strings" "sync" "syscall" "time" ) // Server embedded http.Server type Server struct { *http.Server GraceListener net.Listener SignalHooks map[int]map[os.Signal][]func() tlsInnerListener *graceListener wg sync.WaitGroup sigChan chan os.Signal isChild bool state uint8 Network string } // Serve accepts incoming connections on the Listener l, // creating a new service goroutine for each. // The service goroutines read requests and then call srv.Handler to reply to them. func (srv *Server) Serve() (err error) { srv.state = StateRunning err = srv.Server.Serve(srv.GraceListener) log.Println(syscall.Getpid(), "Waiting for connections to finish...") srv.wg.Wait() srv.state = StateTerminate return } // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve // to handle requests on incoming connections. If srv.Addr is blank, ":http" is // used. func (srv *Server) ListenAndServe() (err error) { addr := srv.Addr if addr == "" { addr = ":http" } go srv.handleSignals() l, err := srv.getListener(addr) if err != nil { log.Println(err) return err } srv.GraceListener = newGraceListener(l, srv) if srv.isChild { process, err := os.FindProcess(os.Getppid()) if err != nil { log.Println(err) return err } err = process.Signal(syscall.SIGTERM) if err != nil { return err } } log.Println(os.Getpid(), srv.Addr) return srv.Serve() } // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls // Serve to handle requests on incoming TLS connections. // // Filenames containing a certificate and matching private key for the server must // be provided. If the certificate is signed by a certificate authority, the // certFile should be the concatenation of the server's certificate followed by the // CA's certificate. // // If srv.Addr is blank, ":https" is used. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { addr := srv.Addr if addr == "" { addr = ":https" } if srv.TLSConfig == nil { srv.TLSConfig = &tls.Config{} } if srv.TLSConfig.NextProtos == nil { srv.TLSConfig.NextProtos = []string{"http/1.1"} } srv.TLSConfig.Certificates = make([]tls.Certificate, 1) srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return } go srv.handleSignals() l, err := srv.getListener(addr) if err != nil { log.Println(err) return err } srv.tlsInnerListener = newGraceListener(l, srv) srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig) if srv.isChild { process, err := os.FindProcess(os.Getppid()) if err != nil { log.Println(err) return err } err = process.Signal(syscall.SIGTERM) if err != nil { return err } } log.Println(os.Getpid(), srv.Addr) return srv.Serve() } //ListenAndServeMutualTLS func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { addr := srv.Addr if addr == "" { addr = ":https" } if srv.TLSConfig == nil { srv.TLSConfig = &tls.Config{} } if srv.TLSConfig.NextProtos == nil { srv.TLSConfig.NextProtos = []string{"http/1.1"} } srv.TLSConfig.Certificates = make([]tls.Certificate, 1) srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return } srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert pool := x509.NewCertPool() data, err := ioutil.ReadFile(trustFile) if err != nil { log.Println(err) return err } pool.AppendCertsFromPEM(data) srv.TLSConfig.ClientCAs = pool log.Println("Mutual HTTPS") go srv.handleSignals() l, err := srv.getListener(addr) if err != nil { log.Println(err) return err } srv.tlsInnerListener = newGraceListener(l, srv) srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig) if srv.isChild { process, err := os.FindProcess(os.Getppid()) if err != nil { log.Println(err) return err } err = process.Kill() if err != nil { return err } } log.Println(os.Getpid(), srv.Addr) return srv.Serve() } // getListener either opens a new socket to listen on, or takes the acceptor socket // it got passed when restarted. func (srv *Server) getListener(laddr string) (l net.Listener, err error) { if srv.isChild { var ptrOffset uint if len(socketPtrOffsetMap) > 0 { ptrOffset = socketPtrOffsetMap[laddr] log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) } f := os.NewFile(uintptr(3+ptrOffset), "") l, err = net.FileListener(f) if err != nil { err = fmt.Errorf("net.FileListener error: %v", err) return } } else { l, err = net.Listen(srv.Network, laddr) if err != nil { err = fmt.Errorf("net.Listen error: %v", err) return } } return } // handleSignals listens for os Signals and calls any hooked in function that the // user had registered with the signal. func (srv *Server) handleSignals() { var sig os.Signal signal.Notify( srv.sigChan, hookableSignals..., ) pid := syscall.Getpid() for { sig = <-srv.sigChan srv.signalHooks(PreSignal, sig) switch sig { case syscall.SIGHUP: log.Println(pid, "Received SIGHUP. forking.") err := srv.fork() if err != nil { log.Println("Fork err:", err) } case syscall.SIGINT: log.Println(pid, "Received SIGINT.") srv.shutdown() case syscall.SIGTERM: log.Println(pid, "Received SIGTERM.") srv.shutdown() default: log.Printf("Received %v: nothing i care about...\n", sig) } srv.signalHooks(PostSignal, sig) } } func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { return } for _, f := range srv.SignalHooks[ppFlag][sig] { f() } } // shutdown closes the listener so that no new connections are accepted. it also // starts a goroutine that will serverTimeout (stop all running requests) the server // after DefaultTimeout. func (srv *Server) shutdown() { if srv.state != StateRunning { return } srv.state = StateShuttingDown if DefaultTimeout >= 0 { go srv.serverTimeout(DefaultTimeout) } err := srv.GraceListener.Close() if err != nil { log.Println(syscall.Getpid(), "Listener.Close() error:", err) } else { log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.") } } // serverTimeout forces the server to shutdown in a given timeout - whether it // finished outstanding requests or not. if Read/WriteTimeout are not set or the // max header size is very big a connection could hang func (srv *Server) serverTimeout(d time.Duration) { defer func() { if r := recover(); r != nil { log.Println("WaitGroup at 0", r) } }() if srv.state != StateShuttingDown { return } time.Sleep(d) log.Println("[STOP - Hammer Time] Forcefully shutting down parent") for { if srv.state == StateTerminate { break } srv.wg.Done() } } func (srv *Server) fork() (err error) { regLock.Lock() defer regLock.Unlock() if runningServersForked { return } runningServersForked = true var files = make([]*os.File, len(runningServers)) var orderArgs = make([]string, len(runningServers)) for _, srvPtr := range runningServers { switch srvPtr.GraceListener.(type) { case *graceListener: files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File() default: files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File() } orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr } log.Println(files) path := os.Args[0] var args []string if len(os.Args) > 1 { for _, arg := range os.Args[1:] { if arg == "-graceful" { break } args = append(args, arg) } } args = append(args, "-graceful") if len(runningServers) > 1 { args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) log.Println(args) } cmd := exec.Command(path, args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.ExtraFiles = files err = cmd.Start() if err != nil { log.Fatalf("Restart: Failed to launch, error: %v", err) } return } // RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { if ppFlag != PreSignal && ppFlag != PostSignal { err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") return } for _, s := range hookableSignals { if s == sig { srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) return } } err = fmt.Errorf("Signal '%v' is not supported", sig) return }