1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-05 11:20:55 +00:00
Beego/grace/grace.go

523 lines
13 KiB
Go
Raw Normal View History

// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/
//
// Usage:
//
// import(
// "log"
// "net/http"
// "os"
//
// "github.com/astaxie/beego/grace"
// )
//
// func handler(w http.ResponseWriter, r *http.Request) {
// w.Write([]byte("WORLD!"))
// }
//
// func main() {
// mux := http.NewServeMux()
// mux.HandleFunc("/hello", handler)
//
// err := grace.ListenAndServe("localhost:8080", mux1)
// if err != nil {
// log.Println(err)
// }
// log.Println("Server on 8080 stopped")
// os.Exit(0)
// }
package grace
import (
"crypto/tls"
"flag"
"fmt"
"log"
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"strings"
"sync"
"syscall"
"time"
)
const (
PRE_SIGNAL = iota
POST_SIGNAL
STATE_INIT
STATE_RUNNING
STATE_SHUTTING_DOWN
STATE_TERMINATE
)
var (
regLock *sync.Mutex
runningServers map[string]*graceServer
runningServersOrder []string
socketPtrOffsetMap map[string]uint
runningServersForked bool
DefaultReadTimeOut time.Duration
DefaultWriteTimeOut time.Duration
DefaultMaxHeaderBytes int
DefaultTimeout time.Duration
isChild bool
socketOrder string
)
func init() {
regLock = &sync.Mutex{}
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
runningServers = make(map[string]*graceServer)
runningServersOrder = []string{}
socketPtrOffsetMap = make(map[string]uint)
DefaultMaxHeaderBytes = 0
// after a restart the parent will finish ongoing requests before
// shutting down. set to a negative value to disable
DefaultTimeout = 60 * time.Second
}
type graceServer struct {
http.Server
GraceListener net.Listener
SignalHooks map[int]map[os.Signal][]func()
tlsInnerListener *graceListener
wg sync.WaitGroup
sigChan chan os.Signal
isChild bool
state uint8
Network string
}
// NewServer returns an intialized graceServer. Calling Serve on it will
// actually "start" the server.
func NewServer(addr string, handler http.Handler) (srv *graceServer) {
regLock.Lock()
defer regLock.Unlock()
if !flag.Parsed() {
flag.Parse()
}
if len(socketOrder) > 0 {
for i, addr := range strings.Split(socketOrder, ",") {
socketPtrOffsetMap[addr] = uint(i)
}
} else {
socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
}
srv = &graceServer{
wg: sync.WaitGroup{},
sigChan: make(chan os.Signal),
isChild: isChild,
SignalHooks: map[int]map[os.Signal][]func(){
PRE_SIGNAL: map[os.Signal][]func(){
syscall.SIGHUP: []func(){},
syscall.SIGUSR1: []func(){},
syscall.SIGUSR2: []func(){},
syscall.SIGINT: []func(){},
syscall.SIGTERM: []func(){},
syscall.SIGTSTP: []func(){},
},
POST_SIGNAL: map[os.Signal][]func(){
syscall.SIGHUP: []func(){},
syscall.SIGUSR1: []func(){},
syscall.SIGUSR2: []func(){},
syscall.SIGINT: []func(){},
syscall.SIGTERM: []func(){},
syscall.SIGTSTP: []func(){},
},
},
state: STATE_INIT,
Network: "tcp",
}
srv.Server.Addr = addr
srv.Server.ReadTimeout = DefaultReadTimeOut
srv.Server.WriteTimeout = DefaultWriteTimeOut
srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
srv.Server.Handler = handler
runningServersOrder = append(runningServersOrder, addr)
runningServers[addr] = srv
return
}
// ListenAndServe listens on the TCP network address addr
// and then calls Serve to handle requests on incoming connections.
func ListenAndServe(addr string, handler http.Handler) error {
server := NewServer(addr, handler)
return server.ListenAndServe()
}
// ListenAndServeTLS listens on the TCP network address addr and then calls
// Serve to handle requests on incoming TLS connections.
//
// Filenames containing a certificate and matching private key for the server must be provided.
// If the certificate is signed by a certificate authority,
// the certFile should be the concatenation of the server's certificate followed by the CA's certificate.
func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
server := NewServer(addr, handler)
return server.ListenAndServeTLS(certFile, keyFile)
}
// Serve accepts incoming connections on the Listener l,
// creating a new service goroutine for each.
// The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *graceServer) Serve() (err error) {
srv.state = STATE_RUNNING
err = srv.Server.Serve(srv.GraceListener)
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
srv.wg.Wait()
srv.state = STATE_TERMINATE
return
}
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
// to handle requests on incoming connections. If srv.Addr is blank, ":http" is
// used.
func (srv *graceServer) ListenAndServe() (err error) {
addr := srv.Addr
if addr == "" {
addr = ":http"
}
go srv.handleSignals()
l, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return
}
srv.GraceListener = newGraceListener(l, srv)
if srv.isChild {
syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
}
log.Println(syscall.Getpid(), srv.Addr)
return srv.Serve()
}
// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming TLS connections.
//
// Filenames containing a certificate and matching private key for the server must
// be provided. If the certificate is signed by a certificate authority, the
// certFile should be the concatenation of the server's certificate followed by the
// CA's certificate.
//
// If srv.Addr is blank, ":https" is used.
func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
}
config := &tls.Config{}
if srv.TLSConfig != nil {
*config = *srv.TLSConfig
}
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1"}
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
}
go srv.handleSignals()
l, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return
}
srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, config)
if srv.isChild {
syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
}
log.Println(syscall.Getpid(), srv.Addr)
return srv.Serve()
}
// getListener either opens a new socket to listen on, or takes the acceptor socket
// it got passed when restarted.
func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) {
if srv.isChild {
var ptrOffset uint = 0
if len(socketPtrOffsetMap) > 0 {
ptrOffset = socketPtrOffsetMap[laddr]
log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
}
f := os.NewFile(uintptr(3+ptrOffset), "")
l, err = net.FileListener(f)
if err != nil {
err = fmt.Errorf("net.FileListener error: %v", err)
return
}
} else {
l, err = net.Listen(srv.Network, laddr)
if err != nil {
err = fmt.Errorf("net.Listen error: %v", err)
return
}
}
return
}
// handleSignals listens for os Signals and calls any hooked in function that the
// user had registered with the signal.
func (srv *graceServer) handleSignals() {
var sig os.Signal
signal.Notify(
srv.sigChan,
syscall.SIGHUP,
syscall.SIGUSR1,
syscall.SIGUSR2,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGTSTP,
)
pid := syscall.Getpid()
for {
sig = <-srv.sigChan
srv.signalHooks(PRE_SIGNAL, sig)
switch sig {
case syscall.SIGHUP:
log.Println(pid, "Received SIGHUP. forking.")
err := srv.fork()
if err != nil {
log.Println("Fork err:", err)
}
case syscall.SIGUSR1:
log.Println(pid, "Received SIGUSR1.")
case syscall.SIGUSR2:
log.Println(pid, "Received SIGUSR2.")
srv.serverTimeout(0 * time.Second)
case syscall.SIGINT:
log.Println(pid, "Received SIGINT.")
srv.shutdown()
case syscall.SIGTERM:
log.Println(pid, "Received SIGTERM.")
srv.shutdown()
case syscall.SIGTSTP:
log.Println(pid, "Received SIGTSTP.")
default:
log.Printf("Received %v: nothing i care about...\n", sig)
}
srv.signalHooks(POST_SIGNAL, sig)
}
}
func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) {
if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
return
}
for _, f := range srv.SignalHooks[ppFlag][sig] {
f()
}
return
}
// shutdown closes the listener so that no new connections are accepted. it also
// starts a goroutine that will hammer (stop all running requests) the server
// after DefaultTimeout.
func (srv *graceServer) shutdown() {
if srv.state != STATE_RUNNING {
return
}
srv.state = STATE_SHUTTING_DOWN
if DefaultTimeout >= 0 {
go srv.serverTimeout(DefaultTimeout)
}
err := srv.GraceListener.Close()
if err != nil {
log.Println(syscall.Getpid(), "Listener.Close() error:", err)
} else {
log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
}
}
// hammerTime forces the server to shutdown in a given timeout - whether it
// finished outstanding requests or not. if Read/WriteTimeout are not set or the
// max header size is very big a connection could hang...
//
// srv.Serve() will not return until all connections are served. this will
// unblock the srv.wg.Wait() in Serve() thus causing ListenAndServe(TLS) to
// return.
func (srv *graceServer) serverTimeout(d time.Duration) {
defer func() {
// we are calling srv.wg.Done() until it panics which means we called
// Done() when the counter was already at 0 and we're done.
// (and thus Serve() will return and the parent will exit)
if r := recover(); r != nil {
log.Println("WaitGroup at 0", r)
}
}()
if srv.state != STATE_SHUTTING_DOWN {
return
}
time.Sleep(d)
log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
for {
if srv.state == STATE_TERMINATE {
break
}
srv.wg.Done()
}
}
func (srv *graceServer) fork() (err error) {
// only one server isntance should fork!
regLock.Lock()
defer regLock.Unlock()
if runningServersForked {
return
}
runningServersForked = true
var files = make([]*os.File, len(runningServers))
var orderArgs = make([]string, len(runningServers))
// get the accessor socket fds for _all_ server instances
for _, srvPtr := range runningServers {
// introspect.PrintTypeDump(srvPtr.EndlessListener)
switch srvPtr.GraceListener.(type) {
case *graceListener:
// normal listener
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
default:
// tls listener
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
}
orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
}
log.Println(files)
path := os.Args[0]
var args []string
if len(os.Args) > 1 {
for _, arg := range os.Args[1:] {
if arg == "-graceful" {
break
}
args = append(args, arg)
}
}
args = append(args, "-graceful")
if len(runningServers) > 1 {
args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
log.Println(args)
}
cmd := exec.Command(path, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.ExtraFiles = files
err = cmd.Start()
if err != nil {
log.Fatalf("Restart: Failed to launch, error: %v", err)
}
return
}
type graceListener struct {
net.Listener
stop chan error
stopped bool
server *graceServer
}
func (gl *graceListener) Accept() (c net.Conn, err error) {
tc, err := gl.Listener.(*net.TCPListener).AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true) // see http.tcpKeepAliveListener
tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
c = graceConn{
Conn: tc,
server: gl.server,
}
gl.server.wg.Add(1)
return
}
func newGraceListener(l net.Listener, srv *graceServer) (el *graceListener) {
el = &graceListener{
Listener: l,
stop: make(chan error),
server: srv,
}
// Starting the listener for the stop signal here because Accept blocks on
// el.Listener.(*net.TCPListener).AcceptTCP()
// The goroutine will unblock it by closing the listeners fd
go func() {
_ = <-el.stop
el.stopped = true
el.stop <- el.Listener.Close()
}()
return
}
func (el *graceListener) Close() error {
if el.stopped {
return syscall.EINVAL
}
el.stop <- nil
return <-el.stop
}
func (el *graceListener) File() *os.File {
// returns a dup(2) - FD_CLOEXEC flag *not* set
tl := el.Listener.(*net.TCPListener)
fl, _ := tl.File()
return fl
}
type graceConn struct {
net.Conn
server *graceServer
}
func (c graceConn) Close() error {
c.server.wg.Done()
return c.Conn.Close()
}