package session import ( "bytes" "crypto/cipher" "crypto/hmac" "crypto/rand" "crypto/sha1" "crypto/subtle" "encoding/base64" "encoding/gob" "errors" "fmt" "io" "strconv" "time" ) func init() { gob.Register([]interface{}{}) gob.Register(map[int]interface{}{}) gob.Register(map[string]interface{}{}) gob.Register(map[interface{}]interface{}{}) gob.Register(map[string]string{}) gob.Register(map[int]string{}) gob.Register(map[int]int{}) gob.Register(map[int]int64{}) } func encodeGob(obj map[interface{}]interface{}) ([]byte, error) { buf := bytes.NewBuffer(nil) enc := gob.NewEncoder(buf) err := enc.Encode(obj) if err != nil { return []byte(""), err } return buf.Bytes(), nil } func decodeGob(encoded []byte) (map[interface{}]interface{}, error) { buf := bytes.NewBuffer(encoded) dec := gob.NewDecoder(buf) var out map[interface{}]interface{} err := dec.Decode(&out) if err != nil { return nil, err } return out, nil } // generateRandomKey creates a random key with the given strength. func generateRandomKey(strength int) []byte { k := make([]byte, strength) if _, err := io.ReadFull(rand.Reader, k); err != nil { return nil } return k } // Encryption ----------------------------------------------------------------- // encrypt encrypts a value using the given block in counter mode. // // A random initialization vector (http://goo.gl/zF67k) with the length of the // block size is prepended to the resulting ciphertext. func encrypt(block cipher.Block, value []byte) ([]byte, error) { iv := generateRandomKey(block.BlockSize()) if iv == nil { return nil, errors.New("encrypt: failed to generate random iv") } // Encrypt it. stream := cipher.NewCTR(block, iv) stream.XORKeyStream(value, value) // Return iv + ciphertext. return append(iv, value...), nil } // decrypt decrypts a value using the given block in counter mode. // // The value to be decrypted must be prepended by a initialization vector // (http://goo.gl/zF67k) with the length of the block size. func decrypt(block cipher.Block, value []byte) ([]byte, error) { size := block.BlockSize() if len(value) > size { // Extract iv. iv := value[:size] // Extract ciphertext. value = value[size:] // Decrypt it. stream := cipher.NewCTR(block, iv) stream.XORKeyStream(value, value) return value, nil } return nil, errors.New("decrypt: the value could not be decrypted") } func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) { var err error var b []byte // 1. encodeGob. if b, err = encodeGob(value); err != nil { return "", err } // 2. Encrypt (optional). if b, err = encrypt(block, b); err != nil { return "", err } b = encode(b) // 3. Create MAC for "name|date|value". Extra pipe to be used later. b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b)) h := hmac.New(sha1.New, []byte(hashKey)) h.Write(b) sig := h.Sum(nil) // Append mac, remove name. b = append(b, sig...)[len(name)+1:] // 4. Encode to base64. b = encode(b) // Done. return string(b), nil } func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) { // 1. Decode from base64. b, err := decode([]byte(value)) if err != nil { return nil, err } // 2. Verify MAC. Value is "date|value|mac". parts := bytes.SplitN(b, []byte("|"), 3) if len(parts) != 3 { return nil, errors.New("Decode: invalid value %v") } b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) h := hmac.New(sha1.New, []byte(hashKey)) h.Write(b) sig := h.Sum(nil) if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 { return nil, errors.New("Decode: the value is not valid") } // 3. Verify date ranges. var t1 int64 if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { return nil, errors.New("Decode: invalid timestamp") } t2 := time.Now().UTC().Unix() if t1 > t2 { return nil, errors.New("Decode: timestamp is too new") } if t1 < t2-gcmaxlifetime { return nil, errors.New("Decode: expired timestamp") } // 4. Decrypt (optional). b, err = decode(parts[1]) if err != nil { return nil, err } if b, err = decrypt(block, b); err != nil { return nil, err } // 5. decodeGob. if dst, err := decodeGob(b); err != nil { return nil, err } else { return dst, nil } // Done. return nil, nil } // Encoding ------------------------------------------------------------------- // encode encodes a value using base64. func encode(value []byte) []byte { encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value))) base64.URLEncoding.Encode(encoded, value) return encoded } // decode decodes a cookie using base64. func decode(value []byte) ([]byte, error) { decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) b, err := base64.URLEncoding.Decode(decoded, value) if err != nil { return nil, err } return decoded[:b], nil }