1
0
mirror of https://github.com/astaxie/beego.git synced 2025-06-12 13:10:39 +00:00

Merge log_format

This commit is contained in:
Ming Deng
2020-09-10 23:31:49 +08:00
402 changed files with 17840 additions and 2268 deletions

View File

@ -0,0 +1,20 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
// ApplicationContext define for future
// when we decide to support DI, IoC, this will be core API
type ApplicationContext interface {
}

View File

@ -0,0 +1,17 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// bean is a basic package
// it should not depend on other modules except common module, log module and config module
package bean

View File

@ -0,0 +1,25 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
import (
"context"
)
// AutoWireBeanFactory wire the bean based on ApplicationContext and context.Context
type AutoWireBeanFactory interface {
// AutoWire will wire the bean.
AutoWire(ctx context.Context, appCtx ApplicationContext, bean interface{}) error
}

View File

@ -0,0 +1,28 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
// BeanMetadata, in other words, bean's config.
// it could be read from config file
type BeanMetadata struct {
// Fields: field name => field metadata
Fields map[string]*FieldMetadata
}
// FieldMetadata contains metadata
type FieldMetadata struct {
// default value in string format
DftValue string
}

View File

@ -0,0 +1,231 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
import (
"context"
"fmt"
"reflect"
"strconv"
"github.com/pkg/errors"
"github.com/astaxie/beego/pkg/infrastructure/logs"
)
const DefaultValueTagKey = "default"
// TagAutoWireBeanFactory wire the bean based on Fields' tag
// if field's value is "zero value", we will execute injection
// see reflect.Value.IsZero()
// If field's kind is one of(reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice
// reflect.UnsafePointer, reflect.Array, reflect.Uintptr, reflect.Complex64, reflect.Complex128
// reflect.Ptr, reflect.Struct),
// it will be ignored
type TagAutoWireBeanFactory struct {
// we allow user register their TypeAdapter
Adapters map[string]TypeAdapter
// FieldTagParser is an extension point which means that you can custom how to read field's metadata from tag
FieldTagParser func(field reflect.StructField) *FieldMetadata
}
// NewTagAutoWireBeanFactory create an instance of TagAutoWireBeanFactory
// by default, we register Time adapter, the time will be parse by using layout "2006-01-02 15:04:05"
// If you need more adapter, you can implement interface TypeAdapter
func NewTagAutoWireBeanFactory() *TagAutoWireBeanFactory {
return &TagAutoWireBeanFactory{
Adapters: map[string]TypeAdapter{
"Time": &TimeTypeAdapter{Layout: "2006-01-02 15:04:05"},
},
FieldTagParser: func(field reflect.StructField) *FieldMetadata {
return &FieldMetadata{
DftValue: field.Tag.Get(DefaultValueTagKey),
}
},
}
}
// AutoWire use value from appCtx to wire the bean, or use default value, or do nothing
func (t *TagAutoWireBeanFactory) AutoWire(ctx context.Context, appCtx ApplicationContext, bean interface{}) error {
if bean == nil {
return nil
}
v := reflect.Indirect(reflect.ValueOf(bean))
bm := t.getConfig(v)
// field name, field metadata
for fn, fm := range bm.Fields {
fValue := v.FieldByName(fn)
if len(fm.DftValue) == 0 || !t.needInject(fValue) || !fValue.CanSet() {
continue
}
// handle type adapter
typeName := fValue.Type().Name()
if adapter, ok := t.Adapters[typeName]; ok {
dftValue, err := adapter.DefaultValue(ctx, fm.DftValue)
if err == nil {
fValue.Set(reflect.ValueOf(dftValue))
continue
} else {
return err
}
}
switch fValue.Kind() {
case reflect.Bool:
if v, err := strconv.ParseBool(fm.DftValue); err != nil {
return errors.WithMessage(err,
fmt.Sprintf("can not convert the field[%s]'s default value[%s] to bool value",
fn, fm.DftValue))
} else {
fValue.SetBool(v)
continue
}
case reflect.Int:
if err := t.setIntXValue(fm.DftValue, 0, fn, fValue); err != nil {
return err
}
continue
case reflect.Int8:
if err := t.setIntXValue(fm.DftValue, 8, fn, fValue); err != nil {
return err
}
continue
case reflect.Int16:
if err := t.setIntXValue(fm.DftValue, 16, fn, fValue); err != nil {
return err
}
continue
case reflect.Int32:
if err := t.setIntXValue(fm.DftValue, 32, fn, fValue); err != nil {
return err
}
continue
case reflect.Int64:
if err := t.setIntXValue(fm.DftValue, 64, fn, fValue); err != nil {
return err
}
continue
case reflect.Uint:
if err := t.setUIntXValue(fm.DftValue, 0, fn, fValue); err != nil {
return err
}
case reflect.Uint8:
if err := t.setUIntXValue(fm.DftValue, 8, fn, fValue); err != nil {
return err
}
continue
case reflect.Uint16:
if err := t.setUIntXValue(fm.DftValue, 16, fn, fValue); err != nil {
return err
}
continue
case reflect.Uint32:
if err := t.setUIntXValue(fm.DftValue, 32, fn, fValue); err != nil {
return err
}
continue
case reflect.Uint64:
if err := t.setUIntXValue(fm.DftValue, 64, fn, fValue); err != nil {
return err
}
continue
case reflect.Float32:
if err := t.setFloatXValue(fm.DftValue, 32, fn, fValue); err != nil {
return err
}
continue
case reflect.Float64:
if err := t.setFloatXValue(fm.DftValue, 64, fn, fValue); err != nil {
return err
}
continue
case reflect.String:
fValue.SetString(fm.DftValue)
continue
// case reflect.Ptr:
// case reflect.Struct:
default:
logs.Warn("this field[%s] has default setting, but we don't support this type: %s",
fn, fValue.Kind().String())
}
}
return nil
}
func (t *TagAutoWireBeanFactory) setFloatXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error {
if v, err := strconv.ParseFloat(dftValue, bitSize); err != nil {
return errors.WithMessage(err,
fmt.Sprintf("can not convert the field[%s]'s default value[%s] to float%d value",
fn, dftValue, bitSize))
} else {
fv.SetFloat(v)
return nil
}
}
func (t *TagAutoWireBeanFactory) setUIntXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error {
if v, err := strconv.ParseUint(dftValue, 10, bitSize); err != nil {
return errors.WithMessage(err,
fmt.Sprintf("can not convert the field[%s]'s default value[%s] to uint%d value",
fn, dftValue, bitSize))
} else {
fv.SetUint(v)
return nil
}
}
func (t *TagAutoWireBeanFactory) setIntXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error {
if v, err := strconv.ParseInt(dftValue, 10, bitSize); err != nil {
return errors.WithMessage(err,
fmt.Sprintf("can not convert the field[%s]'s default value[%s] to int%d value",
fn, dftValue, bitSize))
} else {
fv.SetInt(v)
return nil
}
}
func (t *TagAutoWireBeanFactory) needInject(fValue reflect.Value) bool {
return fValue.IsZero()
}
// getConfig never return nil
func (t *TagAutoWireBeanFactory) getConfig(beanValue reflect.Value) *BeanMetadata {
fms := make(map[string]*FieldMetadata, beanValue.NumField())
for i := 0; i < beanValue.NumField(); i++ {
// f => StructField
f := beanValue.Type().Field(i)
fms[f.Name] = t.FieldTagParser(f)
}
return &BeanMetadata{
Fields: fms,
}
}

View File

@ -0,0 +1,75 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestTagAutoWireBeanFactory_AutoWire(t *testing.T) {
factory := NewTagAutoWireBeanFactory()
bm := &ComplicateStruct{}
err := factory.AutoWire(context.Background(), nil, bm)
assert.Nil(t, err)
assert.Equal(t, 12, bm.IntValue)
assert.Equal(t, "hello, strValue", bm.StrValue)
assert.Equal(t, int8(8), bm.Int8Value)
assert.Equal(t, int16(16), bm.Int16Value)
assert.Equal(t, int32(32), bm.Int32Value)
assert.Equal(t, int64(64), bm.Int64Value)
assert.Equal(t, uint(13), bm.UintValue)
assert.Equal(t, uint8(88), bm.Uint8Value)
assert.Equal(t, uint16(1616), bm.Uint16Value)
assert.Equal(t, uint32(3232), bm.Uint32Value)
assert.Equal(t, uint64(6464), bm.Uint64Value)
assert.Equal(t, float32(32.32), bm.Float32Value)
assert.Equal(t, float64(64.64), bm.Float64Value)
assert.True(t, bm.BoolValue)
assert.Equal(t, 0, bm.ignoreInt)
assert.NotNil(t, bm.TimeValue)
}
type ComplicateStruct struct {
IntValue int `default:"12"`
StrValue string `default:"hello, strValue"`
Int8Value int8 `default:"8"`
Int16Value int16 `default:"16"`
Int32Value int32 `default:"32"`
Int64Value int64 `default:"64"`
UintValue uint `default:"13"`
Uint8Value uint8 `default:"88"`
Uint16Value uint16 `default:"1616"`
Uint32Value uint32 `default:"3232"`
Uint64Value uint64 `default:"6464"`
Float32Value float32 `default:"32.32"`
Float64Value float64 `default:"64.64"`
BoolValue bool `default:"true"`
ignoreInt int `default:"11"`
TimeValue time.Time `default:"2018-02-03 12:13:14.000"`
}

View File

@ -0,0 +1,35 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
import (
"context"
"time"
)
// TimeTypeAdapter process the time.Time
type TimeTypeAdapter struct {
Layout string
}
// DefaultValue parse the DftValue to time.Time
// and if the DftValue == now
// time.Now() is returned
func (t *TimeTypeAdapter) DefaultValue(ctx context.Context, dftValue string) (interface{}, error) {
if dftValue == "now" {
return time.Now(), nil
}
return time.Parse(t.Layout, dftValue)
}

View File

@ -0,0 +1,29 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestTimeTypeAdapter_DefaultValue(t *testing.T) {
typeAdapter := &TimeTypeAdapter{Layout: "2006-01-02 15:04:05"}
tm, err := typeAdapter.DefaultValue(context.Background(), "2018-02-03 12:34:11")
assert.Nil(t, err)
assert.NotNil(t, tm)
}

View File

@ -0,0 +1,26 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
import (
"context"
)
// TypeAdapter is an abstraction that define some behavior of target type
// usually, we don't use this to support basic type since golang has many restriction for basic types
// This is an important extension point
type TypeAdapter interface {
DefaultValue(ctx context.Context, dftValue string) (interface{}, error)
}

View File

@ -0,0 +1,72 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBaseConfiger_DefaultBool(t *testing.T) {
bc := newBaseConfier("true")
assert.True(t, bc.DefaultBool(context.Background(), "key1", false))
assert.True(t, bc.DefaultBool(context.Background(), "key2", true))
}
func TestBaseConfiger_DefaultFloat(t *testing.T) {
bc := newBaseConfier("12.3")
assert.Equal(t, 12.3, bc.DefaultFloat(context.Background(), "key1", 0.1))
assert.Equal(t, 0.1, bc.DefaultFloat(context.Background(), "key2", 0.1))
}
func TestBaseConfiger_DefaultInt(t *testing.T) {
bc := newBaseConfier("10")
assert.Equal(t, 10, bc.DefaultInt(context.Background(), "key1", 8))
assert.Equal(t, 8, bc.DefaultInt(context.Background(), "key2", 8))
}
func TestBaseConfiger_DefaultInt64(t *testing.T) {
bc := newBaseConfier("64")
assert.Equal(t, int64(64), bc.DefaultInt64(context.Background(), "key1", int64(8)))
assert.Equal(t, int64(8), bc.DefaultInt64(context.Background(), "key2", int64(8)))
}
func TestBaseConfiger_DefaultString(t *testing.T) {
bc := newBaseConfier("Hello")
assert.Equal(t, "Hello", bc.DefaultString(context.Background(), "key1", "world"))
assert.Equal(t, "world", bc.DefaultString(context.Background(), "key2", "world"))
}
func TestBaseConfiger_DefaultStrings(t *testing.T) {
bc := newBaseConfier("Hello;world")
assert.Equal(t, []string{"Hello", "world"}, bc.DefaultStrings(context.Background(), "key1", []string{"world"}))
assert.Equal(t, []string{"world"}, bc.DefaultStrings(context.Background(), "key2", []string{"world"}))
}
func newBaseConfier(str1 string) *BaseConfiger {
return &BaseConfiger{
reader: func(ctx context.Context, key string) (string, error) {
if key == "key1" {
return str1, nil
} else {
return "", errors.New("mock error")
}
},
}
}

View File

@ -0,0 +1,379 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package config is used to parse config.
// Usage:
// import "github.com/astaxie/beego/config"
// Examples.
//
// cnf, err := config.NewConfig("ini", "config.conf")
//
// cnf APIS:
//
// cnf.Set(key, val string) error
// cnf.String(key string) string
// cnf.Strings(key string) []string
// cnf.Int(key string) (int, error)
// cnf.Int64(key string) (int64, error)
// cnf.Bool(key string) (bool, error)
// cnf.Float(key string) (float64, error)
// cnf.DefaultString(key string, defaultVal string) string
// cnf.DefaultStrings(key string, defaultVal []string) []string
// cnf.DefaultInt(key string, defaultVal int) int
// cnf.DefaultInt64(key string, defaultVal int64) int64
// cnf.DefaultBool(key string, defaultVal bool) bool
// cnf.DefaultFloat(key string, defaultVal float64) float64
// cnf.DIY(key string) (interface{}, error)
// cnf.GetSection(section string) (map[string]string, error)
// cnf.SaveConfigFile(filename string) error
// More docs http://beego.me/docs/module/config.md
package config
import (
"context"
"errors"
"fmt"
"os"
"reflect"
"strconv"
"strings"
"time"
)
// Configer defines how to get and set value from configuration raw data.
type Configer interface {
// support section::key type in given key when using ini type.
Set(ctx context.Context, key, val string) error
// support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
String(ctx context.Context, key string) (string, error)
// get string slice
Strings(ctx context.Context, key string) ([]string, error)
Int(ctx context.Context, key string) (int, error)
Int64(ctx context.Context, key string) (int64, error)
Bool(ctx context.Context, key string) (bool, error)
Float(ctx context.Context, key string) (float64, error)
// support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
DefaultString(ctx context.Context, key string, defaultVal string) string
// get string slice
DefaultStrings(ctx context.Context, key string, defaultVal []string) []string
DefaultInt(ctx context.Context, key string, defaultVal int) int
DefaultInt64(ctx context.Context, key string, defaultVal int64) int64
DefaultBool(ctx context.Context, key string, defaultVal bool) bool
DefaultFloat(ctx context.Context, key string, defaultVal float64) float64
DIY(ctx context.Context, key string) (interface{}, error)
GetSection(ctx context.Context, section string) (map[string]string, error)
Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error
Sub(ctx context.Context, key string) (Configer, error)
OnChange(ctx context.Context, key string, fn func(value string))
SaveConfigFile(ctx context.Context, filename string) error
}
type BaseConfiger struct {
// The reader should support key like "a.b.c"
reader func(ctx context.Context, key string) (string, error)
}
func NewBaseConfiger(reader func(ctx context.Context, key string) (string, error)) BaseConfiger {
return BaseConfiger{
reader: reader,
}
}
func (c *BaseConfiger) Int(ctx context.Context, key string) (int, error) {
res, err := c.reader(context.TODO(), key)
if err != nil {
return 0, err
}
return strconv.Atoi(res)
}
func (c *BaseConfiger) Int64(ctx context.Context, key string) (int64, error) {
res, err := c.reader(context.TODO(), key)
if err != nil {
return 0, err
}
return strconv.ParseInt(res, 10, 64)
}
func (c *BaseConfiger) Bool(ctx context.Context, key string) (bool, error) {
res, err := c.reader(context.TODO(), key)
if err != nil {
return false, err
}
return ParseBool(res)
}
func (c *BaseConfiger) Float(ctx context.Context, key string) (float64, error) {
res, err := c.reader(context.TODO(), key)
if err != nil {
return 0, err
}
return strconv.ParseFloat(res, 64)
}
// DefaultString returns the string value for a given key.
// if err != nil or value is empty return defaultval
func (c *BaseConfiger) DefaultString(ctx context.Context, key string, defaultVal string) string {
if res, err := c.String(ctx, key); res != "" && err == nil {
return res
}
return defaultVal
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaultval
func (c *BaseConfiger) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string {
if res, err := c.Strings(ctx, key); len(res) > 0 && err == nil {
return res
}
return defaultVal
}
func (c *BaseConfiger) DefaultInt(ctx context.Context, key string, defaultVal int) int {
if res, err := c.Int(ctx, key); err == nil {
return res
}
return defaultVal
}
func (c *BaseConfiger) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 {
if res, err := c.Int64(ctx, key); err == nil {
return res
}
return defaultVal
}
func (c *BaseConfiger) DefaultBool(ctx context.Context, key string, defaultVal bool) bool {
if res, err := c.Bool(ctx, key); err == nil {
return res
}
return defaultVal
}
func (c *BaseConfiger) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 {
if res, err := c.Float(ctx, key); err == nil {
return res
}
return defaultVal
}
func (c *BaseConfiger) String(ctx context.Context, key string) (string, error) {
return c.reader(context.TODO(), key)
}
// Strings returns the []string value for a given key.
// Return nil if config value does not exist or is empty.
func (c *BaseConfiger) Strings(ctx context.Context, key string) ([]string, error) {
res, err := c.String(nil, key)
if err != nil || res == "" {
return nil, err
}
return strings.Split(res, ";"), nil
}
// TODO remove this before release v2.0.0
func (c *BaseConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error {
return errors.New("unsupported operation")
}
// TODO remove this before release v2.0.0
func (c *BaseConfiger) Sub(ctx context.Context, key string) (Configer, error) {
return nil, errors.New("unsupported operation")
}
// TODO remove this before release v2.0.0
func (c *BaseConfiger) OnChange(ctx context.Context, key string, fn func(value string)) {
// do nothing
}
// Config is the adapter interface for parsing config file to get raw data to Configer.
type Config interface {
Parse(key string) (Configer, error)
ParseData(data []byte) (Configer, error)
}
var adapters = make(map[string]Config)
// Register makes a config adapter available by the adapter name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
func Register(name string, adapter Config) {
if adapter == nil {
panic("config: Register adapter is nil")
}
if _, ok := adapters[name]; ok {
panic("config: Register called twice for adapter " + name)
}
adapters[name] = adapter
}
// NewConfig adapterName is ini/json/xml/yaml.
// filename is the config file path.
func NewConfig(adapterName, filename string) (Configer, error) {
adapter, ok := adapters[adapterName]
if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)
}
return adapter.Parse(filename)
}
// NewConfigData adapterName is ini/json/xml/yaml.
// data is the config data.
func NewConfigData(adapterName string, data []byte) (Configer, error) {
adapter, ok := adapters[adapterName]
if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)
}
return adapter.ParseData(data)
}
// ExpandValueEnvForMap convert all string value with environment variable.
func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} {
for k, v := range m {
switch value := v.(type) {
case string:
m[k] = ExpandValueEnv(value)
case map[string]interface{}:
m[k] = ExpandValueEnvForMap(value)
case map[string]string:
for k2, v2 := range value {
value[k2] = ExpandValueEnv(v2)
}
m[k] = value
}
}
return m
}
// ExpandValueEnv returns value of convert with environment variable.
//
// Return environment variable if value start with "${" and end with "}".
// Return default value if environment variable is empty or not exist.
//
// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue".
// Examples:
// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable.
// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/".
// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie".
func ExpandValueEnv(value string) (realValue string) {
realValue = value
vLen := len(value)
// 3 = ${}
if vLen < 3 {
return
}
// Need start with "${" and end with "}", then return.
if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' {
return
}
key := ""
defaultV := ""
// value start with "${"
for i := 2; i < vLen; i++ {
if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') {
key = value[2:i]
defaultV = value[i+2 : vLen-1] // other string is default value.
break
} else if value[i] == '}' {
key = value[2:i]
break
}
}
realValue = os.Getenv(key)
if realValue == "" {
realValue = defaultV
}
return
}
// ParseBool returns the boolean value represented by the string.
//
// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On,
// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off.
// Any other value returns an error.
func ParseBool(val interface{}) (value bool, err error) {
if val != nil {
switch v := val.(type) {
case bool:
return v, nil
case string:
switch v {
case "1", "t", "T", "true", "TRUE", "True", "YES", "yes", "Yes", "Y", "y", "ON", "on", "On":
return true, nil
case "0", "f", "F", "false", "FALSE", "False", "NO", "no", "No", "N", "n", "OFF", "off", "Off":
return false, nil
}
case int8, int32, int64:
strV := fmt.Sprintf("%d", v)
if strV == "1" {
return true, nil
} else if strV == "0" {
return false, nil
}
case float64:
if v == 1.0 {
return true, nil
} else if v == 0.0 {
return false, nil
}
}
return false, fmt.Errorf("parsing %q: invalid syntax", val)
}
return false, fmt.Errorf("parsing <nil>: invalid syntax")
}
// ToString converts values of any type to string.
func ToString(x interface{}) string {
switch y := x.(type) {
// Handle dates with special logic
// This needs to come above the fmt.Stringer
// test since time.Time's have a .String()
// method
case time.Time:
return y.Format("A Monday")
// Handle type string
case string:
return y
// Handle type with .String() method
case fmt.Stringer:
return y.String()
// Handle type with .Error() method
case error:
return y.Error()
}
// Handle named string type
if v := reflect.ValueOf(x); v.Kind() == reflect.String {
return v.String()
}
// Fallback to fmt package for anything else like numeric types
return fmt.Sprint(x)
}
type DecodeOption func(options decodeOptions)
type decodeOptions struct {
}

View File

@ -0,0 +1,55 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
import (
"os"
"testing"
)
func TestExpandValueEnv(t *testing.T) {
testCases := []struct {
item string
want string
}{
{"", ""},
{"$", "$"},
{"{", "{"},
{"{}", "{}"},
{"${}", ""},
{"${|}", ""},
{"${}", ""},
{"${{}}", ""},
{"${{||}}", "}"},
{"${pwd||}", ""},
{"${pwd||}", ""},
{"${pwd||}", ""},
{"${pwd||}}", "}"},
{"${pwd||{{||}}}", "{{||}}"},
{"${GOPATH}", os.Getenv("GOPATH")},
{"${GOPATH||}", os.Getenv("GOPATH")},
{"${GOPATH||root}", os.Getenv("GOPATH")},
{"${GOPATH_NOT||root}", "root"},
{"${GOPATH_NOT||||root}", "||root"},
}
for _, c := range testCases {
if got := ExpandValueEnv(c.item); got != c.want {
t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got)
}
}
}

87
pkg/infrastructure/config/env/env.go vendored Normal file
View File

@ -0,0 +1,87 @@
// Copyright 2014 beego Author. All Rights Reserved.
// Copyright 2017 Faissal Elamraoui. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package env is used to parse environment.
package env
import (
"fmt"
"os"
"strings"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
var env *utils.BeeMap
func init() {
env = utils.NewBeeMap()
for _, e := range os.Environ() {
splits := strings.Split(e, "=")
env.Set(splits[0], os.Getenv(splits[0]))
}
}
// Get returns a value for a given key.
// If the key does not exist, the default value will be returned.
func Get(key string, defVal string) string {
if val := env.Get(key); val != nil {
return val.(string)
}
return defVal
}
// MustGet returns a value by key.
// If the key does not exist, it will return an error.
func MustGet(key string) (string, error) {
if val := env.Get(key); val != nil {
return val.(string), nil
}
return "", fmt.Errorf("no env variable with %s", key)
}
// Set sets a value in the ENV copy.
// This does not affect the child process environment.
func Set(key string, value string) {
env.Set(key, value)
}
// MustSet sets a value in the ENV copy and the child process environment.
// It returns an error in case the set operation failed.
func MustSet(key string, value string) error {
err := os.Setenv(key, value)
if err != nil {
return err
}
env.Set(key, value)
return nil
}
// GetAll returns all keys/values in the current child process environment.
func GetAll() map[string]string {
items := env.Items()
envs := make(map[string]string, env.Count())
for key, val := range items {
switch key := key.(type) {
case string:
switch val := val.(type) {
case string:
envs[key] = val
}
}
}
return envs
}

View File

@ -0,0 +1,75 @@
// Copyright 2014 beego Author. All Rights Reserved.
// Copyright 2017 Faissal Elamraoui. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package env
import (
"os"
"testing"
)
func TestEnvGet(t *testing.T) {
gopath := Get("GOPATH", "")
if gopath != os.Getenv("GOPATH") {
t.Error("expected GOPATH not empty.")
}
noExistVar := Get("NOEXISTVAR", "foo")
if noExistVar != "foo" {
t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar)
}
}
func TestEnvMustGet(t *testing.T) {
gopath, err := MustGet("GOPATH")
if err != nil {
t.Error(err)
}
if gopath != os.Getenv("GOPATH") {
t.Errorf("expected GOPATH to be the same, got %s.", gopath)
}
_, err = MustGet("NOEXISTVAR")
if err == nil {
t.Error("expected error to be non-nil")
}
}
func TestEnvSet(t *testing.T) {
Set("MYVAR", "foo")
myVar := Get("MYVAR", "bar")
if myVar != "foo" {
t.Errorf("expected MYVAR to equal foo, got %s.", myVar)
}
}
func TestEnvMustSet(t *testing.T) {
err := MustSet("FOO", "bar")
if err != nil {
t.Error(err)
}
fooVar := os.Getenv("FOO")
if fooVar != "bar" {
t.Errorf("expected FOO variable to equal bar, got %s.", fooVar)
}
}
func TestEnvGetAll(t *testing.T) {
envMap := GetAll()
if len(envMap) == 0 {
t.Error("expected environment not empty.")
}
}

View File

@ -0,0 +1,214 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package etcd
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/coreos/etcd/clientv3"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
"google.golang.org/grpc"
"github.com/astaxie/beego/pkg/infrastructure/config"
"github.com/astaxie/beego/pkg/infrastructure/logs"
)
const etcdOpts = "etcdOpts"
type EtcdConfiger struct {
prefix string
client *clientv3.Client
config.BaseConfiger
}
func newEtcdConfiger(client *clientv3.Client, prefix string) *EtcdConfiger {
res := &EtcdConfiger{
client: client,
prefix: prefix,
}
res.BaseConfiger = config.NewBaseConfiger(res.reader)
return res
}
// reader is an general implementation that read config from etcd.
func (e *EtcdConfiger) reader(ctx context.Context, key string) (string, error) {
resp, err := get(e.client, ctx, e.prefix+key)
if err != nil {
return "", err
}
if resp.Count > 0 {
return string(resp.Kvs[0].Value), nil
}
return "", nil
}
// Set do nothing and return an error
// I think write data to remote config center is not a good practice
func (e *EtcdConfiger) Set(ctx context.Context, key, val string) error {
return errors.New("Unsupported operation")
}
// DIY return the original response from etcd
// be careful when you decide to use this
func (e *EtcdConfiger) DIY(ctx context.Context, key string) (interface{}, error) {
return get(e.client, context.TODO(), key)
}
// GetSection in this implementation, we use section as prefix
func (e *EtcdConfiger) GetSection(ctx context.Context, section string) (map[string]string, error) {
var (
resp *clientv3.GetResponse
err error
)
if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok {
opts = append(opts, clientv3.WithPrefix())
resp, err = e.client.Get(context.TODO(), e.prefix+section, opts...)
} else {
resp, err = e.client.Get(context.TODO(), e.prefix+section, clientv3.WithPrefix())
}
if err != nil {
return nil, errors.WithMessage(err, "GetSection failed")
}
res := make(map[string]string, len(resp.Kvs))
for _, kv := range resp.Kvs {
res[string(kv.Key)] = string(kv.Value)
}
return res, nil
}
func (e *EtcdConfiger) SaveConfigFile(ctx context.Context, filename string) error {
return errors.New("Unsupported operation")
}
// Unmarshaler is not very powerful because we lost the type information when we get configuration from etcd
// for example, when we got "5", we are not sure whether it's int 5, or it's string "5"
// TODO(support more complicated decoder)
func (e *EtcdConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error {
res, err := e.GetSection(ctx, prefix)
if err != nil {
return errors.WithMessage(err, fmt.Sprintf("could not read config with prefix: %s", prefix))
}
prefixLen := len(e.prefix + prefix)
m := make(map[string]string, len(res))
for k, v := range res {
m[k[prefixLen:]] = v
}
return mapstructure.Decode(m, obj)
}
// Sub return an sub configer.
func (e *EtcdConfiger) Sub(ctx context.Context, key string) (config.Configer, error) {
return newEtcdConfiger(e.client, e.prefix+key), nil
}
// TODO remove this before release v2.0.0
func (e *EtcdConfiger) OnChange(ctx context.Context, key string, fn func(value string)) {
buildOptsFunc := func() []clientv3.OpOption {
if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok {
opts = append(opts, clientv3.WithCreatedNotify())
return opts
}
return []clientv3.OpOption{}
}
rch := e.client.Watch(ctx, e.prefix+key, buildOptsFunc()...)
go func() {
for {
for resp := range rch {
if err := resp.Err(); err != nil {
logs.Error("listen to key but got error callback", err)
break
}
for _, e := range resp.Events {
if e.Kv == nil {
continue
}
fn(string(e.Kv.Value))
}
}
time.Sleep(time.Second)
rch = e.client.Watch(ctx, e.prefix+key, buildOptsFunc()...)
}
}()
}
type EtcdConfigerProvider struct {
}
// Parse = ParseData([]byte(key))
// key must be json
func (provider *EtcdConfigerProvider) Parse(key string) (config.Configer, error) {
return provider.ParseData([]byte(key))
}
// ParseData try to parse key as clientv3.Config, using this to build etcdClient
func (provider *EtcdConfigerProvider) ParseData(data []byte) (config.Configer, error) {
cfg := &clientv3.Config{}
err := json.Unmarshal(data, cfg)
if err != nil {
return nil, errors.WithMessage(err, "parse data to etcd config failed, please check your input")
}
cfg.DialOptions = []grpc.DialOption{
grpc.WithBlock(),
grpc.WithUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor),
grpc.WithStreamInterceptor(grpc_prometheus.StreamClientInterceptor),
}
client, err := clientv3.New(*cfg)
if err != nil {
return nil, errors.WithMessage(err, "create etcd client failed")
}
return newEtcdConfiger(client, ""), nil
}
func get(client *clientv3.Client, ctx context.Context, key string) (*clientv3.GetResponse, error) {
var (
resp *clientv3.GetResponse
err error
)
if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok {
resp, err = client.Get(ctx, key, opts...)
} else {
resp, err = client.Get(ctx, key)
}
if err != nil {
return nil, errors.WithMessage(err, fmt.Sprintf("read config from etcd with key %s failed", key))
}
return resp, err
}
func WithEtcdOption(ctx context.Context, opts ...clientv3.OpOption) context.Context {
return context.WithValue(ctx, etcdOpts, opts)
}
func init() {
config.Register("json", &EtcdConfigerProvider{})
}

View File

@ -0,0 +1,123 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package etcd
import (
"context"
"encoding/json"
"os"
"testing"
"time"
"github.com/coreos/etcd/clientv3"
"github.com/stretchr/testify/assert"
)
func TestWithEtcdOption(t *testing.T) {
ctx := WithEtcdOption(context.Background(), clientv3.WithPrefix())
assert.NotNil(t, ctx.Value(etcdOpts))
}
func TestEtcdConfigerProvider_Parse(t *testing.T) {
provider := &EtcdConfigerProvider{}
cfger, err := provider.Parse(readEtcdConfig())
assert.Nil(t, err)
assert.NotNil(t, cfger)
}
func TestEtcdConfiger(t *testing.T) {
provider := &EtcdConfigerProvider{}
cfger, _ := provider.Parse(readEtcdConfig())
subCfger, err := cfger.Sub(nil, "sub.")
assert.Nil(t, err)
assert.NotNil(t, subCfger)
subSubCfger, err := subCfger.Sub(nil, "sub.")
assert.NotNil(t, subSubCfger)
assert.Nil(t, err)
str, err := subSubCfger.String(nil, "key1")
assert.Nil(t, err)
assert.Equal(t, "sub.sub.key", str)
// we cannot test it
subSubCfger.OnChange(context.Background(), "watch", func(value string) {
// do nothing
})
defStr := cfger.DefaultString(nil, "not_exit", "default value")
assert.Equal(t, "default value", defStr)
defInt64 := cfger.DefaultInt64(nil, "not_exit", -1)
assert.Equal(t, int64(-1), defInt64)
defInt := cfger.DefaultInt(nil, "not_exit", -2)
assert.Equal(t, -2, defInt)
defFlt := cfger.DefaultFloat(nil, "not_exit", 12.3)
assert.Equal(t, 12.3, defFlt)
defBl := cfger.DefaultBool(nil, "not_exit", true)
assert.True(t, defBl)
defStrs := cfger.DefaultStrings(nil, "not_exit", []string{"hello"})
assert.Equal(t, []string{"hello"}, defStrs)
fl, err := cfger.Float(nil, "current.float")
assert.Nil(t, err)
assert.Equal(t, 1.23, fl)
bl, err := cfger.Bool(nil, "current.bool")
assert.Nil(t, err)
assert.True(t, bl)
it, err := cfger.Int(nil, "current.int")
assert.Nil(t, err)
assert.Equal(t, 11, it)
str, err = cfger.String(nil, "current.string")
assert.Nil(t, err)
assert.Equal(t, "hello", str)
tn := &TestEntity{}
err = cfger.Unmarshaler(context.Background(), "current.serialize.", tn)
assert.Nil(t, err)
assert.Equal(t, "test", tn.Name)
}
type TestEntity struct {
Name string `yaml:"name"`
Sub SubEntity `yaml:"sub"`
}
type SubEntity struct {
SubName string `yaml:"subName"`
}
func readEtcdConfig() string {
addr := os.Getenv("ETCD_ADDR")
if addr == "" {
addr = "localhost:2379"
}
obj := clientv3.Config{
Endpoints: []string{addr},
DialTimeout: 3 * time.Second,
}
cfg, _ := json.Marshal(obj)
return string(cfg)
}

View File

@ -0,0 +1,112 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
import (
"context"
"errors"
"strconv"
"strings"
)
type fakeConfigContainer struct {
BaseConfiger
data map[string]string
}
func (c *fakeConfigContainer) getData(key string) string {
return c.data[strings.ToLower(key)]
}
func (c *fakeConfigContainer) Set(ctx context.Context, key, val string) error {
c.data[strings.ToLower(key)] = val
return nil
}
func (c *fakeConfigContainer) Int(ctx context.Context, key string) (int, error) {
return strconv.Atoi(c.getData(key))
}
func (c *fakeConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int {
v, err := c.Int(ctx, key)
if err != nil {
return defaultVal
}
return v
}
func (c *fakeConfigContainer) Int64(ctx context.Context, key string) (int64, error) {
return strconv.ParseInt(c.getData(key), 10, 64)
}
func (c *fakeConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 {
v, err := c.Int64(ctx, key)
if err != nil {
return defaultVal
}
return v
}
func (c *fakeConfigContainer) Bool(ctx context.Context, key string) (bool, error) {
return ParseBool(c.getData(key))
}
func (c *fakeConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool {
v, err := c.Bool(ctx, key)
if err != nil {
return defaultVal
}
return v
}
func (c *fakeConfigContainer) Float(ctx context.Context, key string) (float64, error) {
return strconv.ParseFloat(c.getData(key), 64)
}
func (c *fakeConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 {
v, err := c.Float(ctx, key)
if err != nil {
return defaultVal
}
return v
}
func (c *fakeConfigContainer) DIY(ctx context.Context, key string) (interface{}, error) {
if v, ok := c.data[strings.ToLower(key)]; ok {
return v, nil
}
return nil, errors.New("key not find")
}
func (c *fakeConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) {
return nil, errors.New("not implement in the fakeConfigContainer")
}
func (c *fakeConfigContainer) SaveConfigFile(ctx context.Context, filename string) error {
return errors.New("not implement in the fakeConfigContainer")
}
var _ Configer = new(fakeConfigContainer)
// NewFakeConfig return a fake Configer
func NewFakeConfig() Configer {
res := &fakeConfigContainer{
data: make(map[string]string),
}
res.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) {
return res.getData(key), nil
})
return res
}

View File

@ -0,0 +1,510 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"io/ioutil"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
"sync"
)
var (
defaultSection = "default" // default section means if some ini items not in a section, make them in default section,
bNumComment = []byte{'#'} // number signal
bSemComment = []byte{';'} // semicolon signal
bEmpty = []byte{}
bEqual = []byte{'='} // equal signal
bDQuote = []byte{'"'} // quote signal
sectionStart = []byte{'['} // section start signal
sectionEnd = []byte{']'} // section end signal
lineBreak = "\n"
)
// IniConfig implements Config to parse ini file.
type IniConfig struct {
}
// Parse creates a new Config and parses the file configuration from the named file.
func (ini *IniConfig) Parse(name string) (Configer, error) {
return ini.parseFile(name)
}
func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
data, err := ioutil.ReadFile(name)
if err != nil {
return nil, err
}
return ini.parseData(filepath.Dir(name), data)
}
func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) {
cfg := &IniConfigContainer{
data: make(map[string]map[string]string),
sectionComment: make(map[string]string),
keyComment: make(map[string]string),
RWMutex: sync.RWMutex{},
}
cfg.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) {
return cfg.getdata(key), nil
})
cfg.Lock()
defer cfg.Unlock()
var comment bytes.Buffer
buf := bufio.NewReader(bytes.NewBuffer(data))
// check the BOM
head, err := buf.Peek(3)
if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 {
for i := 1; i <= 3; i++ {
buf.ReadByte()
}
}
section := defaultSection
tmpBuf := bytes.NewBuffer(nil)
for {
tmpBuf.Reset()
shouldBreak := false
for {
tmp, isPrefix, err := buf.ReadLine()
if err == io.EOF {
shouldBreak = true
break
}
//It might be a good idea to throw a error on all unknonw errors?
if _, ok := err.(*os.PathError); ok {
return nil, err
}
tmpBuf.Write(tmp)
if isPrefix {
continue
}
if !isPrefix {
break
}
}
if shouldBreak {
break
}
line := tmpBuf.Bytes()
line = bytes.TrimSpace(line)
if bytes.Equal(line, bEmpty) {
continue
}
var bComment []byte
switch {
case bytes.HasPrefix(line, bNumComment):
bComment = bNumComment
case bytes.HasPrefix(line, bSemComment):
bComment = bSemComment
}
if bComment != nil {
line = bytes.TrimLeft(line, string(bComment))
// Need append to a new line if multi-line comments.
if comment.Len() > 0 {
comment.WriteByte('\n')
}
comment.Write(line)
continue
}
if bytes.HasPrefix(line, sectionStart) && bytes.HasSuffix(line, sectionEnd) {
section = strings.ToLower(string(line[1 : len(line)-1])) // section name case insensitive
if comment.Len() > 0 {
cfg.sectionComment[section] = comment.String()
comment.Reset()
}
if _, ok := cfg.data[section]; !ok {
cfg.data[section] = make(map[string]string)
}
continue
}
if _, ok := cfg.data[section]; !ok {
cfg.data[section] = make(map[string]string)
}
keyValue := bytes.SplitN(line, bEqual, 2)
key := string(bytes.TrimSpace(keyValue[0])) // key name case insensitive
key = strings.ToLower(key)
// handle include "other.conf"
if len(keyValue) == 1 && strings.HasPrefix(key, "include") {
includefiles := strings.Fields(key)
if includefiles[0] == "include" && len(includefiles) == 2 {
otherfile := strings.Trim(includefiles[1], "\"")
if !filepath.IsAbs(otherfile) {
otherfile = filepath.Join(dir, otherfile)
}
i, err := ini.parseFile(otherfile)
if err != nil {
return nil, err
}
for sec, dt := range i.data {
if _, ok := cfg.data[sec]; !ok {
cfg.data[sec] = make(map[string]string)
}
for k, v := range dt {
cfg.data[sec][k] = v
}
}
for sec, comm := range i.sectionComment {
cfg.sectionComment[sec] = comm
}
for k, comm := range i.keyComment {
cfg.keyComment[k] = comm
}
continue
}
}
if len(keyValue) != 2 {
return nil, errors.New("read the content error: \"" + string(line) + "\", should key = val")
}
val := bytes.TrimSpace(keyValue[1])
if bytes.HasPrefix(val, bDQuote) {
val = bytes.Trim(val, `"`)
}
cfg.data[section][key] = ExpandValueEnv(string(val))
if comment.Len() > 0 {
cfg.keyComment[section+"."+key] = comment.String()
comment.Reset()
}
}
return cfg, nil
}
// ParseData parse ini the data
// When include other.conf,other.conf is either absolute directory
// or under beego in default temporary directory(/tmp/beego[-username]).
func (ini *IniConfig) ParseData(data []byte) (Configer, error) {
dir := "beego"
currentUser, err := user.Current()
if err == nil {
dir = "beego-" + currentUser.Username
}
dir = filepath.Join(os.TempDir(), dir)
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
return nil, err
}
return ini.parseData(dir, data)
}
// IniConfigContainer is a config which represents the ini configuration.
// When set and get value, support key as section:name type.
type IniConfigContainer struct {
BaseConfiger
data map[string]map[string]string // section=> key:val
sectionComment map[string]string // section : comment
keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment.
sync.RWMutex
}
// Bool returns the boolean value for a given key.
func (c *IniConfigContainer) Bool(ctx context.Context, key string) (bool, error) {
return ParseBool(c.getdata(key))
}
// DefaultBool returns the boolean value for a given key.
// if err != nil return defaultVal
func (c *IniConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool {
v, err := c.Bool(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// Int returns the integer value for a given key.
func (c *IniConfigContainer) Int(ctx context.Context, key string) (int, error) {
return strconv.Atoi(c.getdata(key))
}
// DefaultInt returns the integer value for a given key.
// if err != nil return defaultVal
func (c *IniConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int {
v, err := c.Int(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// Int64 returns the int64 value for a given key.
func (c *IniConfigContainer) Int64(ctx context.Context, key string) (int64, error) {
return strconv.ParseInt(c.getdata(key), 10, 64)
}
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaultVal
func (c *IniConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 {
v, err := c.Int64(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// Float returns the float value for a given key.
func (c *IniConfigContainer) Float(ctx context.Context, key string) (float64, error) {
return strconv.ParseFloat(c.getdata(key), 64)
}
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaultVal
func (c *IniConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 {
v, err := c.Float(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// String returns the string value for a given key.
func (c *IniConfigContainer) String(ctx context.Context, key string) (string, error) {
return c.getdata(key), nil
}
// DefaultString returns the string value for a given key.
// if err != nil return defaultVal
func (c *IniConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string {
v, err := c.String(nil, key)
if v == "" || err != nil {
return defaultVal
}
return v
}
// Strings returns the []string value for a given key.
// Return nil if config value does not exist or is empty.
func (c *IniConfigContainer) Strings(ctx context.Context, key string) ([]string, error) {
v, err := c.String(nil, key)
if v == "" || err != nil {
return nil, err
}
return strings.Split(v, ";"), nil
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaultVal
func (c *IniConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string {
v, err := c.Strings(ctx, key)
if v == nil || err != nil {
return defaultVal
}
return v
}
// GetSection returns map for the given section
func (c *IniConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v, nil
}
return nil, errors.New("not exist section")
}
// SaveConfigFile save the config into file.
//
// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function.
func (c *IniConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
// Get section or key comments. Fixed #1607
getCommentStr := func(section, key string) string {
var (
comment string
ok bool
)
if len(key) == 0 {
comment, ok = c.sectionComment[section]
} else {
comment, ok = c.keyComment[section+"."+key]
}
if ok {
// Empty comment
if len(comment) == 0 || len(strings.TrimSpace(comment)) == 0 {
return string(bNumComment)
}
prefix := string(bNumComment)
// Add the line head character "#"
return prefix + strings.Replace(comment, lineBreak, lineBreak+prefix, -1)
}
return ""
}
buf := bytes.NewBuffer(nil)
// Save default section at first place
if dt, ok := c.data[defaultSection]; ok {
for key, val := range dt {
if key != " " {
// Write key comments.
if v := getCommentStr(defaultSection, key); len(v) > 0 {
if _, err = buf.WriteString(v + lineBreak); err != nil {
return err
}
}
// Write key and value.
if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil {
return err
}
}
}
// Put a line between sections.
if _, err = buf.WriteString(lineBreak); err != nil {
return err
}
}
// Save named sections
for section, dt := range c.data {
if section != defaultSection {
// Write section comments.
if v := getCommentStr(section, ""); len(v) > 0 {
if _, err = buf.WriteString(v + lineBreak); err != nil {
return err
}
}
// Write section name.
if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil {
return err
}
for key, val := range dt {
if key != " " {
// Write key comments.
if v := getCommentStr(section, key); len(v) > 0 {
if _, err = buf.WriteString(v + lineBreak); err != nil {
return err
}
}
// Write key and value.
if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil {
return err
}
}
}
// Put a line between sections.
if _, err = buf.WriteString(lineBreak); err != nil {
return err
}
}
}
_, err = buf.WriteTo(f)
return err
}
// Set writes a new value for key.
// if write to one section, the key need be "section::key".
// if the section is not existed, it panics.
func (c *IniConfigContainer) Set(ctx context.Context, key, val string) error {
c.Lock()
defer c.Unlock()
if len(key) == 0 {
return errors.New("key is empty")
}
var (
section, k string
sectionKey = strings.Split(strings.ToLower(key), "::")
)
if len(sectionKey) >= 2 {
section = sectionKey[0]
k = sectionKey[1]
} else {
section = defaultSection
k = sectionKey[0]
}
if _, ok := c.data[section]; !ok {
c.data[section] = make(map[string]string)
}
c.data[section][k] = val
return nil
}
// DIY returns the raw value by a given key.
func (c *IniConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) {
if v, ok := c.data[strings.ToLower(key)]; ok {
return v, nil
}
return v, errors.New("key not find")
}
// section.key or key
func (c *IniConfigContainer) getdata(key string) string {
if len(key) == 0 {
return ""
}
c.RLock()
defer c.RUnlock()
var (
section, k string
sectionKey = strings.Split(strings.ToLower(key), "::")
)
if len(sectionKey) >= 2 {
section = sectionKey[0]
k = sectionKey[1]
} else {
section = defaultSection
k = sectionKey[0]
}
if v, ok := c.data[section]; ok {
if vv, ok := v[k]; ok {
return vv
}
}
return ""
}
func init() {
Register("ini", &IniConfig{})
}

View File

@ -0,0 +1,191 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
import (
"fmt"
"io/ioutil"
"os"
"strings"
"testing"
)
func TestIni(t *testing.T) {
var (
inicontext = `
;comment one
#comment two
appname = beeapi
httpport = 8080
mysqlport = 3600
PI = 3.1415976
runmode = "dev"
autorender = false
copyrequestbody = true
session= on
cookieon= off
newreg = OFF
needlogin = ON
enableSession = Y
enableCookie = N
flag = 1
path1 = ${GOPATH}
path2 = ${GOPATH||/home/go}
[demo]
key1="asta"
key2 = "xie"
CaseInsensitive = true
peers = one;two;three
password = ${GOPATH}
`
keyValue = map[string]interface{}{
"appname": "beeapi",
"httpport": 8080,
"mysqlport": int64(3600),
"pi": 3.1415976,
"runmode": "dev",
"autorender": false,
"copyrequestbody": true,
"session": true,
"cookieon": false,
"newreg": false,
"needlogin": true,
"enableSession": true,
"enableCookie": false,
"flag": true,
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"demo::key1": "asta",
"demo::key2": "xie",
"demo::CaseInsensitive": true,
"demo::peers": []string{"one", "two", "three"},
"demo::password": os.Getenv("GOPATH"),
"null": "",
"demo2::key1": "",
"error": "",
"emptystrings": []string{},
}
)
f, err := os.Create("testini.conf")
if err != nil {
t.Fatal(err)
}
_, err = f.WriteString(inicontext)
if err != nil {
f.Close()
t.Fatal(err)
}
f.Close()
defer os.Remove("testini.conf")
iniconf, err := NewConfig("ini", "testini.conf")
if err != nil {
t.Fatal(err)
}
for k, v := range keyValue {
var err error
var value interface{}
switch v.(type) {
case int:
value, err = iniconf.Int(nil, k)
case int64:
value, err = iniconf.Int64(nil, k)
case float64:
value, err = iniconf.Float(nil, k)
case bool:
value, err = iniconf.Bool(nil, k)
case []string:
value, err = iniconf.Strings(nil, k)
case string:
value, err = iniconf.String(nil, k)
default:
value, err = iniconf.DIY(nil, k)
}
if err != nil {
t.Fatalf("get key %q value fail,err %s", k, err)
} else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
t.Fatalf("get key %q value, want %v got %v .", k, v, value)
}
}
if err = iniconf.Set(nil, "name", "astaxie"); err != nil {
t.Fatal(err)
}
res, _ := iniconf.String(nil, "name")
if res != "astaxie" {
t.Fatal("get name error")
}
}
func TestIniSave(t *testing.T) {
const (
inicontext = `
app = app
;comment one
#comment two
# comment three
appname = beeapi
httpport = 8080
# DB Info
# enable db
[dbinfo]
# db type name
# suport mysql,sqlserver
name = mysql
`
saveResult = `
app=app
#comment one
#comment two
# comment three
appname=beeapi
httpport=8080
# DB Info
# enable db
[dbinfo]
# db type name
# suport mysql,sqlserver
name=mysql
`
)
cfg, err := NewConfigData("ini", []byte(inicontext))
if err != nil {
t.Fatal(err)
}
name := "newIniConfig.ini"
if err := cfg.SaveConfigFile(nil, name); err != nil {
t.Fatal(err)
}
defer os.Remove(name)
if data, err := ioutil.ReadFile(name); err != nil {
t.Fatal(err)
} else {
cfgData := string(data)
datas := strings.Split(saveResult, "\n")
for _, line := range datas {
if !strings.Contains(cfgData, line+"\n") {
t.Fatalf("different after save ini config file. need contains %q", line)
}
}
}
}

View File

@ -0,0 +1,313 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package json
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"strconv"
"strings"
"sync"
"github.com/mitchellh/mapstructure"
"github.com/astaxie/beego/pkg/infrastructure/config"
"github.com/astaxie/beego/pkg/infrastructure/logs"
)
// JSONConfig is a json config parser and implements Config interface.
type JSONConfig struct {
}
// Parse returns a ConfigContainer with parsed json config map.
func (js *JSONConfig) Parse(filename string) (config.Configer, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
content, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
return js.ParseData(content)
}
// ParseData returns a ConfigContainer with json string
func (js *JSONConfig) ParseData(data []byte) (config.Configer, error) {
x := &JSONConfigContainer{
data: make(map[string]interface{}),
}
err := json.Unmarshal(data, &x.data)
if err != nil {
var wrappingArray []interface{}
err2 := json.Unmarshal(data, &wrappingArray)
if err2 != nil {
return nil, err
}
x.data["rootArray"] = wrappingArray
}
x.data = config.ExpandValueEnvForMap(x.data)
return x, nil
}
// JSONConfigContainer is a config which represents the json configuration.
// Only when get value, support key as section:name type.
type JSONConfigContainer struct {
data map[string]interface{}
sync.RWMutex
}
func (c *JSONConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error {
sub, err := c.sub(ctx, prefix)
if err != nil {
return err
}
return mapstructure.Decode(sub, obj)
}
func (c *JSONConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) {
sub, err := c.sub(ctx, key)
if err != nil {
return nil, err
}
return &JSONConfigContainer{
data: sub,
}, nil
}
func (c *JSONConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) {
if key == "" {
return c.data, nil
}
value, ok := c.data[key]
if !ok {
return nil, errors.New(fmt.Sprintf("key is not found: %s", key))
}
res, ok := value.(map[string]interface{})
if !ok {
return nil, errors.New(fmt.Sprintf("the type of value is invalid, key: %s", key))
}
return res, nil
}
func (c *JSONConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) {
logs.Warn("unsupported operation")
}
// Bool returns the boolean value for a given key.
func (c *JSONConfigContainer) Bool(ctx context.Context, key string) (bool, error) {
val := c.getData(key)
if val != nil {
return config.ParseBool(val)
}
return false, fmt.Errorf("not exist key: %q", key)
}
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
func (c *JSONConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool {
if v, err := c.Bool(ctx, key); err == nil {
return v
}
return defaultVal
}
// Int returns the integer value for a given key.
func (c *JSONConfigContainer) Int(ctx context.Context, key string) (int, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
return int(v), nil
} else if v, ok := val.(string); ok {
return strconv.Atoi(v)
}
return 0, errors.New("not valid value")
}
return 0, errors.New("not exist key:" + key)
}
// DefaultInt returns the integer value for a given key.
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int {
if v, err := c.Int(ctx, key); err == nil {
return v
}
return defaultVal
}
// Int64 returns the int64 value for a given key.
func (c *JSONConfigContainer) Int64(ctx context.Context, key string) (int64, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
return int64(v), nil
}
return 0, errors.New("not int64 value")
}
return 0, errors.New("not exist key:" + key)
}
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 {
if v, err := c.Int64(ctx, key); err == nil {
return v
}
return defaultVal
}
// Float returns the float value for a given key.
func (c *JSONConfigContainer) Float(ctx context.Context, key string) (float64, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
return v, nil
}
return 0.0, errors.New("not float64 value")
}
return 0.0, errors.New("not exist key:" + key)
}
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 {
if v, err := c.Float(ctx, key); err == nil {
return v
}
return defaultVal
}
// String returns the string value for a given key.
func (c *JSONConfigContainer) String(ctx context.Context, key string) (string, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(string); ok {
return v, nil
}
}
return "", nil
}
// DefaultString returns the string value for a given key.
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string {
// TODO FIXME should not use "" to replace non existence
if v, err := c.String(ctx, key); v != "" && err == nil {
return v
}
return defaultVal
}
// Strings returns the []string value for a given key.
func (c *JSONConfigContainer) Strings(ctx context.Context, key string) ([]string, error) {
stringVal, err := c.String(nil, key)
if stringVal == "" || err != nil {
return nil, err
}
return strings.Split(stringVal, ";"), nil
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string {
if v, err := c.Strings(ctx, key); v != nil && err == nil {
return v
}
return defaultVal
}
// GetSection returns map for the given section
func (c *JSONConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v.(map[string]string), nil
}
return nil, errors.New("nonexist section " + section)
}
// SaveConfigFile save the config into file
func (c *JSONConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
b, err := json.MarshalIndent(c.data, "", " ")
if err != nil {
return err
}
_, err = f.Write(b)
return err
}
// Set writes a new value for key.
func (c *JSONConfigContainer) Set(ctx context.Context, key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
return nil
}
// DIY returns the raw value by a given key.
func (c *JSONConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) {
val := c.getData(key)
if val != nil {
return val, nil
}
return nil, errors.New("not exist key")
}
// section.key or key
func (c *JSONConfigContainer) getData(key string) interface{} {
if len(key) == 0 {
return nil
}
c.RLock()
defer c.RUnlock()
sectionKeys := strings.Split(key, "::")
if len(sectionKeys) >= 2 {
curValue, ok := c.data[sectionKeys[0]]
if !ok {
return nil
}
for _, key := range sectionKeys[1:] {
if v, ok := curValue.(map[string]interface{}); ok {
if curValue, ok = v[key]; !ok {
return nil
}
}
}
return curValue
}
if v, ok := c.data[key]; ok {
return v
}
return nil
}
func init() {
config.Register("json", &JSONConfig{})
}

View File

@ -0,0 +1,252 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package json
import (
"context"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/pkg/infrastructure/config"
)
func TestJsonStartsWithArray(t *testing.T) {
const jsoncontextwitharray = `[
{
"url": "user",
"serviceAPI": "http://www.test.com/user"
},
{
"url": "employee",
"serviceAPI": "http://www.test.com/employee"
}
]`
f, err := os.Create("testjsonWithArray.conf")
if err != nil {
t.Fatal(err)
}
_, err = f.WriteString(jsoncontextwitharray)
if err != nil {
f.Close()
t.Fatal(err)
}
f.Close()
defer os.Remove("testjsonWithArray.conf")
jsonconf, err := config.NewConfig("json", "testjsonWithArray.conf")
if err != nil {
t.Fatal(err)
}
rootArray, err := jsonconf.DIY(nil, "rootArray")
if err != nil {
t.Error("array does not exist as element")
}
rootArrayCasted := rootArray.([]interface{})
if rootArrayCasted == nil {
t.Error("array from root is nil")
} else {
elem := rootArrayCasted[0].(map[string]interface{})
if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" {
t.Error("array[0] values are not valid")
}
elem2 := rootArrayCasted[1].(map[string]interface{})
if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" {
t.Error("array[1] values are not valid")
}
}
}
func TestJson(t *testing.T) {
var (
jsoncontext = `{
"appname": "beeapi",
"testnames": "foo;bar",
"httpport": 8080,
"mysqlport": 3600,
"PI": 3.1415976,
"runmode": "dev",
"autorender": false,
"copyrequestbody": true,
"session": "on",
"cookieon": "off",
"newreg": "OFF",
"needlogin": "ON",
"enableSession": "Y",
"enableCookie": "N",
"flag": 1,
"path1": "${GOPATH}",
"path2": "${GOPATH||/home/go}",
"database": {
"host": "host",
"port": "port",
"database": "database",
"username": "username",
"password": "${GOPATH}",
"conns":{
"maxconnection":12,
"autoconnect":true,
"connectioninfo":"info",
"root": "${GOPATH}"
}
}
}`
keyValue = map[string]interface{}{
"appname": "beeapi",
"testnames": []string{"foo", "bar"},
"httpport": 8080,
"mysqlport": int64(3600),
"PI": 3.1415976,
"runmode": "dev",
"autorender": false,
"copyrequestbody": true,
"session": true,
"cookieon": false,
"newreg": false,
"needlogin": true,
"enableSession": true,
"enableCookie": false,
"flag": true,
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"database::host": "host",
"database::port": "port",
"database::database": "database",
"database::password": os.Getenv("GOPATH"),
"database::conns::maxconnection": 12,
"database::conns::autoconnect": true,
"database::conns::connectioninfo": "info",
"database::conns::root": os.Getenv("GOPATH"),
"unknown": "",
}
)
f, err := os.Create("testjson.conf")
if err != nil {
t.Fatal(err)
}
_, err = f.WriteString(jsoncontext)
if err != nil {
f.Close()
t.Fatal(err)
}
f.Close()
defer os.Remove("testjson.conf")
jsonconf, err := config.NewConfig("json", "testjson.conf")
if err != nil {
t.Fatal(err)
}
for k, v := range keyValue {
var err error
var value interface{}
switch v.(type) {
case int:
value, err = jsonconf.Int(nil, k)
case int64:
value, err = jsonconf.Int64(nil, k)
case float64:
value, err = jsonconf.Float(nil, k)
case bool:
value, err = jsonconf.Bool(nil, k)
case []string:
value, err = jsonconf.Strings(nil, k)
case string:
value, err = jsonconf.String(nil, k)
default:
value, err = jsonconf.DIY(nil, k)
}
if err != nil {
t.Fatalf("get key %q value fatal,%v err %s", k, v, err)
} else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
t.Fatalf("get key %q value, want %v got %v .", k, v, value)
}
}
if err = jsonconf.Set(nil, "name", "astaxie"); err != nil {
t.Fatal(err)
}
res, _ := jsonconf.String(nil, "name")
if res != "astaxie" {
t.Fatal("get name error")
}
if db, err := jsonconf.DIY(nil, "database"); err != nil {
t.Fatal(err)
} else if m, ok := db.(map[string]interface{}); !ok {
t.Log(db)
t.Fatal("db not map[string]interface{}")
} else {
if m["host"].(string) != "host" {
t.Fatal("get host err")
}
}
if _, err := jsonconf.Int(nil, "unknown"); err == nil {
t.Error("unknown keys should return an error when expecting an Int")
}
if _, err := jsonconf.Int64(nil, "unknown"); err == nil {
t.Error("unknown keys should return an error when expecting an Int64")
}
if _, err := jsonconf.Float(nil, "unknown"); err == nil {
t.Error("unknown keys should return an error when expecting a Float")
}
if _, err := jsonconf.DIY(nil, "unknown"); err == nil {
t.Error("unknown keys should return an error when expecting an interface{}")
}
if val, _ := jsonconf.String(nil, "unknown"); val != "" {
t.Error("unknown keys should return an empty string when expecting a String")
}
if _, err := jsonconf.Bool(nil, "unknown"); err == nil {
t.Error("unknown keys should return an error when expecting a Bool")
}
if !jsonconf.DefaultBool(nil, "unknown", true) {
t.Error("unknown keys with default value wrong")
}
sub, err := jsonconf.Sub(context.Background(), "database")
assert.Nil(t, err)
assert.NotNil(t, sub)
sub, err = sub.Sub(context.Background(), "conns")
assert.Nil(t, err)
maxCon, _ := sub.Int(context.Background(), "maxconnection")
assert.Equal(t, 12, maxCon)
dbCfg := &DatabaseConfig{}
err = sub.Unmarshaler(context.Background(), "", dbCfg)
assert.Nil(t, err)
assert.Equal(t, 12, dbCfg.MaxConnection)
assert.True(t, dbCfg.Autoconnect)
assert.Equal(t, "info", dbCfg.Connectioninfo)
}
type DatabaseConfig struct {
MaxConnection int `json:"maxconnection"`
Autoconnect bool `json:"autoconnect"`
Connectioninfo string `json:"connectioninfo"`
}

View File

@ -0,0 +1,276 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package xml for config provider.
//
// depend on github.com/beego/x2j.
//
// go install github.com/beego/x2j.
//
// Usage:
// import(
// _ "github.com/astaxie/beego/config/xml"
// "github.com/astaxie/beego/config"
// )
//
// cnf, err := config.NewConfig("xml", "config.xml")
//
// More docs http://beego.me/docs/module/config.md
package xml
import (
"context"
"encoding/xml"
"errors"
"fmt"
"io/ioutil"
"os"
"strconv"
"strings"
"sync"
"github.com/mitchellh/mapstructure"
"github.com/astaxie/beego/pkg/infrastructure/config"
"github.com/astaxie/beego/pkg/infrastructure/logs"
"github.com/beego/x2j"
)
// Config is a xml config parser and implements Config interface.
// xml configurations should be included in <config></config> tag.
// only support key/value pair as <key>value</key> as each item.
type Config struct{}
// Parse returns a ConfigContainer with parsed xml config map.
func (xc *Config) Parse(filename string) (config.Configer, error) {
context, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
return xc.ParseData(context)
}
// ParseData xml data
func (xc *Config) ParseData(data []byte) (config.Configer, error) {
x := &ConfigContainer{data: make(map[string]interface{})}
d, err := x2j.DocToMap(string(data))
if err != nil {
return nil, err
}
x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{}))
return x, nil
}
// ConfigContainer is a Config which represents the xml configuration.
type ConfigContainer struct {
data map[string]interface{}
sync.Mutex
}
// Unmarshaler is a little be inconvenient since the xml library doesn't know type.
// So when you use
// <id>1</id>
// The "1" is a string, not int
func (c *ConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error {
sub, err := c.sub(ctx, prefix)
if err != nil {
return err
}
return mapstructure.Decode(sub, obj)
}
func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) {
sub, err := c.sub(ctx, key)
if err != nil {
return nil, err
}
return &ConfigContainer{
data: sub,
}, nil
}
func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) {
if key == "" {
return c.data, nil
}
value, ok := c.data[key]
if !ok {
return nil, errors.New(fmt.Sprintf("the key is not found: %s", key))
}
res, ok := value.(map[string]interface{})
if !ok {
return nil, errors.New(fmt.Sprintf("the value of this key is not a structure: %s", key))
}
return res, nil
}
func (c *ConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) {
logs.Warn("Unsupported operation")
}
// Bool returns the boolean value for a given key.
func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) {
if v := c.data[key]; v != nil {
return config.ParseBool(v)
}
return false, fmt.Errorf("not exist key: %q", key)
}
// DefaultBool return the bool value if has no error
// otherwise return the defaultVal
func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool {
v, err := c.Bool(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// Int returns the integer value for a given key.
func (c *ConfigContainer) Int(ctx context.Context, key string) (int, error) {
return strconv.Atoi(c.data[key].(string))
}
// DefaultInt returns the integer value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int {
v, err := c.Int(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// Int64 returns the int64 value for a given key.
func (c *ConfigContainer) Int64(ctx context.Context, key string) (int64, error) {
return strconv.ParseInt(c.data[key].(string), 10, 64)
}
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 {
v, err := c.Int64(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// Float returns the float value for a given key.
func (c *ConfigContainer) Float(ctx context.Context, key string) (float64, error) {
return strconv.ParseFloat(c.data[key].(string), 64)
}
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 {
v, err := c.Float(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// String returns the string value for a given key.
func (c *ConfigContainer) String(ctx context.Context, key string) (string, error) {
if v, ok := c.data[key].(string); ok {
return v, nil
}
return "", nil
}
// DefaultString returns the string value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string {
v, err := c.String(ctx, key)
if v == "" || err != nil {
return defaultVal
}
return v
}
// Strings returns the []string value for a given key.
func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, error) {
v, err := c.String(ctx, key)
if v == "" || err != nil {
return nil, err
}
return strings.Split(v, ";"), nil
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string {
v, err := c.Strings(ctx, key)
if v == nil || err != nil {
return defaultVal
}
return v
}
// GetSection returns map for the given section
func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) {
if v, ok := c.data[section].(map[string]interface{}); ok {
mapstr := make(map[string]string)
for k, val := range v {
mapstr[k] = config.ToString(val)
}
return mapstr, nil
}
return nil, fmt.Errorf("section '%s' not found", section)
}
// SaveConfigFile save the config into file
func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
b, err := xml.MarshalIndent(c.data, " ", " ")
if err != nil {
return err
}
_, err = f.Write(b)
return err
}
// Set writes a new value for key.
func (c *ConfigContainer) Set(ctx context.Context, key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
return nil
}
// DIY returns the raw value by a given key.
func (c *ConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok {
return v, nil
}
return nil, errors.New("not exist key")
}
func init() {
config.Register("xml", &Config{})
}

View File

@ -0,0 +1,158 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package xml
import (
"context"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/pkg/infrastructure/config"
)
func TestXML(t *testing.T) {
var (
// xml parse should incluce in <config></config> tags
xmlcontext = `<?xml version="1.0" encoding="UTF-8"?>
<config>
<appname>beeapi</appname>
<httpport>8080</httpport>
<mysqlport>3600</mysqlport>
<PI>3.1415976</PI>
<runmode>dev</runmode>
<autorender>false</autorender>
<copyrequestbody>true</copyrequestbody>
<path1>${GOPATH}</path1>
<path2>${GOPATH||/home/go}</path2>
<mysection>
<id>1</id>
<name>MySection</name>
</mysection>
</config>
`
keyValue = map[string]interface{}{
"appname": "beeapi",
"httpport": 8080,
"mysqlport": int64(3600),
"PI": 3.1415976,
"runmode": "dev",
"autorender": false,
"copyrequestbody": true,
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"error": "",
"emptystrings": []string{},
}
)
f, err := os.Create("testxml.conf")
if err != nil {
t.Fatal(err)
}
_, err = f.WriteString(xmlcontext)
if err != nil {
f.Close()
t.Fatal(err)
}
f.Close()
defer os.Remove("testxml.conf")
xmlconf, err := config.NewConfig("xml", "testxml.conf")
if err != nil {
t.Fatal(err)
}
var xmlsection map[string]string
xmlsection, err = xmlconf.GetSection(nil, "mysection")
if err != nil {
t.Fatal(err)
}
if len(xmlsection) == 0 {
t.Error("section should not be empty")
}
for k, v := range keyValue {
var (
value interface{}
err error
)
switch v.(type) {
case int:
value, err = xmlconf.Int(nil, k)
case int64:
value, err = xmlconf.Int64(nil, k)
case float64:
value, err = xmlconf.Float(nil, k)
case bool:
value, err = xmlconf.Bool(nil, k)
case []string:
value, err = xmlconf.Strings(nil, k)
case string:
value, err = xmlconf.String(nil, k)
default:
value, err = xmlconf.DIY(nil, k)
}
if err != nil {
t.Errorf("get key %q value fatal,%v err %s", k, v, err)
} else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
t.Errorf("get key %q value, want %v got %v .", k, v, value)
}
}
if err = xmlconf.Set(nil, "name", "astaxie"); err != nil {
t.Fatal(err)
}
res, _ := xmlconf.String(context.Background(), "name")
if res != "astaxie" {
t.Fatal("get name error")
}
sub, err := xmlconf.Sub(context.Background(), "mysection")
assert.Nil(t, err)
assert.NotNil(t, sub)
name, err := sub.String(context.Background(), "name")
assert.Nil(t, err)
assert.Equal(t, "MySection", name)
id, err := sub.Int(context.Background(), "id")
assert.Nil(t, err)
assert.Equal(t, 1, id)
sec := &Section{}
err = sub.Unmarshaler(context.Background(), "", sec)
assert.Nil(t, err)
assert.Equal(t, "MySection", sec.Name)
sec = &Section{}
err = xmlconf.Unmarshaler(context.Background(), "mysection", sec)
assert.Nil(t, err)
assert.Equal(t, "MySection", sec.Name)
}
type Section struct {
Name string `xml:"name"`
}

View File

@ -0,0 +1,375 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package yaml for config provider
//
// depend on github.com/beego/goyaml2
//
// go install github.com/beego/goyaml2
//
// Usage:
// import(
// _ "github.com/astaxie/beego/config/yaml"
// "github.com/astaxie/beego/config"
// )
//
// cnf, err := config.NewConfig("yaml", "config.yaml")
//
// More docs http://beego.me/docs/module/config.md
package yaml
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"strings"
"sync"
"github.com/beego/goyaml2"
"gopkg.in/yaml.v2"
"github.com/astaxie/beego/pkg/infrastructure/config"
"github.com/astaxie/beego/pkg/infrastructure/logs"
)
// Config is a yaml config parser and implements Config interface.
type Config struct{}
// Parse returns a ConfigContainer with parsed yaml config map.
func (yaml *Config) Parse(filename string) (y config.Configer, err error) {
cnf, err := ReadYmlReader(filename)
if err != nil {
return
}
y = &ConfigContainer{
data: cnf,
}
return
}
// ParseData parse yaml data
func (yaml *Config) ParseData(data []byte) (config.Configer, error) {
cnf, err := parseYML(data)
if err != nil {
return nil, err
}
return &ConfigContainer{
data: cnf,
}, nil
}
// ReadYmlReader Read yaml file to map.
// if json like, use json package, unless goyaml2 package.
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
buf, err := ioutil.ReadFile(path)
if err != nil {
return
}
return parseYML(buf)
}
// parseYML parse yaml formatted []byte to map.
func parseYML(buf []byte) (cnf map[string]interface{}, err error) {
if len(buf) < 3 {
return
}
if string(buf[0:1]) == "{" {
log.Println("Look like a Json, try json umarshal")
err = json.Unmarshal(buf, &cnf)
if err == nil {
log.Println("It is Json Map")
return
}
}
data, err := goyaml2.Read(bytes.NewReader(buf))
if err != nil {
log.Println("Goyaml2 ERR>", string(buf), err)
return
}
if data == nil {
log.Println("Goyaml2 output nil? Pls report bug\n" + string(buf))
return
}
cnf, ok := data.(map[string]interface{})
if !ok {
log.Println("Not a Map? >> ", string(buf), data)
cnf = nil
}
cnf = config.ExpandValueEnvForMap(cnf)
return
}
// ConfigContainer is a config which represents the yaml configuration.
type ConfigContainer struct {
data map[string]interface{}
sync.RWMutex
}
// Unmarshaler is similar to Sub
func (c *ConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error {
sub, err := c.sub(ctx, prefix)
if err != nil {
return err
}
bytes, err := yaml.Marshal(sub)
if err != nil {
return err
}
return yaml.Unmarshal(bytes, obj)
}
func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) {
sub, err := c.sub(ctx, key)
if err != nil {
return nil, err
}
return &ConfigContainer{
data: sub,
}, nil
}
func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) {
tmpData := c.data
keys := strings.Split(key, ".")
for idx, k := range keys {
if v, ok := tmpData[k]; ok {
switch v.(type) {
case map[string]interface{}:
{
tmpData = v.(map[string]interface{})
if idx == len(keys)-1 {
return tmpData, nil
}
}
default:
return nil, errors.New(fmt.Sprintf("the key is invalid: %s", key))
}
}
}
return tmpData, nil
}
func (c *ConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) {
// do nothing
logs.Warn("Unsupported operation: OnChange")
}
// Bool returns the boolean value for a given key.
func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) {
v, err := c.getData(key)
if err != nil {
return false, err
}
return config.ParseBool(v)
}
// DefaultBool return the bool value if has no error
// otherwise return the defaultVal
func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool {
v, err := c.Bool(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// Int returns the integer value for a given key.
func (c *ConfigContainer) Int(ctx context.Context, key string) (int, error) {
if v, err := c.getData(key); err != nil {
return 0, err
} else if vv, ok := v.(int); ok {
return vv, nil
} else if vv, ok := v.(int64); ok {
return int(vv), nil
}
return 0, errors.New("not int value")
}
// DefaultInt returns the integer value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int {
v, err := c.Int(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// Int64 returns the int64 value for a given key.
func (c *ConfigContainer) Int64(ctx context.Context, key string) (int64, error) {
if v, err := c.getData(key); err != nil {
return 0, err
} else if vv, ok := v.(int64); ok {
return vv, nil
}
return 0, errors.New("not bool value")
}
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 {
v, err := c.Int64(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// Float returns the float value for a given key.
func (c *ConfigContainer) Float(ctx context.Context, key string) (float64, error) {
if v, err := c.getData(key); err != nil {
return 0.0, err
} else if vv, ok := v.(float64); ok {
return vv, nil
} else if vv, ok := v.(int); ok {
return float64(vv), nil
} else if vv, ok := v.(int64); ok {
return float64(vv), nil
}
return 0.0, errors.New("not float64 value")
}
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 {
v, err := c.Float(ctx, key)
if err != nil {
return defaultVal
}
return v
}
// String returns the string value for a given key.
func (c *ConfigContainer) String(ctx context.Context, key string) (string, error) {
if v, err := c.getData(key); err == nil {
if vv, ok := v.(string); ok {
return vv, nil
}
}
return "", nil
}
// DefaultString returns the string value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string {
v, err := c.String(nil, key)
if v == "" || err != nil {
return defaultVal
}
return v
}
// Strings returns the []string value for a given key.
func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, error) {
v, err := c.String(nil, key)
if v == "" || err != nil {
return nil, err
}
return strings.Split(v, ";"), nil
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaultVal
func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string {
v, err := c.Strings(ctx, key)
if v == nil || err != nil {
return defaultVal
}
return v
}
// GetSection returns map for the given section
func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v.(map[string]string), nil
}
return nil, errors.New("not exist section")
}
// SaveConfigFile save the config into file
func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
err = goyaml2.Write(f, c.data)
return err
}
// Set writes a new value for key.
func (c *ConfigContainer) Set(ctx context.Context, key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
return nil
}
// DIY returns the raw value by a given key.
func (c *ConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) {
return c.getData(key)
}
func (c *ConfigContainer) getData(key string) (interface{}, error) {
if len(key) == 0 {
return nil, errors.New("key is empty")
}
c.RLock()
defer c.RUnlock()
keys := strings.Split(c.key(key), ".")
tmpData := c.data
for idx, k := range keys {
if v, ok := tmpData[k]; ok {
switch v.(type) {
case map[string]interface{}:
{
tmpData = v.(map[string]interface{})
if idx == len(keys)-1 {
return tmpData, nil
}
}
default:
{
return v, nil
}
}
}
}
return nil, fmt.Errorf("not exist key %q", key)
}
func (c *ConfigContainer) key(key string) string {
return key
}
func init() {
config.Register("yaml", &Config{})
}

View File

@ -0,0 +1,152 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package yaml
import (
"context"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/pkg/infrastructure/config"
)
func TestYaml(t *testing.T) {
var (
yamlcontext = `
"appname": beeapi
"httpport": 8080
"mysqlport": 3600
"PI": 3.1415976
"runmode": dev
"autorender": false
"copyrequestbody": true
"PATH": GOPATH
"path1": ${GOPATH}
"path2": ${GOPATH||/home/go}
"empty": ""
"user":
"name": "tom"
"age": 13
`
keyValue = map[string]interface{}{
"appname": "beeapi",
"httpport": 8080,
"mysqlport": int64(3600),
"PI": 3.1415976,
"runmode": "dev",
"autorender": false,
"copyrequestbody": true,
"PATH": "GOPATH",
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"error": "",
"emptystrings": []string{},
}
)
f, err := os.Create("testyaml.conf")
if err != nil {
t.Fatal(err)
}
_, err = f.WriteString(yamlcontext)
if err != nil {
f.Close()
t.Fatal(err)
}
f.Close()
defer os.Remove("testyaml.conf")
yamlconf, err := config.NewConfig("yaml", "testyaml.conf")
if err != nil {
t.Fatal(err)
}
res, _ := yamlconf.String(nil, "appname")
if res != "beeapi" {
t.Fatal("appname not equal to beeapi")
}
for k, v := range keyValue {
var (
value interface{}
err error
)
switch v.(type) {
case int:
value, err = yamlconf.Int(nil, k)
case int64:
value, err = yamlconf.Int64(nil, k)
case float64:
value, err = yamlconf.Float(nil, k)
case bool:
value, err = yamlconf.Bool(nil, k)
case []string:
value, err = yamlconf.Strings(nil, k)
case string:
value, err = yamlconf.String(nil, k)
default:
value, err = yamlconf.DIY(nil, k)
}
if err != nil {
t.Errorf("get key %q value fatal,%v err %s", k, v, err)
} else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
t.Errorf("get key %q value, want %v got %v .", k, v, value)
}
}
if err = yamlconf.Set(nil, "name", "astaxie"); err != nil {
t.Fatal(err)
}
res, _ = yamlconf.String(nil, "name")
if res != "astaxie" {
t.Fatal("get name error")
}
sub, err := yamlconf.Sub(context.Background(), "user")
assert.Nil(t, err)
assert.NotNil(t, sub)
name, err := sub.String(context.Background(), "name")
assert.Nil(t, err)
assert.Equal(t, "tom", name)
age, err := sub.Int(context.Background(), "age")
assert.Nil(t, err)
assert.Equal(t, 13, age)
user := &User{}
err = sub.Unmarshaler(context.Background(), "", user)
assert.Nil(t, err)
assert.Equal(t, "tom", user.Name)
assert.Equal(t, 13, user.Age)
user = &User{}
err = yamlconf.Unmarshaler(context.Background(), "user", user)
assert.Nil(t, err)
assert.Equal(t, "tom", user.Name)
assert.Equal(t, 13, user.Age)
}
type User struct {
Name string `yaml:"name"`
Age int `yaml:"age"`
}

View File

@ -0,0 +1,48 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package governor healthcheck
//
// type DatabaseCheck struct {
// }
//
// func (dc *DatabaseCheck) Check() error {
// if dc.isConnected() {
// return nil
// } else {
// return errors.New("can't connect database")
// }
// }
//
// AddHealthCheck("database",&DatabaseCheck{})
//
// more docs: http://beego.me/docs/module/toolbox.md
package governor
// AdminCheckList holds health checker map
var AdminCheckList map[string]HealthChecker
// HealthChecker health checker interface
type HealthChecker interface {
Check() error
}
// AddHealthCheck add health checker with name string
func AddHealthCheck(name string, hc HealthChecker) {
AdminCheckList[name] = hc
}
func init() {
AdminCheckList = make(map[string]HealthChecker)
}

View File

@ -0,0 +1,158 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package governor
import (
"fmt"
"io"
"log"
"os"
"path"
"runtime"
"runtime/debug"
"runtime/pprof"
"strconv"
"time"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
var startTime = time.Now()
var pid int
func init() {
pid = os.Getpid()
}
// ProcessInput parse input command string
func ProcessInput(input string, w io.Writer) {
switch input {
case "lookup goroutine":
p := pprof.Lookup("goroutine")
p.WriteTo(w, 2)
case "lookup heap":
p := pprof.Lookup("heap")
p.WriteTo(w, 2)
case "lookup threadcreate":
p := pprof.Lookup("threadcreate")
p.WriteTo(w, 2)
case "lookup block":
p := pprof.Lookup("block")
p.WriteTo(w, 2)
case "get cpuprof":
GetCPUProfile(w)
case "get memprof":
MemProf(w)
case "gc summary":
PrintGCSummary(w)
}
}
// MemProf record memory profile in pprof
func MemProf(w io.Writer) {
filename := "mem-" + strconv.Itoa(pid) + ".memprof"
if f, err := os.Create(filename); err != nil {
fmt.Fprintf(w, "create file %s error %s\n", filename, err.Error())
log.Fatal("record heap profile failed: ", err)
} else {
runtime.GC()
pprof.WriteHeapProfile(f)
f.Close()
fmt.Fprintf(w, "create heap profile %s \n", filename)
_, fl := path.Split(os.Args[0])
fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename)
}
}
// GetCPUProfile start cpu profile monitor
func GetCPUProfile(w io.Writer) {
sec := 30
filename := "cpu-" + strconv.Itoa(pid) + ".pprof"
f, err := os.Create(filename)
if err != nil {
fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err)
log.Fatal("record cpu profile failed: ", err)
}
pprof.StartCPUProfile(f)
time.Sleep(time.Duration(sec) * time.Second)
pprof.StopCPUProfile()
fmt.Fprintf(w, "create cpu profile %s \n", filename)
_, fl := path.Split(os.Args[0])
fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename)
}
// PrintGCSummary print gc information to io.Writer
func PrintGCSummary(w io.Writer) {
memStats := &runtime.MemStats{}
runtime.ReadMemStats(memStats)
gcstats := &debug.GCStats{PauseQuantiles: make([]time.Duration, 100)}
debug.ReadGCStats(gcstats)
printGC(memStats, gcstats, w)
}
func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) {
if gcstats.NumGC > 0 {
lastPause := gcstats.Pause[0]
elapsed := time.Now().Sub(startTime)
overhead := float64(gcstats.PauseTotal) / float64(elapsed) * 100
allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds()
fmt.Fprintf(w, "NumGC:%d Pause:%s Pause(Avg):%s Overhead:%3.2f%% Alloc:%s Sys:%s Alloc(Rate):%s/s Histogram:%s %s %s \n",
gcstats.NumGC,
utils.ToShortTimeFormat(lastPause),
utils.ToShortTimeFormat(avg(gcstats.Pause)),
overhead,
toH(memStats.Alloc),
toH(memStats.Sys),
toH(uint64(allocatedRate)),
utils.ToShortTimeFormat(gcstats.PauseQuantiles[94]),
utils.ToShortTimeFormat(gcstats.PauseQuantiles[98]),
utils.ToShortTimeFormat(gcstats.PauseQuantiles[99]))
} else {
// while GC has disabled
elapsed := time.Now().Sub(startTime)
allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds()
fmt.Fprintf(w, "Alloc:%s Sys:%s Alloc(Rate):%s/s\n",
toH(memStats.Alloc),
toH(memStats.Sys),
toH(uint64(allocatedRate)))
}
}
func avg(items []time.Duration) time.Duration {
var sum time.Duration
for _, item := range items {
sum += item
}
return time.Duration(int64(sum) / int64(len(items)))
}
// format bytes number friendly
func toH(bytes uint64) string {
switch {
case bytes < 1024:
return fmt.Sprintf("%dB", bytes)
case bytes < 1024*1024:
return fmt.Sprintf("%.2fK", float64(bytes)/1024)
case bytes < 1024*1024*1024:
return fmt.Sprintf("%.2fM", float64(bytes)/1024/1024)
default:
return fmt.Sprintf("%.2fG", float64(bytes)/1024/1024/1024)
}
}

View File

@ -0,0 +1,28 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package governor
import (
"os"
"testing"
)
func TestProcessInput(t *testing.T) {
ProcessInput("lookup goroutine", os.Stdout)
ProcessInput("lookup heap", os.Stdout)
ProcessInput("lookup threadcreate", os.Stdout)
ProcessInput("lookup block", os.Stdout)
ProcessInput("gc summary", os.Stdout)
}

View File

@ -0,0 +1,72 @@
## logs
logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` .
## How to install?
go get github.com/astaxie/beego/logs
## What adapters are supported?
As of now this logs support console, file,smtp and conn.
## How to use it?
First you must import it
```golang
import (
"github.com/astaxie/beego/logs"
)
```
Then init a Log (example with console adapter)
```golang
log := logs.NewLogger(10000)
log.SetLogger("console", "")
```
> the first params stand for how many channel
Use it like this:
```golang
log.Trace("trace")
log.Info("info")
log.Warn("warning")
log.Debug("debug")
log.Critical("critical")
```
## File adapter
Configure file adapter like this:
```golang
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`)
```
## Conn adapter
Configure like this:
```golang
log := NewLogger(1000)
log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`)
log.Info("info")
```
## Smtp adapter
Configure like this:
```golang
log := NewLogger(10000)
log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`)
log.Critical("sendmail critical")
time.Sleep(time.Second * 30)
```

View File

@ -0,0 +1,88 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
)
const (
apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s"
apacheFormat = "APACHE_FORMAT"
jsonFormat = "JSON_FORMAT"
)
// AccessLogRecord is astruct for holding access log data.
type AccessLogRecord struct {
RemoteAddr string `json:"remote_addr"`
RequestTime time.Time `json:"request_time"`
RequestMethod string `json:"request_method"`
Request string `json:"request"`
ServerProtocol string `json:"server_protocol"`
Host string `json:"host"`
Status int `json:"status"`
BodyBytesSent int64 `json:"body_bytes_sent"`
ElapsedTime time.Duration `json:"elapsed_time"`
HTTPReferrer string `json:"http_referrer"`
HTTPUserAgent string `json:"http_user_agent"`
RemoteUser string `json:"remote_user"`
}
func (r *AccessLogRecord) json() ([]byte, error) {
buffer := &bytes.Buffer{}
encoder := json.NewEncoder(buffer)
disableEscapeHTML(encoder)
err := encoder.Encode(r)
return buffer.Bytes(), err
}
func disableEscapeHTML(i interface{}) {
if e, ok := i.(interface {
SetEscapeHTML(bool)
}); ok {
e.SetEscapeHTML(false)
}
}
// AccessLog - Format and print access log.
func AccessLog(r *AccessLogRecord, format string) {
var msg string
switch format {
case apacheFormat:
timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05")
msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent,
r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent)
case jsonFormat:
fallthrough
default:
jsonData, err := r.json()
if err != nil {
msg = fmt.Sprintf(`{"Error": "%s"}`, err)
} else {
msg = string(jsonData)
}
}
lm := &LogMsg{
Msg: strings.TrimSpace(msg),
When: time.Now(),
Level: levelLoggerImpl,
}
beeLogger.writeMsg(lm)
}

View File

@ -0,0 +1,204 @@
package alils
import (
"encoding/json"
"strings"
"sync"
"github.com/astaxie/beego/pkg/infrastructure/logs"
"github.com/astaxie/beego/pkg/infrastructure/utils"
"github.com/gogo/protobuf/proto"
)
const (
// CacheSize sets the flush size
CacheSize int = 64
// Delimiter defines the topic delimiter
Delimiter string = "##"
)
// Config is the Config for Ali Log
type Config struct {
Project string `json:"project"`
Endpoint string `json:"endpoint"`
KeyID string `json:"key_id"`
KeySecret string `json:"key_secret"`
LogStore string `json:"log_store"`
Topics []string `json:"topics"`
Source string `json:"source"`
Level int `json:"level"`
FlushWhen int `json:"flush_when"`
}
// aliLSWriter implements LoggerInterface.
// Writes messages in keep-live tcp connection.
type aliLSWriter struct {
store *LogStore
group []*LogGroup
withMap bool
groupMap map[string]*LogGroup
lock *sync.Mutex
customFormatter func(*logs.LogMsg) string
Config
}
// NewAliLS creates a new Logger
func NewAliLS() logs.Logger {
alils := new(aliLSWriter)
alils.Level = logs.LevelTrace
return alils
}
// Init parses config and initializes struct
func (c *aliLSWriter) Init(jsonConfig string, opts ...utils.KV) error {
for _, elem := range opts {
if elem.GetKey() == "formatter" {
formatter, err := logs.GetFormatter(elem)
if err != nil {
return err
}
c.customFormatter = formatter
}
}
json.Unmarshal([]byte(jsonConfig), c)
if c.FlushWhen > CacheSize {
c.FlushWhen = CacheSize
}
prj := &LogProject{
Name: c.Project,
Endpoint: c.Endpoint,
AccessKeyID: c.KeyID,
AccessKeySecret: c.KeySecret,
}
store, err := prj.GetLogStore(c.LogStore)
if err != nil {
return err
}
c.store = store
// Create default Log Group
c.group = append(c.group, &LogGroup{
Topic: proto.String(""),
Source: proto.String(c.Source),
Logs: make([]*Log, 0, c.FlushWhen),
})
// Create other Log Group
c.groupMap = make(map[string]*LogGroup)
for _, topic := range c.Topics {
lg := &LogGroup{
Topic: proto.String(topic),
Source: proto.String(c.Source),
Logs: make([]*Log, 0, c.FlushWhen),
}
c.group = append(c.group, lg)
c.groupMap[topic] = lg
}
if len(c.group) == 1 {
c.withMap = false
} else {
c.withMap = true
}
c.lock = &sync.Mutex{}
return nil
}
func (c *aliLSWriter) Format(lm *logs.LogMsg) string {
return lm.Msg
}
// WriteMsg writes a message in connection.
// If connection is down, try to re-connect.
func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error {
if lm.Level > c.Level {
return nil
}
var topic string
var content string
var lg *LogGroup
if c.withMap {
// TopicLogGroup
strs := strings.SplitN(lm.Msg, Delimiter, 2)
if len(strs) == 2 {
pos := strings.LastIndex(strs[0], " ")
topic = strs[0][pos+1 : len(strs[0])]
lg = c.groupMap[topic]
}
// send to empty Topic
if lg == nil {
lg = c.group[0]
}
} else {
lg = c.group[0]
}
if c.customFormatter != nil {
content = c.customFormatter(lm)
} else {
content = c.Format(lm)
}
c1 := &LogContent{
Key: proto.String("msg"),
Value: proto.String(content),
}
l := &Log{
Time: proto.Uint32(uint32(lm.When.Unix())),
Contents: []*LogContent{
c1,
},
}
c.lock.Lock()
lg.Logs = append(lg.Logs, l)
c.lock.Unlock()
if len(lg.Logs) >= c.FlushWhen {
c.flush(lg)
}
return nil
}
// Flush implementing method. empty.
func (c *aliLSWriter) Flush() {
// flush all group
for _, lg := range c.group {
c.flush(lg)
}
}
// Destroy destroy connection writer and close tcp listener.
func (c *aliLSWriter) Destroy() {
}
func (c *aliLSWriter) flush(lg *LogGroup) {
c.lock.Lock()
defer c.lock.Unlock()
err := c.store.PutLogs(lg)
if err != nil {
return
}
lg.Logs = make([]*Log, 0, c.FlushWhen)
}
func init() {
logs.Register(logs.AdapterAliLS, NewAliLS)
}

View File

@ -0,0 +1,13 @@
package alils
const (
version = "0.5.0" // SDK version
signatureMethod = "hmac-sha1" // Signature method
// OffsetNewest is the log head offset, i.e. the offset that will be
// assigned to the next message that will be produced to the shard.
OffsetNewest = "end"
// OffsetOldest is the the oldest offset available on the logstore for a
// shard.
OffsetOldest = "begin"
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,42 @@
package alils
// InputDetail defines log detail
type InputDetail struct {
LogType string `json:"logType"`
LogPath string `json:"logPath"`
FilePattern string `json:"filePattern"`
LocalStorage bool `json:"localStorage"`
TimeFormat string `json:"timeFormat"`
LogBeginRegex string `json:"logBeginRegex"`
Regex string `json:"regex"`
Keys []string `json:"key"`
FilterKeys []string `json:"filterKey"`
FilterRegex []string `json:"filterRegex"`
TopicFormat string `json:"topicFormat"`
}
// OutputDetail defines the output detail
type OutputDetail struct {
Endpoint string `json:"endpoint"`
LogStoreName string `json:"logstoreName"`
}
// LogConfig defines Log Config
type LogConfig struct {
Name string `json:"configName"`
InputType string `json:"inputType"`
InputDetail InputDetail `json:"inputDetail"`
OutputType string `json:"outputType"`
OutputDetail OutputDetail `json:"outputDetail"`
CreateTime uint32
LastModifyTime uint32
project *LogProject
}
// GetAppliedMachineGroup returns applied machine group of this config.
func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) {
groupNames, err = c.project.GetAppliedMachineGroups(c.Name)
return
}

View File

@ -0,0 +1,819 @@
/*
Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS).
For more description about SLS, please read this article:
http://gitlab.alibaba-inc.com/sls/doc.
*/
package alils
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
)
// Error message in SLS HTTP response.
type errorMessage struct {
Code string `json:"errorCode"`
Message string `json:"errorMessage"`
}
// LogProject defines the Ali Project detail
type LogProject struct {
Name string // Project name
Endpoint string // IP or hostname of SLS endpoint
AccessKeyID string
AccessKeySecret string
}
// NewLogProject creates a new SLS project.
func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) {
p = &LogProject{
Name: name,
Endpoint: endpoint,
AccessKeyID: AccessKeyID,
AccessKeySecret: accessKeySecret,
}
return p, nil
}
// ListLogStore returns all logstore names of project p.
func (p *LogProject) ListLogStore() (storeNames []string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/logstores")
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to list logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
Count int
LogStores []string
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
storeNames = body.LogStores
return
}
// GetLogStore returns logstore according by logstore name.
func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "GET", "/logstores/"+name, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
s = &LogStore{}
err = json.Unmarshal(buf, s)
if err != nil {
return
}
s.project = p
return
}
// CreateLogStore creates a new logstore in SLS,
// where name is logstore name,
// and ttl is time-to-live(in day) of logs,
// and shardCnt is the number of shards.
func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) {
type Body struct {
Name string `json:"logstoreName"`
TTL int `json:"ttl"`
ShardCount int `json:"shardCount"`
}
store := &Body{
Name: name,
TTL: ttl,
ShardCount: shardCnt,
}
body, err := json.Marshal(store)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "POST", "/logstores", h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to create logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// DeleteLogStore deletes a logstore according by logstore name.
func (p *LogProject) DeleteLogStore(name string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "DELETE", "/logstores/"+name, h, nil)
if err != nil {
return
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// UpdateLogStore updates a logstore according by logstore name,
// obviously we can't modify the logstore name itself.
func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) {
type Body struct {
Name string `json:"logstoreName"`
TTL int `json:"ttl"`
ShardCount int `json:"shardCount"`
}
store := &Body{
Name: name,
TTL: ttl,
ShardCount: shardCnt,
}
body, err := json.Marshal(store)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "PUT", "/logstores", h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to update logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// ListMachineGroup returns machine group name list and the total number of machine groups.
// The offset starts from 0 and the size is the max number of machine groups could be returned.
func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
if size <= 0 {
size = 500
}
uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size)
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to list machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
MachineGroups []string
Count int
Total int
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
m = body.MachineGroups
total = body.Total
return
}
// GetMachineGroup retruns machine group according by machine group name.
func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "GET", "/machinegroups/"+name, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get machine group:%v", name)
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
m = &MachineGroup{}
err = json.Unmarshal(buf, m)
if err != nil {
return
}
m.project = p
return
}
// CreateMachineGroup creates a new machine group in SLS.
func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) {
body, err := json.Marshal(m)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "POST", "/machinegroups", h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to create machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// UpdateMachineGroup updates a machine group.
func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) {
body, err := json.Marshal(m)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to update machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// DeleteMachineGroup deletes machine group according machine group name.
func (p *LogProject) DeleteMachineGroup(name string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil)
if err != nil {
return
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// ListConfig returns config names list and the total number of configs.
// The offset starts from 0 and the size is the max number of configs could be returned.
func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
if size <= 0 {
size = 100
}
uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size)
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
Total int
Configs []string
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
cfgNames = body.Configs
total = body.Total
return
}
// GetConfig returns config according by config name.
func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "GET", "/configs/"+name, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete config")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
c = &LogConfig{}
err = json.Unmarshal(buf, c)
if err != nil {
return
}
c.project = p
return
}
// UpdateConfig updates a config.
func (p *LogProject) UpdateConfig(c *LogConfig) (err error) {
body, err := json.Marshal(c)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "PUT", "/configs/"+c.Name, h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to update config")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// CreateConfig creates a new config in SLS.
func (p *LogProject) CreateConfig(c *LogConfig) (err error) {
body, err := json.Marshal(c)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "POST", "/configs", h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to update config")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// DeleteConfig deletes a config according by config name.
func (p *LogProject) DeleteConfig(name string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "DELETE", "/configs/"+name, h, nil)
if err != nil {
return
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete config")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// GetAppliedMachineGroups returns applied machine group names list according config name.
func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/configs/%v/machinegroups", confName)
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get applied machine groups")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
Count int
Machinegroups []string
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
groupNames = body.Machinegroups
return
}
// GetAppliedConfigs returns applied config names list according machine group name groupName.
func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/machinegroups/%v/configs", groupName)
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to applied configs")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Cfg struct {
Count int `json:"count"`
Configs []string `json:"configs"`
}
body := &Cfg{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
confNames = body.Configs
return
}
// ApplyConfigToMachineGroup applies config to machine group.
func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
r, err := request(p, "PUT", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to apply config to machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// RemoveConfigFromMachineGroup removes config from machine group.
func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
r, err := request(p, "DELETE", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to remove config from machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}

View File

@ -0,0 +1,271 @@
package alils
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
"strconv"
lz4 "github.com/cloudflare/golz4"
"github.com/gogo/protobuf/proto"
)
// LogStore stores the logs
type LogStore struct {
Name string `json:"logstoreName"`
TTL int
ShardCount int
CreateTime uint32
LastModifyTime uint32
project *LogProject
}
// Shard defines the Log Shard
type Shard struct {
ShardID int `json:"shardID"`
}
// ListShards returns shard id list of this logstore.
func (s *LogStore) ListShards() (shardIDs []int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/logstores/%v/shards", s.Name)
r, err := request(s.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to list logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
var shards []*Shard
err = json.Unmarshal(buf, &shards)
if err != nil {
return
}
for _, v := range shards {
shardIDs = append(shardIDs, v.ShardID)
}
return
}
// PutLogs puts logs into logstore.
// The callers should transform user logs into LogGroup.
func (s *LogStore) PutLogs(lg *LogGroup) (err error) {
body, err := proto.Marshal(lg)
if err != nil {
return
}
// Compresse body with lz4
out := make([]byte, lz4.CompressBound(body))
n, err := lz4.Compress(body, out)
if err != nil {
return
}
h := map[string]string{
"x-sls-compresstype": "lz4",
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/x-protobuf",
}
uri := fmt.Sprintf("/logstores/%v", s.Name)
r, err := request(s.project, "POST", uri, h, out[:n])
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to put logs")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// GetCursor gets log cursor of one shard specified by shardID.
// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end".
// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore
func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v",
s.Name, shardID, from)
r, err := request(s.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get cursor")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
Cursor string
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
cursor = body.Cursor
return
}
// GetLogsBytes gets logs binary data from shard specified by shardID according cursor.
// The logGroupMaxCount is the max number of logGroup could be returned.
// The nextCursor is the next curosr can be used to read logs at next time.
func (s *LogStore) GetLogsBytes(shardID int, cursor string,
logGroupMaxCount int) (out []byte, nextCursor string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
"Accept": "application/x-protobuf",
"Accept-Encoding": "lz4",
}
uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v",
s.Name, shardID, cursor, logGroupMaxCount)
r, err := request(s.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get cursor")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
v, ok := r.Header["X-Sls-Compresstype"]
if !ok || len(v) == 0 {
err = fmt.Errorf("can't find 'x-sls-compresstype' header")
return
}
if v[0] != "lz4" {
err = fmt.Errorf("unexpected compress type:%v", v[0])
return
}
v, ok = r.Header["X-Sls-Cursor"]
if !ok || len(v) == 0 {
err = fmt.Errorf("can't find 'x-sls-cursor' header")
return
}
nextCursor = v[0]
v, ok = r.Header["X-Sls-Bodyrawsize"]
if !ok || len(v) == 0 {
err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header")
return
}
bodyRawSize, err := strconv.Atoi(v[0])
if err != nil {
return
}
out = make([]byte, bodyRawSize)
err = lz4.Uncompress(buf, out)
if err != nil {
return
}
return
}
// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API
func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) {
gl = &LogGroupList{}
err = proto.Unmarshal(data, gl)
if err != nil {
return
}
return
}
// GetLogs gets logs from shard specified by shardID according cursor.
// The logGroupMaxCount is the max number of logGroup could be returned.
// The nextCursor is the next curosr can be used to read logs at next time.
func (s *LogStore) GetLogs(shardID int, cursor string,
logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) {
out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount)
if err != nil {
return
}
gl, err = LogsBytesDecode(out)
if err != nil {
return
}
return
}

View File

@ -0,0 +1,91 @@
package alils
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
)
// MachineGroupAttribute defines the Attribute
type MachineGroupAttribute struct {
ExternalName string `json:"externalName"`
TopicName string `json:"groupTopic"`
}
// MachineGroup defines the machine Group
type MachineGroup struct {
Name string `json:"groupName"`
Type string `json:"groupType"`
MachineIDType string `json:"machineIdentifyType"`
MachineIDList []string `json:"machineList"`
Attribute MachineGroupAttribute `json:"groupAttribute"`
CreateTime uint32
LastModifyTime uint32
project *LogProject
}
// Machine defines the Machine
type Machine struct {
IP string
UniqueID string `json:"machine-uniqueid"`
UserdefinedID string `json:"userdefined-id"`
}
// MachineList defines the Machine List
type MachineList struct {
Total int
Machines []*Machine
}
// ListMachines returns the machine list of this machine group.
func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name)
r, err := request(m.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to remove config from machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
body := &MachineList{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
ms = body.Machines
total = body.Total
return
}
// GetAppliedConfigs returns applied configs of this machine group.
func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) {
confNames, err = m.project.GetAppliedConfigs(m.Name)
return
}

View File

@ -0,0 +1,62 @@
package alils
import (
"bytes"
"crypto/md5"
"fmt"
"net/http"
)
// request sends a request to SLS.
func request(project *LogProject, method, uri string, headers map[string]string,
body []byte) (resp *http.Response, err error) {
// The caller should provide 'x-sls-bodyrawsize' header
if _, ok := headers["x-sls-bodyrawsize"]; !ok {
err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header")
return
}
// SLS public request headers
headers["Host"] = project.Name + "." + project.Endpoint
headers["Date"] = nowRFC1123()
headers["x-sls-apiversion"] = version
headers["x-sls-signaturemethod"] = signatureMethod
if body != nil {
bodyMD5 := fmt.Sprintf("%X", md5.Sum(body))
headers["Content-MD5"] = bodyMD5
if _, ok := headers["Content-Type"]; !ok {
err = fmt.Errorf("Can't find 'Content-Type' header")
return
}
}
// Calc Authorization
// Authorization = "SLS <AccessKeyID>:<Signature>"
digest, err := signature(project, method, uri, headers)
if err != nil {
return
}
auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest)
headers["Authorization"] = auth
// Initialize http request
reader := bytes.NewReader(body)
urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri)
req, err := http.NewRequest(method, urlStr, reader)
if err != nil {
return
}
for k, v := range headers {
req.Header.Add(k, v)
}
// Get ready to do request
resp, err = http.DefaultClient.Do(req)
if err != nil {
return
}
return
}

View File

@ -0,0 +1,111 @@
package alils
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"net/url"
"sort"
"strings"
"time"
)
// GMT location
var gmtLoc = time.FixedZone("GMT", 0)
// NowRFC1123 returns now time in RFC1123 format with GMT timezone,
// eg. "Mon, 02 Jan 2006 15:04:05 GMT".
func nowRFC1123() string {
return time.Now().In(gmtLoc).Format(time.RFC1123)
}
// signature calculates a request's signature digest.
func signature(project *LogProject, method, uri string,
headers map[string]string) (digest string, err error) {
var contentMD5, contentType, date, canoHeaders, canoResource string
var slsHeaderKeys sort.StringSlice
// SignString = VERB + "\n"
// + CONTENT-MD5 + "\n"
// + CONTENT-TYPE + "\n"
// + DATE + "\n"
// + CanonicalizedSLSHeaders + "\n"
// + CanonicalizedResource
if val, ok := headers["Content-MD5"]; ok {
contentMD5 = val
}
if val, ok := headers["Content-Type"]; ok {
contentType = val
}
date, ok := headers["Date"]
if !ok {
err = fmt.Errorf("Can't find 'Date' header")
return
}
// Calc CanonicalizedSLSHeaders
slsHeaders := make(map[string]string, len(headers))
for k, v := range headers {
l := strings.TrimSpace(strings.ToLower(k))
if strings.HasPrefix(l, "x-sls-") {
slsHeaders[l] = strings.TrimSpace(v)
slsHeaderKeys = append(slsHeaderKeys, l)
}
}
sort.Sort(slsHeaderKeys)
for i, k := range slsHeaderKeys {
canoHeaders += k + ":" + slsHeaders[k]
if i+1 < len(slsHeaderKeys) {
canoHeaders += "\n"
}
}
// Calc CanonicalizedResource
u, err := url.Parse(uri)
if err != nil {
return
}
canoResource += url.QueryEscape(u.Path)
if u.RawQuery != "" {
var keys sort.StringSlice
vals := u.Query()
for k := range vals {
keys = append(keys, k)
}
sort.Sort(keys)
canoResource += "?"
for i, k := range keys {
if i > 0 {
canoResource += "&"
}
for _, v := range vals[k] {
canoResource += k + "=" + v
}
}
}
signStr := method + "\n" +
contentMD5 + "\n" +
contentType + "\n" +
date + "\n" +
canoHeaders + "\n" +
canoResource
// Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString)AccessKeySecret))
mac := hmac.New(sha1.New, []byte(project.AccessKeySecret))
_, err = mac.Write([]byte(signStr))
if err != nil {
return
}
digest = base64.StdEncoding.EncodeToString(mac.Sum(nil))
return
}

View File

@ -0,0 +1,144 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"encoding/json"
"io"
"net"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
// connWriter implements LoggerInterface.
// Writes messages in keep-live tcp connection.
type connWriter struct {
lg *logWriter
innerWriter io.WriteCloser
customFormatter func(*LogMsg) string
ReconnectOnMsg bool `json:"reconnectOnMsg"`
Reconnect bool `json:"reconnect"`
Net string `json:"net"`
Addr string `json:"addr"`
Level int `json:"level"`
}
// NewConn creates new ConnWrite returning as LoggerInterface.
func NewConn() Logger {
conn := new(connWriter)
conn.Level = LevelTrace
return conn
}
func (c *connWriter) Format(lm *LogMsg) string {
return lm.Msg
}
// Init initializes a connection writer with json config.
// json config only needs they "level" key
func (c *connWriter) Init(jsonConfig string, opts ...utils.KV) error {
for _, elem := range opts {
if elem.GetKey() == "formatter" {
formatter, err := GetFormatter(elem)
if err != nil {
return err
}
c.customFormatter = formatter
}
}
return json.Unmarshal([]byte(jsonConfig), c)
}
// WriteMsg writes message in connection.
// If connection is down, try to re-connect.
func (c *connWriter) WriteMsg(lm *LogMsg) error {
if lm.Level > c.Level {
return nil
}
if c.needToConnectOnMsg() {
err := c.connect()
if err != nil {
return err
}
}
if c.ReconnectOnMsg {
defer c.innerWriter.Close()
}
msg := ""
if c.customFormatter != nil {
msg = c.customFormatter(lm)
} else {
msg = c.Format(lm)
}
_, err := c.lg.writeln(msg)
if err != nil {
return err
}
return nil
}
// Flush implementing method. empty.
func (c *connWriter) Flush() {
}
// Destroy destroy connection writer and close tcp listener.
func (c *connWriter) Destroy() {
if c.innerWriter != nil {
c.innerWriter.Close()
}
}
func (c *connWriter) connect() error {
if c.innerWriter != nil {
c.innerWriter.Close()
c.innerWriter = nil
}
conn, err := net.Dial(c.Net, c.Addr)
if err != nil {
return err
}
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
}
c.innerWriter = conn
c.lg = newLogWriter(conn)
return nil
}
func (c *connWriter) needToConnectOnMsg() bool {
if c.Reconnect {
return true
}
if c.innerWriter == nil {
return true
}
return c.ReconnectOnMsg
}
func init() {
Register(AdapterConn, NewConn)
}

View File

@ -0,0 +1,79 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"net"
"os"
"testing"
)
// ConnTCPListener takes a TCP listener and accepts n TCP connections
// Returns connections using connChan
func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) {
// Listen and accept n incoming connections
for i := 0; i < n; i++ {
conn, err := ln.Accept()
if err != nil {
t.Log("Error accepting connection: ", err.Error())
os.Exit(1)
}
// Send accepted connection to channel
connChan <- conn
}
ln.Close()
close(connChan)
}
func TestConn(t *testing.T) {
log := NewLogger(1000)
log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`)
log.Informational("informational")
}
func TestReconnect(t *testing.T) {
// Setup connection listener
newConns := make(chan net.Conn)
connNum := 2
ln, err := net.Listen("tcp", ":6002")
if err != nil {
t.Log("Error listening:", err.Error())
os.Exit(1)
}
go connTCPListener(t, connNum, ln, newConns)
// Setup logger
log := NewLogger(1000)
log.SetPrefix("test")
log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`)
log.Informational("informational 1")
// Refuse first connection
first := <-newConns
first.Close()
// Send another log after conn closed
log.Informational("informational 2")
// Check if there was a second connection attempt
select {
case second := <-newConns:
second.Close()
default:
t.Error("Did not reconnect")
}
}

View File

@ -0,0 +1,134 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"encoding/json"
"os"
"strings"
"github.com/astaxie/beego/pkg/infrastructure/utils"
"github.com/shiena/ansicolor"
)
// brush is a color join function
type brush func(string) string
// newBrush returns a fix color Brush
func newBrush(color string) brush {
pre := "\033["
reset := "\033[0m"
return func(text string) string {
return pre + color + "m" + text + reset
}
}
var colors = []brush{
newBrush("1;37"), // Emergency white
newBrush("1;36"), // Alert cyan
newBrush("1;35"), // Critical magenta
newBrush("1;31"), // Error red
newBrush("1;33"), // Warning yellow
newBrush("1;32"), // Notice green
newBrush("1;34"), // Informational blue
newBrush("1;44"), // Debug Background blue
}
// consoleWriter implements LoggerInterface and writes messages to terminal.
type consoleWriter struct {
lg *logWriter
customFormatter func(*LogMsg) string
Level int `json:"level"`
Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color
}
func (c *consoleWriter) Format(lm *LogMsg) string {
msg := lm.Msg
h, _, _ := formatTimeHeader(lm.When)
bytes := append(append(h, msg...), '\n')
return string(bytes)
}
// NewConsole creates ConsoleWriter returning as LoggerInterface.
func NewConsole() Logger {
cw := &consoleWriter{
lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)),
Level: LevelDebug,
Colorful: true,
}
return cw
}
// Init initianlizes the console logger.
// jsonConfig must be in the format '{"level":LevelTrace}'
func (c *consoleWriter) Init(jsonConfig string, opts ...utils.KV) error {
for _, elem := range opts {
if elem.GetKey() == "formatter" {
formatter, err := GetFormatter(elem)
if err != nil {
return err
}
c.customFormatter = formatter
}
}
if len(jsonConfig) == 0 {
return nil
}
return json.Unmarshal([]byte(jsonConfig), c)
}
// WriteMsg writes message in console.
func (c *consoleWriter) WriteMsg(lm *LogMsg) error {
if lm.Level > c.Level {
return nil
}
msg := ""
if c.Colorful {
lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1)
}
if c.customFormatter != nil {
msg = c.customFormatter(lm)
} else {
msg = c.Format(lm)
}
c.lg.writeln(msg)
return nil
}
// Destroy implementing method. empty.
func (c *consoleWriter) Destroy() {
}
// Flush implementing method. empty.
func (c *consoleWriter) Flush() {
}
func init() {
Register(AdapterConsole, NewConsole)
}

View File

@ -0,0 +1,64 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"testing"
"time"
)
// Try each log level in decreasing order of priority.
func testConsoleCalls(bl *BeeLogger) {
bl.Emergency("emergency")
bl.Alert("alert")
bl.Critical("critical")
bl.Error("error")
bl.Warning("warning")
bl.Notice("notice")
bl.Informational("informational")
bl.Debug("debug")
}
// Test console logging by visually comparing the lines being output with and
// without a log level specification.
func TestConsole(t *testing.T) {
log1 := NewLogger(10000)
log1.EnableFuncCallDepth(true)
log1.SetLogger("console", "")
testConsoleCalls(log1)
log2 := NewLogger(100)
log2.SetLogger("console", `{"level":3}`)
testConsoleCalls(log2)
}
// Test console without color
func TestConsoleNoColor(t *testing.T) {
log := NewLogger(100)
log.SetLogger("console", `{"color":false}`)
testConsoleCalls(log)
}
// Test console async
func TestConsoleAsync(t *testing.T) {
log := NewLogger(100)
log.SetLogger("console")
log.Async()
//log.Close()
testConsoleCalls(log)
for len(log.msgChan) != 0 {
time.Sleep(1 * time.Millisecond)
}
}

View File

@ -0,0 +1,126 @@
package es
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
"time"
"github.com/elastic/go-elasticsearch/v6"
"github.com/elastic/go-elasticsearch/v6/esapi"
"github.com/astaxie/beego/pkg/infrastructure/logs"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
// NewES returns a LoggerInterface
func NewES() logs.Logger {
cw := &esLogger{
Level: logs.LevelDebug,
}
return cw
}
// esLogger will log msg into ES
// before you using this implementation,
// please import this package
// usually means that you can import this package in your main package
// for example, anonymous:
// import _ "github.com/astaxie/beego/logs/es"
type esLogger struct {
*elasticsearch.Client
DSN string `json:"dsn"`
Level int `json:"level"`
customFormatter func(*logs.LogMsg) string
}
func (el *esLogger) Format(lm *logs.LogMsg) string {
return lm.Msg
}
// {"dsn":"http://localhost:9200/","level":1}
func (el *esLogger) Init(jsonConfig string, opts ...utils.KV) error {
for _, elem := range opts {
if elem.GetKey() == "formatter" {
formatter, err := logs.GetFormatter(elem)
if err != nil {
return err
}
el.customFormatter = formatter
}
}
err := json.Unmarshal([]byte(jsonConfig), el)
if err != nil {
return err
}
if el.DSN == "" {
return errors.New("empty dsn")
} else if u, err := url.Parse(el.DSN); err != nil {
return err
} else if u.Path == "" {
return errors.New("missing prefix")
} else {
conn, err := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{el.DSN},
})
if err != nil {
return err
}
el.Client = conn
}
return nil
}
// WriteMsg writes the msg and level into es
func (el *esLogger) WriteMsg(lm *logs.LogMsg) error {
if lm.Level > el.Level {
return nil
}
msg := ""
if el.customFormatter != nil {
msg = el.customFormatter(lm)
} else {
msg = el.Format(lm)
}
idx := LogDocument{
Timestamp: lm.When.Format(time.RFC3339),
Msg: msg,
}
body, err := json.Marshal(idx)
if err != nil {
return err
}
req := esapi.IndexRequest{
Index: fmt.Sprintf("%04d.%02d.%02d", lm.When.Year(), lm.When.Month(), lm.When.Day()),
DocumentType: "logs",
Body: strings.NewReader(string(body)),
}
_, err = req.Do(context.Background(), el.Client)
return err
}
// Destroy is a empty method
func (el *esLogger) Destroy() {
}
// Flush is a empty method
func (el *esLogger) Flush() {
}
type LogDocument struct {
Timestamp string `json:"timestamp"`
Msg string `json:"msg"`
}
func init() {
logs.Register(logs.AdapterEs, NewES)
}

View File

@ -0,0 +1,436 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
// fileLogWriter implements LoggerInterface.
// Writes messages by lines limit, file size limit, or time frequency.
type fileLogWriter struct {
sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
// The opened file
Filename string `json:"filename"`
fileWriter *os.File
// Rotate at line
MaxLines int `json:"maxlines"`
maxLinesCurLines int
MaxFiles int `json:"maxfiles"`
MaxFilesCurFiles int
// Rotate at size
MaxSize int `json:"maxsize"`
maxSizeCurSize int
// Rotate daily
Daily bool `json:"daily"`
MaxDays int64 `json:"maxdays"`
dailyOpenDate int
dailyOpenTime time.Time
// Rotate hourly
Hourly bool `json:"hourly"`
MaxHours int64 `json:"maxhours"`
hourlyOpenDate int
hourlyOpenTime time.Time
customFormatter func(*LogMsg) string
Rotate bool `json:"rotate"`
Level int `json:"level"`
Perm string `json:"perm"`
RotatePerm string `json:"rotateperm"`
fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix
}
// newFileWriter creates a FileLogWriter returning as LoggerInterface.
func newFileWriter() Logger {
w := &fileLogWriter{
Daily: true,
MaxDays: 7,
Hourly: false,
MaxHours: 168,
Rotate: true,
RotatePerm: "0440",
Level: LevelTrace,
Perm: "0660",
MaxLines: 10000000,
MaxFiles: 999,
MaxSize: 1 << 28,
}
return w
}
func (w *fileLogWriter) Format(lm *LogMsg) string {
return lm.Msg
}
// Init file logger with json config.
// jsonConfig like:
// {
// "filename":"logs/beego.log",
// "maxLines":10000,
// "maxsize":1024,
// "daily":true,
// "maxDays":15,
// "rotate":true,
// "perm":"0600"
// }
func (w *fileLogWriter) Init(jsonConfig string, opts ...utils.KV) error {
for _, elem := range opts {
if elem.GetKey() == "formatter" {
formatter, err := GetFormatter(elem)
if err != nil {
return err
}
w.customFormatter = formatter
}
}
err := json.Unmarshal([]byte(jsonConfig), w)
if err != nil {
return err
}
if len(w.Filename) == 0 {
return errors.New("jsonconfig must have filename")
}
w.suffix = filepath.Ext(w.Filename)
w.fileNameOnly = strings.TrimSuffix(w.Filename, w.suffix)
if w.suffix == "" {
w.suffix = ".log"
}
err = w.startLogger()
return err
}
// start file logger. create log file and set to locker-inside file writer.
func (w *fileLogWriter) startLogger() error {
file, err := w.createLogFile()
if err != nil {
return err
}
if w.fileWriter != nil {
w.fileWriter.Close()
}
w.fileWriter = file
return w.initFd()
}
func (w *fileLogWriter) needRotateDaily(size int, day int) bool {
return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
(w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
(w.Daily && day != w.dailyOpenDate)
}
func (w *fileLogWriter) needRotateHourly(size int, hour int) bool {
return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
(w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
(w.Hourly && hour != w.hourlyOpenDate)
}
// WriteMsg writes logger message into file.
func (w *fileLogWriter) WriteMsg(lm *LogMsg) error {
if lm.Level > w.Level {
return nil
}
hd, d, h := formatTimeHeader(lm.When)
msg := ""
if w.customFormatter != nil {
msg = w.customFormatter(lm)
} else {
msg = w.Format(lm)
}
msg = fmt.Sprintf("%s %s\n", string(hd), msg)
if w.Rotate {
w.RLock()
if w.needRotateHourly(len(lm.Msg), h) {
w.RUnlock()
w.Lock()
if w.needRotateHourly(len(lm.Msg), h) {
if err := w.doRotate(lm.When); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
} else if w.needRotateDaily(len(lm.Msg), d) {
w.RUnlock()
w.Lock()
if w.needRotateDaily(len(lm.Msg), d) {
if err := w.doRotate(lm.When); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
} else {
w.RUnlock()
}
}
w.Lock()
_, err := w.fileWriter.Write([]byte(msg))
if err == nil {
w.maxLinesCurLines++
w.maxSizeCurSize += len(msg)
}
w.Unlock()
return err
}
func (w *fileLogWriter) createLogFile() (*os.File, error) {
// Open the log file
perm, err := strconv.ParseInt(w.Perm, 8, 64)
if err != nil {
return nil, err
}
filepath := path.Dir(w.Filename)
os.MkdirAll(filepath, os.FileMode(perm))
fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm))
if err == nil {
// Make sure file perm is user set perm cause of `os.OpenFile` will obey umask
os.Chmod(w.Filename, os.FileMode(perm))
}
return fd, err
}
func (w *fileLogWriter) initFd() error {
fd := w.fileWriter
fInfo, err := fd.Stat()
if err != nil {
return fmt.Errorf("get stat err: %s", err)
}
w.maxSizeCurSize = int(fInfo.Size())
w.dailyOpenTime = time.Now()
w.dailyOpenDate = w.dailyOpenTime.Day()
w.hourlyOpenTime = time.Now()
w.hourlyOpenDate = w.hourlyOpenTime.Hour()
w.maxLinesCurLines = 0
if w.Hourly {
go w.hourlyRotate(w.hourlyOpenTime)
} else if w.Daily {
go w.dailyRotate(w.dailyOpenTime)
}
if fInfo.Size() > 0 && w.MaxLines > 0 {
count, err := w.lines()
if err != nil {
return err
}
w.maxLinesCurLines = count
}
return nil
}
func (w *fileLogWriter) dailyRotate(openTime time.Time) {
y, m, d := openTime.Add(24 * time.Hour).Date()
nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location())
tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
<-tm.C
w.Lock()
if w.needRotateDaily(0, time.Now().Day()) {
if err := w.doRotate(time.Now()); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
}
func (w *fileLogWriter) hourlyRotate(openTime time.Time) {
y, m, d := openTime.Add(1 * time.Hour).Date()
h, _, _ := openTime.Add(1 * time.Hour).Clock()
nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location())
tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100))
<-tm.C
w.Lock()
if w.needRotateHourly(0, time.Now().Hour()) {
if err := w.doRotate(time.Now()); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
}
func (w *fileLogWriter) lines() (int, error) {
fd, err := os.Open(w.Filename)
if err != nil {
return 0, err
}
defer fd.Close()
buf := make([]byte, 32768) // 32k
count := 0
lineSep := []byte{'\n'}
for {
c, err := fd.Read(buf)
if err != nil && err != io.EOF {
return count, err
}
count += bytes.Count(buf[:c], lineSep)
if err == io.EOF {
break
}
}
return count, nil
}
// DoRotate means it needs to write logs into a new file.
// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size)
func (w *fileLogWriter) doRotate(logTime time.Time) error {
// file exists
// Find the next available number
num := w.MaxFilesCurFiles + 1
fName := ""
format := ""
var openTime time.Time
rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64)
if err != nil {
return err
}
_, err = os.Lstat(w.Filename)
if err != nil {
// even if the file is not exist or other ,we should RESTART the logger
goto RESTART_LOGGER
}
if w.Hourly {
format = "2006010215"
openTime = w.hourlyOpenTime
} else if w.Daily {
format = "2006-01-02"
openTime = w.dailyOpenTime
}
// only when one of them be setted, then the file would be splited
if w.MaxLines > 0 || w.MaxSize > 0 {
for ; err == nil && num <= w.MaxFiles; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix)
_, err = os.Lstat(fName)
}
} else {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix)
_, err = os.Lstat(fName)
w.MaxFilesCurFiles = num
}
// return error if the last file checked still existed
if err == nil {
return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename)
}
// close fileWriter before rename
w.fileWriter.Close()
// Rename the file to its new found name
// even if occurs error,we MUST guarantee to restart new logger
err = os.Rename(w.Filename, fName)
if err != nil {
goto RESTART_LOGGER
}
err = os.Chmod(fName, os.FileMode(rotatePerm))
RESTART_LOGGER:
startLoggerErr := w.startLogger()
go w.deleteOldLog()
if startLoggerErr != nil {
return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr)
}
if err != nil {
return fmt.Errorf("Rotate: %s", err)
}
return nil
}
func (w *fileLogWriter) deleteOldLog() {
dir := filepath.Dir(w.Filename)
absolutePath, err := filepath.EvalSymlinks(w.Filename)
if err == nil {
dir = filepath.Dir(absolutePath)
}
filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) {
defer func() {
if r := recover(); r != nil {
fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r)
}
}()
if info == nil {
return
}
if w.Hourly {
if !info.IsDir() && info.ModTime().Add(1*time.Hour*time.Duration(w.MaxHours)).Before(time.Now()) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path)
}
}
} else if w.Daily {
if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path)
}
}
}
return
})
}
// Destroy close the file description, close file writer.
func (w *fileLogWriter) Destroy() {
w.fileWriter.Close()
}
// Flush flushes file logger.
// there are no buffering messages in file logger in memory.
// flush file means sync file from disk.
func (w *fileLogWriter) Flush() {
w.fileWriter.Sync()
}
func init() {
Register(AdapterFile, newFileWriter)
}

View File

@ -0,0 +1,425 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"bufio"
"fmt"
"io/ioutil"
"os"
"strconv"
"testing"
"time"
)
func TestFilePerm(t *testing.T) {
log := NewLogger(10000)
// use 0666 as test perm cause the default umask is 022
log.SetLogger("file", `{"filename":"test.log", "perm": "0666"}`)
log.Debug("debug")
log.Informational("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
file, err := os.Stat("test.log")
if err != nil {
t.Fatal(err)
}
if file.Mode() != 0666 {
t.Fatal("unexpected log file permission")
}
os.Remove("test.log")
}
func TestFile1(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`)
log.Debug("debug")
log.Informational("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
f, err := os.Open("test.log")
if err != nil {
t.Fatal(err)
}
b := bufio.NewReader(f)
lineNum := 0
for {
line, _, err := b.ReadLine()
if err != nil {
break
}
if len(line) > 0 {
lineNum++
}
}
var expected = LevelDebug + 1
if lineNum != expected {
t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
}
os.Remove("test.log")
}
func TestFile2(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", fmt.Sprintf(`{"filename":"test2.log","level":%d}`, LevelError))
log.Debug("debug")
log.Info("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
f, err := os.Open("test2.log")
if err != nil {
t.Fatal(err)
}
b := bufio.NewReader(f)
lineNum := 0
for {
line, _, err := b.ReadLine()
if err != nil {
break
}
if len(line) > 0 {
lineNum++
}
}
var expected = LevelError + 1
if lineNum != expected {
t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
}
os.Remove("test2.log")
}
func TestFileDailyRotate_01(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
log.Debug("debug")
log.Info("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
b, err := exists(rotateName)
if !b || err != nil {
os.Remove("test3.log")
t.Fatal("rotate not generated")
}
os.Remove(rotateName)
os.Remove("test3.log")
}
func TestFileDailyRotate_02(t *testing.T) {
fn1 := "rotate_day.log"
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileRotate(t, fn1, fn2, true, false)
}
func TestFileDailyRotate_03(t *testing.T) {
fn1 := "rotate_day.log"
fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
os.Create(fn)
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileRotate(t, fn1, fn2, true, false)
os.Remove(fn)
}
func TestFileDailyRotate_04(t *testing.T) {
fn1 := "rotate_day.log"
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileDailyRotate(t, fn1, fn2)
}
func TestFileDailyRotate_05(t *testing.T) {
fn1 := "rotate_day.log"
fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
os.Create(fn)
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileDailyRotate(t, fn1, fn2)
os.Remove(fn)
}
func TestFileDailyRotate_06(t *testing.T) { //test file mode
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
log.Debug("debug")
log.Info("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
s, _ := os.Lstat(rotateName)
if s.Mode() != 0440 {
os.Remove(rotateName)
os.Remove("test3.log")
t.Fatal("rotate file mode error")
}
os.Remove(rotateName)
os.Remove("test3.log")
}
func TestFileHourlyRotate_01(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`)
log.Debug("debug")
log.Info("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log"
b, err := exists(rotateName)
if !b || err != nil {
os.Remove("test3.log")
t.Fatal("rotate not generated")
}
os.Remove(rotateName)
os.Remove("test3.log")
}
func TestFileHourlyRotate_02(t *testing.T) {
fn1 := "rotate_hour.log"
fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
testFileRotate(t, fn1, fn2, false, true)
}
func TestFileHourlyRotate_03(t *testing.T) {
fn1 := "rotate_hour.log"
fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log"
os.Create(fn)
fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
testFileRotate(t, fn1, fn2, false, true)
os.Remove(fn)
}
func TestFileHourlyRotate_04(t *testing.T) {
fn1 := "rotate_hour.log"
fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
testFileHourlyRotate(t, fn1, fn2)
}
func TestFileHourlyRotate_05(t *testing.T) {
fn1 := "rotate_hour.log"
fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log"
os.Create(fn)
fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
testFileHourlyRotate(t, fn1, fn2)
os.Remove(fn)
}
func TestFileHourlyRotate_06(t *testing.T) { //test file mode
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`)
log.Debug("debug")
log.Info("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log"
s, _ := os.Lstat(rotateName)
if s.Mode() != 0440 {
os.Remove(rotateName)
os.Remove("test3.log")
t.Fatal("rotate file mode error")
}
os.Remove(rotateName)
os.Remove("test3.log")
}
func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) {
fw := &fileLogWriter{
Daily: daily,
MaxDays: 7,
Hourly: hourly,
MaxHours: 168,
Rotate: true,
Level: LevelTrace,
Perm: "0660",
RotatePerm: "0440",
}
if daily {
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
fw.dailyOpenDate = fw.dailyOpenTime.Day()
}
if hourly {
fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1))
fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour)
fw.hourlyOpenDate = fw.hourlyOpenTime.Day()
}
lm := &LogMsg{
Msg: "Test message",
Level: LevelDebug,
When: time.Now(),
}
fw.WriteMsg(lm)
for _, file := range []string{fn1, fn2} {
_, err := os.Stat(file)
if err != nil {
t.Log(err)
t.FailNow()
}
os.Remove(file)
}
fw.Destroy()
}
func testFileDailyRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{
Daily: true,
MaxDays: 7,
Rotate: true,
Level: LevelTrace,
Perm: "0660",
RotatePerm: "0440",
}
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
fw.dailyOpenDate = fw.dailyOpenTime.Day()
today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location())
today = today.Add(-1 * time.Second)
fw.dailyRotate(today)
for _, file := range []string{fn1, fn2} {
_, err := os.Stat(file)
if err != nil {
t.FailNow()
}
content, err := ioutil.ReadFile(file)
if err != nil {
t.FailNow()
}
if len(content) > 0 {
t.FailNow()
}
os.Remove(file)
}
fw.Destroy()
}
func testFileHourlyRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{
Hourly: true,
MaxHours: 168,
Rotate: true,
Level: LevelTrace,
Perm: "0660",
RotatePerm: "0440",
}
fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1))
fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour)
fw.hourlyOpenDate = fw.hourlyOpenTime.Hour()
hour, _ := time.ParseInLocation("2006010215", time.Now().Format("2006010215"), fw.hourlyOpenTime.Location())
hour = hour.Add(-1 * time.Second)
fw.hourlyRotate(hour)
for _, file := range []string{fn1, fn2} {
_, err := os.Stat(file)
if err != nil {
t.FailNow()
}
content, err := ioutil.ReadFile(file)
if err != nil {
t.FailNow()
}
if len(content) > 0 {
t.FailNow()
}
os.Remove(file)
}
fw.Destroy()
}
func exists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
func BenchmarkFile(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileAsynchronous(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
log.Async()
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileCallDepth(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
log.EnableFuncCallDepth(true)
log.SetLogFuncCallDepth(2)
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileAsynchronousCallDepth(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
log.EnableFuncCallDepth(true)
log.SetLogFuncCallDepth(2)
log.Async()
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileOnGoroutine(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
for i := 0; i < b.N; i++ {
go log.Debug("debug")
}
os.Remove("test4.log")
}

View File

@ -0,0 +1,95 @@
package logs
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
// JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook
type JLWriter struct {
AuthorName string `json:"authorname"`
Title string `json:"title"`
WebhookURL string `json:"webhookurl"`
RedirectURL string `json:"redirecturl,omitempty"`
ImageURL string `json:"imageurl,omitempty"`
Level int `json:"level"`
customFormatter func(*LogMsg) string
}
// newJLWriter creates jiaoliao writer.
func newJLWriter() Logger {
return &JLWriter{Level: LevelTrace}
}
// Init JLWriter with json config string
func (s *JLWriter) Init(jsonConfig string, opts ...utils.KV) error {
for _, elem := range opts {
if elem.GetKey() == "formatter" {
formatter, err := GetFormatter(elem)
if err != nil {
return err
}
s.customFormatter = formatter
}
}
return json.Unmarshal([]byte(jsonConfig), s)
}
func (s *JLWriter) Format(lm *LogMsg) string {
return lm.Msg
}
// WriteMsg writes message in smtp writer.
// Sends an email with subject and only this message.
func (s *JLWriter) WriteMsg(lm *LogMsg) error {
if lm.Level > s.Level {
return nil
}
text := ""
if s.customFormatter != nil {
text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.customFormatter(lm))
} else {
text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.Format(lm))
}
form := url.Values{}
form.Add("authorName", s.AuthorName)
form.Add("title", s.Title)
form.Add("text", text)
if s.RedirectURL != "" {
form.Add("redirectUrl", s.RedirectURL)
}
if s.ImageURL != "" {
form.Add("imageUrl", s.ImageURL)
}
resp, err := http.PostForm(s.WebhookURL, form)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode)
}
return nil
}
// Flush implementing method. empty.
func (s *JLWriter) Flush() {
}
// Destroy implementing method. empty.
func (s *JLWriter) Destroy() {
}
func init() {
Register(AdapterJianLiao, newJLWriter)
}

View File

@ -0,0 +1,896 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package logs provide a general log interface
// Usage:
//
// import "github.com/astaxie/beego/logs"
//
// log := NewLogger(10000)
// log.SetLogger("console", "")
//
// > the first params stand for how many channel
//
// Use it like this:
//
// log.Trace("trace")
// log.Info("info")
// log.Warn("warning")
// log.Debug("debug")
// log.Critical("critical")
//
// more docs http://beego.me/docs/module/logs.md
package logs
import (
"fmt"
"log"
"os"
"path"
"reflect"
"runtime"
"strings"
"sync"
"time"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
// RFC5424 log message levels.
const (
LevelEmergency = iota
LevelAlert
LevelCritical
LevelError
LevelWarning
LevelNotice
LevelInformational
LevelDebug
)
// levelLogLogger is defined to implement log.Logger
// the real log level will be LevelEmergency
const levelLoggerImpl = -1
// Name for adapter with beego official support
const (
AdapterConsole = "console"
AdapterFile = "file"
AdapterMultiFile = "multifile"
AdapterMail = "smtp"
AdapterConn = "conn"
AdapterEs = "es"
AdapterJianLiao = "jianliao"
AdapterSlack = "slack"
AdapterAliLS = "alils"
)
// Legacy log level constants to ensure backwards compatibility.
const (
LevelInfo = LevelInformational
LevelTrace = LevelDebug
LevelWarn = LevelWarning
)
type newLoggerFunc func() Logger
// Logger defines the behavior of a log provider.
type Logger interface {
Init(config string, opts ...utils.KV) error
WriteMsg(lm *LogMsg) error
Format(lm *LogMsg) string
Destroy()
Flush()
}
var adapters = make(map[string]newLoggerFunc)
var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"}
// Register makes a log provide available by the provided name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
func Register(name string, log newLoggerFunc) {
if log == nil {
panic("logs: Register provide is nil")
}
if _, dup := adapters[name]; dup {
panic("logs: Register called twice for provider " + name)
}
adapters[name] = log
}
// BeeLogger is default logger in beego application.
// Can contain several providers and log message into all providers.
type BeeLogger struct {
lock sync.Mutex
level int
init bool
enableFuncCallDepth bool
loggerFuncCallDepth int
globalFormatter func(*LogMsg) string
enableFullFilePath bool
asynchronous bool
prefix string
msgChanLen int64
msgChan chan *LogMsg
signalChan chan string
wg sync.WaitGroup
outputs []*nameLogger
}
const defaultAsyncMsgLen = 1e3
type nameLogger struct {
Logger
name string
}
type LogMsg struct {
Level int
Msg string
When time.Time
FilePath string
LineNumber int
}
type LogFormatter interface {
Format(lm *LogMsg) string
}
var logMsgPool *sync.Pool
// NewLogger returns a new BeeLogger.
// channelLen: the number of messages in chan(used where asynchronous is true).
// if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channelLens ...int64) *BeeLogger {
bl := new(BeeLogger)
bl.level = LevelDebug
bl.loggerFuncCallDepth = 2
bl.msgChanLen = append(channelLens, 0)[0]
if bl.msgChanLen <= 0 {
bl.msgChanLen = defaultAsyncMsgLen
}
bl.signalChan = make(chan string, 1)
bl.setLogger(AdapterConsole)
return bl
}
// Async sets the log to asynchronous and start the goroutine
func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger {
bl.lock.Lock()
defer bl.lock.Unlock()
if bl.asynchronous {
return bl
}
bl.asynchronous = true
if len(msgLen) > 0 && msgLen[0] > 0 {
bl.msgChanLen = msgLen[0]
}
bl.msgChan = make(chan *LogMsg, bl.msgChanLen)
logMsgPool = &sync.Pool{
New: func() interface{} {
return &LogMsg{}
},
}
bl.wg.Add(1)
go bl.startLogger()
return bl
}
func Format(lm *LogMsg) string {
return lm.Msg
}
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config must in in JSON format like {"interval":360}}
func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error {
config := append(configs, "{}")[0]
for _, l := range bl.outputs {
if l.name == adapterName {
return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName)
}
}
logAdapter, ok := adapters[adapterName]
if !ok {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
}
lg := logAdapter()
var err error
// Global formatter overrides the default set formatter
// but not adapter specific formatters set with logs.SetLoggerWithOpts()
if bl.globalFormatter != nil {
err = lg.Init(config, &utils.SimpleKV{Key: "formatter", Value: bl.globalFormatter})
} else {
err = lg.Init(config)
}
if err != nil {
fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error())
return err
}
bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg})
return nil
}
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config must in in JSON format like {"interval":360}}
func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
if !bl.init {
bl.outputs = []*nameLogger{}
bl.init = true
}
return bl.setLogger(adapterName, configs...)
}
// DelLogger removes a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adapterName string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
outputs := []*nameLogger{}
for _, lg := range bl.outputs {
if lg.name == adapterName {
lg.Destroy()
} else {
outputs = append(outputs, lg)
}
}
if len(outputs) == len(bl.outputs) {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
}
bl.outputs = outputs
return nil
}
func (bl *BeeLogger) writeToLoggers(lm *LogMsg) {
for _, l := range bl.outputs {
err := l.WriteMsg(lm)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err)
}
}
}
func (bl *BeeLogger) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
// writeMsg will always add a '\n' character
if p[len(p)-1] == '\n' {
p = p[0 : len(p)-1]
}
lm := &LogMsg{
Msg: string(p),
Level: levelLoggerImpl,
}
// set levelLoggerImpl to ensure all log message will be write out
err = bl.writeMsg(lm)
if err == nil {
return len(p), err
}
return 0, err
}
func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error {
if !bl.init {
bl.lock.Lock()
bl.setLogger(AdapterConsole)
bl.lock.Unlock()
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
lm.Msg = bl.prefix + " " + lm.Msg
var (
file string
line int
ok bool
)
if bl.enableFuncCallDepth {
_, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth)
if !ok {
file = "???"
line = 0
}
if !bl.enableFullFilePath {
_, file = path.Split(file)
}
lm.FilePath = file
lm.LineNumber = line
lm.Msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, lm.Msg)
}
// set level info in front of filename info
if lm.Level == levelLoggerImpl {
// set to emergency to ensure all log will be print out correctly
lm.Level = LevelEmergency
} else {
lm.Msg = levelPrefix[lm.Level] + " " + lm.Msg
}
if bl.asynchronous {
logM := logMsgPool.Get().(*LogMsg)
logM.Level = lm.Level
logM.Msg = lm.Msg
logM.When = lm.When
if bl.outputs != nil {
bl.msgChan <- lm
} else {
logMsgPool.Put(lm)
}
} else {
bl.writeToLoggers(lm)
}
return nil
}
// SetLevel sets log message level.
// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning),
// log providers will not be sent the message.
func (bl *BeeLogger) SetLevel(l int) {
bl.level = l
}
// GetLevel Get Current log message level.
func (bl *BeeLogger) GetLevel() int {
return bl.level
}
// SetLogFuncCallDepth set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d
}
// GetLogFuncCallDepth return log funcCallDepth for wrapper
func (bl *BeeLogger) GetLogFuncCallDepth() int {
return bl.loggerFuncCallDepth
}
// EnableFuncCallDepth enable log funcCallDepth
func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
bl.enableFuncCallDepth = b
}
// set prefix
func (bl *BeeLogger) SetPrefix(s string) {
bl.prefix = s
}
// start logger chan reading.
// when chan is not empty, write logs.
func (bl *BeeLogger) startLogger() {
gameOver := false
for {
select {
case bm := <-bl.msgChan:
bl.writeToLoggers(bm)
logMsgPool.Put(bm)
case sg := <-bl.signalChan:
// Now should only send "flush" or "close" to bl.signalChan
bl.flush()
if sg == "close" {
for _, l := range bl.outputs {
l.Destroy()
}
bl.outputs = nil
gameOver = true
}
bl.wg.Done()
}
if gameOver {
break
}
}
}
// Get the formatter from the opts common.SimpleKV structure
// Looks for a key: "formatter" with value: func(*LogMsg) string
func GetFormatter(opts utils.KV) (func(*LogMsg) string, error) {
if strings.ToLower(opts.GetKey().(string)) == "formatter" {
formatterInterface := reflect.ValueOf(opts.GetValue()).Interface()
formatterFunc := formatterInterface.(func(*LogMsg) string)
return formatterFunc, nil
}
return nil, fmt.Errorf("no \"formatter\" key given in simpleKV")
}
// SetLoggerWithOpts sets a log adapter with a user defined logging format. Config must be valid JSON
// such as: {"interval":360}
func (bl *BeeLogger) setLoggerWithOpts(adapterName string, opts utils.KV, configs ...string) error {
config := append(configs, "{}")[0]
for _, l := range bl.outputs {
if l.name == adapterName {
return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName)
}
}
logAdapter, ok := adapters[adapterName]
if !ok {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
}
if opts.GetKey() == nil {
return fmt.Errorf("No SimpleKV struct set for %s log adapter", adapterName)
}
lg := logAdapter()
err := lg.Init(config, opts)
if err != nil {
fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error())
return err
}
bl.outputs = append(bl.outputs, &nameLogger{
name: adapterName,
Logger: lg,
})
return nil
}
// SetLogger provides a given logger adapter into BeeLogger with config string.
func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, opts utils.KV, configs ...string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
if !bl.init {
bl.outputs = []*nameLogger{}
bl.init = true
}
return bl.setLoggerWithOpts(adapterName, opts, configs...)
}
// SetLoggerWIthOpts sets a given log adapter with a custom log adapter.
// Log Adapter must be given in the form common.SimpleKV{Key: "formatter": Value: struct.FormatFunc}
// where FormatFunc has the signature func(*LogMsg) string
// func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error {
func SetLoggerWithOpts(adapter string, config []string, opts utils.KV) error {
err := beeLogger.SetLoggerWithOpts(adapter, opts, config...)
if err != nil {
log.Fatal(err)
}
return nil
}
func (bl *BeeLogger) setGlobalFormatter(fmtter func(*LogMsg) string) error {
bl.globalFormatter = fmtter
return nil
}
// SetGlobalFormatter sets the global formatter for all log adapters
// This overrides and other individually set adapter
func SetGlobalFormatter(fmtter func(*LogMsg) string) error {
return beeLogger.setGlobalFormatter(fmtter)
}
// Emergency Log EMERGENCY level message.
func (bl *BeeLogger) Emergency(format string, v ...interface{}) {
if LevelEmergency > bl.level {
return
}
lm := &LogMsg{
Level: LevelEmergency,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Alert Log ALERT level message.
func (bl *BeeLogger) Alert(format string, v ...interface{}) {
if LevelAlert > bl.level {
return
}
lm := &LogMsg{
Level: LevelAlert,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Critical Log CRITICAL level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
if LevelCritical > bl.level {
return
}
lm := &LogMsg{
Level: LevelCritical,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Error Log ERROR level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) {
if LevelError > bl.level {
return
}
lm := &LogMsg{
Level: LevelError,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Warning Log WARNING level message.
func (bl *BeeLogger) Warning(format string, v ...interface{}) {
if LevelWarn > bl.level {
return
}
lm := &LogMsg{
Level: LevelWarn,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Notice Log NOTICE level message.
func (bl *BeeLogger) Notice(format string, v ...interface{}) {
if LevelNotice > bl.level {
return
}
lm := &LogMsg{
Level: LevelNotice,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Informational Log INFORMATIONAL level message.
func (bl *BeeLogger) Informational(format string, v ...interface{}) {
if LevelInfo > bl.level {
return
}
lm := &LogMsg{
Level: LevelInfo,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Debug Log DEBUG level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
if LevelDebug > bl.level {
return
}
lm := &LogMsg{
Level: LevelDebug,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Warn Log WARN level message.
// compatibility alias for Warning()
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
if LevelWarn > bl.level {
return
}
lm := &LogMsg{
Level: LevelWarn,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Info Log INFO level message.
// compatibility alias for Informational()
func (bl *BeeLogger) Info(format string, v ...interface{}) {
if LevelInfo > bl.level {
return
}
lm := &LogMsg{
Level: LevelInfo,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Trace Log TRACE level message.
// compatibility alias for Debug()
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
if LevelDebug > bl.level {
return
}
lm := &LogMsg{
Level: LevelDebug,
Msg: format,
When: time.Now(),
}
if len(v) > 0 {
lm.Msg = fmt.Sprintf(lm.Msg, v...)
}
bl.writeMsg(lm)
}
// Flush flush all chan data.
func (bl *BeeLogger) Flush() {
if bl.asynchronous {
bl.signalChan <- "flush"
bl.wg.Wait()
bl.wg.Add(1)
return
}
bl.flush()
}
// Close close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() {
if bl.asynchronous {
bl.signalChan <- "close"
bl.wg.Wait()
close(bl.msgChan)
} else {
bl.flush()
for _, l := range bl.outputs {
l.Destroy()
}
bl.outputs = nil
}
close(bl.signalChan)
}
// Reset close all outputs, and set bl.outputs to nil
func (bl *BeeLogger) Reset() {
bl.Flush()
for _, l := range bl.outputs {
l.Destroy()
}
bl.outputs = nil
}
func (bl *BeeLogger) flush() {
if bl.asynchronous {
for {
if len(bl.msgChan) > 0 {
bm := <-bl.msgChan
bl.writeToLoggers(bm)
logMsgPool.Put(bm)
continue
}
break
}
}
for _, l := range bl.outputs {
l.Flush()
}
}
// beeLogger references the used application logger.
var beeLogger = NewLogger()
// GetBeeLogger returns the default BeeLogger
func GetBeeLogger() *BeeLogger {
return beeLogger
}
var beeLoggerMap = struct {
sync.RWMutex
logs map[string]*log.Logger
}{
logs: map[string]*log.Logger{},
}
// GetLogger returns the default BeeLogger
func GetLogger(prefixes ...string) *log.Logger {
prefix := append(prefixes, "")[0]
if prefix != "" {
prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix))
}
beeLoggerMap.RLock()
l, ok := beeLoggerMap.logs[prefix]
if ok {
beeLoggerMap.RUnlock()
return l
}
beeLoggerMap.RUnlock()
beeLoggerMap.Lock()
defer beeLoggerMap.Unlock()
l, ok = beeLoggerMap.logs[prefix]
if !ok {
l = log.New(beeLogger, prefix, 0)
beeLoggerMap.logs[prefix] = l
}
return l
}
// EnableFullFilePath enables full file path logging. Disabled by default
// e.g "/home/Documents/GitHub/beego/mainapp/" instead of "mainapp"
func EnableFullFilePath(b bool) {
beeLogger.enableFullFilePath = b
}
// Reset will remove all the adapter
func Reset() {
beeLogger.Reset()
}
// Async set the beelogger with Async mode and hold msglen messages
func Async(msgLen ...int64) *BeeLogger {
return beeLogger.Async(msgLen...)
}
// SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) {
beeLogger.SetLevel(l)
}
// SetPrefix sets the prefix
func SetPrefix(s string) {
beeLogger.SetPrefix(s)
}
// EnableFuncCallDepth enable log funcCallDepth
func EnableFuncCallDepth(b bool) {
beeLogger.enableFuncCallDepth = b
}
// SetLogFuncCall set the CallDepth, default is 4
func SetLogFuncCall(b bool) {
beeLogger.EnableFuncCallDepth(b)
beeLogger.SetLogFuncCallDepth(4)
}
// SetLogFuncCallDepth set log funcCallDepth
func SetLogFuncCallDepth(d int) {
beeLogger.loggerFuncCallDepth = d
}
// SetLogger sets a new logger.
func SetLogger(adapter string, config ...string) error {
return beeLogger.SetLogger(adapter, config...)
}
// Emergency logs a message at emergency level.
func Emergency(f interface{}, v ...interface{}) {
beeLogger.Emergency(formatLog(f, v...))
}
// Alert logs a message at alert level.
func Alert(f interface{}, v ...interface{}) {
beeLogger.Alert(formatLog(f, v...))
}
// Critical logs a message at critical level.
func Critical(f interface{}, v ...interface{}) {
beeLogger.Critical(formatLog(f, v...))
}
// Error logs a message at error level.
func Error(f interface{}, v ...interface{}) {
beeLogger.Error(formatLog(f, v...))
}
// Warning logs a message at warning level.
func Warning(f interface{}, v ...interface{}) {
beeLogger.Warn(formatLog(f, v...))
}
// Warn compatibility alias for Warning()
func Warn(f interface{}, v ...interface{}) {
beeLogger.Warn(formatLog(f, v...))
}
// Notice logs a message at notice level.
func Notice(f interface{}, v ...interface{}) {
beeLogger.Notice(formatLog(f, v...))
}
// Informational logs a message at info level.
func Informational(f interface{}, v ...interface{}) {
beeLogger.Info(formatLog(f, v...))
}
// Info compatibility alias for Warning()
func Info(f interface{}, v ...interface{}) {
beeLogger.Info(formatLog(f, v...))
}
// Debug logs a message at debug level.
func Debug(f interface{}, v ...interface{}) {
beeLogger.Debug(formatLog(f, v...))
}
// Trace logs a message at trace level.
// compatibility alias for Warning()
func Trace(f interface{}, v ...interface{}) {
beeLogger.Trace(formatLog(f, v...))
}
func formatLog(f interface{}, v ...interface{}) string {
var msg string
switch f.(type) {
case string:
msg = f.(string)
if len(v) == 0 {
return msg
}
if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") {
// format string
} else {
// do not contain format char
msg += strings.Repeat(" %v", len(v))
}
default:
msg = fmt.Sprint(f)
if len(v) == 0 {
return msg
}
msg += strings.Repeat(" %v", len(v))
}
return fmt.Sprintf(msg, v...)
}

View File

@ -0,0 +1,35 @@
package logs
import (
"fmt"
"testing"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
func customFormatter(lm *LogMsg) string {
return fmt.Sprintf("[CUSTOM CONSOLE LOGGING] %s", lm.Msg)
}
func globalFormatter(lm *LogMsg) string {
return fmt.Sprintf("[GLOBAL] %s", lm.Msg)
}
func TestCustomLoggingFormatter(t *testing.T) {
// beego.BConfig.Log.AccessLogs = true
SetLoggerWithOpts("console", []string{`{"color":true}`}, &utils.SimpleKV{Key: "formatter", Value: customFormatter})
// Message will be formatted by the customFormatter with colorful text set to true
Informational("Test message")
}
func TestGlobalLoggingFormatter(t *testing.T) {
SetGlobalFormatter(globalFormatter)
SetLogger("console", `{"color":true}`)
// Message will be formatted by globalFormatter
Informational("Test message")
}

View File

@ -0,0 +1,176 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"io"
"runtime"
"sync"
"time"
)
type logWriter struct {
sync.Mutex
writer io.Writer
}
func newLogWriter(wr io.Writer) *logWriter {
return &logWriter{writer: wr}
}
func (lg *logWriter) writeln(msg string) (int, error) {
lg.Lock()
msg += "\n"
n, err := lg.writer.Write([]byte(msg))
lg.Unlock()
return n, err
}
const (
y1 = `0123456789`
y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
y3 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999`
y4 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
mo1 = `000000000111`
mo2 = `123456789012`
d1 = `0000000001111111111222222222233`
d2 = `1234567890123456789012345678901`
h1 = `000000000011111111112222`
h2 = `012345678901234567890123`
mi1 = `000000000011111111112222222222333333333344444444445555555555`
mi2 = `012345678901234567890123456789012345678901234567890123456789`
s1 = `000000000011111111112222222222333333333344444444445555555555`
s2 = `012345678901234567890123456789012345678901234567890123456789`
ns1 = `0123456789`
)
func formatTimeHeader(when time.Time) ([]byte, int, int) {
y, mo, d := when.Date()
h, mi, s := when.Clock()
ns := when.Nanosecond() / 1000000
//len("2006/01/02 15:04:05.123 ")==24
var buf [24]byte
buf[0] = y1[y/1000%10]
buf[1] = y2[y/100]
buf[2] = y3[y-y/100*100]
buf[3] = y4[y-y/100*100]
buf[4] = '/'
buf[5] = mo1[mo-1]
buf[6] = mo2[mo-1]
buf[7] = '/'
buf[8] = d1[d-1]
buf[9] = d2[d-1]
buf[10] = ' '
buf[11] = h1[h]
buf[12] = h2[h]
buf[13] = ':'
buf[14] = mi1[mi]
buf[15] = mi2[mi]
buf[16] = ':'
buf[17] = s1[s]
buf[18] = s2[s]
buf[19] = '.'
buf[20] = ns1[ns/100]
buf[21] = ns1[ns%100/10]
buf[22] = ns1[ns%10]
buf[23] = ' '
return buf[0:], d, h
}
var (
green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
w32Green = string([]byte{27, 91, 52, 50, 109})
w32White = string([]byte{27, 91, 52, 55, 109})
w32Yellow = string([]byte{27, 91, 52, 51, 109})
w32Red = string([]byte{27, 91, 52, 49, 109})
w32Blue = string([]byte{27, 91, 52, 52, 109})
w32Magenta = string([]byte{27, 91, 52, 53, 109})
w32Cyan = string([]byte{27, 91, 52, 54, 109})
reset = string([]byte{27, 91, 48, 109})
)
var once sync.Once
var colorMap map[string]string
func initColor() {
if runtime.GOOS == "windows" {
green = w32Green
white = w32White
yellow = w32Yellow
red = w32Red
blue = w32Blue
magenta = w32Magenta
cyan = w32Cyan
}
colorMap = map[string]string{
//by color
"green": green,
"white": white,
"yellow": yellow,
"red": red,
//by method
"GET": blue,
"POST": cyan,
"PUT": yellow,
"DELETE": red,
"PATCH": green,
"HEAD": magenta,
"OPTIONS": white,
}
}
// ColorByStatus return color by http code
// 2xx return Green
// 3xx return White
// 4xx return Yellow
// 5xx return Red
func ColorByStatus(code int) string {
once.Do(initColor)
switch {
case code >= 200 && code < 300:
return colorMap["green"]
case code >= 300 && code < 400:
return colorMap["white"]
case code >= 400 && code < 500:
return colorMap["yellow"]
default:
return colorMap["red"]
}
}
// ColorByMethod return color by http code
func ColorByMethod(method string) string {
once.Do(initColor)
if c := colorMap[method]; c != "" {
return c
}
return reset
}
// ResetColor return reset color
func ResetColor() string {
return reset
}

View File

@ -0,0 +1,57 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"testing"
"time"
)
func TestFormatHeader_0(t *testing.T) {
tm := time.Now()
if tm.Year() >= 2100 {
t.FailNow()
}
dur := time.Second
for {
if tm.Year() >= 2100 {
break
}
h, _, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05.000 ") != string(h) {
t.Log(tm)
t.FailNow()
}
tm = tm.Add(dur)
dur *= 2
}
}
func TestFormatHeader_1(t *testing.T) {
tm := time.Now()
year := tm.Year()
dur := time.Second
for {
if tm.Year() >= year+1 {
break
}
h, _, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05.000 ") != string(h) {
t.Log(tm)
t.FailNow()
}
tm = tm.Add(dur)
}
}

View File

@ -0,0 +1,134 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"encoding/json"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
// A filesLogWriter manages several fileLogWriter
// filesLogWriter will write logs to the file in json configuration and write the same level log to correspond file
// means if the file name in configuration is project.log filesLogWriter will create project.error.log/project.debug.log
// and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log
// the rotate attribute also acts like fileLogWriter
type multiFileLogWriter struct {
writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter
fullLogWriter *fileLogWriter
Separate []string `json:"separate"`
customFormatter func(*LogMsg) string
}
var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"}
// Init file logger with json config.
// jsonConfig like:
// {
// "filename":"logs/beego.log",
// "maxLines":0,
// "maxsize":0,
// "daily":true,
// "maxDays":15,
// "rotate":true,
// "perm":0600,
// "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"],
// }
func (f *multiFileLogWriter) Init(jsonConfig string, opts ...utils.KV) error {
for _, elem := range opts {
if elem.GetKey() == "formatter" {
formatter, err := GetFormatter(elem)
if err != nil {
return err
}
f.customFormatter = formatter
}
}
writer := newFileWriter().(*fileLogWriter)
err := writer.Init(jsonConfig)
if err != nil {
return err
}
f.fullLogWriter = writer
f.writers[LevelDebug+1] = writer
//unmarshal "separate" field to f.Separate
json.Unmarshal([]byte(jsonConfig), f)
jsonMap := map[string]interface{}{}
json.Unmarshal([]byte(jsonConfig), &jsonMap)
for i := LevelEmergency; i < LevelDebug+1; i++ {
for _, v := range f.Separate {
if v == levelNames[i] {
jsonMap["filename"] = f.fullLogWriter.fileNameOnly + "." + levelNames[i] + f.fullLogWriter.suffix
jsonMap["level"] = i
bs, _ := json.Marshal(jsonMap)
writer = newFileWriter().(*fileLogWriter)
err := writer.Init(string(bs))
if err != nil {
return err
}
f.writers[i] = writer
}
}
}
return nil
}
func (f *multiFileLogWriter) Format(lm *LogMsg) string {
return lm.Msg
}
func (f *multiFileLogWriter) Destroy() {
for i := 0; i < len(f.writers); i++ {
if f.writers[i] != nil {
f.writers[i].Destroy()
}
}
}
func (f *multiFileLogWriter) WriteMsg(lm *LogMsg) error {
if f.fullLogWriter != nil {
f.fullLogWriter.WriteMsg(lm)
}
for i := 0; i < len(f.writers)-1; i++ {
if f.writers[i] != nil {
if lm.Level == f.writers[i].Level {
f.writers[i].WriteMsg(lm)
}
}
}
return nil
}
func (f *multiFileLogWriter) Flush() {
for i := 0; i < len(f.writers); i++ {
if f.writers[i] != nil {
f.writers[i].Flush()
}
}
}
// newFilesWriter create a FileLogWriter returning as LoggerInterface.
func newFilesWriter() Logger {
return &multiFileLogWriter{}
}
func init() {
Register(AdapterMultiFile, newFilesWriter)
}

View File

@ -0,0 +1,78 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"bufio"
"os"
"strconv"
"strings"
"testing"
)
func TestFiles_1(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("multifile", `{"filename":"test.log","separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"]}`)
log.Debug("debug")
log.Informational("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
fns := []string{""}
fns = append(fns, levelNames[0:]...)
name := "test"
suffix := ".log"
for _, fn := range fns {
file := name + suffix
if fn != "" {
file = name + "." + fn + suffix
}
f, err := os.Open(file)
if err != nil {
t.Fatal(err)
}
b := bufio.NewReader(f)
lineNum := 0
lastLine := ""
for {
line, _, err := b.ReadLine()
if err != nil {
break
}
if len(line) > 0 {
lastLine = string(line)
lineNum++
}
}
var expected = 1
if fn == "" {
expected = LevelDebug + 1
}
if lineNum != expected {
t.Fatal(file, "has", lineNum, "lines not "+strconv.Itoa(expected)+" lines")
}
if lineNum == 1 {
if !strings.Contains(lastLine, fn) {
t.Fatal(file + " " + lastLine + " not contains the log msg " + fn)
}
}
os.Remove(file)
}
}

View File

@ -0,0 +1,73 @@
package logs
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook
type SLACKWriter struct {
WebhookURL string `json:"webhookurl"`
Level int `json:"level"`
UseCustomFormatter bool
CustomFormatter func(*LogMsg) string
}
// newSLACKWriter creates jiaoliao writer.
func newSLACKWriter() Logger {
return &SLACKWriter{Level: LevelTrace}
}
func (s *SLACKWriter) Format(lm *LogMsg) string {
return lm.Msg
}
// Init SLACKWriter with json config string
func (s *SLACKWriter) Init(jsonConfig string, opts ...utils.KV) error {
// if elem != nil {
// s.UseCustomFormatter = true
// s.CustomFormatter = elem
// }
// }
return json.Unmarshal([]byte(jsonConfig), s)
}
// WriteMsg write message in smtp writer.
// Sends an email with subject and only this message.
func (s *SLACKWriter) WriteMsg(lm *LogMsg) error {
if lm.Level > s.Level {
return nil
}
msg := s.Format(lm)
text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), msg)
form := url.Values{}
form.Add("payload", text)
resp, err := http.PostForm(s.WebhookURL, form)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode)
}
return nil
}
// Flush implementing method. empty.
func (s *SLACKWriter) Flush() {
}
// Destroy implementing method. empty.
func (s *SLACKWriter) Destroy() {
}
func init() {
Register(AdapterSlack, newSLACKWriter)
}

View File

@ -0,0 +1,168 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/smtp"
"strings"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server.
type SMTPWriter struct {
Username string `json:"username"`
Password string `json:"password"`
Host string `json:"host"`
Subject string `json:"subject"`
FromAddress string `json:"fromAddress"`
RecipientAddresses []string `json:"sendTos"`
Level int `json:"level"`
customFormatter func(*LogMsg) string
}
// NewSMTPWriter creates the smtp writer.
func newSMTPWriter() Logger {
return &SMTPWriter{Level: LevelTrace}
}
// Init smtp writer with json config.
// config like:
// {
// "username":"example@gmail.com",
// "password:"password",
// "host":"smtp.gmail.com:465",
// "subject":"email title",
// "fromAddress":"from@example.com",
// "sendTos":["email1","email2"],
// "level":LevelError
// }
func (s *SMTPWriter) Init(jsonConfig string, opts ...utils.KV) error {
for _, elem := range opts {
if elem.GetKey() == "formatter" {
formatter, err := GetFormatter(elem)
if err != nil {
return err
}
s.customFormatter = formatter
}
}
return json.Unmarshal([]byte(jsonConfig), s)
}
func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 {
return nil
}
return smtp.PlainAuth(
"",
s.Username,
s.Password,
host,
)
}
func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error {
client, err := smtp.Dial(hostAddressWithPort)
if err != nil {
return err
}
host, _, _ := net.SplitHostPort(hostAddressWithPort)
tlsConn := &tls.Config{
InsecureSkipVerify: true,
ServerName: host,
}
if err = client.StartTLS(tlsConn); err != nil {
return err
}
if auth != nil {
if err = client.Auth(auth); err != nil {
return err
}
}
if err = client.Mail(fromAddress); err != nil {
return err
}
for _, rec := range recipients {
if err = client.Rcpt(rec); err != nil {
return err
}
}
w, err := client.Data()
if err != nil {
return err
}
_, err = w.Write(msgContent)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
return client.Quit()
}
func (s *SMTPWriter) Format(lm *LogMsg) string {
return lm.Msg
}
// WriteMsg writes message in smtp writer.
// Sends an email with subject and only this message.
func (s *SMTPWriter) WriteMsg(lm *LogMsg) error {
if lm.Level > s.Level {
return nil
}
hp := strings.Split(s.Host, ":")
// Set up authentication information.
auth := s.getSMTPAuth(hp[0])
msg := s.Format(lm)
// Connect to the server, authenticate, set the sender and recipient,
// and send the email all in one step.
contentType := "Content-Type: text/plain" + "; charset=UTF-8"
mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress +
">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", lm.When.Format("2006-01-02 15:04:05")) + msg)
return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg)
}
// Flush implementing method. empty.
func (s *SMTPWriter) Flush() {
}
// Destroy implementing method. empty.
func (s *SMTPWriter) Destroy() {
}
func init() {
Register(AdapterMail, newSMTPWriter)
}

View File

@ -0,0 +1,27 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"testing"
"time"
)
func TestSmtp(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`)
log.Critical("sendmail critical")
time.Sleep(time.Second * 30)
}

View File

@ -0,0 +1,114 @@
session
==============
session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`.
## How to install?
go get github.com/astaxie/beego/session
## What providers are supported?
As of now this session manager support memory, file, Redis and MySQL.
## How to use it?
First you must import it
import (
"github.com/astaxie/beego/session"
)
Then in you web app init the global session manager
var globalSessions *session.Manager
* Use **memory** as provider:
func init() {
globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
go globalSessions.GC()
}
* Use **file** as provider, the last param is the path where you want file to be stored:
func init() {
globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`)
go globalSessions.GC()
}
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
func init() {
globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`)
go globalSessions.GC()
}
* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name):
func init() {
globalSessions, _ = session.NewManager(
"mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`)
go globalSessions.GC()
}
* Use **Cookie** as provider:
func init() {
globalSessions, _ = session.NewManager(
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
go globalSessions.GC()
}
Finally in the handlerfunc you can use it like this
func login(w http.ResponseWriter, r *http.Request) {
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease(w)
username := sess.Get("username")
fmt.Println(username)
if r.Method == "GET" {
t, _ := template.ParseFiles("login.gtpl")
t.Execute(w, nil)
} else {
fmt.Println("username:", r.Form["username"])
sess.Set("username", r.Form["username"])
fmt.Println("password:", r.Form["password"])
}
}
## How to write own provider?
When you develop a web app, maybe you want to write own provider because you must meet the requirements.
Writing a provider is easy. You only need to define two struct types
(Session and Provider), which satisfy the interface definition.
Maybe you will find the **memory** provider is a good example.
type SessionStore interface {
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data
}
type Provider interface {
SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) (bool, error)
SessionRegenerate(oldsid, sid string) (SessionStore, error)
SessionDestroy(sid string) error
SessionAll() int //get all active session
SessionGC()
}
## LICENSE
BSD License http://creativecommons.org/licenses/BSD/

View File

@ -0,0 +1,248 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package couchbase for session provider
//
// depend on github.com/couchbaselabs/go-couchbasee
//
// go install github.com/couchbaselabs/go-couchbase
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/couchbase"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package couchbase
import (
"context"
"net/http"
"strings"
"sync"
couchbase "github.com/couchbase/go-couchbase"
"github.com/astaxie/beego/pkg/infrastructure/session"
)
var couchbpder = &Provider{}
// SessionStore store each session
type SessionStore struct {
b *couchbase.Bucket
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Provider couchabse provided
type Provider struct {
maxlifetime int64
savePath string
pool string
bucket string
b *couchbase.Bucket
}
// Set value to couchabse session
func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values[key] = value
return nil
}
// Get value from couchabse session
func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
cs.lock.RLock()
defer cs.lock.RUnlock()
if v, ok := cs.values[key]; ok {
return v
}
return nil
}
// Delete value in couchbase session by given key
func (cs *SessionStore) Delete(ctx context.Context, key interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
delete(cs.values, key)
return nil
}
// Flush Clean all values in couchbase session
func (cs *SessionStore) Flush(context.Context) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values = make(map[interface{}]interface{})
return nil
}
// SessionID Get couchbase session store id
func (cs *SessionStore) SessionID(context.Context) string {
return cs.sid
}
// SessionRelease Write couchbase session with Gob string
func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer cs.b.Close()
bo, err := session.EncodeGob(cs.values)
if err != nil {
return
}
cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
}
func (cp *Provider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.savePath)
if err != nil {
return nil
}
pool, err := c.GetPool(cp.pool)
if err != nil {
return nil
}
bucket, err := pool.GetBucket(cp.bucket)
if err != nil {
return nil
}
return bucket
}
// SessionInit init couchbase session
// savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
cp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
cp.savePath = configs[0]
}
if len(configs) > 1 {
cp.pool = configs[1]
}
if len(configs) > 2 {
cp.bucket = configs[2]
}
return nil
}
// SessionRead read couchbase session by sid
func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var (
kv map[interface{}]interface{}
err error
doc []byte
)
err = cp.b.Get(sid, &doc)
if err != nil {
return nil, err
} else if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
// SessionExist Check couchbase session exist.
// it checkes sid exist or not.
func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
cp.b = cp.getBucket()
defer cp.b.Close()
var doc []byte
if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
return false, err
}
return true, nil
}
// SessionRegenerate remove oldsid and use sid to generate new session
func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var doc []byte
if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil {
cp.b.Set(sid, int(cp.maxlifetime), "")
} else {
err := cp.b.Delete(oldsid)
if err != nil {
return nil, err
}
_, _ = cp.b.Add(sid, int(cp.maxlifetime), doc)
}
err := cp.b.Get(sid, &doc)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
// SessionDestroy Remove bucket in this couchbase
func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error {
cp.b = cp.getBucket()
defer cp.b.Close()
cp.b.Delete(sid)
return nil
}
// SessionGC Recycle
func (cp *Provider) SessionGC(context.Context) {
}
// SessionAll return all active session
func (cp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("couchbase", couchbpder)
}

View File

@ -0,0 +1,174 @@
// Package ledis provide session Provider
package ledis
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"github.com/ledisdb/ledisdb/config"
"github.com/ledisdb/ledisdb/ledis"
"github.com/astaxie/beego/pkg/infrastructure/session"
)
var (
ledispder = &Provider{}
c *ledis.DB
)
// SessionStore ledis session store
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in ledis session
func (ls *SessionStore) Set(ctx context.Context, key, value interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values[key] = value
return nil
}
// Get value in ledis session
func (ls *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
ls.lock.RLock()
defer ls.lock.RUnlock()
if v, ok := ls.values[key]; ok {
return v
}
return nil
}
// Delete value in ledis session
func (ls *SessionStore) Delete(ctx context.Context, key interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
delete(ls.values, key)
return nil
}
// Flush clear all values in ledis session
func (ls *SessionStore) Flush(context.Context) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values = make(map[interface{}]interface{})
return nil
}
// SessionID get ledis session id
func (ls *SessionStore) SessionID(context.Context) string {
return ls.sid
}
// SessionRelease save session values to ledis
func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(ls.values)
if err != nil {
return
}
c.Set([]byte(ls.sid), b)
c.Expire([]byte(ls.sid), ls.maxlifetime)
}
// Provider ledis session provider
type Provider struct {
maxlifetime int64
savePath string
db int
}
// SessionInit init ledis session
// savepath like ledis server saveDataPath,pool size
// e.g. 127.0.0.1:6379,100,astaxie
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
var err error
lp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) == 1 {
lp.savePath = configs[0]
} else if len(configs) == 2 {
lp.savePath = configs[0]
lp.db, err = strconv.Atoi(configs[1])
if err != nil {
return err
}
}
cfg := new(config.Config)
cfg.DataDir = lp.savePath
var ledisInstance *ledis.Ledis
ledisInstance, err = ledis.Open(cfg)
if err != nil {
return err
}
c, err = ledisInstance.Select(lp.db)
return err
}
// SessionRead read ledis session by sid
func (lp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var (
kv map[interface{}]interface{}
err error
)
kvs, _ := c.Get([]byte(sid))
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob(kvs); err != nil {
return nil, err
}
}
ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
return ls, nil
}
// SessionExist check ledis session exist by sid
func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
count, _ := c.Exists([]byte(sid))
return count != 0, nil
}
// SessionRegenerate generate new sid for ledis session
func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
count, _ := c.Exists([]byte(sid))
if count == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set([]byte(sid), []byte(""))
c.Expire([]byte(sid), lp.maxlifetime)
} else {
data, _ := c.Get([]byte(oldsid))
c.Set([]byte(sid), data)
c.Expire([]byte(sid), lp.maxlifetime)
}
return lp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete ledis session by id
func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c.Del([]byte(sid))
return nil
}
// SessionGC Impelment method, no used.
func (lp *Provider) SessionGC(context.Context) {
}
// SessionAll return all active session
func (lp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("ledis", ledispder)
}

View File

@ -0,0 +1,231 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package memcache for session provider
//
// depend on github.com/bradfitz/gomemcache/memcache
//
// go install github.com/bradfitz/gomemcache/memcache
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/memcache"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package memcache
import (
"context"
"net/http"
"strings"
"sync"
"github.com/astaxie/beego/pkg/infrastructure/session"
"github.com/bradfitz/gomemcache/memcache"
)
var mempder = &MemProvider{}
var client *memcache.Client
// SessionStore memcache session store
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in memcache session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in memcache session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in memcache session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in memcache session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get memcache session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to memcache
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)}
client.Set(&item)
}
// MemProvider memcache session provider
type MemProvider struct {
maxlifetime int64
conninfo []string
poolsize int
password string
}
// SessionInit init memcache session
// savepath like
// e.g. 127.0.0.1:9090
func (rp *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
rp.conninfo = strings.Split(savePath, ";")
client = memcache.New(rp.conninfo...)
return nil
}
// SessionRead read memcache session by sid
func (rp *MemProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
}
}
item, err := client.Get(sid)
if err != nil {
if err == memcache.ErrCacheMiss {
rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime}
return rs, nil
}
return nil, err
}
var kv map[interface{}]interface{}
if len(item.Value) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(item.Value)
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check memcache session exist by sid
func (rp *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return false, err
}
}
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for memcache session
func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
}
}
var contain []byte
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
item.Key = sid
item.Value = []byte("")
item.Expiration = int32(rp.maxlifetime)
client.Set(item)
} else {
client.Delete(oldsid)
item.Key = sid
item.Expiration = int32(rp.maxlifetime)
client.Set(item)
contain = item.Value
}
var kv map[interface{}]interface{}
if len(contain) == 0 {
kv = make(map[interface{}]interface{})
} else {
var err error
kv, err = session.DecodeGob(contain)
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionDestroy delete memcache session by id
func (rp *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
if client == nil {
if err := rp.connectInit(); err != nil {
return err
}
}
return client.Delete(sid)
}
func (rp *MemProvider) connectInit() error {
client = memcache.New(rp.conninfo...)
return nil
}
// SessionGC Impelment method, no used.
func (rp *MemProvider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *MemProvider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("memcache", mempder)
}

View File

@ -0,0 +1,235 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package mysql for session provider
//
// depends on github.com/go-sql-driver/mysql:
//
// go install github.com/go-sql-driver/mysql
//
// mysql session support need create table as sql:
// CREATE TABLE `session` (
// `session_key` char(64) NOT NULL,
// `session_data` blob,
// `session_expiry` int(11) unsigned NOT NULL,
// PRIMARY KEY (`session_key`)
// ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/mysql"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package mysql
import (
"context"
"database/sql"
"net/http"
"sync"
"time"
"github.com/astaxie/beego/pkg/infrastructure/session"
// import mysql driver
_ "github.com/go-sql-driver/mysql"
)
var (
// TableName store the session in MySQL
TableName = "session"
mysqlpder = &Provider{}
)
// SessionStore mysql session store
type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// Set value in mysql session.
// it is temp value in map.
func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from mysql session
func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
}
return nil
}
// Delete value in mysql session
func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Flush clear all values in mysql session
func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// SessionID get session id of this mysql session store
func (st *SessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease save mysql session values to database.
// must call this method to save values to database.
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid)
}
// Provider mysql session provider
type Provider struct {
maxlifetime int64
savePath string
}
// connect to mysql
func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("mysql", mp.savePath)
if e != nil {
return nil
}
return db
}
// SessionInit init mysql session.
// savepath is the connection string of mysql.
func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// SessionRead get mysql session by sid
func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionExist check mysql session exist
func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for mysql session
func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
}
c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionDestroy delete mysql session by sid
func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM "+TableName+" where session_key=?", sid)
c.Close()
return nil
}
// SessionGC delete expired values in mysql session
func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit()
c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
c.Close()
}
// SessionAll count values in mysql session
func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total)
if err != nil {
return 0
}
return total
}
func init() {
session.Register("mysql", mysqlpder)
}

View File

@ -0,0 +1,250 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package postgres for session provider
//
// depends on github.com/lib/pq:
//
// go install github.com/lib/pq
//
//
// needs this table in your database:
//
// CREATE TABLE session (
// session_key char(64) NOT NULL,
// session_data bytea,
// session_expiry timestamp NOT NULL,
// CONSTRAINT session_key PRIMARY KEY(session_key)
// );
//
// will be activated with these settings in app.conf:
//
// SessionOn = true
// SessionProvider = postgresql
// SessionSavePath = "user=a password=b dbname=c sslmode=disable"
// SessionName = session
//
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/postgresql"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package postgres
import (
"context"
"database/sql"
"net/http"
"sync"
"time"
"github.com/astaxie/beego/pkg/infrastructure/session"
// import postgresql Driver
_ "github.com/lib/pq"
)
var postgresqlpder = &Provider{}
// SessionStore postgresql session store
type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// Set value in postgresql session.
// it is temp value in map.
func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from postgresql session
func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
}
return nil
}
// Delete value in postgresql session
func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Flush clear all values in postgresql session
func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// SessionID get session id of this postgresql session store
func (st *SessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease save postgresql session values to database.
// must call this method to save values to database.
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3",
b, time.Now().Format(time.RFC3339), st.sid)
}
// Provider postgresql session provider
type Provider struct {
maxlifetime int64
savePath string
}
// connect to postgresql
func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("postgres", mp.savePath)
if e != nil {
return nil
}
return db
}
// SessionInit init postgresql session.
// savepath is the connection string of postgresql.
func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// SessionRead get postgresql session by sid
func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
_, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
sid, "", time.Now().Format(time.RFC3339))
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionExist check postgresql session exist
func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for postgresql session
func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
oldsid, "", time.Now().Format(time.RFC3339))
}
c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionDestroy delete postgresql session by sid
func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=$1", sid)
c.Close()
return nil
}
// SessionGC delete expired values in postgresql session
func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
}
// SessionAll count values in postgresql session
func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
if err != nil {
return 0
}
return total
}
func init() {
session.Register("postgresql", postgresqlpder)
}

View File

@ -0,0 +1,252 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/gomodule/redigo/redis
//
// go install github.com/gomodule/redigo/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package redis
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/pkg/infrastructure/session"
"github.com/go-redis/redis/v7"
)
var redispder = &Provider{}
// MaxPoolSize redis max pool size
var MaxPoolSize = 100
// SessionStore redis session store
type SessionStore struct {
p *redis.Client
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *redis.Client
}
// SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
// e.g. 127.0.0.1:6379,100,astaxie,0,30
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
rp.idleTimeout = time.Duration(timeout) * time.Second
}
}
if len(configs) > 5 {
checkFrequency, err := strconv.Atoi(configs[5])
if err == nil && checkFrequency > 0 {
rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second
}
}
if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
}
}
rp.poollist = redis.NewClient(&redis.Options{
Addr: rp.savePath,
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis session by sid
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis session exist by sid
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for redis session
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, _ := c.Exists(oldsid).Result(); existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Do(c.Context(), "SET", sid, "", "EX", rp.maxlifetime)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("redis", redispder)
}

View File

@ -0,0 +1,96 @@
package redis
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/astaxie/beego/pkg/infrastructure/session"
)
func TestRedis(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
}
redisAddr := os.Getenv("REDIS_ADDR")
if redisAddr == "" {
redisAddr = "127.0.0.1:6379"
}
sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr)
globalSession, err := session.NewManager("redis", sessionConfig)
if err != nil {
t.Fatal("could not create manager:", err)
}
go globalSession.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSession.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(nil, w)
// SET AND GET
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete(nil, "username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get(nil, "password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush(nil)
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get(nil, "password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(nil, w)
}

View File

@ -0,0 +1,247 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/go-redis/redis
//
// go install github.com/go-redis/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis_cluster"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package redis_cluster
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/pkg/infrastructure/session"
rediss "github.com/go-redis/redis/v7"
)
var redispder = &Provider{}
// MaxPoolSize redis_cluster max pool size
var MaxPoolSize = 1000
// SessionStore redis_cluster session store
type SessionStore struct {
p *rediss.ClusterClient
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis_cluster session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis_cluster session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis_cluster session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis_cluster session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis_cluster session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis_cluster
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis_cluster session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *rediss.ClusterClient
}
// SessionInit init redis_cluster session
// savepath like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
rp.idleTimeout = time.Duration(timeout) * time.Second
}
}
if len(configs) > 5 {
checkFrequency, err := strconv.Atoi(configs[5])
if err == nil && checkFrequency > 0 {
rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second
}
}
if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
}
}
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
Addrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_cluster session by sid
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != rediss.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis_cluster session exist by sid
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for redis_cluster session
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("redis_cluster", redispder)
}

View File

@ -0,0 +1,260 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/go-redis/redis
//
// go install github.com/go-redis/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis_sentinel"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``)
// go globalSessions.GC()
// }
//
// more detail about params: please check the notes on the function SessionInit in this package
package redis_sentinel
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/pkg/infrastructure/session"
"github.com/go-redis/redis/v7"
)
var redispder = &Provider{}
// DefaultPoolSize redis_sentinel default pool size
var DefaultPoolSize = 100
// SessionStore redis_sentinel session store
type SessionStore struct {
p *redis.Client
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis_sentinel session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis_sentinel session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis_sentinel session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis_sentinel session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis_sentinel session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis_sentinel
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis_sentinel session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *redis.Client
masterName string
}
// SessionInit init redis_sentinel session
// savepath like redis sentinel addr,pool size,password,dbnum,masterName
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = DefaultPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = DefaultPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
if configs[4] != "" {
rp.masterName = configs[4]
} else {
rp.masterName = "mymaster"
}
} else {
rp.masterName = "mymaster"
}
if len(configs) > 5 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
rp.idleTimeout = time.Duration(timeout) * time.Second
}
}
if len(configs) > 6 {
checkFrequency, err := strconv.Atoi(configs[5])
if err == nil && checkFrequency > 0 {
rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second
}
}
if len(configs) > 7 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
}
}
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
MasterName: rp.masterName,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_sentinel session by sid
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis_sentinel session exist by sid
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for redis_sentinel session
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("redis_sentinel", redispder)
}

View File

@ -0,0 +1,90 @@
package redis_sentinel
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/astaxie/beego/pkg/infrastructure/session"
)
func TestRedisSentinel(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
ProviderConfig: "127.0.0.1:6379,100,,0,master",
}
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
if e != nil {
t.Log(e)
return
}
//todo test if e==nil
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(nil, w)
// SET AND GET
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete(nil, "username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get(nil, "password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush(nil)
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get(nil, "password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(nil, w)
}

View File

@ -0,0 +1,181 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/json"
"net/http"
"net/url"
"sync"
)
var cookiepder = &CookieProvider{}
// CookieSessionStore Cookie SessionStore
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} // session data
lock sync.RWMutex
}
// Set value to cookie session.
// the value are encoded as gob with hash block string.
func (st *CookieSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from cookie session
func (st *CookieSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
}
return nil
}
// Delete value in cookie session
func (st *CookieSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Flush Clean all values in cookie session
func (st *CookieSessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// SessionID Return id of this cookie session
func (st *CookieSessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
st.lock.Lock()
encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values)
st.lock.Unlock()
if err == nil {
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
Value: url.QueryEscape(encodedCookie),
Path: "/",
HttpOnly: true,
Secure: cookiepder.config.Secure,
MaxAge: cookiepder.config.Maxage}
http.SetCookie(w, cookie)
}
}
type cookieConfig struct {
SecurityKey string `json:"securityKey"`
BlockKey string `json:"blockKey"`
SecurityName string `json:"securityName"`
CookieName string `json:"cookieName"`
Secure bool `json:"secure"`
Maxage int `json:"maxage"`
}
// CookieProvider Cookie session provider
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
// SessionInit Init cookie session provider with max lifetime and config json.
// maxlifetime is ignored.
// json config:
// securityKey - hash string
// blockKey - gob encode hash string. it's saved as aes crypto.
// securityName - recognized name in encoded cookie string
// cookieName - cookie name
// maxage - cookie max life time.
func (pder *CookieProvider) SessionInit(ctx context.Context, maxlifetime int64, config string) error {
pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config)
if err != nil {
return err
}
if pder.config.BlockKey == "" {
pder.config.BlockKey = string(generateRandomKey(16))
}
if pder.config.SecurityName == "" {
pder.config.SecurityName = string(generateRandomKey(20))
}
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
if err != nil {
return err
}
pder.maxlifetime = maxlifetime
return nil
}
// SessionRead Get SessionStore in cooke.
// decode cooke string to map and put into SessionStore with sid.
func (pder *CookieProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
sid, pder.maxlifetime)
if maps == nil {
maps = make(map[interface{}]interface{})
}
rs := &CookieSessionStore{sid: sid, values: maps}
return rs, nil
}
// SessionExist Cookie session is always existed
func (pder *CookieProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
return true, nil
}
// SessionRegenerate Implement method, no used.
func (pder *CookieProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
return nil, nil
}
// SessionDestroy Implement method, no used.
func (pder *CookieProvider) SessionDestroy(ctx context.Context, sid string) error {
return nil
}
// SessionGC Implement method, no used.
func (pder *CookieProvider) SessionGC(context.Context) {
}
// SessionAll Implement method, return 0.
func (pder *CookieProvider) SessionAll(context.Context) int {
return 0
}
// SessionUpdate Implement method, no used.
func (pder *CookieProvider) SessionUpdate(ctx context.Context, sid string) error {
return nil
}
func init() {
Register("cookie", cookiepder)
}

View File

@ -0,0 +1,105 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil {
t.Fatal("json decode error", err)
}
globalSessions, err := NewManager("cookie", conf)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("set error,", err)
}
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error")
}
sess.SessionRelease(nil, w)
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}
func TestDestorySessionCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil {
t.Fatal("json decode error", err)
}
globalSessions, err := NewManager("cookie", conf)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
session, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start err,", err)
}
// request again ,will get same sesssion id .
r1, _ := http.NewRequest("GET", "/", nil)
r1.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
w = httptest.NewRecorder()
newSession, err := globalSessions.SessionStart(w, r1)
if err != nil {
t.Fatal("session start err,", err)
}
if newSession.SessionID(nil) != session.SessionID(nil) {
t.Fatal("get cookie session id is not the same again.")
}
// After destroy session , will get a new session id .
globalSessions.SessionDestroy(w, r1)
r2, _ := http.NewRequest("GET", "/", nil)
r2.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
w = httptest.NewRecorder()
newSession, err = globalSessions.SessionStart(w, r2)
if err != nil {
t.Fatal("session start error")
}
if newSession.SessionID(nil) == session.SessionID(nil) {
t.Fatal("after destroy session and reqeust again ,get cookie session id is same.")
}
}

View File

@ -0,0 +1,316 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
)
var (
filepder = &FileProvider{}
gcmaxlifetime int64
)
// FileSessionStore File session store
type FileSessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// Set value to file session
func (fs *FileSessionStore) Set(ctx context.Context, key, value interface{}) error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values[key] = value
return nil
}
// Get value from file session
func (fs *FileSessionStore) Get(ctx context.Context, key interface{}) interface{} {
fs.lock.RLock()
defer fs.lock.RUnlock()
if v, ok := fs.values[key]; ok {
return v
}
return nil
}
// Delete value in file session by given key
func (fs *FileSessionStore) Delete(ctx context.Context, key interface{}) error {
fs.lock.Lock()
defer fs.lock.Unlock()
delete(fs.values, key)
return nil
}
// Flush Clean all values in file session
func (fs *FileSessionStore) Flush(context.Context) error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values = make(map[interface{}]interface{})
return nil
}
// SessionID Get file session store id
func (fs *FileSessionStore) SessionID(context.Context) string {
return fs.sid
}
// SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
b, err := EncodeGob(fs.values)
if err != nil {
SLogger.Println(err)
return
}
_, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777)
if err != nil {
SLogger.Println(err)
return
}
} else if os.IsNotExist(err) {
f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
if err != nil {
SLogger.Println(err)
return
}
} else {
return
}
f.Truncate(0)
f.Seek(0, 0)
f.Write(b)
f.Close()
}
// FileProvider File session provider
type FileProvider struct {
lock sync.RWMutex
maxlifetime int64
savePath string
}
// SessionInit Init file session provider.
// savePath sets the session files path.
func (fp *FileProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
fp.maxlifetime = maxlifetime
fp.savePath = savePath
return nil
}
// SessionRead Read file session by sid.
// if file is not exist, create it.
// the file path is generated from sid string.
func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
invalidChars := "./"
if strings.ContainsAny(sid, invalidChars) {
return nil, errors.New("the sid shouldn't have following characters: " + invalidChars)
}
if len(sid) < 2 {
return nil, errors.New("length of the sid is less than 2")
}
filepder.lock.Lock()
defer filepder.lock.Unlock()
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755)
if err != nil {
SLogger.Println(err.Error())
}
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777)
} else if os.IsNotExist(err) {
f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
} else {
return nil, err
}
defer f.Close()
os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
var kv map[interface{}]interface{}
b, err := ioutil.ReadAll(f)
if err != nil {
return nil, err
}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = DecodeGob(b)
if err != nil {
return nil, err
}
}
ss := &FileSessionStore{sid: sid, values: kv}
return ss, nil
}
// SessionExist Check file session exist.
// it checks the file named from sid exist or not.
func (fp *FileProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
if len(sid) < 2 {
SLogger.Println("min length of session id is 2 but got length: ", sid)
return false, errors.New("min length of session id is 2")
}
_, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
return err == nil, nil
}
// SessionDestroy Remove all files in this save path
func (fp *FileProvider) SessionDestroy(ctx context.Context, sid string) error {
filepder.lock.Lock()
defer filepder.lock.Unlock()
os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
return nil
}
// SessionGC Recycle files in save path
func (fp *FileProvider) SessionGC(context.Context) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
gcmaxlifetime = fp.maxlifetime
filepath.Walk(fp.savePath, gcpath)
}
// SessionAll Get active file session number.
// it walks save path to count files.
func (fp *FileProvider) SessionAll(context.Context) int {
a := &activeSession{}
err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error {
return a.visit(path, f, err)
})
if err != nil {
SLogger.Printf("filepath.Walk() returned %v\n", err)
return 0
}
return a.total
}
// SessionRegenerate Generate new sid for file session.
// it delete old file and create new file named from new sid.
func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]))
oldSidFile := path.Join(oldPath, oldsid)
newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1]))
newSidFile := path.Join(newPath, sid)
// new sid file is exist
_, err := os.Stat(newSidFile)
if err == nil {
return nil, fmt.Errorf("newsid %s exist", newSidFile)
}
err = os.MkdirAll(newPath, 0755)
if err != nil {
SLogger.Println(err.Error())
}
// if old sid file exist
// 1.read and parse file content
// 2.write content to new sid file
// 3.remove old sid file, change new sid file atime and ctime
// 4.return FileSessionStore
_, err = os.Stat(oldSidFile)
if err == nil {
b, err := ioutil.ReadFile(oldSidFile)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = DecodeGob(b)
if err != nil {
return nil, err
}
}
ioutil.WriteFile(newSidFile, b, 0777)
os.Remove(oldSidFile)
os.Chtimes(newSidFile, time.Now(), time.Now())
ss := &FileSessionStore{sid: sid, values: kv}
return ss, nil
}
// if old sid file not exist, just create new sid file and return
newf, err := os.Create(newSidFile)
if err != nil {
return nil, err
}
newf.Close()
ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})}
return ss, nil
}
// remove file in save path if expired
func gcpath(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() {
os.Remove(path)
}
return nil
}
type activeSession struct {
total int
}
func (as *activeSession) visit(paths string, f os.FileInfo, err error) error {
if err != nil {
return err
}
if f.IsDir() {
return nil
}
as.total = as.total + 1
return nil
}
func init() {
Register("file", filepder)
}

View File

@ -0,0 +1,427 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
)
const sid = "Session_id"
const sidNew = "Session_id_new"
const sessionPath = "./_session_runtime"
var (
mutex sync.Mutex
)
func TestFileProvider_SessionInit(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
if fp.maxlifetime != 180 {
t.Error()
}
if fp.savePath != sessionPath {
t.Error()
}
}
func TestFileProvider_SessionExist(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
_, err = fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
}
func TestFileProvider_SessionExist2(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
exists, err = fp.SessionExist(context.Background(), "")
if err == nil {
t.Error()
}
if exists {
t.Error()
}
exists, err = fp.SessionExist(context.Background(), "1")
if err == nil {
t.Error()
}
if exists {
t.Error()
}
}
func TestFileProvider_SessionRead(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
s, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
_ = s.Set(nil, "sessionValue", 18975)
v := s.Get(nil, "sessionValue")
if v.(int) != 18975 {
t.Error()
}
}
func TestFileProvider_SessionRead1(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(context.Background(), "")
if err == nil {
t.Error(err)
}
_, err = fp.SessionRead(context.Background(), "1")
if err == nil {
t.Error(err)
}
}
func TestFileProvider_SessionAll(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 546
for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
}
if fp.SessionAll(nil) != sessionCount {
t.Error()
}
}
func TestFileProvider_SessionRegenerate(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
_, err = fp.SessionRegenerate(context.Background(), sid, sidNew)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
exists, err = fp.SessionExist(context.Background(), sidNew)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
}
func TestFileProvider_SessionDestroy(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
err = fp.SessionDestroy(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
}
func TestFileProvider_SessionGC(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 1, sessionPath)
sessionCount := 412
for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
}
time.Sleep(2 * time.Second)
fp.SessionGC(nil)
if fp.SessionAll(nil) != 0 {
t.Error()
}
}
func TestFileSessionStore_Set(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
err := s.Set(nil, i, i)
if err != nil {
t.Error(err)
}
}
}
func TestFileSessionStore_Get(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
_ = s.Set(nil, i, i)
v := s.Get(nil, i)
if v.(int) != i {
t.Error()
}
}
}
func TestFileSessionStore_Delete(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
s, _ := fp.SessionRead(context.Background(), sid)
s.Set(nil, "1", 1)
if s.Get(nil, "1") == nil {
t.Error()
}
s.Delete(nil, "1")
if s.Get(nil, "1") != nil {
t.Error()
}
}
func TestFileSessionStore_Flush(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
_ = s.Set(nil, i, i)
}
_ = s.Flush(nil)
for i := 1; i <= sessionCount; i++ {
if s.Get(nil, i) != nil {
t.Error()
}
}
}
func TestFileSessionStore_SessionID(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 85
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) {
t.Error(err)
}
}
}
func TestFileSessionStore_SessionRelease(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
filepder.savePath = sessionPath
sessionCount := 85
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
s.Set(nil, i, i)
s.SessionRelease(nil, nil)
}
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
if s.Get(nil, i).(int) != i {
t.Error()
}
}
}

View File

@ -0,0 +1,197 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"container/list"
"context"
"net/http"
"sync"
"time"
)
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
// MemSessionStore memory session store.
// it saved sessions in a map in memory.
type MemSessionStore struct {
sid string //session id
timeAccessed time.Time //last access time
value map[interface{}]interface{} //session store
lock sync.RWMutex
}
// Set value to memory session
func (st *MemSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.value[key] = value
return nil
}
// Get value from memory session by key
func (st *MemSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.value[key]; ok {
return v
}
return nil
}
// Delete in memory session by key
func (st *MemSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.value, key)
return nil
}
// Flush clear all values in memory session
func (st *MemSessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.value = make(map[interface{}]interface{})
return nil
}
// SessionID get this id of memory session store
func (st *MemSessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease Implement method, no used.
func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
}
// MemProvider Implement the provider interface
type MemProvider struct {
lock sync.RWMutex // locker
sessions map[string]*list.Element // map in memory
list *list.List // for gc
maxlifetime int64
savePath string
}
// SessionInit init memory session
func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
pder.maxlifetime = maxlifetime
pder.savePath = savePath
return nil
}
// SessionRead get memory session store by sid
func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[sid]; ok {
go pder.SessionUpdate(nil, sid)
pder.lock.RUnlock()
return element.Value.(*MemSessionStore), nil
}
pder.lock.RUnlock()
pder.lock.Lock()
newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
element := pder.list.PushFront(newsess)
pder.sessions[sid] = element
pder.lock.Unlock()
return newsess, nil
}
// SessionExist check session store exist in memory session by sid
func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
pder.lock.RLock()
defer pder.lock.RUnlock()
if _, ok := pder.sessions[sid]; ok {
return true, nil
}
return false, nil
}
// SessionRegenerate generate new sid for session store in memory session
func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[oldsid]; ok {
go pder.SessionUpdate(nil, oldsid)
pder.lock.RUnlock()
pder.lock.Lock()
element.Value.(*MemSessionStore).sid = sid
pder.sessions[sid] = element
delete(pder.sessions, oldsid)
pder.lock.Unlock()
return element.Value.(*MemSessionStore), nil
}
pder.lock.RUnlock()
pder.lock.Lock()
newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
element := pder.list.PushFront(newsess)
pder.sessions[sid] = element
pder.lock.Unlock()
return newsess, nil
}
// SessionDestroy delete session store in memory session by id
func (pder *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok {
delete(pder.sessions, sid)
pder.list.Remove(element)
return nil
}
return nil
}
// SessionGC clean expired session stores in memory session
func (pder *MemProvider) SessionGC(context.Context) {
pder.lock.RLock()
for {
element := pder.list.Back()
if element == nil {
break
}
if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() {
pder.lock.RUnlock()
pder.lock.Lock()
pder.list.Remove(element)
delete(pder.sessions, element.Value.(*MemSessionStore).sid)
pder.lock.Unlock()
pder.lock.RLock()
} else {
break
}
}
pder.lock.RUnlock()
}
// SessionAll get count number of memory session
func (pder *MemProvider) SessionAll(context.Context) int {
return pder.list.Len()
}
// SessionUpdate expand time of session store by id in memory session
func (pder *MemProvider) SessionUpdate(ctx context.Context, sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok {
element.Value.(*MemSessionStore).timeAccessed = time.Now()
pder.list.MoveToFront(element)
return nil
}
return nil
}
func init() {
Register("memory", mempder)
}

View File

@ -0,0 +1,58 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestMem(t *testing.T) {
config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}`
conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil {
t.Fatal("json decode error", err)
}
globalSessions, _ := NewManager("memory", conf)
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("set error,", err)
}
defer sess.SessionRelease(nil, w)
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error")
}
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -0,0 +1,131 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"crypto/aes"
"encoding/json"
"testing"
)
func Test_gob(t *testing.T) {
a := make(map[interface{}]interface{})
a["username"] = "astaxie"
a[12] = 234
a["user"] = User{"asta", "xie"}
b, err := EncodeGob(a)
if err != nil {
t.Error(err)
}
c, err := DecodeGob(b)
if err != nil {
t.Error(err)
}
if len(c) == 0 {
t.Error("decodeGob empty")
}
if c["username"] != "astaxie" {
t.Error("decode string error")
}
if c[12] != 234 {
t.Error("decode int error")
}
if c["user"].(User).Username != "asta" {
t.Error("decode struct error")
}
}
type User struct {
Username string
NickName string
}
func TestGenerate(t *testing.T) {
str := generateRandomKey(20)
if len(str) != 20 {
t.Fatal("generate length is not equal to 20")
}
}
func TestCookieEncodeDecode(t *testing.T) {
hashKey := "testhashKey"
blockkey := generateRandomKey(16)
block, err := aes.NewCipher(blockkey)
if err != nil {
t.Fatal("NewCipher:", err)
}
securityName := string(generateRandomKey(20))
val := make(map[interface{}]interface{})
val["name"] = "astaxie"
val["gender"] = "male"
str, err := encodeCookie(block, hashKey, securityName, val)
if err != nil {
t.Fatal("encodeCookie:", err)
}
dst, err := decodeCookie(block, hashKey, securityName, str, 3600)
if err != nil {
t.Fatal("decodeCookie", err)
}
if dst["name"] != "astaxie" {
t.Fatal("dst get map error")
}
if dst["gender"] != "male" {
t.Fatal("dst get map error")
}
}
func TestParseConfig(t *testing.T) {
s := `{"cookieName":"gosessionid","gclifetime":3600}`
cf := new(ManagerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(s), cf)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
cf2 := new(ManagerConfig)
cf2.EnableSetCookie = true
err = json.Unmarshal([]byte(cc), cf2)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf2.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf2.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
if cf2.EnableSetCookie {
t.Fatal("parseconfig get enableSetCookie error")
}
cconfig := new(cookieConfig)
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
if err != nil {
t.Fatal("parse ProviderConfig err,", err)
}
if cconfig.CookieName != "gosessionid" {
t.Fatal("ProviderConfig get cookieName error")
}
if cconfig.SecurityKey != "beegocookiehashkey" {
t.Fatal("ProviderConfig get securityKey error")
}
}

View File

@ -0,0 +1,207 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"io"
"strconv"
"time"
"github.com/astaxie/beego/pkg/infrastructure/utils"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
// EncodeGob encode the obj to gob
func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
for _, v := range obj {
gob.Register(v)
}
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
// DecodeGob decode data to map
func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}
// generateRandomKey creates a random key with the given strength.
func generateRandomKey(strength int) []byte {
k := make([]byte, strength)
if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil {
return utils.RandomCreateBytes(strength)
}
return k
}
// Encryption -----------------------------------------------------------------
// encrypt encrypts a value using the given block in counter mode.
//
// A random initialization vector (http://goo.gl/zF67k) with the length of the
// block size is prepended to the resulting ciphertext.
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := generateRandomKey(block.BlockSize())
if iv == nil {
return nil, errors.New("encrypt: failed to generate random iv")
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
// Return iv + ciphertext.
return append(iv, value...), nil
}
// decrypt decrypts a value using the given block in counter mode.
//
// The value to be decrypted must be prepended by a initialization vector
// (http://goo.gl/zF67k) with the length of the block size.
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
size := block.BlockSize()
if len(value) > size {
// Extract iv.
iv := value[:size]
// Extract ciphertext.
value = value[size:]
// Decrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
return value, nil
}
return nil, errors.New("decrypt: the value could not be decrypted")
}
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
var err error
var b []byte
// 1. EncodeGob.
if b, err = EncodeGob(value); err != nil {
return "", err
}
// 2. Encrypt (optional).
if b, err = encrypt(block, b); err != nil {
return "", err
}
b = encode(b)
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
h := hmac.New(sha256.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
// Append mac, remove name.
b = append(b, sig...)[len(name)+1:]
// 4. Encode to base64.
b = encode(b)
// Done.
return string(b), nil
}
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
// 1. Decode from base64.
b, err := decode([]byte(value))
if err != nil {
return nil, err
}
// 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return nil, errors.New("Decode: invalid value format")
}
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
h := hmac.New(sha256.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
return nil, errors.New("Decode: the value is not valid")
}
// 3. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
return nil, errors.New("Decode: invalid timestamp")
}
t2 := time.Now().UTC().Unix()
if t1 > t2 {
return nil, errors.New("Decode: timestamp is too new")
}
if t1 < t2-gcmaxlifetime {
return nil, errors.New("Decode: expired timestamp")
}
// 4. Decrypt (optional).
b, err = decode(parts[1])
if err != nil {
return nil, err
}
if b, err = decrypt(block, b); err != nil {
return nil, err
}
// 5. DecodeGob.
dst, err := DecodeGob(b)
if err != nil {
return nil, err
}
return dst, nil
}
// Encoding -------------------------------------------------------------------
// encode encodes a value using base64.
func encode(value []byte) []byte {
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
base64.URLEncoding.Encode(encoded, value)
return encoded
}
// decode decodes a cookie using base64.
func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
return nil, err
}
return decoded[:b], nil
}

View File

@ -0,0 +1,384 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package session provider
//
// Usage:
// import(
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package session
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/textproto"
"net/url"
"os"
"time"
)
// Store contains all data for one session process with specific id.
type Store interface {
Set(ctx context.Context, key, value interface{}) error //set session value
Get(ctx context.Context, key interface{}) interface{} //get session value
Delete(ctx context.Context, key interface{}) error //delete session value
SessionID(ctx context.Context) string //back current sessionID
SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush(ctx context.Context) error //delete all data
}
// Provider contains global session methods and saved SessionStores.
// it can operate a SessionStore by its id.
type Provider interface {
SessionInit(ctx context.Context, gclifetime int64, config string) error
SessionRead(ctx context.Context, sid string) (Store, error)
SessionExist(ctx context.Context, sid string) (bool, error)
SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error)
SessionDestroy(ctx context.Context, sid string) error
SessionAll(ctx context.Context) int //get all active session
SessionGC(ctx context.Context)
}
var provides = make(map[string]Provider)
// SLogger a helpful variable to log information about session
var SLogger = NewSessionLog(os.Stderr)
// Register makes a session provide available by the provided name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
func Register(name string, provide Provider) {
if provide == nil {
panic("session: Register provide is nil")
}
if _, dup := provides[name]; dup {
panic("session: Register called twice for provider " + name)
}
provides[name] = provide
}
//GetProvider
func GetProvider(name string) (Provider, error) {
provider, ok := provides[name]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name)
}
return provider, nil
}
// ManagerConfig define the session config
type ManagerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
DisableHTTPOnly bool `json:"disableHTTPOnly"`
Secure bool `json:"secure"`
CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"`
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
SessionIDPrefix string `json:"sessionIDPrefix"`
}
// Manager contains Provider and its configuration.
type Manager struct {
provider Provider
config *ManagerConfig
}
// NewManager Create new Manager with provider name and json config string.
// provider name:
// 1. cookie
// 2. file
// 3. memory
// 4. redis
// 5. mysql
// json config:
// 1. is https default false
// 2. hashfunc default sha1
// 3. hashkey default beegosessionkey
// 4. maxage default is none
func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
provider, ok := provides[provideName]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
}
if cf.Maxlifetime == 0 {
cf.Maxlifetime = cf.Gclifetime
}
if cf.EnableSidInHTTPHeader {
if cf.SessionNameInHTTPHeader == "" {
panic(errors.New("SessionNameInHTTPHeader is empty"))
}
strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader)
if cf.SessionNameInHTTPHeader != strMimeHeader {
strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader
panic(errors.New(strErrMsg))
}
}
err := provider.SessionInit(nil, cf.Maxlifetime, cf.ProviderConfig)
if err != nil {
return nil, err
}
if cf.SessionIDLength == 0 {
cf.SessionIDLength = 16
}
return &Manager{
provider,
cf,
}, nil
}
// GetProvider return current manager's provider
func (manager *Manager) GetProvider() Provider {
return manager.provider
}
// getSid retrieves session identifier from HTTP Request.
// First try to retrieve id by reading from cookie, session cookie name is configurable,
// if not exist, then retrieve id from querying parameters.
//
// error is not nil when there is anything wrong.
// sid is empty when need to generate a new session id
// otherwise return an valid session id.
func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName)
if errs != nil || cookie.Value == "" {
var sid string
if manager.config.EnableSidInURLQuery {
errs := r.ParseForm()
if errs != nil {
return "", errs
}
sid = r.FormValue(manager.config.CookieName)
}
// if not found in Cookie / param, then read it from request headers
if manager.config.EnableSidInHTTPHeader && sid == "" {
sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader]
if isFound && len(sids) != 0 {
return sids[0], nil
}
}
return sid, nil
}
// HTTP Request contains cookie for sessionid info.
return url.QueryUnescape(cookie.Value)
}
// SessionStart generate or read the session id from http request.
// if session id exists, return SessionStore with this id.
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) {
sid, errs := manager.getSid(r)
if errs != nil {
return nil, errs
}
if sid != "" {
exists, err := manager.provider.SessionExist(nil, sid)
if err != nil {
return nil, err
}
if exists {
return manager.provider.SessionRead(nil, sid)
}
}
// Generate a new session
sid, errs = manager.sessionID()
if errs != nil {
return nil, errs
}
session, err = manager.provider.SessionRead(nil, sid)
if err != nil {
return nil, err
}
cookie := &http.Cookie{
Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: !manager.config.DisableHTTPOnly,
Secure: manager.isSecure(r),
Domain: manager.config.Domain,
}
if manager.config.CookieLifeTime > 0 {
cookie.MaxAge = manager.config.CookieLifeTime
cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
}
r.AddCookie(cookie)
if manager.config.EnableSidInHTTPHeader {
r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
}
return
}
// SessionDestroy Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
if manager.config.EnableSidInHTTPHeader {
r.Header.Del(manager.config.SessionNameInHTTPHeader)
w.Header().Del(manager.config.SessionNameInHTTPHeader)
}
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
return
}
sid, _ := url.QueryUnescape(cookie.Value)
manager.provider.SessionDestroy(nil, sid)
if manager.config.EnableSetCookie {
expiration := time.Now()
cookie = &http.Cookie{Name: manager.config.CookieName,
Path: "/",
HttpOnly: !manager.config.DisableHTTPOnly,
Expires: expiration,
MaxAge: -1,
Domain: manager.config.Domain}
http.SetCookie(w, cookie)
}
}
// GetSessionStore Get SessionStore by its id.
func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) {
sessions, err = manager.provider.SessionRead(nil, sid)
return
}
// GC Start session gc process.
// it can do gc in times after gc lifetime.
func (manager *Manager) GC() {
manager.provider.SessionGC(nil)
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
}
// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request.
func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) {
sid, err := manager.sessionID()
if err != nil {
return
}
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
//delete old cookie
session, _ = manager.provider.SessionRead(nil, sid)
cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: !manager.config.DisableHTTPOnly,
Secure: manager.isSecure(r),
Domain: manager.config.Domain,
}
} else {
oldsid, _ := url.QueryUnescape(cookie.Value)
session, _ = manager.provider.SessionRegenerate(nil, oldsid, sid)
cookie.Value = url.QueryEscape(sid)
cookie.HttpOnly = true
cookie.Path = "/"
}
if manager.config.CookieLifeTime > 0 {
cookie.MaxAge = manager.config.CookieLifeTime
cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
}
r.AddCookie(cookie)
if manager.config.EnableSidInHTTPHeader {
r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
}
return
}
// GetActiveSession Get all active sessions count number.
func (manager *Manager) GetActiveSession() int {
return manager.provider.SessionAll(nil)
}
// SetSecure Set cookie with https.
func (manager *Manager) SetSecure(secure bool) {
manager.config.Secure = secure
}
func (manager *Manager) sessionID() (string, error) {
b := make([]byte, manager.config.SessionIDLength)
n, err := rand.Read(b)
if n != len(b) || err != nil {
return "", fmt.Errorf("Could not successfully read from the system CSPRNG")
}
return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil
}
// Set cookie with https.
func (manager *Manager) isSecure(req *http.Request) bool {
if !manager.config.Secure {
return false
}
if req.URL.Scheme != "" {
return req.URL.Scheme == "https"
}
if req.TLS == nil {
return false
}
return true
}
// Log implement the log.Logger
type Log struct {
*log.Logger
}
// NewSessionLog set io.Writer to create a Logger for session.
func NewSessionLog(out io.Writer) *Log {
sl := new(Log)
sl.Logger = log.New(out, "[SESSION]", 1e9)
return sl
}

View File

@ -0,0 +1,200 @@
package ssdb
import (
"context"
"errors"
"net/http"
"strconv"
"strings"
"sync"
"github.com/astaxie/beego/pkg/infrastructure/session"
"github.com/ssdb/gossdb/ssdb"
)
var ssdbProvider = &Provider{}
// Provider holds ssdb client and configs
type Provider struct {
client *ssdb.Client
host string
port int
maxLifetime int64
}
func (p *Provider) connectInit() error {
var err error
if p.host == "" || p.port == 0 {
return errors.New("SessionInit First")
}
p.client, err = ssdb.Connect(p.host, p.port)
return err
}
// SessionInit init the ssdb with the config
func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error {
p.maxLifetime = maxLifetime
address := strings.Split(savePath, ":")
p.host = address[0]
var err error
if p.port, err = strconv.Atoi(address[1]); err != nil {
return err
}
return p.connectInit()
}
// SessionRead return a ssdb client session Store
func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
}
}
var kv map[interface{}]interface{}
value, err := p.client.Get(sid)
if err != nil {
return nil, err
}
if value == nil || len(value.(string)) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob([]byte(value.(string)))
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
return rs, nil
}
// SessionExist judged whether sid is exist in session
func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return false, err
}
}
value, err := p.client.Get(sid)
if err != nil {
panic(err)
}
if value == nil || len(value.(string)) == 0 {
return false, nil
}
return true, nil
}
// SessionRegenerate regenerate session with new sid and delete oldsid
func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
//conn.Do("setx", key, v, ttl)
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
}
}
value, err := p.client.Get(oldsid)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if value == nil || len(value.(string)) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob([]byte(value.(string)))
if err != nil {
return nil, err
}
_, err = p.client.Del(oldsid)
if err != nil {
return nil, err
}
}
_, e := p.client.Do("setx", sid, value, p.maxLifetime)
if e != nil {
return nil, e
}
rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
return rs, nil
}
// SessionDestroy destroy the sid
func (p *Provider) SessionDestroy(ctx context.Context, sid string) error {
if p.client == nil {
if err := p.connectInit(); err != nil {
return err
}
}
_, err := p.client.Del(sid)
return err
}
// SessionGC not implemented
func (p *Provider) SessionGC(context.Context) {
}
// SessionAll not implemented
func (p *Provider) SessionAll(context.Context) int {
return 0
}
// SessionStore holds the session information which stored in ssdb
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxLifetime int64
client *ssdb.Client
}
// Set the key and value
func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values[key] = value
return nil
}
// Get return the value by the key
func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
s.lock.Lock()
defer s.lock.Unlock()
if value, ok := s.values[key]; ok {
return value
}
return nil
}
// Delete the key in session store
func (s *SessionStore) Delete(ctx context.Context, key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.values, key)
return nil
}
// Flush delete all keys and values
func (s *SessionStore) Flush(context.Context) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values = make(map[interface{}]interface{})
return nil
}
// SessionID return the sessionID
func (s *SessionStore) SessionID(context.Context) string {
return s.sid
}
// SessionRelease Store the keyvalues into ssdb
func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(s.values)
if err != nil {
return
}
s.client.Do("setx", s.sid, string(b), s.maxLifetime)
}
func init() {
session.Register("ssdb", ssdbProvider)
}

View File

@ -0,0 +1,25 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"reflect"
"runtime"
)
// GetFuncName get function name
func GetFuncName(i interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
}

View File

@ -0,0 +1,28 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"strings"
"testing"
)
func TestGetFuncName(t *testing.T) {
name := GetFuncName(TestGetFuncName)
t.Log(name)
if !strings.HasSuffix(name, ".TestGetFuncName") {
t.Error("get func name error")
}
}

View File

@ -0,0 +1,478 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"bytes"
"fmt"
"log"
"reflect"
"runtime"
)
var (
dunno = []byte("???")
centerDot = []byte("·")
dot = []byte(".")
)
type pointerInfo struct {
prev *pointerInfo
n int
addr uintptr
pos int
used []int
}
// Display print the data in console
func Display(data ...interface{}) {
display(true, data...)
}
// GetDisplayString return data print string
func GetDisplayString(data ...interface{}) string {
return display(false, data...)
}
func display(displayed bool, data ...interface{}) string {
var pc, file, line, ok = runtime.Caller(2)
if !ok {
return ""
}
var buf = new(bytes.Buffer)
fmt.Fprintf(buf, "[Debug] at %s() [%s:%d]\n", function(pc), file, line)
fmt.Fprintf(buf, "\n[Variables]\n")
for i := 0; i < len(data); i += 2 {
var output = fomateinfo(len(data[i].(string))+3, data[i+1])
fmt.Fprintf(buf, "%s = %s", data[i], output)
}
if displayed {
log.Print(buf)
}
return buf.String()
}
// return data dump and format bytes
func fomateinfo(headlen int, data ...interface{}) []byte {
var buf = new(bytes.Buffer)
if len(data) > 1 {
fmt.Fprint(buf, " ")
fmt.Fprint(buf, "[")
fmt.Fprintln(buf)
}
for k, v := range data {
var buf2 = new(bytes.Buffer)
var pointers *pointerInfo
var interfaces = make([]reflect.Value, 0, 10)
printKeyValue(buf2, reflect.ValueOf(v), &pointers, &interfaces, nil, true, " ", 1)
if k < len(data)-1 {
fmt.Fprint(buf2, ", ")
}
fmt.Fprintln(buf2)
buf.Write(buf2.Bytes())
}
if len(data) > 1 {
fmt.Fprintln(buf)
fmt.Fprint(buf, " ")
fmt.Fprint(buf, "]")
}
return buf.Bytes()
}
// check data is golang basic type
func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool {
switch kind {
case reflect.Bool:
return true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
case reflect.Complex64, reflect.Complex128:
return true
case reflect.String:
return true
case reflect.Chan:
return true
case reflect.Invalid:
return true
case reflect.Interface:
for _, in := range *interfaces {
if reflect.DeepEqual(in, val) {
return true
}
}
return false
case reflect.UnsafePointer:
if val.IsNil() {
return true
}
var elem = val.Elem()
if isSimpleType(elem, elem.Kind(), pointers, interfaces) {
return true
}
var addr = val.Elem().UnsafeAddr()
for p := *pointers; p != nil; p = p.prev {
if addr == p.addr {
return true
}
}
return false
}
return false
}
// dump value
func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) {
var t = val.Kind()
switch t {
case reflect.Bool:
fmt.Fprint(buf, val.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fmt.Fprint(buf, val.Int())
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
fmt.Fprint(buf, val.Uint())
case reflect.Float32, reflect.Float64:
fmt.Fprint(buf, val.Float())
case reflect.Complex64, reflect.Complex128:
fmt.Fprint(buf, val.Complex())
case reflect.UnsafePointer:
fmt.Fprintf(buf, "unsafe.Pointer(0x%X)", val.Pointer())
case reflect.Ptr:
if val.IsNil() {
fmt.Fprint(buf, "nil")
return
}
var addr = val.Elem().UnsafeAddr()
for p := *pointers; p != nil; p = p.prev {
if addr == p.addr {
p.used = append(p.used, buf.Len())
fmt.Fprintf(buf, "0x%X", addr)
return
}
}
*pointers = &pointerInfo{
prev: *pointers,
addr: addr,
pos: buf.Len(),
used: make([]int, 0),
}
fmt.Fprint(buf, "&")
printKeyValue(buf, val.Elem(), pointers, interfaces, structFilter, formatOutput, indent, level)
case reflect.String:
fmt.Fprint(buf, "\"", val.String(), "\"")
case reflect.Interface:
var value = val.Elem()
if !value.IsValid() {
fmt.Fprint(buf, "nil")
} else {
for _, in := range *interfaces {
if reflect.DeepEqual(in, val) {
fmt.Fprint(buf, "repeat")
return
}
}
*interfaces = append(*interfaces, val)
printKeyValue(buf, value, pointers, interfaces, structFilter, formatOutput, indent, level+1)
}
case reflect.Struct:
var t = val.Type()
fmt.Fprint(buf, t)
fmt.Fprint(buf, "{")
for i := 0; i < val.NumField(); i++ {
if formatOutput {
fmt.Fprintln(buf)
} else {
fmt.Fprint(buf, " ")
}
var name = t.Field(i).Name
if formatOutput {
for ind := 0; ind < level; ind++ {
fmt.Fprint(buf, indent)
}
}
fmt.Fprint(buf, name)
fmt.Fprint(buf, ": ")
if structFilter != nil && structFilter(t.String(), name) {
fmt.Fprint(buf, "ignore")
} else {
printKeyValue(buf, val.Field(i), pointers, interfaces, structFilter, formatOutput, indent, level+1)
}
fmt.Fprint(buf, ",")
}
if formatOutput {
fmt.Fprintln(buf)
for ind := 0; ind < level-1; ind++ {
fmt.Fprint(buf, indent)
}
} else {
fmt.Fprint(buf, " ")
}
fmt.Fprint(buf, "}")
case reflect.Array, reflect.Slice:
fmt.Fprint(buf, val.Type())
fmt.Fprint(buf, "{")
var allSimple = true
for i := 0; i < val.Len(); i++ {
var elem = val.Index(i)
var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces)
if !isSimple {
allSimple = false
}
if formatOutput && !isSimple {
fmt.Fprintln(buf)
} else {
fmt.Fprint(buf, " ")
}
if formatOutput && !isSimple {
for ind := 0; ind < level; ind++ {
fmt.Fprint(buf, indent)
}
}
printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1)
if i != val.Len()-1 || !allSimple {
fmt.Fprint(buf, ",")
}
}
if formatOutput && !allSimple {
fmt.Fprintln(buf)
for ind := 0; ind < level-1; ind++ {
fmt.Fprint(buf, indent)
}
} else {
fmt.Fprint(buf, " ")
}
fmt.Fprint(buf, "}")
case reflect.Map:
var t = val.Type()
var keys = val.MapKeys()
fmt.Fprint(buf, t)
fmt.Fprint(buf, "{")
var allSimple = true
for i := 0; i < len(keys); i++ {
var elem = val.MapIndex(keys[i])
var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces)
if !isSimple {
allSimple = false
}
if formatOutput && !isSimple {
fmt.Fprintln(buf)
} else {
fmt.Fprint(buf, " ")
}
if formatOutput && !isSimple {
for ind := 0; ind <= level; ind++ {
fmt.Fprint(buf, indent)
}
}
printKeyValue(buf, keys[i], pointers, interfaces, structFilter, formatOutput, indent, level+1)
fmt.Fprint(buf, ": ")
printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1)
if i != val.Len()-1 || !allSimple {
fmt.Fprint(buf, ",")
}
}
if formatOutput && !allSimple {
fmt.Fprintln(buf)
for ind := 0; ind < level-1; ind++ {
fmt.Fprint(buf, indent)
}
} else {
fmt.Fprint(buf, " ")
}
fmt.Fprint(buf, "}")
case reflect.Chan:
fmt.Fprint(buf, val.Type())
case reflect.Invalid:
fmt.Fprint(buf, "invalid")
default:
fmt.Fprint(buf, "unknow")
}
}
// PrintPointerInfo dump pointer value
func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) {
var anyused = false
var pointerNum = 0
for p := pointers; p != nil; p = p.prev {
if len(p.used) > 0 {
anyused = true
}
pointerNum++
p.n = pointerNum
}
if anyused {
var pointerBufs = make([][]rune, pointerNum+1)
for i := 0; i < len(pointerBufs); i++ {
var pointerBuf = make([]rune, buf.Len()+headlen)
for j := 0; j < len(pointerBuf); j++ {
pointerBuf[j] = ' '
}
pointerBufs[i] = pointerBuf
}
for pn := 0; pn <= pointerNum; pn++ {
for p := pointers; p != nil; p = p.prev {
if len(p.used) > 0 && p.n >= pn {
if pn == p.n {
pointerBufs[pn][p.pos+headlen] = '└'
var maxpos = 0
for i, pos := range p.used {
if i < len(p.used)-1 {
pointerBufs[pn][pos+headlen] = '┴'
} else {
pointerBufs[pn][pos+headlen] = '┘'
}
maxpos = pos
}
for i := 0; i < maxpos-p.pos-1; i++ {
if pointerBufs[pn][i+p.pos+headlen+1] == ' ' {
pointerBufs[pn][i+p.pos+headlen+1] = '─'
}
}
} else {
pointerBufs[pn][p.pos+headlen] = '│'
for _, pos := range p.used {
if pointerBufs[pn][pos+headlen] == ' ' {
pointerBufs[pn][pos+headlen] = '│'
} else {
pointerBufs[pn][pos+headlen] = '┼'
}
}
}
}
}
buf.WriteString(string(pointerBufs[pn]) + "\n")
}
}
}
// Stack get stack bytes
func Stack(skip int, indent string) []byte {
var buf = new(bytes.Buffer)
for i := skip; ; i++ {
var pc, file, line, ok = runtime.Caller(i)
if !ok {
break
}
buf.WriteString(indent)
fmt.Fprintf(buf, "at %s() [%s:%d]\n", function(pc), file, line)
}
return buf.Bytes()
}
// return the name of the function containing the PC if possible,
func function(pc uintptr) []byte {
fn := runtime.FuncForPC(pc)
if fn == nil {
return dunno
}
name := []byte(fn.Name())
// The name includes the path name to the package, which is unnecessary
// since the file name is already included. Plus, it has center dots.
// That is, we see
// runtime/debug.*T·ptrmethod
// and want
// *T.ptrmethod
if period := bytes.Index(name, dot); period >= 0 {
name = name[period+1:]
}
name = bytes.Replace(name, centerDot, dot, -1)
return name
}

View File

@ -0,0 +1,46 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"testing"
)
type mytype struct {
next *mytype
prev *mytype
}
func TestPrint(t *testing.T) {
Display("v1", 1, "v2", 2, "v3", 3)
}
func TestPrintPoint(t *testing.T) {
var v1 = new(mytype)
var v2 = new(mytype)
v1.prev = nil
v1.next = v2
v2.prev = v1
v2.next = nil
Display("v1", v1, "v2", v2)
}
func TestPrintString(t *testing.T) {
str := GetDisplayString("v1", 1, "v2", 2)
println(str)
}

View File

@ -0,0 +1,101 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"bufio"
"errors"
"io"
"os"
"path/filepath"
"regexp"
)
// SelfPath gets compiled executable file absolute path
func SelfPath() string {
path, _ := filepath.Abs(os.Args[0])
return path
}
// SelfDir gets compiled executable file directory
func SelfDir() string {
return filepath.Dir(SelfPath())
}
// FileExists reports whether the named file or directory exists.
func FileExists(name string) bool {
if _, err := os.Stat(name); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
// SearchFile Search a file in paths.
// this is often used in search config file in /etc ~/
func SearchFile(filename string, paths ...string) (fullpath string, err error) {
for _, path := range paths {
if fullpath = filepath.Join(path, filename); FileExists(fullpath) {
return
}
}
err = errors.New(fullpath + " not found in paths")
return
}
// GrepFile like command grep -E
// for example: GrepFile(`^hello`, "hello.txt")
// \n is striped while read
func GrepFile(patten string, filename string) (lines []string, err error) {
re, err := regexp.Compile(patten)
if err != nil {
return
}
fd, err := os.Open(filename)
if err != nil {
return
}
lines = make([]string, 0)
reader := bufio.NewReader(fd)
prefix := ""
var isLongLine bool
for {
byteLine, isPrefix, er := reader.ReadLine()
if er != nil && er != io.EOF {
return nil, er
}
if er == io.EOF {
break
}
line := string(byteLine)
if isPrefix {
prefix += line
continue
} else {
isLongLine = true
}
line = prefix + line
if isLongLine {
prefix = ""
}
if re.MatchString(line) {
lines = append(lines, line)
}
}
return lines, nil
}

View File

@ -0,0 +1,75 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"path/filepath"
"reflect"
"testing"
)
var noExistedFile = "/tmp/not_existed_file"
func TestSelfPath(t *testing.T) {
path := SelfPath()
if path == "" {
t.Error("path cannot be empty")
}
t.Logf("SelfPath: %s", path)
}
func TestSelfDir(t *testing.T) {
dir := SelfDir()
t.Logf("SelfDir: %s", dir)
}
func TestFileExists(t *testing.T) {
if !FileExists("./file.go") {
t.Errorf("./file.go should exists, but it didn't")
}
if FileExists(noExistedFile) {
t.Errorf("Weird, how could this file exists: %s", noExistedFile)
}
}
func TestSearchFile(t *testing.T) {
path, err := SearchFile(filepath.Base(SelfPath()), SelfDir())
if err != nil {
t.Error(err)
}
t.Log(path)
_, err = SearchFile(noExistedFile, ".")
if err == nil {
t.Errorf("err shouldnt be nil, got path: %s", SelfDir())
}
}
func TestGrepFile(t *testing.T) {
_, err := GrepFile("", noExistedFile)
if err == nil {
t.Error("expect file-not-existed error, but got nothing")
}
path := filepath.Join(".", "testdata", "grepe.test")
lines, err := GrepFile(`^\s*[^#]+`, path)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(lines, []string{"hello", "world"}) {
t.Errorf("expect [hello world], but receive %v", lines)
}
}

View File

@ -0,0 +1,87 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
type KV interface {
GetKey() interface{}
GetValue() interface{}
}
// SimpleKV is common structure to store key-value pairs.
// When you need something like Pair, you can use this
type SimpleKV struct {
Key interface{}
Value interface{}
}
var _ KV = new(SimpleKV)
func (s *SimpleKV) GetKey() interface{} {
return s.Key
}
func (s *SimpleKV) GetValue() interface{} {
return s.Value
}
// KVs interface
type KVs interface {
GetValueOr(key interface{}, defValue interface{}) interface{}
Contains(key interface{}) bool
IfContains(key interface{}, action func(value interface{})) KVs
}
// SimpleKVs will store SimpleKV collection as map
type SimpleKVs struct {
kvs map[interface{}]interface{}
}
var _ KVs = new(SimpleKVs)
// GetValueOr returns the value for a given key, if non-existant
// it returns defValue
func (kvs *SimpleKVs) GetValueOr(key interface{}, defValue interface{}) interface{} {
v, ok := kvs.kvs[key]
if ok {
return v
}
return defValue
}
// Contains checks if a key exists
func (kvs *SimpleKVs) Contains(key interface{}) bool {
_, ok := kvs.kvs[key]
return ok
}
// IfContains invokes the action on a key if it exists
func (kvs *SimpleKVs) IfContains(key interface{}, action func(value interface{})) KVs {
v, ok := kvs.kvs[key]
if ok {
action(v)
}
return kvs
}
// NewKVs creates the *KVs instance
func NewKVs(kvs ...KV) KVs {
res := &SimpleKVs{
kvs: make(map[interface{}]interface{}, len(kvs)),
}
for _, kv := range kvs {
res.kvs[kv.GetKey()] = kv.GetValue()
}
return res
}

View File

@ -0,0 +1,38 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestKVs(t *testing.T) {
key := "my-key"
kvs := NewKVs(&SimpleKV{
Key: key,
Value: 12,
})
assert.True(t, kvs.Contains(key))
v := kvs.GetValueOr(key, 13)
assert.Equal(t, 12, v)
v = kvs.GetValueOr(`key-not-exists`, 8546)
assert.Equal(t, 8546, v)
}

View File

@ -0,0 +1,424 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net/mail"
"net/smtp"
"net/textproto"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
)
const (
maxLineLength = 76
upperhex = "0123456789ABCDEF"
)
// Email is the type used for email messages
type Email struct {
Auth smtp.Auth
Identity string `json:"identity"`
Username string `json:"username"`
Password string `json:"password"`
Host string `json:"host"`
Port int `json:"port"`
From string `json:"from"`
To []string
Bcc []string
Cc []string
Subject string
Text string // Plaintext message (optional)
HTML string // Html message (optional)
Headers textproto.MIMEHeader
Attachments []*Attachment
ReadReceipt []string
}
// Attachment is a struct representing an email attachment.
// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question
type Attachment struct {
Filename string
Header textproto.MIMEHeader
Content []byte
}
// NewEMail create new Email struct with config json.
// config json is followed from Email struct fields.
func NewEMail(config string) *Email {
e := new(Email)
e.Headers = textproto.MIMEHeader{}
err := json.Unmarshal([]byte(config), e)
if err != nil {
return nil
}
return e
}
// Bytes Make all send information to byte
func (e *Email) Bytes() ([]byte, error) {
buff := &bytes.Buffer{}
w := multipart.NewWriter(buff)
// Set the appropriate headers (overwriting any conflicts)
// Leave out Bcc (only included in envelope headers)
e.Headers.Set("To", strings.Join(e.To, ","))
if e.Cc != nil {
e.Headers.Set("Cc", strings.Join(e.Cc, ","))
}
e.Headers.Set("From", e.From)
e.Headers.Set("Subject", e.Subject)
if len(e.ReadReceipt) != 0 {
e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ","))
}
e.Headers.Set("MIME-Version", "1.0")
// Write the envelope headers (including any custom headers)
if err := headerToBytes(buff, e.Headers); err != nil {
return nil, fmt.Errorf("Failed to render message headers: %s", err)
}
e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary()))
fmt.Fprintf(buff, "%s:", "Content-Type")
fmt.Fprintf(buff, " %s\r\n", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary()))
// Start the multipart/mixed part
fmt.Fprintf(buff, "--%s\r\n", w.Boundary())
header := textproto.MIMEHeader{}
// Check to see if there is a Text or HTML field
if e.Text != "" || e.HTML != "" {
subWriter := multipart.NewWriter(buff)
// Create the multipart alternative part
header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary()))
// Write the header
if err := headerToBytes(buff, header); err != nil {
return nil, fmt.Errorf("Failed to render multipart message headers: %s", err)
}
// Create the body sections
if e.Text != "" {
header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8"))
header.Set("Content-Transfer-Encoding", "quoted-printable")
if _, err := subWriter.CreatePart(header); err != nil {
return nil, err
}
// Write the text
if err := quotePrintEncode(buff, e.Text); err != nil {
return nil, err
}
}
if e.HTML != "" {
header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8"))
header.Set("Content-Transfer-Encoding", "quoted-printable")
if _, err := subWriter.CreatePart(header); err != nil {
return nil, err
}
// Write the text
if err := quotePrintEncode(buff, e.HTML); err != nil {
return nil, err
}
}
if err := subWriter.Close(); err != nil {
return nil, err
}
}
// Create attachment part, if necessary
for _, a := range e.Attachments {
ap, err := w.CreatePart(a.Header)
if err != nil {
return nil, err
}
// Write the base64Wrapped content to the part
base64Wrap(ap, a.Content)
}
if err := w.Close(); err != nil {
return nil, err
}
return buff.Bytes(), nil
}
// AttachFile Add attach file to the send mail
func (e *Email) AttachFile(args ...string) (a *Attachment, err error) {
if len(args) < 1 || len(args) > 2 { // change && to ||
err = errors.New("Must specify a file name and number of parameters can not exceed at least two")
return
}
filename := args[0]
id := ""
if len(args) > 1 {
id = args[1]
}
f, err := os.Open(filename)
if err != nil {
return
}
defer f.Close()
ct := mime.TypeByExtension(filepath.Ext(filename))
basename := path.Base(filename)
return e.Attach(f, basename, ct, id)
}
// Attach is used to attach content from an io.Reader to the email.
// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type.
func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) {
if len(args) < 1 || len(args) > 2 { // change && to ||
err = errors.New("Must specify the file type and number of parameters can not exceed at least two")
return
}
c := args[0] //Content-Type
id := ""
if len(args) > 1 {
id = args[1] //Content-ID
}
var buffer bytes.Buffer
if _, err = io.Copy(&buffer, r); err != nil {
return
}
at := &Attachment{
Filename: filename,
Header: textproto.MIMEHeader{},
Content: buffer.Bytes(),
}
// Get the Content-Type to be used in the MIMEHeader
if c != "" {
at.Header.Set("Content-Type", c)
} else {
// If the Content-Type is blank, set the Content-Type to "application/octet-stream"
at.Header.Set("Content-Type", "application/octet-stream")
}
if id != "" {
at.Header.Set("Content-Disposition", fmt.Sprintf("inline;\r\n filename=\"%s\"", filename))
at.Header.Set("Content-ID", fmt.Sprintf("<%s>", id))
} else {
at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename))
}
at.Header.Set("Content-Transfer-Encoding", "base64")
e.Attachments = append(e.Attachments, at)
return at, nil
}
// Send will send out the mail
func (e *Email) Send() error {
if e.Auth == nil {
e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host)
}
// Merge the To, Cc, and Bcc fields
to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc))
to = append(append(append(to, e.To...), e.Cc...), e.Bcc...)
// Check to make sure there is at least one recipient and one "From" address
if len(to) == 0 {
return errors.New("Must specify at least one To address")
}
// Use the username if no From is provided
if len(e.From) == 0 {
e.From = e.Username
}
from, err := mail.ParseAddress(e.From)
if err != nil {
return err
}
// use mail's RFC 2047 to encode any string
e.Subject = qEncode("utf-8", e.Subject)
raw, err := e.Bytes()
if err != nil {
return err
}
return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw)
}
// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045)
func quotePrintEncode(w io.Writer, s string) error {
var buf [3]byte
mc := 0
for i := 0; i < len(s); i++ {
c := s[i]
// We're assuming Unix style text formats as input (LF line break), and
// quoted-printble uses CRLF line breaks. (Literal CRs will become
// "=0D", but probably shouldn't be there to begin with!)
if c == '\n' {
io.WriteString(w, "\r\n")
mc = 0
continue
}
var nextOut []byte
if isPrintable(c) {
nextOut = append(buf[:0], c)
} else {
nextOut = buf[:]
qpEscape(nextOut, c)
}
// Add a soft line break if the next (encoded) byte would push this line
// to or past the limit.
if mc+len(nextOut) >= maxLineLength {
if _, err := io.WriteString(w, "=\r\n"); err != nil {
return err
}
mc = 0
}
if _, err := w.Write(nextOut); err != nil {
return err
}
mc += len(nextOut)
}
// No trailing end-of-line?? Soft line break, then. TODO: is this sane?
if mc > 0 {
io.WriteString(w, "=\r\n")
}
return nil
}
// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise
func isPrintable(c byte) bool {
return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t')
}
// qpEscape is a helper function for quotePrintEncode which escapes a
// non-printable byte. Expects len(dest) == 3.
func qpEscape(dest []byte, c byte) {
const nums = "0123456789ABCDEF"
dest[0] = '='
dest[1] = nums[(c&0xf0)>>4]
dest[2] = nums[(c & 0xf)]
}
// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer
func headerToBytes(w io.Writer, t textproto.MIMEHeader) error {
for k, v := range t {
// Write the header key
_, err := fmt.Fprintf(w, "%s:", k)
if err != nil {
return err
}
// Write each value in the header
for _, c := range v {
_, err := fmt.Fprintf(w, " %s\r\n", c)
if err != nil {
return err
}
}
}
return nil
}
// base64Wrap encodes the attachment content, and wraps it according to RFC 2045 standards (every 76 chars)
// The output is then written to the specified io.Writer
func base64Wrap(w io.Writer, b []byte) {
// 57 raw bytes per 76-byte base64 line.
const maxRaw = 57
// Buffer for each line, including trailing CRLF.
var buffer [maxLineLength + len("\r\n")]byte
copy(buffer[maxLineLength:], "\r\n")
// Process raw chunks until there's no longer enough to fill a line.
for len(b) >= maxRaw {
base64.StdEncoding.Encode(buffer[:], b[:maxRaw])
w.Write(buffer[:])
b = b[maxRaw:]
}
// Handle the last chunk of bytes.
if len(b) > 0 {
out := buffer[:base64.StdEncoding.EncodedLen(len(b))]
base64.StdEncoding.Encode(out, b)
out = append(out, "\r\n"...)
w.Write(out)
}
}
// Encode returns the encoded-word form of s. If s is ASCII without special
// characters, it is returned unchanged. The provided charset is the IANA
// charset name of s. It is case insensitive.
// RFC 2047 encoded-word
func qEncode(charset, s string) string {
if !needsEncoding(s) {
return s
}
return encodeWord(charset, s)
}
func needsEncoding(s string) bool {
for _, b := range s {
if (b < ' ' || b > '~') && b != '\t' {
return true
}
}
return false
}
// encodeWord encodes a string into an encoded-word.
func encodeWord(charset, s string) string {
buf := getBuffer()
buf.WriteString("=?")
buf.WriteString(charset)
buf.WriteByte('?')
buf.WriteByte('q')
buf.WriteByte('?')
enc := make([]byte, 3)
for i := 0; i < len(s); i++ {
b := s[i]
switch {
case b == ' ':
buf.WriteByte('_')
case b <= '~' && b >= '!' && b != '=' && b != '?' && b != '_':
buf.WriteByte(b)
default:
enc[0] = '='
enc[1] = upperhex[b>>4]
enc[2] = upperhex[b&0x0f]
buf.Write(enc)
}
}
buf.WriteString("?=")
es := buf.String()
putBuffer(buf)
return es
}
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
func getBuffer() *bytes.Buffer {
return bufPool.Get().(*bytes.Buffer)
}
func putBuffer(buf *bytes.Buffer) {
if buf.Len() > 1024 {
return
}
buf.Reset()
bufPool.Put(buf)
}

View File

@ -0,0 +1,41 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import "testing"
func TestMail(t *testing.T) {
config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}`
mail := NewEMail(config)
if mail.Username != "astaxie@gmail.com" {
t.Fatal("email parse get username error")
}
if mail.Password != "astaxie" {
t.Fatal("email parse get password error")
}
if mail.Host != "smtp.gmail.com" {
t.Fatal("email parse get host error")
}
if mail.Port != 587 {
t.Fatal("email parse get port error")
}
mail.To = []string{"xiemengjun@gmail.com"}
mail.From = "astaxie@gmail.com"
mail.Subject = "hi, just from beego!"
mail.Text = "Text Body is, of course, supported!"
mail.HTML = "<h1>Fancy Html is supported, too!</h1>"
mail.AttachFile("/Users/astaxie/github/beego/beego.go")
mail.Send()
}

View File

@ -0,0 +1,58 @@
/*
Package pagination provides utilities to setup a paginator within the
context of a http request.
Usage
In your beego.Controller:
package controllers
import "github.com/astaxie/beego/pkg/infrastructure/utils/pagination"
type PostsController struct {
beego.Controller
}
func (this *PostsController) ListAllPosts() {
// sets this.Data["paginator"] with the current offset (from the url query param)
postsPerPage := 20
paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts())
// fetch the next 20 posts
this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage)
}
In your view templates:
{{if .paginator.HasPages}}
<ul class="pagination pagination">
{{if .paginator.HasPrev}}
<li><a href="{{.paginator.PageLinkFirst}}">{{ i18n .Lang "paginator.first_page"}}</a></li>
<li><a href="{{.paginator.PageLinkPrev}}">&laquo;</a></li>
{{else}}
<li class="disabled"><a>{{ i18n .Lang "paginator.first_page"}}</a></li>
<li class="disabled"><a>&laquo;</a></li>
{{end}}
{{range $index, $page := .paginator.Pages}}
<li{{if $.paginator.IsActive .}} class="active"{{end}}>
<a href="{{$.paginator.PageLink $page}}">{{$page}}</a>
</li>
{{end}}
{{if .paginator.HasNext}}
<li><a href="{{.paginator.PageLinkNext}}">&raquo;</a></li>
<li><a href="{{.paginator.PageLinkLast}}">{{ i18n .Lang "paginator.last_page"}}</a></li>
{{else}}
<li class="disabled"><a>&raquo;</a></li>
<li class="disabled"><a>{{ i18n .Lang "paginator.last_page"}}</a></li>
{{end}}
</ul>
{{end}}
See also
http://beego.me/docs/mvc/view/page.md
*/
package pagination

View File

@ -0,0 +1,189 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pagination
import (
"math"
"net/http"
"net/url"
"strconv"
)
// Paginator within the state of a http request.
type Paginator struct {
Request *http.Request
PerPageNums int
MaxPages int
nums int64
pageRange []int
pageNums int
page int
}
// PageNums Returns the total number of pages.
func (p *Paginator) PageNums() int {
if p.pageNums != 0 {
return p.pageNums
}
pageNums := math.Ceil(float64(p.nums) / float64(p.PerPageNums))
if p.MaxPages > 0 {
pageNums = math.Min(pageNums, float64(p.MaxPages))
}
p.pageNums = int(pageNums)
return p.pageNums
}
// Nums Returns the total number of items (e.g. from doing SQL count).
func (p *Paginator) Nums() int64 {
return p.nums
}
// SetNums Sets the total number of items.
func (p *Paginator) SetNums(nums interface{}) {
p.nums, _ = toInt64(nums)
}
// Page Returns the current page.
func (p *Paginator) Page() int {
if p.page != 0 {
return p.page
}
if p.Request.Form == nil {
p.Request.ParseForm()
}
p.page, _ = strconv.Atoi(p.Request.Form.Get("p"))
if p.page > p.PageNums() {
p.page = p.PageNums()
}
if p.page <= 0 {
p.page = 1
}
return p.page
}
// Pages Returns a list of all pages.
//
// Usage (in a view template):
//
// {{range $index, $page := .paginator.Pages}}
// <li{{if $.paginator.IsActive .}} class="active"{{end}}>
// <a href="{{$.paginator.PageLink $page}}">{{$page}}</a>
// </li>
// {{end}}
func (p *Paginator) Pages() []int {
if p.pageRange == nil && p.nums > 0 {
var pages []int
pageNums := p.PageNums()
page := p.Page()
switch {
case page >= pageNums-4 && pageNums > 9:
start := pageNums - 9 + 1
pages = make([]int, 9)
for i := range pages {
pages[i] = start + i
}
case page >= 5 && pageNums > 9:
start := page - 5 + 1
pages = make([]int, int(math.Min(9, float64(page+4+1))))
for i := range pages {
pages[i] = start + i
}
default:
pages = make([]int, int(math.Min(9, float64(pageNums))))
for i := range pages {
pages[i] = i + 1
}
}
p.pageRange = pages
}
return p.pageRange
}
// PageLink Returns URL for a given page index.
func (p *Paginator) PageLink(page int) string {
link, _ := url.ParseRequestURI(p.Request.URL.String())
values := link.Query()
if page == 1 {
values.Del("p")
} else {
values.Set("p", strconv.Itoa(page))
}
link.RawQuery = values.Encode()
return link.String()
}
// PageLinkPrev Returns URL to the previous page.
func (p *Paginator) PageLinkPrev() (link string) {
if p.HasPrev() {
link = p.PageLink(p.Page() - 1)
}
return
}
// PageLinkNext Returns URL to the next page.
func (p *Paginator) PageLinkNext() (link string) {
if p.HasNext() {
link = p.PageLink(p.Page() + 1)
}
return
}
// PageLinkFirst Returns URL to the first page.
func (p *Paginator) PageLinkFirst() (link string) {
return p.PageLink(1)
}
// PageLinkLast Returns URL to the last page.
func (p *Paginator) PageLinkLast() (link string) {
return p.PageLink(p.PageNums())
}
// HasPrev Returns true if the current page has a predecessor.
func (p *Paginator) HasPrev() bool {
return p.Page() > 1
}
// HasNext Returns true if the current page has a successor.
func (p *Paginator) HasNext() bool {
return p.Page() < p.PageNums()
}
// IsActive Returns true if the given page index points to the current page.
func (p *Paginator) IsActive(page int) bool {
return p.Page() == page
}
// Offset Returns the current offset.
func (p *Paginator) Offset() int {
return (p.Page() - 1) * p.PerPageNums
}
// HasPages Returns true if there is more than one page.
func (p *Paginator) HasPages() bool {
return p.PageNums() > 1
}
// NewPaginator Instantiates a paginator struct for the current http request.
func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator {
p := Paginator{}
p.Request = req
if per <= 0 {
per = 10
}
p.PerPageNums = per
p.SetNums(nums)
return &p
}

View File

@ -0,0 +1,34 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pagination
import (
"fmt"
"reflect"
)
// ToInt64 convert any numeric value to int64
func toInt64(value interface{}) (d int64, err error) {
val := reflect.ValueOf(value)
switch value.(type) {
case int, int8, int16, int32, int64:
d = val.Int()
case uint, uint8, uint16, uint32, uint64:
d = int64(val.Uint())
default:
err = fmt.Errorf("ToInt64 need numeric not `%T`", value)
}
return
}

View File

@ -0,0 +1,44 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"crypto/rand"
r "math/rand"
"time"
)
var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`)
// RandomCreateBytes generate random []byte by specify chars.
func RandomCreateBytes(n int, alphabets ...byte) []byte {
if len(alphabets) == 0 {
alphabets = alphaNum
}
var bytes = make([]byte, n)
var randBy bool
if num, err := rand.Read(bytes); num != n || err != nil {
r.Seed(time.Now().UnixNano())
randBy = true
}
for i, b := range bytes {
if randBy {
bytes[i] = alphabets[r.Intn(len(alphabets))]
} else {
bytes[i] = alphabets[b%byte(len(alphabets))]
}
}
return bytes
}

View File

@ -0,0 +1,33 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import "testing"
func TestRand_01(t *testing.T) {
bs0 := RandomCreateBytes(16)
bs1 := RandomCreateBytes(16)
t.Log(string(bs0), string(bs1))
if string(bs0) == string(bs1) {
t.FailNow()
}
bs0 = RandomCreateBytes(4, []byte(`a`)...)
if string(bs0) != "aaaa" {
t.FailNow()
}
}

View File

@ -0,0 +1,91 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"sync"
)
// BeeMap is a map with lock
type BeeMap struct {
lock *sync.RWMutex
bm map[interface{}]interface{}
}
// NewBeeMap return new safemap
func NewBeeMap() *BeeMap {
return &BeeMap{
lock: new(sync.RWMutex),
bm: make(map[interface{}]interface{}),
}
}
// Get from maps return the k's value
func (m *BeeMap) Get(k interface{}) interface{} {
m.lock.RLock()
defer m.lock.RUnlock()
if val, ok := m.bm[k]; ok {
return val
}
return nil
}
// Set Maps the given key and value. Returns false
// if the key is already in the map and changes nothing.
func (m *BeeMap) Set(k interface{}, v interface{}) bool {
m.lock.Lock()
defer m.lock.Unlock()
if val, ok := m.bm[k]; !ok {
m.bm[k] = v
} else if val != v {
m.bm[k] = v
} else {
return false
}
return true
}
// Check Returns true if k is exist in the map.
func (m *BeeMap) Check(k interface{}) bool {
m.lock.RLock()
defer m.lock.RUnlock()
_, ok := m.bm[k]
return ok
}
// Delete the given key and value.
func (m *BeeMap) Delete(k interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.bm, k)
}
// Items returns all items in safemap.
func (m *BeeMap) Items() map[interface{}]interface{} {
m.lock.RLock()
defer m.lock.RUnlock()
r := make(map[interface{}]interface{})
for k, v := range m.bm {
r[k] = v
}
return r
}
// Count returns the number of items within the map.
func (m *BeeMap) Count() int {
m.lock.RLock()
defer m.lock.RUnlock()
return len(m.bm)
}

View File

@ -0,0 +1,89 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import "testing"
var safeMap *BeeMap
func TestNewBeeMap(t *testing.T) {
safeMap = NewBeeMap()
if safeMap == nil {
t.Fatal("expected to return non-nil BeeMap", "got", safeMap)
}
}
func TestSet(t *testing.T) {
safeMap = NewBeeMap()
if ok := safeMap.Set("astaxie", 1); !ok {
t.Error("expected", true, "got", false)
}
}
func TestReSet(t *testing.T) {
safeMap := NewBeeMap()
if ok := safeMap.Set("astaxie", 1); !ok {
t.Error("expected", true, "got", false)
}
// set diff value
if ok := safeMap.Set("astaxie", -1); !ok {
t.Error("expected", true, "got", false)
}
// set same value
if ok := safeMap.Set("astaxie", -1); ok {
t.Error("expected", false, "got", true)
}
}
func TestCheck(t *testing.T) {
if exists := safeMap.Check("astaxie"); !exists {
t.Error("expected", true, "got", false)
}
}
func TestGet(t *testing.T) {
if val := safeMap.Get("astaxie"); val.(int) != 1 {
t.Error("expected value", 1, "got", val)
}
}
func TestDelete(t *testing.T) {
safeMap.Delete("astaxie")
if exists := safeMap.Check("astaxie"); exists {
t.Error("expected element to be deleted")
}
}
func TestItems(t *testing.T) {
safeMap := NewBeeMap()
safeMap.Set("astaxie", "hello")
for k, v := range safeMap.Items() {
key := k.(string)
value := v.(string)
if key != "astaxie" {
t.Error("expected the key should be astaxie")
}
if value != "hello" {
t.Error("expected the value should be hello")
}
}
}
func TestCount(t *testing.T) {
if count := safeMap.Count(); count != 0 {
t.Error("expected count to be", 0, "got", count)
}
}

View File

@ -0,0 +1,170 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"math/rand"
"time"
)
type reducetype func(interface{}) interface{}
type filtertype func(interface{}) bool
// InSlice checks given string in string slice or not.
func InSlice(v string, sl []string) bool {
for _, vv := range sl {
if vv == v {
return true
}
}
return false
}
// InSliceIface checks given interface in interface slice.
func InSliceIface(v interface{}, sl []interface{}) bool {
for _, vv := range sl {
if vv == v {
return true
}
}
return false
}
// SliceRandList generate an int slice from min to max.
func SliceRandList(min, max int) []int {
if max < min {
min, max = max, min
}
length := max - min + 1
t0 := time.Now()
rand.Seed(int64(t0.Nanosecond()))
list := rand.Perm(length)
for index := range list {
list[index] += min
}
return list
}
// SliceMerge merges interface slices to one slice.
func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) {
c = append(slice1, slice2...)
return
}
// SliceReduce generates a new slice after parsing every value by reduce function
func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) {
for _, v := range slice {
dslice = append(dslice, a(v))
}
return
}
// SliceRand returns random one from slice.
func SliceRand(a []interface{}) (b interface{}) {
randnum := rand.Intn(len(a))
b = a[randnum]
return
}
// SliceSum sums all values in int64 slice.
func SliceSum(intslice []int64) (sum int64) {
for _, v := range intslice {
sum += v
}
return
}
// SliceFilter generates a new slice after filter function.
func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) {
for _, v := range slice {
if a(v) {
ftslice = append(ftslice, v)
}
}
return
}
// SliceDiff returns diff slice of slice1 - slice2.
func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) {
for _, v := range slice1 {
if !InSliceIface(v, slice2) {
diffslice = append(diffslice, v)
}
}
return
}
// SliceIntersect returns slice that are present in all the slice1 and slice2.
func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) {
for _, v := range slice1 {
if InSliceIface(v, slice2) {
diffslice = append(diffslice, v)
}
}
return
}
// SliceChunk separates one slice to some sized slice.
func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) {
if size >= len(slice) {
chunkslice = append(chunkslice, slice)
return
}
end := size
for i := 0; i <= (len(slice) - size); i += size {
chunkslice = append(chunkslice, slice[i:end])
end += size
}
return
}
// SliceRange generates a new slice from begin to end with step duration of int64 number.
func SliceRange(start, end, step int64) (intslice []int64) {
for i := start; i <= end; i += step {
intslice = append(intslice, i)
}
return
}
// SlicePad prepends size number of val into slice.
func SlicePad(slice []interface{}, size int, val interface{}) []interface{} {
if size <= len(slice) {
return slice
}
for i := 0; i < (size - len(slice)); i++ {
slice = append(slice, val)
}
return slice
}
// SliceUnique cleans repeated values in slice.
func SliceUnique(slice []interface{}) (uniqueslice []interface{}) {
for _, v := range slice {
if !InSliceIface(v, uniqueslice) {
uniqueslice = append(uniqueslice, v)
}
}
return
}
// SliceShuffle shuffles a slice.
func SliceShuffle(slice []interface{}) []interface{} {
for i := 0; i < len(slice); i++ {
a := rand.Intn(len(slice))
b := rand.Intn(len(slice))
slice[a], slice[b] = slice[b], slice[a]
}
return slice
}

View File

@ -0,0 +1,29 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"testing"
)
func TestInSlice(t *testing.T) {
sl := []string{"A", "b"}
if !InSlice("A", sl) {
t.Error("should be true")
}
if InSlice("B", sl) {
t.Error("should be false")
}
}

View File

@ -0,0 +1,7 @@
# empty lines
hello
# comment
world

View File

@ -0,0 +1,48 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"fmt"
"time"
)
// short string format
func ToShortTimeFormat(d time.Duration) string {
u := uint64(d)
if u < uint64(time.Second) {
switch {
case u == 0:
return "0"
case u < uint64(time.Microsecond):
return fmt.Sprintf("%.2fns", float64(u))
case u < uint64(time.Millisecond):
return fmt.Sprintf("%.2fus", float64(u)/1000)
default:
return fmt.Sprintf("%.2fms", float64(u)/1000/1000)
}
} else {
switch {
case u < uint64(time.Minute):
return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000)
case u < uint64(time.Hour):
return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60)
default:
return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60)
}
}
}

View File

@ -0,0 +1,89 @@
package utils
import (
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
)
// GetGOPATHs returns all paths in GOPATH variable.
func GetGOPATHs() []string {
gopath := os.Getenv("GOPATH")
if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 {
gopath = defaultGOPATH()
}
return filepath.SplitList(gopath)
}
func compareGoVersion(a, b string) int {
reg := regexp.MustCompile("^\\d*")
a = strings.TrimPrefix(a, "go")
b = strings.TrimPrefix(b, "go")
versionsA := strings.Split(a, ".")
versionsB := strings.Split(b, ".")
for i := 0; i < len(versionsA) && i < len(versionsB); i++ {
versionA := versionsA[i]
versionB := versionsB[i]
vA, err := strconv.Atoi(versionA)
if err != nil {
str := reg.FindString(versionA)
if str != "" {
vA, _ = strconv.Atoi(str)
} else {
vA = -1
}
}
vB, err := strconv.Atoi(versionB)
if err != nil {
str := reg.FindString(versionB)
if str != "" {
vB, _ = strconv.Atoi(str)
} else {
vB = -1
}
}
if vA > vB {
// vA = 12, vB = 8
return 1
} else if vA < vB {
// vA = 6, vB = 8
return -1
} else if vA == -1 {
// vA = rc1, vB = rc3
return strings.Compare(versionA, versionB)
}
// vA = vB = 8
continue
}
if len(versionsA) > len(versionsB) {
return 1
} else if len(versionsA) == len(versionsB) {
return 0
}
return -1
}
func defaultGOPATH() string {
env := "HOME"
if runtime.GOOS == "windows" {
env = "USERPROFILE"
} else if runtime.GOOS == "plan9" {
env = "home"
}
if home := os.Getenv(env); home != "" {
return filepath.Join(home, "go")
}
return ""
}

View File

@ -0,0 +1,36 @@
package utils
import (
"testing"
)
func TestCompareGoVersion(t *testing.T) {
targetVersion := "go1.8"
if compareGoVersion("go1.12.4", targetVersion) != 1 {
t.Error("should be 1")
}
if compareGoVersion("go1.8.7", targetVersion) != 1 {
t.Error("should be 1")
}
if compareGoVersion("go1.8", targetVersion) != 0 {
t.Error("should be 0")
}
if compareGoVersion("go1.7.6", targetVersion) != -1 {
t.Error("should be -1")
}
if compareGoVersion("go1.12.1rc1", targetVersion) != 1 {
t.Error("should be 1")
}
if compareGoVersion("go1.8rc1", targetVersion) != 0 {
t.Error("should be 0")
}
if compareGoVersion("go1.7rc1", targetVersion) != -1 {
t.Error("should be -1")
}
}

Some files were not shown because too many files have changed in this diff Show More