1
0
mirror of https://github.com/beego/bee.git synced 2025-07-04 16:40:19 +00:00
This commit is contained in:
astaxie
2017-03-17 13:33:15 +08:00
parent 20c6a26952
commit dd3f690bf7
96 changed files with 865 additions and 8639 deletions

View File

@ -13,16 +13,25 @@ import (
"math/rand"
"net"
"strconv"
"sync"
"time"
"unicode/utf8"
)
const (
// Frame header byte 0 bits from Section 5.2 of RFC 6455
finalBit = 1 << 7
rsv1Bit = 1 << 6
rsv2Bit = 1 << 5
rsv3Bit = 1 << 4
// Frame header byte 1 bits from Section 5.2 of RFC 6455
maskBit = 1 << 7
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
maxControlFramePayloadSize = 125
finalBit = 1 << 7
maskBit = 1 << 7
writeWait = time.Second
writeWait = time.Second
defaultReadBufferSize = 4096
defaultWriteBufferSize = 4096
@ -172,6 +181,11 @@ var (
errInvalidControlFrame = errors.New("websocket: invalid control frame")
)
func newMaskKey() [4]byte {
n := rand.Uint32()
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
}
func hideTempErr(err error) error {
if e, ok := err.(net.Error); ok && e.Temporary() {
err = &netError{msg: e.Error(), timeout: e.Timeout()}
@ -210,39 +224,28 @@ func isValidReceivedCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
}
func maskBytes(key [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
func newMaskKey() [4]byte {
n := rand.Uint32()
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
}
// Conn represents a WebSocket connection.
// The Conn type represents a WebSocket connection.
type Conn struct {
conn net.Conn
isServer bool
subprotocol string
// Write fields
mu chan bool // used as mutex to protect write to conn and closeSent
closeSent bool // true if close message was sent
mu chan bool // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer.
writeDeadline time.Time
writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection
// Message writer fields.
writeErr error
writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf.
writeFrameType int // type of the current frame.
writeDeadline time.Time
isWriting bool // for best-effort concurrent write detection
messageWriter *messageWriter // the current writer
writeErrMu sync.Mutex
writeErr error
enableWriteCompression bool
compressionLevel int
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
// Read fields
reader io.ReadCloser // the current reader returned to the application
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
@ -253,34 +256,83 @@ type Conn struct {
readMaskKey [4]byte
handlePong func(string) error
handlePing func(string) error
handleClose func(int, string) error
readErrCount int
messageReader *messageReader // the current reader
messageReader *messageReader // the current low-level reader
readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.ReadCloser
}
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
}
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
mu := make(chan bool, 1)
mu <- true
if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize
var br *bufio.Reader
if readBufferSize == 0 && brw != nil && brw.Reader != nil {
// Reuse the supplied bufio.Reader if the buffer has a useful size.
// This code assumes that peek on a reader returns
// bufio.Reader.buf[:0].
brw.Reader.Reset(conn)
if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
br = brw.Reader
}
}
if readBufferSize < maxControlFramePayloadSize {
readBufferSize = maxControlFramePayloadSize
if br == nil {
if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize
}
if readBufferSize < maxControlFramePayloadSize {
readBufferSize = maxControlFramePayloadSize
}
br = bufio.NewReaderSize(conn, readBufferSize)
}
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize
var writeBuf []byte
if writeBufferSize == 0 && brw != nil && brw.Writer != nil {
// Use the bufio.Writer's buffer if the buffer has a useful size. This
// code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
brw.Writer.Reset(&wh)
brw.Writer.WriteByte(0)
brw.Flush()
if cap(wh.p) >= maxFrameHeaderSize+256 {
writeBuf = wh.p[:cap(wh.p)]
}
}
if writeBuf == nil {
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize
}
writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
}
c := &Conn{
isServer: isServer,
br: bufio.NewReaderSize(conn, readBufferSize),
conn: conn,
mu: mu,
readFinal: true,
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
writeFrameType: noFrame,
writePos: maxFrameHeaderSize,
isServer: isServer,
br: br,
conn: conn,
mu: mu,
readFinal: true,
writeBuf: writeBuf,
enableWriteCompression: true,
compressionLevel: defaultCompressionLevel,
}
c.SetCloseHandler(nil)
c.SetPingHandler(nil)
c.SetPongHandler(nil)
return c
@ -308,29 +360,40 @@ func (c *Conn) RemoteAddr() net.Addr {
// Write methods
func (c *Conn) writeFatal(err error) error {
err = hideTempErr(err)
c.writeErrMu.Lock()
if c.writeErr == nil {
c.writeErr = err
}
c.writeErrMu.Unlock()
return err
}
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
<-c.mu
defer func() { c.mu <- true }()
if c.closeSent {
return ErrCloseSent
} else if frameType == CloseMessage {
c.closeSent = true
c.writeErrMu.Lock()
err := c.writeErr
c.writeErrMu.Unlock()
if err != nil {
return err
}
c.conn.SetWriteDeadline(deadline)
for _, buf := range bufs {
if len(buf) > 0 {
n, err := c.conn.Write(buf)
if n != len(buf) {
// Close on partial write.
c.conn.Close()
}
_, err := c.conn.Write(buf)
if err != nil {
return err
return c.writeFatal(err)
}
}
}
if frameType == CloseMessage {
c.writeFatal(ErrCloseSent)
}
return nil
}
@ -379,18 +442,41 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
}
defer func() { c.mu <- true }()
if c.closeSent {
return ErrCloseSent
} else if messageType == CloseMessage {
c.closeSent = true
c.writeErrMu.Lock()
err := c.writeErr
c.writeErrMu.Unlock()
if err != nil {
return err
}
c.conn.SetWriteDeadline(deadline)
n, err := c.conn.Write(buf)
if n != 0 && n != len(buf) {
c.conn.Close()
_, err = c.conn.Write(buf)
if err != nil {
return c.writeFatal(err)
}
return hideTempErr(err)
if messageType == CloseMessage {
c.writeFatal(ErrCloseSent)
}
return err
}
func (c *Conn) prepWrite(messageType int) error {
// Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
if c.writer != nil {
c.writer.Close()
c.writer = nil
}
if !isControl(messageType) && !isData(messageType) {
return errBadWriteOpCode
}
c.writeErrMu.Lock()
err := c.writeErr
c.writeErrMu.Unlock()
return err
}
// NextWriter returns a writer for the next message to send. The writer's Close
@ -399,42 +485,61 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
// There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if c.writeErr != nil {
return nil, c.writeErr
if err := c.prepWrite(messageType); err != nil {
return nil, err
}
if c.writeFrameType != noFrame {
if err := c.flushFrame(true, nil); err != nil {
return nil, err
}
mw := &messageWriter{
c: c,
frameType: messageType,
pos: maxFrameHeaderSize,
}
if !isControl(messageType) && !isData(messageType) {
return nil, errBadWriteOpCode
c.writer = mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w := c.newCompressionWriter(c.writer, c.compressionLevel)
mw.compress = true
c.writer = w
}
c.writeFrameType = messageType
w := &messageWriter{c}
c.messageWriter = w
return w, nil
return c.writer, nil
}
func (c *Conn) flushFrame(final bool, extra []byte) error {
length := c.writePos - maxFrameHeaderSize + len(extra)
type messageWriter struct {
c *Conn
compress bool // whether next call to flushFrame should set RSV1
pos int // end of data in writeBuf.
frameType int // type of the current frame.
err error
}
func (w *messageWriter) fatal(err error) error {
if w.err != nil {
w.err = err
w.c.writer = nil
}
return err
}
// flushFrame writes buffered data and extra as a frame to the network. The
// final argument indicates that this is the last frame in the message.
func (w *messageWriter) flushFrame(final bool, extra []byte) error {
c := w.c
length := w.pos - maxFrameHeaderSize + len(extra)
// Check for invalid control frames.
if isControl(c.writeFrameType) &&
if isControl(w.frameType) &&
(!final || length > maxControlFramePayloadSize) {
c.messageWriter = nil
c.writeFrameType = noFrame
c.writePos = maxFrameHeaderSize
return errInvalidControlFrame
return w.fatal(errInvalidControlFrame)
}
b0 := byte(c.writeFrameType)
b0 := byte(w.frameType)
if final {
b0 |= finalBit
}
if w.compress {
b0 |= rsv1Bit
}
w.compress = false
b1 := byte(0)
if !c.isServer {
b1 |= maskBit
@ -466,10 +571,9 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
if !c.isServer {
key := newMaskKey()
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
if len(extra) > 0 {
c.writeErr = errors.New("websocket: internal error, extra used in client mode")
return c.writeErr
return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
}
}
@ -482,43 +586,35 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
}
c.isWriting = true
c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
if !c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = false
// Setup for next frame.
c.writePos = maxFrameHeaderSize
c.writeFrameType = continuationFrame
if err != nil {
return w.fatal(err)
}
if final {
c.messageWriter = nil
c.writeFrameType = noFrame
c.writer = nil
return nil
}
return c.writeErr
}
type messageWriter struct{ c *Conn }
func (w *messageWriter) err() error {
c := w.c
if c.messageWriter != w {
return errWriteClosed
}
if c.writeErr != nil {
return c.writeErr
}
// Setup for next frame.
w.pos = maxFrameHeaderSize
w.frameType = continuationFrame
return nil
}
func (w *messageWriter) ncopy(max int) (int, error) {
n := len(w.c.writeBuf) - w.c.writePos
n := len(w.c.writeBuf) - w.pos
if n <= 0 {
if err := w.c.flushFrame(false, nil); err != nil {
if err := w.flushFrame(false, nil); err != nil {
return 0, err
}
n = len(w.c.writeBuf) - w.c.writePos
n = len(w.c.writeBuf) - w.pos
}
if n > max {
n = max
@ -526,14 +622,14 @@ func (w *messageWriter) ncopy(max int) (int, error) {
return n, nil
}
func (w *messageWriter) write(final bool, p []byte) (int, error) {
if err := w.err(); err != nil {
return 0, err
func (w *messageWriter) Write(p []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
// Don't buffer large messages.
err := w.c.flushFrame(final, p)
err := w.flushFrame(false, p)
if err != nil {
return 0, err
}
@ -546,20 +642,16 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
if err != nil {
return 0, err
}
copy(w.c.writeBuf[w.c.writePos:], p[:n])
w.c.writePos += n
copy(w.c.writeBuf[w.pos:], p[:n])
w.pos += n
p = p[n:]
}
return nn, nil
}
func (w *messageWriter) Write(p []byte) (int, error) {
return w.write(false, p)
}
func (w *messageWriter) WriteString(p string) (int, error) {
if err := w.err(); err != nil {
return 0, err
if w.err != nil {
return 0, w.err
}
nn := len(p)
@ -568,27 +660,27 @@ func (w *messageWriter) WriteString(p string) (int, error) {
if err != nil {
return 0, err
}
copy(w.c.writeBuf[w.c.writePos:], p[:n])
w.c.writePos += n
copy(w.c.writeBuf[w.pos:], p[:n])
w.pos += n
p = p[n:]
}
return nn, nil
}
func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
if err := w.err(); err != nil {
return 0, err
if w.err != nil {
return 0, w.err
}
for {
if w.c.writePos == len(w.c.writeBuf) {
err = w.c.flushFrame(false, nil)
if w.pos == len(w.c.writeBuf) {
err = w.flushFrame(false, nil)
if err != nil {
break
}
}
var n int
n, err = r.Read(w.c.writeBuf[w.c.writePos:])
w.c.writePos += n
n, err = r.Read(w.c.writeBuf[w.pos:])
w.pos += n
nn += int64(n)
if err != nil {
if err == io.EOF {
@ -601,27 +693,59 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
}
func (w *messageWriter) Close() error {
if err := w.err(); err != nil {
if w.err != nil {
return w.err
}
if err := w.flushFrame(true, nil); err != nil {
return err
}
return w.c.flushFrame(true, nil)
w.err = errWriteClosed
return nil
}
// WritePreparedMessage writes prepared message into connection.
func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
frameType, frameData, err := pm.frame(prepareKey{
isServer: c.isServer,
compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
compressionLevel: c.compressionLevel,
})
if err != nil {
return err
}
if c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = true
err = c.write(frameType, c.writeDeadline, frameData, nil)
if !c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = false
return err
}
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error {
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
// Fast path with no allocations and single frame.
if err := c.prepWrite(messageType); err != nil {
return err
}
mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
n := copy(c.writeBuf[mw.pos:], data)
mw.pos += n
data = data[n:]
return mw.flushFrame(true, data)
}
w, err := c.NextWriter(messageType)
if err != nil {
return err
}
if _, ok := w.(*messageWriter); ok && c.isServer {
// Optimize write as a single frame.
n := copy(c.writeBuf[c.writePos:], data)
c.writePos += n
data = data[n:]
err = c.flushFrame(true, data)
return err
}
if _, err = w.Write(data); err != nil {
return err
}
@ -658,12 +782,17 @@ func (c *Conn) advanceFrame() (int, error) {
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
reserved := int((p[0] >> 4) & 0x7)
mask := p[1]&maskBit != 0
c.readRemaining = int64(p[1] & 0x7f)
if reserved != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
c.readDecompress = true
p[0] &^= rsv1Bit
}
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
}
switch frameType {
@ -759,11 +888,9 @@ func (c *Conn) advanceFrame() (int, error) {
return noFrame, err
}
case CloseMessage:
echoMessage := []byte{}
closeCode := CloseNoStatusReceived
closeText := ""
if len(payload) >= 2 {
echoMessage = payload[:2]
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code")
@ -773,7 +900,9 @@ func (c *Conn) advanceFrame() (int, error) {
return noFrame, c.handleProtocolError("invalid utf8 payload in close frame")
}
}
c.WriteControl(CloseMessage, echoMessage, time.Now().Add(writeWait))
if err := c.handleClose(closeCode, closeText); err != nil {
return noFrame, err
}
return noFrame, &CloseError{Code: closeCode, Text: closeText}
}
@ -796,6 +925,11 @@ func (c *Conn) handleProtocolError(message string) error {
// permanent. Once this method returns a non-nil error, all subsequent calls to
// this method return the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
// Close previous reader, only relevant for decompression.
if c.reader != nil {
c.reader.Close()
c.reader = nil
}
c.messageReader = nil
c.readLength = 0
@ -807,9 +941,12 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
break
}
if frameType == TextMessage || frameType == BinaryMessage {
r := &messageReader{c}
c.messageReader = r
return frameType, r, nil
c.messageReader = &messageReader{c}
c.reader = c.messageReader
if c.readDecompress {
c.reader = c.newDecompressionReader(c.reader)
}
return frameType, c.reader, nil
}
}
@ -871,6 +1008,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
return 0, err
}
func (r *messageReader) Close() error {
return nil
}
// ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
@ -898,6 +1039,38 @@ func (c *Conn) SetReadLimit(limit int64) {
c.readLimit = limit
}
// CloseHandler returns the current close handler
func (c *Conn) CloseHandler() func(code int, text string) error {
return c.handleClose
}
// SetCloseHandler sets the handler for close messages received from the peer.
// The code argument to h is the received close code or CloseNoStatusReceived
// if the close message is empty. The default close handler sends a close frame
// back to the peer.
//
// The application must read the connection to process close messages as
// described in the section on Control Frames above.
//
// The connection read methods return a CloseError when a close frame is
// received. Most applications should handle close messages as part of their
// normal error handling. Applications should only set a close handler when the
// application must perform some action before sending a close frame back to
// the peer.
func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if h == nil {
h = func(code int, text string) error {
message := []byte{}
if code != CloseNoStatusReceived {
message = FormatCloseMessage(code, "")
}
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
return nil
}
}
c.handleClose = h
}
// PingHandler returns the current ping handler
func (c *Conn) PingHandler() func(appData string) error {
return c.handlePing
@ -906,6 +1079,9 @@ func (c *Conn) PingHandler() func(appData string) error {
// SetPingHandler sets the handler for ping messages received from the peer.
// The appData argument to h is the PING frame application data. The default
// ping handler sends a pong to the peer.
//
// The application must read the connection to process ping messages as
// described in the section on Control Frames above.
func (c *Conn) SetPingHandler(h func(appData string) error) {
if h == nil {
h = func(message string) error {
@ -929,6 +1105,9 @@ func (c *Conn) PongHandler() func(appData string) error {
// SetPongHandler sets the handler for pong messages received from the peer.
// The appData argument to h is the PONG frame application data. The default
// pong handler does nothing.
//
// The application must read the connection to process ping messages as
// described in the section on Control Frames above.
func (c *Conn) SetPongHandler(h func(appData string) error) {
if h == nil {
h = func(string) error { return nil }
@ -942,6 +1121,25 @@ func (c *Conn) UnderlyingConn() net.Conn {
return c.conn
}
// EnableWriteCompression enables and disables write compression of
// subsequent text and binary messages. This function is a noop if
// compression was not negotiated with the peer.
func (c *Conn) EnableWriteCompression(enable bool) {
c.enableWriteCompression = enable
}
// SetCompressionLevel sets the flate compression level for subsequent text and
// binary messages. This function is a noop if compression was not negotiated
// with the peer. See the compress/flate package for a description of
// compression levels.
func (c *Conn) SetCompressionLevel(level int) error {
if !isValidCompressionLevel(level) {
return errors.New("websocket: invalid compression level")
}
c.compressionLevel = level
return nil
}
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
func FormatCloseMessage(closeCode int, text string) []byte {
buf := make([]byte, 2+len(text))