// Copyright 2011 Aaron Jacobs. All Rights Reserved. // Author: aaronjjacobs@gmail.com (Aaron Jacobs) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package oglematchers import ( "errors" "fmt" "math" "reflect" ) // Equals(x) returns a matcher that matches values v such that v and x are // equivalent. This includes the case when the comparison v == x using Go's // built-in comparison operator is legal (except for structs, which this // matcher does not support), but for convenience the following rules also // apply: // // * Type checking is done based on underlying types rather than actual // types, so that e.g. two aliases for string can be compared: // // type stringAlias1 string // type stringAlias2 string // // a := "taco" // b := stringAlias1("taco") // c := stringAlias2("taco") // // ExpectTrue(a == b) // Legal, passes // ExpectTrue(b == c) // Illegal, doesn't compile // // ExpectThat(a, Equals(b)) // Passes // ExpectThat(b, Equals(c)) // Passes // // * Values of numeric type are treated as if they were abstract numbers, and // compared accordingly. Therefore Equals(17) will match int(17), // int16(17), uint(17), float32(17), complex64(17), and so on. // // If you want a stricter matcher that contains no such cleverness, see // IdenticalTo instead. // // Arrays are supported by this matcher, but do not participate in the // exceptions above. Two arrays compared with this matcher must have identical // types, and their element type must itself be comparable according to Go's == // operator. func Equals(x interface{}) Matcher { v := reflect.ValueOf(x) // This matcher doesn't support structs. if v.Kind() == reflect.Struct { panic(fmt.Sprintf("oglematchers.Equals: unsupported kind %v", v.Kind())) } // The == operator is not defined for non-nil slices. if v.Kind() == reflect.Slice && v.Pointer() != uintptr(0) { panic(fmt.Sprintf("oglematchers.Equals: non-nil slice")) } return &equalsMatcher{v} } type equalsMatcher struct { expectedValue reflect.Value } //////////////////////////////////////////////////////////////////////// // Numeric types //////////////////////////////////////////////////////////////////////// func isSignedInteger(v reflect.Value) bool { k := v.Kind() return k >= reflect.Int && k <= reflect.Int64 } func isUnsignedInteger(v reflect.Value) bool { k := v.Kind() return k >= reflect.Uint && k <= reflect.Uintptr } func isInteger(v reflect.Value) bool { return isSignedInteger(v) || isUnsignedInteger(v) } func isFloat(v reflect.Value) bool { k := v.Kind() return k == reflect.Float32 || k == reflect.Float64 } func isComplex(v reflect.Value) bool { k := v.Kind() return k == reflect.Complex64 || k == reflect.Complex128 } func checkAgainstInt64(e int64, c reflect.Value) (err error) { err = errors.New("") switch { case isSignedInteger(c): if c.Int() == e { err = nil } case isUnsignedInteger(c): u := c.Uint() if u <= math.MaxInt64 && int64(u) == e { err = nil } // Turn around the various floating point types so that the checkAgainst* // functions for them can deal with precision issues. case isFloat(c), isComplex(c): return Equals(c.Interface()).Matches(e) default: err = NewFatalError("which is not numeric") } return } func checkAgainstUint64(e uint64, c reflect.Value) (err error) { err = errors.New("") switch { case isSignedInteger(c): i := c.Int() if i >= 0 && uint64(i) == e { err = nil } case isUnsignedInteger(c): if c.Uint() == e { err = nil } // Turn around the various floating point types so that the checkAgainst* // functions for them can deal with precision issues. case isFloat(c), isComplex(c): return Equals(c.Interface()).Matches(e) default: err = NewFatalError("which is not numeric") } return } func checkAgainstFloat32(e float32, c reflect.Value) (err error) { err = errors.New("") switch { case isSignedInteger(c): if float32(c.Int()) == e { err = nil } case isUnsignedInteger(c): if float32(c.Uint()) == e { err = nil } case isFloat(c): // Compare using float32 to avoid a false sense of precision; otherwise // e.g. Equals(float32(0.1)) won't match float32(0.1). if float32(c.Float()) == e { err = nil } case isComplex(c): comp := c.Complex() rl := real(comp) im := imag(comp) // Compare using float32 to avoid a false sense of precision; otherwise // e.g. Equals(float32(0.1)) won't match (0.1 + 0i). if im == 0 && float32(rl) == e { err = nil } default: err = NewFatalError("which is not numeric") } return } func checkAgainstFloat64(e float64, c reflect.Value) (err error) { err = errors.New("") ck := c.Kind() switch { case isSignedInteger(c): if float64(c.Int()) == e { err = nil } case isUnsignedInteger(c): if float64(c.Uint()) == e { err = nil } // If the actual value is lower precision, turn the comparison around so we // apply the low-precision rules. Otherwise, e.g. Equals(0.1) may not match // float32(0.1). case ck == reflect.Float32 || ck == reflect.Complex64: return Equals(c.Interface()).Matches(e) // Otherwise, compare with double precision. case isFloat(c): if c.Float() == e { err = nil } case isComplex(c): comp := c.Complex() rl := real(comp) im := imag(comp) if im == 0 && rl == e { err = nil } default: err = NewFatalError("which is not numeric") } return } func checkAgainstComplex64(e complex64, c reflect.Value) (err error) { err = errors.New("") realPart := real(e) imaginaryPart := imag(e) switch { case isInteger(c) || isFloat(c): // If we have no imaginary part, then we should just compare against the // real part. Otherwise, we can't be equal. if imaginaryPart != 0 { return } return checkAgainstFloat32(realPart, c) case isComplex(c): // Compare using complex64 to avoid a false sense of precision; otherwise // e.g. Equals(0.1 + 0i) won't match float32(0.1). if complex64(c.Complex()) == e { err = nil } default: err = NewFatalError("which is not numeric") } return } func checkAgainstComplex128(e complex128, c reflect.Value) (err error) { err = errors.New("") realPart := real(e) imaginaryPart := imag(e) switch { case isInteger(c) || isFloat(c): // If we have no imaginary part, then we should just compare against the // real part. Otherwise, we can't be equal. if imaginaryPart != 0 { return } return checkAgainstFloat64(realPart, c) case isComplex(c): if c.Complex() == e { err = nil } default: err = NewFatalError("which is not numeric") } return } //////////////////////////////////////////////////////////////////////// // Other types //////////////////////////////////////////////////////////////////////// func checkAgainstBool(e bool, c reflect.Value) (err error) { if c.Kind() != reflect.Bool { err = NewFatalError("which is not a bool") return } err = errors.New("") if c.Bool() == e { err = nil } return } func checkAgainstChan(e reflect.Value, c reflect.Value) (err error) { // Create a description of e's type, e.g. "chan int". typeStr := fmt.Sprintf("%s %s", e.Type().ChanDir(), e.Type().Elem()) // Make sure c is a chan of the correct type. if c.Kind() != reflect.Chan || c.Type().ChanDir() != e.Type().ChanDir() || c.Type().Elem() != e.Type().Elem() { err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr)) return } err = errors.New("") if c.Pointer() == e.Pointer() { err = nil } return } func checkAgainstFunc(e reflect.Value, c reflect.Value) (err error) { // Make sure c is a function. if c.Kind() != reflect.Func { err = NewFatalError("which is not a function") return } err = errors.New("") if c.Pointer() == e.Pointer() { err = nil } return } func checkAgainstMap(e reflect.Value, c reflect.Value) (err error) { // Make sure c is a map. if c.Kind() != reflect.Map { err = NewFatalError("which is not a map") return } err = errors.New("") if c.Pointer() == e.Pointer() { err = nil } return } func checkAgainstPtr(e reflect.Value, c reflect.Value) (err error) { // Create a description of e's type, e.g. "*int". typeStr := fmt.Sprintf("*%v", e.Type().Elem()) // Make sure c is a pointer of the correct type. if c.Kind() != reflect.Ptr || c.Type().Elem() != e.Type().Elem() { err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr)) return } err = errors.New("") if c.Pointer() == e.Pointer() { err = nil } return } func checkAgainstSlice(e reflect.Value, c reflect.Value) (err error) { // Create a description of e's type, e.g. "[]int". typeStr := fmt.Sprintf("[]%v", e.Type().Elem()) // Make sure c is a slice of the correct type. if c.Kind() != reflect.Slice || c.Type().Elem() != e.Type().Elem() { err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr)) return } err = errors.New("") if c.Pointer() == e.Pointer() { err = nil } return } func checkAgainstString(e reflect.Value, c reflect.Value) (err error) { // Make sure c is a string. if c.Kind() != reflect.String { err = NewFatalError("which is not a string") return } err = errors.New("") if c.String() == e.String() { err = nil } return } func checkAgainstArray(e reflect.Value, c reflect.Value) (err error) { // Create a description of e's type, e.g. "[2]int". typeStr := fmt.Sprintf("%v", e.Type()) // Make sure c is the correct type. if c.Type() != e.Type() { err = NewFatalError(fmt.Sprintf("which is not %s", typeStr)) return } // Check for equality. if e.Interface() != c.Interface() { err = errors.New("") return } return } func checkAgainstUnsafePointer(e reflect.Value, c reflect.Value) (err error) { // Make sure c is a pointer. if c.Kind() != reflect.UnsafePointer { err = NewFatalError("which is not a unsafe.Pointer") return } err = errors.New("") if c.Pointer() == e.Pointer() { err = nil } return } func checkForNil(c reflect.Value) (err error) { err = errors.New("") // Make sure it is legal to call IsNil. switch c.Kind() { case reflect.Invalid: case reflect.Chan: case reflect.Func: case reflect.Interface: case reflect.Map: case reflect.Ptr: case reflect.Slice: default: err = NewFatalError("which cannot be compared to nil") return } // Ask whether the value is nil. Handle a nil literal (kind Invalid) // specially, since it's not legal to call IsNil there. if c.Kind() == reflect.Invalid || c.IsNil() { err = nil } return } //////////////////////////////////////////////////////////////////////// // Public implementation //////////////////////////////////////////////////////////////////////// func (m *equalsMatcher) Matches(candidate interface{}) error { e := m.expectedValue c := reflect.ValueOf(candidate) ek := e.Kind() switch { case ek == reflect.Bool: return checkAgainstBool(e.Bool(), c) case isSignedInteger(e): return checkAgainstInt64(e.Int(), c) case isUnsignedInteger(e): return checkAgainstUint64(e.Uint(), c) case ek == reflect.Float32: return checkAgainstFloat32(float32(e.Float()), c) case ek == reflect.Float64: return checkAgainstFloat64(e.Float(), c) case ek == reflect.Complex64: return checkAgainstComplex64(complex64(e.Complex()), c) case ek == reflect.Complex128: return checkAgainstComplex128(complex128(e.Complex()), c) case ek == reflect.Chan: return checkAgainstChan(e, c) case ek == reflect.Func: return checkAgainstFunc(e, c) case ek == reflect.Map: return checkAgainstMap(e, c) case ek == reflect.Ptr: return checkAgainstPtr(e, c) case ek == reflect.Slice: return checkAgainstSlice(e, c) case ek == reflect.String: return checkAgainstString(e, c) case ek == reflect.Array: return checkAgainstArray(e, c) case ek == reflect.UnsafePointer: return checkAgainstUnsafePointer(e, c) case ek == reflect.Invalid: return checkForNil(c) } panic(fmt.Sprintf("equalsMatcher.Matches: unexpected kind: %v", ek)) } func (m *equalsMatcher) Description() string { // Special case: handle nil. if !m.expectedValue.IsValid() { return "is nil" } return fmt.Sprintf("%v", m.expectedValue.Interface()) }