2020-07-22 14:50:08 +00:00
|
|
|
package grace
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"crypto/tls"
|
|
|
|
"crypto/x509"
|
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
|
|
|
"log"
|
|
|
|
"net"
|
|
|
|
"net/http"
|
|
|
|
"os"
|
|
|
|
"os/exec"
|
|
|
|
"os/signal"
|
|
|
|
"strings"
|
|
|
|
"syscall"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
// Server embedded http.Server
|
|
|
|
type Server struct {
|
|
|
|
*http.Server
|
|
|
|
ln net.Listener
|
|
|
|
SignalHooks map[int]map[os.Signal][]func()
|
|
|
|
sigChan chan os.Signal
|
|
|
|
isChild bool
|
|
|
|
state uint8
|
|
|
|
Network string
|
|
|
|
terminalChan chan error
|
|
|
|
}
|
|
|
|
|
2020-08-06 15:07:18 +00:00
|
|
|
// Serve accepts incoming connections on the Listener l
|
|
|
|
// and creates a new service goroutine for each.
|
2020-07-22 14:50:08 +00:00
|
|
|
// The service goroutines read requests and then call srv.Handler to reply to them.
|
|
|
|
func (srv *Server) Serve() (err error) {
|
|
|
|
srv.state = StateRunning
|
|
|
|
defer func() { srv.state = StateTerminate }()
|
|
|
|
|
|
|
|
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
|
|
|
|
// immediately return ErrServerClosed. Make sure the program doesn't exit
|
|
|
|
// and waits instead for Shutdown to return.
|
|
|
|
if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed {
|
|
|
|
log.Println(syscall.Getpid(), "Server.Serve() error:", err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
|
|
|
|
// wait for Shutdown to return
|
|
|
|
if shutdownErr := <-srv.terminalChan; shutdownErr != nil {
|
|
|
|
return shutdownErr
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
|
|
|
|
// to handle requests on incoming connections. If srv.Addr is blank, ":http" is
|
|
|
|
// used.
|
|
|
|
func (srv *Server) ListenAndServe() (err error) {
|
|
|
|
addr := srv.Addr
|
|
|
|
if addr == "" {
|
|
|
|
addr = ":http"
|
|
|
|
}
|
|
|
|
|
|
|
|
go srv.handleSignals()
|
|
|
|
|
|
|
|
srv.ln, err = srv.getListener(addr)
|
|
|
|
if err != nil {
|
|
|
|
log.Println(err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if srv.isChild {
|
|
|
|
process, err := os.FindProcess(os.Getppid())
|
|
|
|
if err != nil {
|
|
|
|
log.Println(err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
err = process.Signal(syscall.SIGTERM)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println(os.Getpid(), srv.Addr)
|
|
|
|
return srv.Serve()
|
|
|
|
}
|
|
|
|
|
|
|
|
// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
|
|
|
|
// Serve to handle requests on incoming TLS connections.
|
|
|
|
//
|
|
|
|
// Filenames containing a certificate and matching private key for the server must
|
|
|
|
// be provided. If the certificate is signed by a certificate authority, the
|
|
|
|
// certFile should be the concatenation of the server's certificate followed by the
|
|
|
|
// CA's certificate.
|
|
|
|
//
|
|
|
|
// If srv.Addr is blank, ":https" is used.
|
|
|
|
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
|
|
|
|
addr := srv.Addr
|
|
|
|
if addr == "" {
|
|
|
|
addr = ":https"
|
|
|
|
}
|
|
|
|
|
|
|
|
if srv.TLSConfig == nil {
|
|
|
|
srv.TLSConfig = &tls.Config{}
|
|
|
|
}
|
|
|
|
if srv.TLSConfig.NextProtos == nil {
|
|
|
|
srv.TLSConfig.NextProtos = []string{"http/1.1"}
|
|
|
|
}
|
|
|
|
|
|
|
|
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
|
|
|
|
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
go srv.handleSignals()
|
|
|
|
|
|
|
|
ln, err := srv.getListener(addr)
|
|
|
|
if err != nil {
|
|
|
|
log.Println(err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
|
|
|
|
|
|
|
|
if srv.isChild {
|
|
|
|
process, err := os.FindProcess(os.Getppid())
|
|
|
|
if err != nil {
|
|
|
|
log.Println(err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
err = process.Signal(syscall.SIGTERM)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println(os.Getpid(), srv.Addr)
|
|
|
|
return srv.Serve()
|
|
|
|
}
|
|
|
|
|
|
|
|
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
|
|
|
|
// Serve to handle requests on incoming mutual TLS connections.
|
|
|
|
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
|
|
|
|
addr := srv.Addr
|
|
|
|
if addr == "" {
|
|
|
|
addr = ":https"
|
|
|
|
}
|
|
|
|
|
|
|
|
if srv.TLSConfig == nil {
|
|
|
|
srv.TLSConfig = &tls.Config{}
|
|
|
|
}
|
|
|
|
if srv.TLSConfig.NextProtos == nil {
|
|
|
|
srv.TLSConfig.NextProtos = []string{"http/1.1"}
|
|
|
|
}
|
|
|
|
|
|
|
|
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
|
|
|
|
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
|
|
|
pool := x509.NewCertPool()
|
|
|
|
data, err := ioutil.ReadFile(trustFile)
|
|
|
|
if err != nil {
|
|
|
|
log.Println(err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
pool.AppendCertsFromPEM(data)
|
|
|
|
srv.TLSConfig.ClientCAs = pool
|
|
|
|
log.Println("Mutual HTTPS")
|
|
|
|
go srv.handleSignals()
|
|
|
|
|
|
|
|
ln, err := srv.getListener(addr)
|
|
|
|
if err != nil {
|
|
|
|
log.Println(err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
|
|
|
|
|
|
|
|
if srv.isChild {
|
|
|
|
process, err := os.FindProcess(os.Getppid())
|
|
|
|
if err != nil {
|
|
|
|
log.Println(err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
err = process.Signal(syscall.SIGTERM)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println(os.Getpid(), srv.Addr)
|
|
|
|
return srv.Serve()
|
|
|
|
}
|
|
|
|
|
|
|
|
// getListener either opens a new socket to listen on, or takes the acceptor socket
|
|
|
|
// it got passed when restarted.
|
|
|
|
func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
|
|
|
|
if srv.isChild {
|
|
|
|
var ptrOffset uint
|
|
|
|
if len(socketPtrOffsetMap) > 0 {
|
|
|
|
ptrOffset = socketPtrOffsetMap[laddr]
|
|
|
|
log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
|
|
|
|
}
|
|
|
|
|
|
|
|
f := os.NewFile(uintptr(3+ptrOffset), "")
|
|
|
|
l, err = net.FileListener(f)
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("net.FileListener error: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
l, err = net.Listen(srv.Network, laddr)
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("net.Listen error: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
type tcpKeepAliveListener struct {
|
|
|
|
*net.TCPListener
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
|
|
|
tc, err := ln.AcceptTCP()
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
tc.SetKeepAlive(true)
|
|
|
|
tc.SetKeepAlivePeriod(3 * time.Minute)
|
|
|
|
return tc, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// handleSignals listens for os Signals and calls any hooked in function that the
|
|
|
|
// user had registered with the signal.
|
|
|
|
func (srv *Server) handleSignals() {
|
|
|
|
var sig os.Signal
|
|
|
|
|
|
|
|
signal.Notify(
|
|
|
|
srv.sigChan,
|
|
|
|
hookableSignals...,
|
|
|
|
)
|
|
|
|
|
|
|
|
pid := syscall.Getpid()
|
|
|
|
for {
|
|
|
|
sig = <-srv.sigChan
|
|
|
|
srv.signalHooks(PreSignal, sig)
|
|
|
|
switch sig {
|
|
|
|
case syscall.SIGHUP:
|
|
|
|
log.Println(pid, "Received SIGHUP. forking.")
|
|
|
|
err := srv.fork()
|
|
|
|
if err != nil {
|
|
|
|
log.Println("Fork err:", err)
|
|
|
|
}
|
|
|
|
case syscall.SIGINT:
|
|
|
|
log.Println(pid, "Received SIGINT.")
|
|
|
|
srv.shutdown()
|
|
|
|
case syscall.SIGTERM:
|
|
|
|
log.Println(pid, "Received SIGTERM.")
|
|
|
|
srv.shutdown()
|
|
|
|
default:
|
|
|
|
log.Printf("Received %v: nothing i care about...\n", sig)
|
|
|
|
}
|
|
|
|
srv.signalHooks(PostSignal, sig)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
|
|
|
|
if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
for _, f := range srv.SignalHooks[ppFlag][sig] {
|
|
|
|
f()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// shutdown closes the listener so that no new connections are accepted. it also
|
|
|
|
// starts a goroutine that will serverTimeout (stop all running requests) the server
|
|
|
|
// after DefaultTimeout.
|
|
|
|
func (srv *Server) shutdown() {
|
|
|
|
if srv.state != StateRunning {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
srv.state = StateShuttingDown
|
|
|
|
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
|
|
|
|
ctx := context.Background()
|
|
|
|
if DefaultTimeout >= 0 {
|
|
|
|
var cancel context.CancelFunc
|
|
|
|
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
|
|
|
|
defer cancel()
|
|
|
|
}
|
|
|
|
srv.terminalChan <- srv.Server.Shutdown(ctx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (srv *Server) fork() (err error) {
|
|
|
|
regLock.Lock()
|
|
|
|
defer regLock.Unlock()
|
|
|
|
if runningServersForked {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
runningServersForked = true
|
|
|
|
|
|
|
|
var files = make([]*os.File, len(runningServers))
|
|
|
|
var orderArgs = make([]string, len(runningServers))
|
|
|
|
for _, srvPtr := range runningServers {
|
|
|
|
f, _ := srvPtr.ln.(*net.TCPListener).File()
|
|
|
|
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f
|
|
|
|
orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println(files)
|
|
|
|
path := os.Args[0]
|
|
|
|
var args []string
|
|
|
|
if len(os.Args) > 1 {
|
|
|
|
for _, arg := range os.Args[1:] {
|
|
|
|
if arg == "-graceful" {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
args = append(args, arg)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
args = append(args, "-graceful")
|
|
|
|
if len(runningServers) > 1 {
|
|
|
|
args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
|
|
|
|
log.Println(args)
|
|
|
|
}
|
|
|
|
cmd := exec.Command(path, args...)
|
|
|
|
cmd.Stdout = os.Stdout
|
|
|
|
cmd.Stderr = os.Stderr
|
|
|
|
cmd.ExtraFiles = files
|
|
|
|
err = cmd.Start()
|
|
|
|
if err != nil {
|
|
|
|
log.Fatalf("Restart: Failed to launch, error: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
|
|
|
|
func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
|
|
|
|
if ppFlag != PreSignal && ppFlag != PostSignal {
|
|
|
|
err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
for _, s := range hookableSignals {
|
|
|
|
if s == sig {
|
|
|
|
srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
err = fmt.Errorf("Signal '%v' is not supported", sig)
|
|
|
|
return
|
|
|
|
}
|