diff --git a/client/cache/cache_test.go b/client/cache/cache_test.go index bd9b0801..85f83fc4 100644 --- a/client/cache/cache_test.go +++ b/client/cache/cache_test.go @@ -72,21 +72,9 @@ func TestCache(t *testing.T) { t.Error("set Error", err) } - if err = bm.Incr(context.Background(), "astaxie"); err != nil { - t.Error("Incr Error", err) - } + // test different integer type for incr & decr + testMultiIncrDecr(t, bm, timeoutDuration) - if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr(context.Background(), "astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 1 { - t.Error("get err") - } bm.Delete(context.Background(), "astaxie") if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("delete err") @@ -153,21 +141,9 @@ func TestFileCache(t *testing.T) { t.Error("get err") } - if err = bm.Incr(context.Background(), "astaxie"); err != nil { - t.Error("Incr Error", err) - } + // test different integer type for incr & decr + testMultiIncrDecr(t, bm, timeoutDuration) - if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr(context.Background(), "astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 1 { - t.Error("get err") - } bm.Delete(context.Background(), "astaxie") if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("delete err") @@ -219,3 +195,41 @@ func TestFileCache(t *testing.T) { os.RemoveAll("cache") } + +func testMultiIncrDecr(t *testing.T, c Cache, timeout time.Duration) { + testIncrDecr(t, c, 1, 2, timeout) + testIncrDecr(t, c, int32(1), int32(2), timeout) + testIncrDecr(t, c, int64(1), int64(2), timeout) + testIncrDecr(t, c, uint(1), uint(2), timeout) + testIncrDecr(t, c, uint32(1), uint32(2), timeout) + testIncrDecr(t, c, uint64(1), uint64(2), timeout) +} + +func testIncrDecr(t *testing.T, c Cache, beforeIncr interface{}, afterIncr interface{}, timeout time.Duration) { + var err error + ctx := context.Background() + key := "incDecKey" + if err = c.Put(ctx, key, beforeIncr, timeout); err != nil { + t.Error("Get Error", err) + } + + if err = c.Incr(ctx, key); err != nil { + t.Error("Incr Error", err) + } + + if v, _ := c.Get(ctx, key); v != afterIncr { + t.Error("Get Error") + } + + if err = c.Decr(ctx, key); err != nil { + t.Error("Decr Error", err) + } + + if v, _ := c.Get(ctx, key); v != beforeIncr { + t.Error("Get Error") + } + + if err := c.Delete(ctx, key); err != nil { + t.Error("Delete Error") + } +} diff --git a/client/cache/file.go b/client/cache/file.go index 84ac03c8..043c4650 100644 --- a/client/cache/file.go +++ b/client/cache/file.go @@ -26,7 +26,6 @@ import ( "io/ioutil" "os" "path/filepath" - "reflect" "strconv" "strings" "time" @@ -195,28 +194,70 @@ func (fc *FileCache) Delete(ctx context.Context, key string) error { // Incr increases cached int value. // fc value is saved forever unless deleted. func (fc *FileCache) Incr(ctx context.Context, key string) error { - data, _ := fc.Get(context.Background(), key) - var incr int - if reflect.TypeOf(data).Name() != "int" { - incr = 0 - } else { - incr = data.(int) + 1 + data, err := fc.Get(context.Background(), key) + if err != nil { + return err } - fc.Put(context.Background(), key, incr, time.Duration(fc.EmbedExpiry)) - return nil + + var res interface{} + switch val := data.(type) { + case int: + res = val + 1 + case int32: + res = val + 1 + case int64: + res = val + 1 + case uint: + res = val + 1 + case uint32: + res = val + 1 + case uint64: + res = val + 1 + default: + return errors.Errorf("data is not (u)int (u)int32 (u)int64") + } + + return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) } // Decr decreases cached int value. func (fc *FileCache) Decr(ctx context.Context, key string) error { - data, _ := fc.Get(context.Background(), key) - var decr int - if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { - decr = 0 - } else { - decr = data.(int) - 1 + data, err := fc.Get(context.Background(), key) + if err != nil { + return err } - fc.Put(context.Background(), key, decr, time.Duration(fc.EmbedExpiry)) - return nil + + var res interface{} + switch val := data.(type) { + case int: + res = val - 1 + case int32: + res = val - 1 + case int64: + res = val - 1 + case uint: + if val > 0 { + res = val - 1 + } else { + return errors.New("data val is less than 0") + } + case uint32: + if val > 0 { + res = val - 1 + } else { + return errors.New("data val is less than 0") + } + case uint64: + if val > 0 { + res = val - 1 + } else { + return errors.New("data val is less than 0") + } + default: + return errors.Errorf("data is not (u)int (u)int32 (u)int64") + } + + return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) } // IsExist checks if value exists.