diff --git a/app.go b/app.go index 25ea2a04..16daa3ec 100644 --- a/app.go +++ b/app.go @@ -15,7 +15,10 @@ package beego import ( + "crypto/tls" + "crypto/x509" "fmt" + "io/ioutil" "net" "net/http" "net/http/fcgi" @@ -102,7 +105,7 @@ func (app *App) Run() { if BConfig.Listen.Graceful { httpsAddr := BConfig.Listen.HTTPSAddr app.Server.Addr = httpsAddr - if BConfig.Listen.EnableHTTPS { + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { go func() { time.Sleep(20 * time.Microsecond) if BConfig.Listen.HTTPSPort != 0 { @@ -112,10 +115,19 @@ func (app *App) Run() { server := grace.NewServer(httpsAddr, app.Handlers) server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.WriteTimeout = app.Server.WriteTimeout - if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - endRunning <- true + if BConfig.Listen.EnableMutualHTTPS { + + if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + } else { + if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } } }() } @@ -139,7 +151,7 @@ func (app *App) Run() { } // run normal mode - if BConfig.Listen.EnableHTTPS { + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { go func() { time.Sleep(20 * time.Microsecond) if BConfig.Listen.HTTPSPort != 0 { @@ -149,6 +161,19 @@ func (app *App) Run() { return } logs.Info("https server Running on https://%s", app.Server.Addr) + if BConfig.Listen.EnableMutualHTTPS { + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) + if err != nil { + BeeLogger.Info("MutualHTTPS should provide TrustCaFile") + return + } + pool.AppendCertsFromPEM(data) + app.Server.TLSConfig = &tls.Config{ + ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err) time.Sleep(100 * time.Microsecond) diff --git a/config.go b/config.go index e6e99570..d3156a91 100644 --- a/config.go +++ b/config.go @@ -49,22 +49,24 @@ type Config struct { // Listen holds for http and https related config type Listen struct { - Graceful bool // Graceful means use graceful module to start the server - ServerTimeOut int64 - ListenTCP4 bool - EnableHTTP bool - HTTPAddr string - HTTPPort int - EnableHTTPS bool - HTTPSAddr string - HTTPSPort int - HTTPSCertFile string - HTTPSKeyFile string - EnableAdmin bool - AdminAddr string - AdminPort int - EnableFcgi bool - EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O + Graceful bool // Graceful means use graceful module to start the server + ServerTimeOut int64 + ListenTCP4 bool + EnableHTTP bool + HTTPAddr string + HTTPPort int + EnableHTTPS bool + EnableMutualHTTPS bool + HTTPSAddr string + HTTPSPort int + HTTPSCertFile string + HTTPSKeyFile string + TrustCaFile string + EnableAdmin bool + AdminAddr string + AdminPort int + EnableFcgi bool + EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O } // WebConfig holds web related config diff --git a/grace/server.go b/grace/server.go index b8242335..7c7bcecb 100644 --- a/grace/server.go +++ b/grace/server.go @@ -2,7 +2,9 @@ package grace import ( "crypto/tls" + "crypto/x509" "fmt" + "io/ioutil" "log" "net" "net/http" @@ -129,6 +131,61 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { return srv.Serve() } +//ListenAndServeMutualTLS +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.NextProtos == nil { + srv.TLSConfig.NextProtos = []string{"http/1.1"} + } + + srv.TLSConfig.Certificates = make([]tls.Certificate, 1) + srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(trustFile) + if err != nil { + log.Println(err) + return err + } + pool.AppendCertsFromPEM(data) + srv.TLSConfig.ClientCAs = pool + log.Println("Mutual HTTPS") + go srv.handleSignals() + + l, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + + srv.tlsInnerListener = newGraceListener(l, srv) + srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig) + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Kill() + if err != nil { + return err + } + } + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + // getListener either opens a new socket to listen on, or takes the acceptor socket // it got passed when restarted. func (srv *Server) getListener(laddr string) (l net.Listener, err error) {