diff --git a/deep.go b/deep.go index dbc89c0..9f6c62e 100644 --- a/deep.go +++ b/deep.go @@ -1,6 +1,5 @@ -// Package deep provides function deep.Equal which is like reflect.DeepEqual but -// returns a list of differences. This is helpful when comparing complex types -// like structures and maps. +// Package deep provides function deep.Equal which is like reflect.DeepEqual but returns a list of differences. +// This is helpful when comparing complex types like structures and maps. package deep import ( @@ -9,6 +8,7 @@ import ( "log" "reflect" "strings" + "time" ) var ( @@ -16,6 +16,9 @@ var ( // to when comparing. FloatPrecision = 10 + // TimePrecision is a precision used for time.Time.Truncate(), if it is non-zero. + TimePrecision time.Duration + // MaxDiff specifies the maximum number of differences to return. MaxDiff = 10 @@ -30,6 +33,10 @@ var ( // T{s int}, to be compared when true. CompareUnexportedFields = false + // CompareFunctions causes functions to be compared according to reflect.DeepEqual rules: + // that is, Func values are equal if both are nil; otherwise they are not equal. + CompareFunctions = false + // NilSlicesAreEmpty causes a nil slice to be equal to an empty slice. NilSlicesAreEmpty = false @@ -49,12 +56,18 @@ var ( ) type cmp struct { - diff []string - buff []string + diff []string + buff []string + seen map[uintptr]struct{} + floatFormat string } -var errorType = reflect.TypeOf((*error)(nil)).Elem() +var ( + errorType = reflect.TypeOf((*error)(nil)).Elem() + timeType = reflect.TypeOf(time.Time{}) + durationType = reflect.TypeOf(time.Nanosecond) +) // Equal compares variables a and b, recursing into their structure up to // MaxDepth levels deep (if greater than zero), and returns a list of differences, @@ -67,32 +80,34 @@ var errorType = reflect.TypeOf((*error)(nil)).Elem() // When comparing a struct, if a field has the tag `deep:"-"` then it will be // ignored. func Equal(a, b interface{}) []string { - aVal := reflect.ValueOf(a) - bVal := reflect.ValueOf(b) c := &cmp{ - diff: []string{}, - buff: []string{}, + seen: make(map[uintptr]struct{}), + floatFormat: fmt.Sprintf("%%.%df", FloatPrecision), } - if a == nil && b == nil { - return nil - } else if a == nil && b != nil { - c.saveDiff("", b) - } else if a != nil && b == nil { - c.saveDiff(a, "") - } - if len(c.diff) > 0 { + + if a == nil || b == nil { + switch { + case b != nil: + c.saveDiff("", b) + + case a != nil: + c.saveDiff(a, "") + } + return c.diff } - c.equals(aVal, bVal, 0) - if len(c.diff) > 0 { - return c.diff // diffs - } - return nil // no diffs + c.equals(reflect.ValueOf(a), reflect.ValueOf(b), 0) + + return c.diff } func (c *cmp) equals(a, b reflect.Value, level int) { + if len(c.diff) >= MaxDiff { + return + } + if MaxDepth > 0 && level > MaxDepth { logError(ErrMaxRecursion) return @@ -100,11 +115,14 @@ func (c *cmp) equals(a, b reflect.Value, level int) { // Check if one value is nil, e.g. T{x: *X} and T.x is nil if !a.IsValid() || !b.IsValid() { - if a.IsValid() && !b.IsValid() { - c.saveDiff(a.Type(), "") - } else if !a.IsValid() && b.IsValid() { - c.saveDiff("", b.Type()) + switch { + case a.IsValid(): + c.saveDiff(a.Type(), "") + + case b.IsValid(): + c.saveDiff("", b.Type()) } + return } @@ -112,58 +130,127 @@ func (c *cmp) equals(a, b reflect.Value, level int) { aType := a.Type() bType := b.Type() if aType != bType { + logError(ErrTypeMismatch) + // Built-in types don't have a name, so don't report [3]int != [2]int as " != " if aType.Name() == "" || aType.Name() != bType.Name() { c.saveDiff(aType, bType) - } else { - // Type names can be the same, e.g. pkg/v1.Error and pkg/v2.Error - // are both exported as pkg, so unless we include the full pkg path - // the diff will be "pkg.Error != pkg.Error" - // https://github.com/go-test/deep/issues/39 - aFullType := aType.PkgPath() + "." + aType.Name() - bFullType := bType.PkgPath() + "." + bType.Name() - c.saveDiff(aFullType, bFullType) + return } - logError(ErrTypeMismatch) + + // Type names can be the same, e.g. pkg/v1.Error and pkg/v2.Error + // are both exported as pkg, so unless we include the full pkg path + // the diff will be "pkg.Error != pkg.Error" + // https://github.com/go-test/deep/issues/39 + aFullType := aType.PkgPath() + "." + aType.Name() + bFullType := bType.PkgPath() + "." + bType.Name() + + c.saveDiff(aFullType, bFullType) return } // Primitive https://golang.org/pkg/reflect/#Kind - aKind := a.Kind() - bKind := b.Kind() + kind := a.Kind() // We know aType == bType, so a.Kind() == b.Kind() - // Do a and b have underlying elements? Yes if they're ptr or interface. - aElem := aKind == reflect.Ptr || aKind == reflect.Interface - bElem := bKind == reflect.Ptr || bKind == reflect.Interface + // Do a and b have underlying elements? Yes, if they're ptr or interface. + elem := kind == reflect.Ptr || kind == reflect.Interface // If both types implement the error interface, compare the error strings. - // This must be done before dereferencing because the interface is on a - // pointer receiver. Re https://github.com/go-test/deep/issues/31, a/b might - // be primitive kinds; see TestErrorPrimitiveKind. - if aType.Implements(errorType) && bType.Implements(errorType) { - if (!aElem || !a.IsNil()) && (!bElem || !b.IsNil()) { - aString := a.MethodByName("Error").Call(nil)[0].String() - bString := b.MethodByName("Error").Call(nil)[0].String() - if aString != bString { - c.saveDiff(aString, bString) + // This must be done before dereferencing because the interface may be on a pointer receiver. + // Re https://github.com/go-test/deep/issues/31, a/b might be primitive kinds; see TestErrorPrimitiveKind. + if aType.Implements(errorType) { + if !elem || (!a.IsNil() && !b.IsNil()) { + aFunc := a.MethodByName("Error") + bFunc := b.MethodByName("Error") + + if aFunc.CanInterface() && bFunc.CanInterface() { + aString := aFunc.Call(nil)[0].String() + bString := bFunc.Call(nil)[0].String() + if aString != bString { + c.saveDiff(aString, bString) + } return } } } - // Dereference pointers and interface{} - if aElem || bElem { - if aElem { - a = a.Elem() + if TimePrecision > 0 { + switch aType { + case timeType, durationType: + aFunc := a.MethodByName("Truncate") + bFunc := a.MethodByName("Truncate") + + if aFunc.CanInterface() && bFunc.CanInterface() { + precision := reflect.ValueOf(TimePrecision) + + a = aFunc.Call([]reflect.Value{precision})[0] + b = bFunc.Call([]reflect.Value{precision})[0] + } + } + } + + // For types with an `Equal(bType) bool` method like time.Time, we want to use that. + // But not if it is from an unexported struct field (CanInterface). + if eqFunc := a.MethodByName("Equal"); eqFunc.IsValid() && eqFunc.CanInterface() { + // Handle https://github.com/go-test/deep/issues/15: + // Don't call a.Equal if the method is from an embedded struct, like: + // type Foo struct { time.Time } + // First, we'll encounter Equal(Foo, time.Time), + // but if we pass b as the 2nd argument, then we'll panic: "Call using pkg.Foo as type time.Time" + // As far as I can tell, there's no way to see that the method is from time.Time not Foo. + // So we check the type of the 1st (0) arg and skip unless it's b type. + // Later, we'll encounter the time.Time anonymous/embedded field, + // and then we'll have Equal(time.Time, time.Time). + typ := eqFunc.Type() + switch { + case typ.NumIn() != 1, typ.In(0) != bType: + // Equal does not take one argument of the same type. + case typ.NumOut() != 1, typ.Out(0).Kind() != reflect.Bool: + // Equal does not return only one value of kind bool. + default: + retVals := eqFunc.Call([]reflect.Value{b}) + if !retVals[0].Bool() { + c.saveDiff(a, b) + } + return + } + } + + // Dereference pointers and interfaces + if elem { + if a.IsNil() || b.IsNil() { + if !a.IsNil() { + for a.Kind() == reflect.Interface { + // resolve a to its concrete value. + a = a.Elem() + } + c.saveDiff(a.Type(), "") + } + + if !b.IsNil() { + for b.Kind() == reflect.Interface { + // resolve b to its concrete value. + b = b.Elem() + } + c.saveDiff("", b.Type()) + } + + return } - if bElem { - b = b.Elem() + + if kind == reflect.Ptr { + if c.haveSeen(a.Pointer(), b.Pointer()) { + return + } + + c.saw(a.Pointer(), b.Pointer()) } - c.equals(a, b, level+1) + + c.equals(a.Elem(), b.Elem(), level+1) return } - switch aKind { + switch kind { ///////////////////////////////////////////////////////////////////// // Iterable kinds @@ -181,29 +268,11 @@ func (c *cmp) equals(a, b reflect.Value, level int) { Iterate through the fields (FirstName, LastName), recurse into their values. */ - // Types with an Equal() method, like time.Time, only if struct field - // is exported (CanInterface) - if eqFunc := a.MethodByName("Equal"); eqFunc.IsValid() && eqFunc.CanInterface() { - // Handle https://github.com/go-test/deep/issues/15: - // Don't call T.Equal if the method is from an embedded struct, like: - // type Foo struct { time.Time } - // First, we'll encounter Equal(Ttime, time.Time) but if we pass b - // as the 2nd arg we'll panic: "Call using pkg.Foo as type time.Time" - // As far as I can tell, there's no way to see that the method is from - // time.Time not Foo. So we check the type of the 1st (0) arg and skip - // unless it's b type. Later, we'll encounter the time.Time anonymous/ - // embedded field and then we'll have Equal(time.Time, time.Time). - funcType := eqFunc.Type() - if funcType.NumIn() == 1 && funcType.In(0) == bType { - retVals := eqFunc.Call([]reflect.Value{b}) - if !retVals[0].Bool() { - c.saveDiff(a, b) - } + for i := 0; i < a.NumField(); i++ { + if len(c.diff) >= MaxDiff { return } - } - for i := 0; i < a.NumField(); i++ { if aType.Field(i).PkgPath != "" && !CompareUnexportedFields { continue // skip unexported field, e.g. s in type T struct {s string} } @@ -212,22 +281,11 @@ func (c *cmp) equals(a, b reflect.Value, level int) { continue // field wants to be ignored } - c.push(aType.Field(i).Name) // push field name to buff - - // Get the Value for each field, e.g. FirstName has Type = string, - // Kind = reflect.String. - af := a.Field(i) - bf := b.Field(i) - - // Recurse to compare the field values - c.equals(af, bf, level+1) - - c.pop() // pop field name from buff - - if len(c.diff) >= MaxDiff { - break - } + c.push(aType.Field(i).Name) + c.equals(a.Field(i), b.Field(i), level+1) + c.pop() } + case reflect.Map: /* The variables are maps like: @@ -246,20 +304,25 @@ func (c *cmp) equals(a, b reflect.Value, level int) { if a.IsNil() || b.IsNil() { if NilMapsAreEmpty { - if a.IsNil() && b.Len() != 0 { + if b.Len() != 0 { c.saveDiff("", b) - return - } else if a.Len() != 0 && b.IsNil() { - c.saveDiff(a, "") - return } - } else { - if a.IsNil() && !b.IsNil() { - c.saveDiff("", b) - } else if !a.IsNil() && b.IsNil() { + + if a.Len() != 0 { c.saveDiff(a, "") } + + return } + + if !b.IsNil() { + c.saveDiff("", b) + } + + if !a.IsNil() { + c.saveDiff(a, "") + } + return } @@ -267,89 +330,113 @@ func (c *cmp) equals(a, b reflect.Value, level int) { return } + prefix := func(key reflect.Value) string { return fmt.Sprintf("map[%v]", key) } + for _, key := range a.MapKeys() { - c.push(fmt.Sprintf("map[%v]", key)) + if len(c.diff) >= MaxDiff { + return + } aVal := a.MapIndex(key) bVal := b.MapIndex(key) - if bVal.IsValid() { - c.equals(aVal, bVal, level+1) - } else { - c.saveDiff(aVal, "") + + if !bVal.IsValid() { + c.prefixDiff(prefix(key), aVal, "") + continue } + c.push(prefix(key)) + c.equals(aVal, bVal, level+1) c.pop() + } + for _, key := range b.MapKeys() { if len(c.diff) >= MaxDiff { return } - } - for _, key := range b.MapKeys() { if aVal := a.MapIndex(key); aVal.IsValid() { continue } - c.push(fmt.Sprintf("map[%v]", key)) - c.saveDiff("", b.MapIndex(key)) - c.pop() - if len(c.diff) >= MaxDiff { - return - } + c.prefixDiff(prefix(key), "", b.MapIndex(key)) } + case reflect.Array: n := a.Len() for i := 0; i < n; i++ { + if len(c.diff) >= MaxDiff { + return + } + c.push(fmt.Sprintf("array[%d]", i)) c.equals(a.Index(i), b.Index(i), level+1) c.pop() - if len(c.diff) >= MaxDiff { - break - } } + case reflect.Slice: - if NilSlicesAreEmpty { - if a.IsNil() && b.Len() != 0 { - c.saveDiff("", b) - return - } else if a.Len() != 0 && b.IsNil() { - c.saveDiff(a, "") + if a.IsNil() || b.IsNil() { + if NilSlicesAreEmpty { + if b.Len() != 0 { + c.saveDiff("", b) + } + + if a.Len() != 0 { + c.saveDiff(a, "") + } + return } - } else { - if a.IsNil() && !b.IsNil() { + + if !b.IsNil() { c.saveDiff("", b) - return - } else if !a.IsNil() && b.IsNil() { + } + if !a.IsNil() { c.saveDiff(a, "") - return } + + return } aLen := a.Len() bLen := b.Len() - if a.Pointer() == b.Pointer() && aLen == bLen { - return - } + prefix := func(i int) string { return fmt.Sprintf("slice[%d]", i) } - n := aLen - if bLen > aLen { - n = bLen - } - for i := 0; i < n; i++ { - c.push(fmt.Sprintf("slice[%d]", i)) - if i < aLen && i < bLen { + if a.Pointer() != b.Pointer() { + // These values can only be different if they have different backing store arrays. + // So, there is no need to check them if a.Pointer() == b.Pointer(). + + n := aLen + if n > bLen { + n = bLen + } + + for i := 0; i < n; i++ { + if len(c.diff) >= MaxDiff { + return + } + + c.push(prefix(i)) c.equals(a.Index(i), b.Index(i), level+1) - } else if i < aLen { - c.saveDiff(a.Index(i), "") - } else { - c.saveDiff("", b.Index(i)) + c.pop() } - c.pop() + } + + for i := bLen; i < aLen; i++ { if len(c.diff) >= MaxDiff { - break + return } + + c.prefixDiff(prefix(i), a.Index(i), "") + } + + for i := aLen; i < bLen; i++ { + if len(c.diff) >= MaxDiff { + return + } + + c.prefixDiff(prefix(i), "", b.Index(i)) } ///////////////////////////////////////////////////////////////////// @@ -357,33 +444,67 @@ func (c *cmp) equals(a, b reflect.Value, level int) { ///////////////////////////////////////////////////////////////////// case reflect.Float32, reflect.Float64: - // Round floats to FloatPrecision decimal places to compare with - // user-defined precision. As is commonly know, floats have "imprecision" - // such that 0.1 becomes 0.100000001490116119384765625. This cannot - // be avoided; it can only be handled. Issue 30 suggested that floats - // be compared using an epsilon: equal = |a-b| < epsilon. - // In many cases the result is the same, but I think epsilon is a little - // less clear for users to reason about. See issue 30 for details. + // Zero and negative-zero format to different strings. + // The equality test here short-circuits all cases where values are equal by definition. + // Strictly, this test is only necessary for the case of zero and negative-zero, + // but this actual-equality short-circuit is useful for all cases. + if a.Float() == b.Float() { + return + } + + // Round floats to FloatPrecision decimal places to compare with user-defined precision. + // As is commonly known, floats have "imprecision" such that 0.1 becomes 0.100000001490116119384765625. + // This cannot be avoided; it can only be handled. + // Issue 30 suggested that floats be compared using an epsilon: equal = |a-b| < epsilon. + // In many cases the result is the same, + // but I think epsilon is a little less clear for users to reason about. + // See issue 30 for details. + aval := fmt.Sprintf(c.floatFormat, a.Float()) bval := fmt.Sprintf(c.floatFormat, b.Float()) if aval != bval { - c.saveDiff(a.Float(), b.Float()) + c.saveDiff(a, b) } + case reflect.Bool: if a.Bool() != b.Bool() { - c.saveDiff(a.Bool(), b.Bool()) + c.saveDiff(a, b) } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if a.Int() != b.Int() { - c.saveDiff(a.Int(), b.Int()) + c.saveDiff(a, b) } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if a.Uint() != b.Uint() { - c.saveDiff(a.Uint(), b.Uint()) + c.saveDiff(a, b) } + case reflect.String: if a.String() != b.String() { - c.saveDiff(a.String(), b.String()) + c.saveDiff(a, b) + } + + ///////////////////////////////////////////////////////////////////// + // Edge-cases + ///////////////////////////////////////////////////////////////////// + + case reflect.Func: + if CompareFunctions { + if a.IsNil() || b.IsNil() { + if !a.IsNil() { + c.saveDiff("", "") + } + + if !b.IsNil() { + c.saveDiff("", "") + } + + return + } + + c.saveDiff("", "") } default: @@ -391,6 +512,22 @@ func (c *cmp) equals(a, b reflect.Value, level int) { } } +func (c *cmp) saw(ptrs ...uintptr) { + for _, ptr := range ptrs { + c.seen[ptr] = struct{}{} + } +} + +func (c *cmp) haveSeen(ptrs ...uintptr) bool { + for _, ptr := range ptrs { + if _, ok := c.seen[ptr]; ok { + return true + } + } + + return false +} + func (c *cmp) push(name string) { c.buff = append(c.buff, name) } @@ -401,13 +538,21 @@ func (c *cmp) pop() { } } -func (c *cmp) saveDiff(aval, bval interface{}) { - if len(c.buff) > 0 { - varName := strings.Join(c.buff, ".") - c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval)) - } else { - c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval)) +func formatDiff(prefixes []string, aval, bval interface{}) string { + if len(prefixes) > 0 { + prefix := strings.Join(prefixes, ".") + return fmt.Sprintf("%s: %v != %v", prefix, aval, bval) } + + return fmt.Sprintf("%v != %v", aval, bval) +} + +func (c *cmp) saveDiff(aval, bval interface{}) { + c.diff = append(c.diff, formatDiff(c.buff, aval, bval)) +} + +func (c *cmp) prefixDiff(prefix string, aval, bval interface{}) { + c.diff = append(c.diff, formatDiff(append(c.buff, prefix), aval, bval)) } func logError(err error) { diff --git a/deep_test.go b/deep_test.go index 2a03392..a2d2b02 100644 --- a/deep_test.go +++ b/deep_test.go @@ -3,6 +3,8 @@ package deep_test import ( "errors" "fmt" + "math" + "net" "reflect" "testing" "time" @@ -12,255 +14,233 @@ import ( v2 "github.com/go-test/deep/test/v2" ) -func TestString(t *testing.T) { - diff := deep.Equal("foo", "foo") +func shouldBeEqual(t testing.TB, diff []string) { + t.Helper() + if len(diff) > 0 { - t.Error("should be equal:", diff) + t.Errorf("should be equal: %q", diff) } +} - diff = deep.Equal("foo", "bar") - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) +func shouldBeMaxDiff(t testing.TB, diff []string) { + t.Helper() + + if len(diff) == 0 { + t.Fatal("no diffs") } - if diff[0] != "foo != bar" { - t.Error("wrong diff:", diff[0]) + + if len(diff) != deep.MaxDiff { + t.Logf("diff: %q", diff) + t.Errorf("wrong number of diffs: got %d, expected %d", len(diff), deep.MaxDiff) } } -func TestFloat(t *testing.T) { - diff := deep.Equal(1.1, 1.1) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } +const ( + multilineTestError = `wrong diff: + got: %q + expected: %q` +) + +func reportWrongDiff(t testing.TB, got, expect string) { + t.Helper() - diff = deep.Equal(1.1234561, 1.1234562) - if diff == nil { - t.Error("no diff") + output := fmt.Sprintf("wrong diff: got %q, expected %q", got, expect) + if len(output) > 120 { + output = fmt.Sprintf(multilineTestError, got, expect) } - defaultFloatPrecision := deep.FloatPrecision - deep.FloatPrecision = 6 - defer func() { deep.FloatPrecision = defaultFloatPrecision }() + t.Error(output) +} - diff = deep.Equal(1.1234561, 1.1234562) - if len(diff) > 0 { - t.Error("should be equal:", diff) +func shouldBeDiffs(t testing.TB, diff []string, head string, tail ...string) { + t.Helper() + + if len(diff) == 0 { + t.Fatal("no diffs") } - diff = deep.Equal(1.123456, 1.123457) - if diff == nil { - t.Fatal("no diff") + if len(diff) != len(tail)+1 { + t.Log("diff:", diff) + t.Errorf("wrong number of diffs: got %d, expected %d", len(diff), len(tail)+1) } - if len(diff) != 1 { - t.Error("too many diff:", diff) + + if expect := head; diff[0] != expect { + reportWrongDiff(t, diff[0], expect) } - if diff[0] != "1.123456 != 1.123457" { - t.Error("wrong diff:", diff[0]) + + for i, expect := range tail { + if i+1 >= len(diff) { + t.Errorf("missing diff: %q", expect) + continue + } + + if got := diff[i+1]; got != expect { + reportWrongDiff(t, got, expect) + } + } +} +func TestString(t *testing.T) { + shouldBeEqual(t, deep.Equal("foo", "foo")) + + shouldBeDiffs(t, deep.Equal("foo", "bar"), "foo != bar") +} + +func TestFloat(t *testing.T) { + shouldBeEqual(t, deep.Equal(1.1, 1.1)) + + shouldBeDiffs(t, deep.Equal(1.1234561, 1.1234562), "1.1234561 != 1.1234562") + + shouldBeEqual(t, deep.Equal(float32(0.3), float32(0.1)+float32(0.2))) + shouldBeEqual(t, deep.Equal(float64(0.3), float64(0.1)+float64(0.2))) + + restoreFloatPrecision := deep.FloatPrecision + t.Cleanup(func() { deep.FloatPrecision = restoreFloatPrecision }) + + deep.FloatPrecision = 6 + + shouldBeEqual(t, deep.Equal(1.1234561, 1.1234562)) + + shouldBeDiffs(t, deep.Equal(1.123456, 1.123457), "1.123456 != 1.123457") + + // Since we compare string representations, NaN should compare equal to NaN + shouldBeEqual(t, deep.Equal(math.NaN(), math.NaN())) + + shouldBeEqual(t, deep.Equal(math.Inf(1), math.Inf(1))) + shouldBeEqual(t, deep.Equal(math.Inf(-1), math.Inf(-1))) + + shouldBeDiffs(t, deep.Equal(math.Inf(1), math.Inf(-1)), "+Inf != -Inf") + shouldBeDiffs(t, deep.Equal(math.Inf(-1), math.Inf(1)), "-Inf != +Inf") + + var zero float64 + + shouldBeEqual(t, deep.Equal(zero, zero)) + shouldBeEqual(t, deep.Equal(-zero, -zero)) + + shouldBeEqual(t, deep.Equal(zero, -zero)) + shouldBeEqual(t, deep.Equal(-zero, zero)) } func TestInt(t *testing.T) { - diff := deep.Equal(1, 1) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(1, 1)) - diff = deep.Equal(1, 2) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "1 != 2" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(1, 2), "1 != 2") } func TestUint(t *testing.T) { - diff := deep.Equal(uint(2), uint(2)) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(uint(2), uint(2))) - diff = deep.Equal(uint(2), uint(3)) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "2 != 3" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(uint(2), uint(3)), "2 != 3") } func TestBool(t *testing.T) { - diff := deep.Equal(true, true) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(true, true)) + shouldBeEqual(t, deep.Equal(false, false)) - diff = deep.Equal(false, false) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } - - diff = deep.Equal(true, false) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "true != false" { // unless you're fipar - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(true, false), "true != false") } func TestTypeMismatch(t *testing.T) { type T1 int // same type kind (int) type T2 int // but different type - var t1 T1 = 1 - var t2 T2 = 1 - diff := deep.Equal(t1, t2) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "deep_test.T1 != deep_test.T2" { - t.Error("wrong diff:", diff[0]) - } + + t1 := T1(1) + t2 := T2(2) + + shouldBeDiffs(t, deep.Equal(t1, t2), "deep_test.T1 != deep_test.T2") // Same pkg name but differnet full paths // https://github.com/go-test/deep/issues/39 err1 := v1.Error{} err2 := v2.Error{} - diff = deep.Equal(err1, err2) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "github.com/go-test/deep/test/v1.Error != github.com/go-test/deep/test/v2.Error" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, + deep.Equal(err1, err2), + "github.com/go-test/deep/test/v1.Error != github.com/go-test/deep/test/v2.Error", + ) } func TestKindMismatch(t *testing.T) { - deep.LogErrors = true + restoreLogErrors := deep.LogErrors + t.Cleanup(func() { deep.LogErrors = restoreLogErrors }) - var x int = 100 - var y float64 = 100 - diff := deep.Equal(x, y) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "int != float64" { - t.Error("wrong diff:", diff[0]) - } + deep.LogErrors = true - deep.LogErrors = false + shouldBeDiffs(t, deep.Equal(int(100), float64(100)), "int != float64") } func TestDeepRecursion(t *testing.T) { - deep.MaxDepth = 2 - defer func() { deep.MaxDepth = 10 }() + restoreMaxDepth := deep.MaxDepth + t.Cleanup(func() { deep.MaxDepth = restoreMaxDepth }) - type s3 struct { - S int - } - type s2 struct { - S s3 - } - type s1 struct { - S s2 - } - foo := map[string]s1{ - "foo": { // 1 + type ( + s1 struct { + S int + } + s2 struct { + S s1 + } + s3 struct { + S s2 + } + ) + + foo := map[string]s3{ + "foo": s3{ // 1 S: s2{ // 2 - S: s3{ // 3 + S: s1{ // 3 S: 42, // 4 }, }, }, } - bar := map[string]s1{ - "foo": { + bar := map[string]s3{ + "foo": s3{ S: s2{ - S: s3{ + S: s1{ S: 100, }, }, }, } + // No diffs because MaxDepth=2 prevents seeing the diff at 3rd level down - diff := deep.Equal(foo, bar) - if diff != nil { - t.Errorf("got %d diffs, expected none: %v", len(diff), diff) - } + deep.MaxDepth = 2 + shouldBeEqual(t, deep.Equal(foo, bar)) - defaultMaxDepth := deep.MaxDepth deep.MaxDepth = 4 - defer func() { deep.MaxDepth = defaultMaxDepth }() - diff = deep.Equal(foo, bar) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "map[foo].S.S.S: 42 != 100" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(foo, bar), "map[foo].S.S.S: 42 != 100") } func TestMaxDiff(t *testing.T) { - a := []int{1, 2, 3, 4, 5, 6, 7} - b := []int{0, 0, 0, 0, 0, 0, 0} + restoreMaxDiff := deep.MaxDiff + t.Cleanup(func() { deep.MaxDiff = restoreMaxDiff }) - defaultMaxDiff := deep.MaxDiff deep.MaxDiff = 3 - defer func() { deep.MaxDiff = defaultMaxDiff }() - diff := deep.Equal(a, b) - if diff == nil { - t.Fatal("no diffs") - } - if len(diff) != deep.MaxDiff { - t.Errorf("got %d diffs, expected %d", len(diff), deep.MaxDiff) - } + a1 := []int{1, 2, 3, 4, 5, 6, 7} + a2 := []int{0, 0, 0, 0, 0, 0, 0} + + shouldBeMaxDiff(t, deep.Equal(a1, a2)) + + restoreCompareUnexportedFields := deep.CompareUnexportedFields + t.Cleanup(func() { deep.CompareUnexportedFields = restoreCompareUnexportedFields }) - defaultCompareUnexportedFields := deep.CompareUnexportedFields deep.CompareUnexportedFields = true - defer func() { deep.CompareUnexportedFields = defaultCompareUnexportedFields }() + type fiveFields struct { - a int // unexported fields require ^ + a int // unexported fields require: deep.CompareUnexportedFields = true b int c int d int e int } - t1 := fiveFields{1, 2, 3, 4, 5} - t2 := fiveFields{0, 0, 0, 0, 0} - diff = deep.Equal(t1, t2) - if diff == nil { - t.Fatal("no diffs") - } - if len(diff) != deep.MaxDiff { - t.Errorf("got %d diffs, expected %d", len(diff), deep.MaxDiff) - } + + s1 := fiveFields{1, 2, 3, 4, 5} + s2 := fiveFields{0, 0, 0, 0, 0} + + shouldBeMaxDiff(t, deep.Equal(s1, s2)) // Same keys, too many diffs m1 := map[int]int{ @@ -277,14 +257,8 @@ func TestMaxDiff(t *testing.T) { 4: 0, 5: 0, } - diff = deep.Equal(m1, m2) - if diff == nil { - t.Fatal("no diffs") - } - if len(diff) != deep.MaxDiff { - t.Log(diff) - t.Errorf("got %d diffs, expected %d", len(diff), deep.MaxDiff) - } + + shouldBeMaxDiff(t, deep.Equal(m1, m2)) // Too many missing keys m1 = map[int]int{ @@ -300,83 +274,45 @@ func TestMaxDiff(t *testing.T) { 6: 0, 7: 0, } - diff = deep.Equal(m1, m2) - if diff == nil { - t.Fatal("no diffs") - } - if len(diff) != deep.MaxDiff { - t.Log(diff) - t.Errorf("got %d diffs, expected %d", len(diff), deep.MaxDiff) - } + + shouldBeMaxDiff(t, deep.Equal(m1, m2)) } func TestNotHandled(t *testing.T) { - a := func(int) {} - b := func(int) {} - diff := deep.Equal(a, b) - if len(diff) > 0 { - t.Error("got diffs:", diff) - } + shouldBeEqual(t, deep.Equal(func(int) {}, func(int) {})) } func TestStruct(t *testing.T) { - type s1 struct { + type s struct { id int Name string Number int } - sa := s1{ + + s1 := s{ id: 1, Name: "foo", Number: 2, } - sb := sa - diff := deep.Equal(sa, sb) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + s2 := s1 - sb.Name = "bar" - diff = deep.Equal(sa, sb) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "Name: foo != bar" { - t.Error("wrong diff:", diff[0]) - } + shouldBeEqual(t, deep.Equal(s1, s2)) - sb.Number = 22 - diff = deep.Equal(sa, sb) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 2 { - t.Error("too many diff:", diff) - } - if diff[0] != "Name: foo != bar" { - t.Error("wrong diff:", diff[0]) - } - if diff[1] != "Number: 2 != 22" { - t.Error("wrong diff:", diff[1]) - } + s2.Name = "bar" + shouldBeDiffs(t, deep.Equal(s1, s2), "Name: foo != bar") - sb.id = 11 - diff = deep.Equal(sa, sb) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 2 { - t.Error("too many diff:", diff) - } - if diff[0] != "Name: foo != bar" { - t.Error("wrong diff:", diff[0]) - } - if diff[1] != "Number: 2 != 22" { - t.Error("wrong diff:", diff[1]) - } + s2.Number = 22 + shouldBeDiffs(t, deep.Equal(s1, s2), + "Name: foo != bar", + "Number: 2 != 22", + ) + + s2.id = 11 + shouldBeDiffs(t, deep.Equal(s1, s2), + "Name: foo != bar", + "Number: 2 != 22", + // should skip unexported fields + ) } func TestStructWithTags(t *testing.T) { @@ -484,12 +420,12 @@ func TestStructWithTags(t *testing.T) { }, } - orig := deep.CompareUnexportedFields + restoreCompareUnexportedFields := deep.CompareUnexportedFields + t.Cleanup(func() { deep.CompareUnexportedFields = restoreCompareUnexportedFields }) + deep.CompareUnexportedFields = true - diff := deep.Equal(sa, sb) - deep.CompareUnexportedFields = orig - want := []string{ + shouldBeDiffs(t, deep.Equal(sa, sb), "s1.modified: 1 != 10", "s1.ExportedModified: 5 != 50", "modified: 1 != 10", @@ -498,10 +434,7 @@ func TestStructWithTags(t *testing.T) { "recurseInline.ExportedModified: 5 != 50", "recursePtr.modified: 1 != 10", "recursePtr.ExportedModified: 5 != 50", - } - if !reflect.DeepEqual(want, diff) { - t.Errorf("got %v, want %v", diff, want) - } + ) } func TestNestedStruct(t *testing.T) { @@ -512,27 +445,18 @@ func TestNestedStruct(t *testing.T) { Name string Alias s2 } + sa := s1{ Name: "Robert", Alias: s2{Nickname: "Bob"}, } sb := sa - diff := deep.Equal(sa, sb) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + + shouldBeEqual(t, deep.Equal(sa, sb)) sb.Alias.Nickname = "Bobby" - diff = deep.Equal(sa, sb) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "Alias.Nickname: Bob != Bobby" { - t.Error("wrong diff:", diff[0]) - } + + shouldBeDiffs(t, deep.Equal(sa, sb), "Alias.Nickname: Bob != Bobby") } func TestMap(t *testing.T) { @@ -540,217 +464,88 @@ func TestMap(t *testing.T) { "foo": 1, "bar": 2, } + + shouldBeEqual(t, deep.Equal(ma, ma)) + mb := map[string]int{ "foo": 1, "bar": 2, } - diff := deep.Equal(ma, mb) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } - diff = deep.Equal(ma, ma) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(ma, mb)) mb["foo"] = 111 - diff = deep.Equal(ma, mb) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "map[foo]: 1 != 111" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(ma, mb), "map[foo]: 1 != 111") delete(mb, "foo") - diff = deep.Equal(ma, mb) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "map[foo]: 1 != " { - t.Error("wrong diff:", diff[0]) - } - - diff = deep.Equal(mb, ma) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "map[foo]: != 1" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(ma, mb), "map[foo]: 1 != ") + shouldBeDiffs(t, deep.Equal(mb, ma), "map[foo]: != 1") var mc map[string]int - diff = deep.Equal(ma, mc) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - // handle hash order randomness - if diff[0] != "map[foo:1 bar:2] != " && diff[0] != "map[bar:2 foo:1] != " { - t.Error("wrong diff:", diff[0]) - } - - diff = deep.Equal(mc, ma) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != " != map[foo:1 bar:2]" && diff[0] != " != map[bar:2 foo:1]" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(mb, mc), "map[bar:2] != ") + shouldBeDiffs(t, deep.Equal(mc, mb), " != map[bar:2]") } func TestArray(t *testing.T) { a := [3]int{1, 2, 3} - b := [3]int{1, 2, 3} - diff := deep.Equal(a, b) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(a, a)) - diff = deep.Equal(a, a) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + b := [3]int{1, 2, 3} + + shouldBeEqual(t, deep.Equal(a, b)) b[2] = 333 - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "array[2]: 3 != 333" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "array[2]: 3 != 333") c := [3]int{1, 2, 2} - diff = deep.Equal(a, c) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "array[2]: 3 != 2" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, c), "array[2]: 3 != 2") var d [2]int - diff = deep.Equal(a, d) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "[3]int != [2]int" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, d), "[3]int != [2]int") e := [12]int{} - f := [12]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - diff = deep.Equal(e, f) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != deep.MaxDiff { - t.Error("not enough diffs:", diff) - } - for i := 0; i < deep.MaxDiff; i++ { - if diff[i] != fmt.Sprintf("array[%d]: 0 != %d", i+1, i+1) { - t.Error("wrong diff:", diff[i]) - } - } + f := [12]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} + + restoreMaxDiff := deep.MaxDiff + t.Cleanup(func() { deep.MaxDiff = restoreMaxDiff }) + + deep.MaxDiff = 10 + + shouldBeDiffs(t, deep.Equal(e, f), + "array[1]: 0 != 1", + "array[2]: 0 != 2", + "array[3]: 0 != 3", + "array[4]: 0 != 4", + "array[5]: 0 != 5", + "array[6]: 0 != 6", + "array[7]: 0 != 7", + "array[8]: 0 != 8", + "array[9]: 0 != 9", + "array[10]: 0 != 10", + ) } func TestSlice(t *testing.T) { a := []int{1, 2, 3} - b := []int{1, 2, 3} - diff := deep.Equal(a, b) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(a, a)) - diff = deep.Equal(a, a) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + b := []int{1, 2, 3} + + shouldBeEqual(t, deep.Equal(a, b)) b[2] = 333 - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[2]: 3 != 333" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "slice[2]: 3 != 333") b = b[0:2] - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[2]: 3 != " { - t.Error("wrong diff:", diff[0]) - } - - diff = deep.Equal(b, a) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[2]: != 3" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "slice[2]: 3 != ") + shouldBeDiffs(t, deep.Equal(b, a), "slice[2]: != 3") var c []int - diff = deep.Equal(a, c) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "[1 2 3] != " { - t.Error("wrong diff:", diff[0]) - } - diff = deep.Equal(c, a) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != " != [1 2 3]" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, c), "[1 2 3] != ") + shouldBeDiffs(t, deep.Equal(c, a), " != [1 2 3]") } func TestSiblingSlices(t *testing.T) { @@ -758,336 +553,130 @@ func TestSiblingSlices(t *testing.T) { a := father[0:3] b := father[0:3] - diff := deep.Equal(a, b) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } - diff = deep.Equal(b, a) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(a, b)) a = father[0:3] b = father[0:2] - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[2]: 3 != " { - t.Error("wrong diff:", diff[0]) - } - - a = father[0:2] - b = father[0:3] - - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[2]: != 3" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "slice[2]: 3 != ") + shouldBeDiffs(t, deep.Equal(b, a), "slice[2]: != 3") a = father[0:2] b = father[2:4] - - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 2 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[0]: 1 != 3" { - t.Error("wrong diff:", diff[0]) - } - if diff[1] != "slice[1]: 2 != 4" { - t.Error("wrong diff:", diff[1]) - } + shouldBeDiffs(t, deep.Equal(a, b), + "slice[0]: 1 != 3", + "slice[1]: 2 != 4", + ) a = father[0:0] b = father[1:1] - diff = deep.Equal(a, b) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } - diff = deep.Equal(b, a) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(a, b)) + shouldBeEqual(t, deep.Equal(b, a)) } func TestEmptySlice(t *testing.T) { a := []int{1} b := []int{} - var c []int // Non-empty is not equal to empty. - diff := deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[0]: 1 != " { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "slice[0]: 1 != ") // Empty is not equal to non-empty. - diff = deep.Equal(b, a) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[0]: != 1" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(b, a), "slice[0]: != 1") + + var c []int // Empty is not equal to nil. - diff = deep.Equal(b, c) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "[] != " { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(b, c), "[] != ") // Nil is not equal to empty. - diff = deep.Equal(c, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != " != []" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(c, b), " != []") } func TestNilSlicesAreEmpty(t *testing.T) { - defaultNilSlicesAreEmpty := deep.NilSlicesAreEmpty + restoreNilSlicesAreEmpty := deep.NilSlicesAreEmpty + t.Cleanup(func() { deep.NilSlicesAreEmpty = restoreNilSlicesAreEmpty }) + deep.NilSlicesAreEmpty = true - defer func() { deep.NilSlicesAreEmpty = defaultNilSlicesAreEmpty }() a := []int{1} b := []int{} + var c []int // Empty is equal to nil. - diff := deep.Equal(b, c) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(b, c)) // Nil is equal to empty. - diff = deep.Equal(c, b) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(c, b)) // Non-empty is not equal to nil. - diff = deep.Equal(a, c) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "[1] != " { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, c), "[1] != ") // Nil is not equal to non-empty. - diff = deep.Equal(c, a) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != " != [1]" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(c, a), " != [1]") // Non-empty is not equal to empty. - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[0]: 1 != " { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "slice[0]: 1 != ") // Empty is not equal to non-empty. - diff = deep.Equal(b, a) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "slice[0]: != 1" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(b, a), "slice[0]: != 1") } func TestNilMapsAreEmpty(t *testing.T) { - defaultNilMapsAreEmpty := deep.NilSlicesAreEmpty + restoreNilMapsAreEmpty := deep.NilMapsAreEmpty + t.Cleanup(func() { deep.NilMapsAreEmpty = restoreNilMapsAreEmpty }) + deep.NilMapsAreEmpty = true - defer func() { deep.NilSlicesAreEmpty = defaultNilMapsAreEmpty }() a := map[int]int{1: 1} - b := map[int]int{} + b := make(map[int]int) var c map[int]int // Empty is equal to nil. - diff := deep.Equal(b, c) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(b, c)) // Nil is equal to empty. - diff = deep.Equal(c, b) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(c, b)) // Non-empty is not equal to nil. - diff = deep.Equal(a, c) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "map[1:1] != " { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, c), "map[1:1] != ") // Nil is not equal to non-empty. - diff = deep.Equal(c, a) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != " != map[1:1]" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(c, a), " != map[1:1]") // Non-empty is not equal to empty. - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "map[1]: 1 != " { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "map[1]: 1 != ") // Empty is not equal to non-empty. - diff = deep.Equal(b, a) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "map[1]: != 1" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(b, a), "map[1]: != 1") } func TestNilInterface(t *testing.T) { type T struct{ i int } a := &T{i: 1} - diff := deep.Equal(nil, a) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != " != &{1}" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(nil, a), " != &{1}") - diff = deep.Equal(a, nil) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "&{1} != " { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, nil), "&{1} != ") - diff = deep.Equal(nil, nil) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(nil, nil)) } func TestPointer(t *testing.T) { type T struct{ i int } a, b := &T{i: 1}, &T{i: 1} - diff := deep.Equal(a, b) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(a, b)) a, b = nil, &T{} - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != " != deep_test.T" { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), " != *deep_test.T") a, b = &T{}, nil - diff = deep.Equal(a, b) - if diff == nil { - t.Fatal("no diff") - } - if len(diff) != 1 { - t.Error("too many diff:", diff) - } - if diff[0] != "deep_test.T != " { - t.Error("wrong diff:", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "*deep_test.T != ") a, b = nil, nil - diff = deep.Equal(a, b) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(a, b)) } func TestTime(t *testing.T) { @@ -1095,21 +684,18 @@ func TestTime(t *testing.T) { type sTime struct { T time.Time } - now := time.Now() - got := sTime{T: now} - expect := sTime{T: now.Add(1 * time.Second)} - diff := deep.Equal(got, expect) - if len(diff) != 1 { - t.Error("expected 1 diff:", diff) - } + + now := time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC) + later := now.Add(1 * time.Second) + + s1 := sTime{T: now} + s2 := sTime{T: later} + shouldBeDiffs(t, deep.Equal(s1, s2), + "T: 2009-11-10 23:00:00 +0000 UTC != 2009-11-10 23:00:01 +0000 UTC", + ) // Directly - a := now - b := now - diff = deep.Equal(a, b) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(now, now)) // https://github.com/go-test/deep/issues/15 type Time15 struct { @@ -1117,17 +703,12 @@ func TestTime(t *testing.T) { } a15 := Time15{now} b15 := Time15{now} - diff = deep.Equal(a15, b15) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + shouldBeEqual(t, deep.Equal(a15, b15)) - later := now.Add(1 * time.Second) b15 = Time15{later} - diff = deep.Equal(a15, b15) - if len(diff) != 1 { - t.Errorf("got %d diffs, expected 1: %s", len(diff), diff) - } + shouldBeDiffs(t, deep.Equal(a15, b15), + "Time: 2009-11-10 23:00:00 +0000 UTC != 2009-11-10 23:00:01 +0000 UTC", + ) // No diff in Equal should not affect diff of other fields (Foo) type Time17 struct { @@ -1136,47 +717,52 @@ func TestTime(t *testing.T) { } a17 := Time17{Time: now, Foo: 1} b17 := Time17{Time: now, Foo: 2} - diff = deep.Equal(a17, b17) - if len(diff) != 1 { - t.Errorf("got %d diffs, expected 1: %s", len(diff), diff) - } + shouldBeDiffs(t, deep.Equal(a17, b17), "Foo: 1 != 2") } func TestTimeUnexported(t *testing.T) { - // https://github.com/go-test/deep/issues/18 - // Can't call Call() on exported Value func - defaultCompareUnexportedFields := deep.CompareUnexportedFields + restoreCompareUnexportedFields := deep.CompareUnexportedFields + t.Cleanup(func() { deep.CompareUnexportedFields = restoreCompareUnexportedFields }) + deep.CompareUnexportedFields = true - defer func() { deep.CompareUnexportedFields = defaultCompareUnexportedFields }() - now := time.Now() + now := time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC) + later := now.Add(1 * time.Second) + + // https://github.com/go-test/deep/issues/18 + // Can't call Call() on unexported Value func + type hiddenTime struct { t time.Time } htA := &hiddenTime{t: now} htB := &hiddenTime{t: now} - diff := deep.Equal(htA, htB) - if len(diff) > 0 { - t.Error("should be equal:", diff) - } + + shouldBeEqual(t, deep.Equal(htA, htB)) + + htB.t = later + shouldBeDiffs(t, deep.Equal(htA, htB), "t.ext: 63393490800 != 63393490801") // This doesn't call time.Time.Equal(), it compares the unexported fields // in time.Time, causing a diff like: // [t.wall: 13740788835924462040 != 13740788836998203864 t.ext: 1447549 != 1001447549] - later := now.Add(1 * time.Second) - htC := &hiddenTime{t: later} - diff = deep.Equal(htA, htC) + htA.t = time.Now() + htB.t = htA.t.Add(1 * time.Second) + diff := deep.Equal(htA, htB) - expected := 1 - if _, ok := reflect.TypeOf(htA.t).FieldByName("ext"); ok { - expected = 2 - } + expected := reflect.TypeOf(htA.t).NumField() - 1 // loc *Location will always be the same. if len(diff) != expected { t.Errorf("got %d diffs, expected %d: %s", len(diff), expected, diff) } } func TestInterface(t *testing.T) { + defer func() { + if val := recover(); val != nil { + t.Fatal("panic:", val) + } + }() + a := map[string]interface{}{ "foo": map[string]string{ "bar": "a", @@ -1187,90 +773,40 @@ func TestInterface(t *testing.T) { "bar": "b", }, } - diff := deep.Equal(a, b) - if len(diff) == 0 { - t.Fatalf("expected 1 diff, got zero") - } - if len(diff) != 1 { - t.Errorf("expected 1 diff, got %d: %s", len(diff), diff) - } -} + shouldBeDiffs(t, deep.Equal(a, b), "map[foo].map[bar]: a != b") -func TestInterface2(t *testing.T) { - defer func() { - if val := recover(); val != nil { - t.Fatalf("panic: %v", val) - } - }() + a["foo"] = 1 + b["foo"] = 1.23 - a := map[string]interface{}{ - "bar": 1, - } - b := map[string]interface{}{ - "bar": 1.23, - } - diff := deep.Equal(a, b) - if len(diff) == 0 { - t.Fatalf("expected 1 diff, got zero") - } - if len(diff) != 1 { - t.Errorf("expected 1 diff, got %d: %s", len(diff), diff) - } -} + shouldBeDiffs(t, deep.Equal(a, b), "map[foo]: int != float64") -func TestInterface3(t *testing.T) { type Value struct{ int } - a := map[string]interface{}{ - "foo": &Value{}, - } - b := map[string]interface{}{ - "foo": 1.23, - } - diff := deep.Equal(a, b) - if len(diff) == 0 { - t.Fatalf("expected 1 diff, got zero") - } + a["foo"] = &Value{} - if len(diff) != 1 { - t.Errorf("expected 1 diff, got: %s", diff) - } + shouldBeDiffs(t, deep.Equal(a, b), "map[foo]: *deep_test.Value != float64") } func TestError(t *testing.T) { a := errors.New("it broke") b := errors.New("it broke") - diff := deep.Equal(a, b) - if len(diff) != 0 { - t.Fatalf("expected zero diffs, got %d: %s", len(diff), diff) - } + shouldBeEqual(t, deep.Equal(a, b)) b = errors.New("it fell apart") - diff = deep.Equal(a, b) - if len(diff) != 1 { - t.Fatalf("expected 1 diff, got %d: %s", len(diff), diff) - } - if diff[0] != "it broke != it fell apart" { - t.Errorf("got '%s', expected 'it broke != it fell apart'", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "it broke != it fell apart") // Both errors set type tWithError struct { Error error } + t1 := tWithError{ Error: a, } t2 := tWithError{ Error: b, } - diff = deep.Equal(t1, t2) - if len(diff) != 1 { - t.Fatalf("expected 1 diff, got %d: %s", len(diff), diff) - } - if diff[0] != "Error: it broke != it fell apart" { - t.Errorf("got '%s', expected 'Error: it broke != it fell apart'", diff[0]) - } + shouldBeDiffs(t, deep.Equal(t1, t2), "Error: it broke != it fell apart") // Both errors nil t1 = tWithError{ @@ -1279,11 +815,7 @@ func TestError(t *testing.T) { t2 = tWithError{ Error: nil, } - diff = deep.Equal(t1, t2) - if len(diff) != 0 { - t.Log(diff) - t.Fatalf("expected 0 diff, got %d: %s", len(diff), diff) - } + shouldBeEqual(t, deep.Equal(t1, t2)) // One error is nil t1 = tWithError{ @@ -1292,39 +824,24 @@ func TestError(t *testing.T) { t2 = tWithError{ Error: nil, } - diff = deep.Equal(t1, t2) - if len(diff) != 1 { - t.Log(diff) - t.Fatalf("expected 1 diff, got %d: %s", len(diff), diff) - } - if diff[0] != "Error: *errors.errorString != " { - t.Errorf("got '%s', expected 'Error: *errors.errorString != '", diff[0]) - } + shouldBeDiffs(t, deep.Equal(t1, t2), "Error: *errors.errorString != ") } func TestErrorWithOtherFields(t *testing.T) { a := errors.New("it broke") b := errors.New("it broke") - diff := deep.Equal(a, b) - if len(diff) != 0 { - t.Fatalf("expected zero diffs, got %d: %s", len(diff), diff) - } + shouldBeEqual(t, deep.Equal(a, b)) b = errors.New("it fell apart") - diff = deep.Equal(a, b) - if len(diff) != 1 { - t.Fatalf("expected 1 diff, got %d: %s", len(diff), diff) - } - if diff[0] != "it broke != it fell apart" { - t.Errorf("got '%s', expected 'it broke != it fell apart'", diff[0]) - } + shouldBeDiffs(t, deep.Equal(a, b), "it broke != it fell apart") // Both errors set type tWithError struct { Error error Other string } + t1 := tWithError{ Error: a, Other: "ok", @@ -1333,13 +850,7 @@ func TestErrorWithOtherFields(t *testing.T) { Error: b, Other: "ok", } - diff = deep.Equal(t1, t2) - if len(diff) != 1 { - t.Fatalf("expected 1 diff, got %d: %s", len(diff), diff) - } - if diff[0] != "Error: it broke != it fell apart" { - t.Errorf("got '%s', expected 'Error: it broke != it fell apart'", diff[0]) - } + shouldBeDiffs(t, deep.Equal(t1, t2), "Error: it broke != it fell apart") // Both errors nil t1 = tWithError{ @@ -1350,11 +861,7 @@ func TestErrorWithOtherFields(t *testing.T) { Error: nil, Other: "ok", } - diff = deep.Equal(t1, t2) - if len(diff) != 0 { - t.Log(diff) - t.Fatalf("expected 0 diff, got %d: %s", len(diff), diff) - } + shouldBeEqual(t, deep.Equal(t1, t2)) // Different Other value t1 = tWithError{ @@ -1365,13 +872,7 @@ func TestErrorWithOtherFields(t *testing.T) { Error: nil, Other: "nope", } - diff = deep.Equal(t1, t2) - if len(diff) != 1 { - t.Fatalf("expected 1 diff, got %d: %s", len(diff), diff) - } - if diff[0] != "Other: ok != nope" { - t.Errorf("got '%s', expected 'Other: ok != nope'", diff[0]) - } + shouldBeDiffs(t, deep.Equal(t1, t2), "Other: ok != nope") // Different Other value, same error t1 = tWithError{ @@ -1382,13 +883,7 @@ func TestErrorWithOtherFields(t *testing.T) { Error: a, Other: "nope", } - diff = deep.Equal(t1, t2) - if len(diff) != 1 { - t.Fatalf("expected 1 diff, got %d: %s", len(diff), diff) - } - if diff[0] != "Other: ok != nope" { - t.Errorf("got '%s', expected 'Other: ok != nope'", diff[0]) - } + shouldBeDiffs(t, deep.Equal(t1, t2), "Other: ok != nope") } type primKindError string @@ -1403,12 +898,9 @@ func TestErrorPrimitiveKind(t *testing.T) { // (https://github.com/go-test/deep/issues/31), we presumed a and b // were ptr or interface (and not nil), so a.Elem() worked. But when // a/b are primitive kinds, Elem() causes a panic. - var err1 primKindError = "abc" - var err2 primKindError = "abc" - diff := deep.Equal(err1, err2) - if len(diff) != 0 { - t.Fatalf("expected zero diffs, got %d: %s", len(diff), diff) - } + a := primKindError("abc") + b := primKindError("abc") + shouldBeEqual(t, deep.Equal(a, b)) } func TestNil(t *testing.T) { @@ -1419,16 +911,171 @@ func TestNil(t *testing.T) { mark := student{"mark", 10} var someNilThing interface{} = nil - diff := deep.Equal(someNilThing, mark) - if diff == nil { - t.Error("Nil value to comparison should not be equal") + shouldBeDiffs(t, deep.Equal(someNilThing, mark), " != {mark 10}") + + shouldBeDiffs(t, deep.Equal(mark, someNilThing), "{mark 10} != ") + + shouldBeEqual(t, deep.Equal(someNilThing, someNilThing)) +} + +type equalReturnsNothing int + +func (equalReturnsNothing) Equal(_ equalReturnsWrongType) {} + +func TestEqualReturnsNothing(t *testing.T) { + a := equalReturnsNothing(13) + b := equalReturnsNothing(42) + shouldBeDiffs(t, deep.Equal(a, b), "13 != 42") +} + +type equalReturnsWrongType int + +func (equalReturnsWrongType) Equal(_ equalReturnsWrongType) int { + return 1 +} + +func TestEqualReturnsWrongType(t *testing.T) { + a := equalReturnsWrongType(13) + b := equalReturnsWrongType(42) + shouldBeDiffs(t, deep.Equal(a, b), "13 != 42") +} + +type boolKind bool + +type equalReturnsBoolKind int + +func (equalReturnsBoolKind) Equal(_ equalReturnsBoolKind) boolKind { + return true +} + +func TestEqualReturnsBoolKind(t *testing.T) { + a := equalReturnsBoolKind(13) + b := equalReturnsBoolKind(42) + shouldBeEqual(t, deep.Equal(a, b)) // Equal should have overriden the comparison. +} + +type ring struct { + Prev, Next *ring +} + +func newRing() *ring { + r := new(ring) + r.Prev = r + r.Next = r + return r +} + +func TestRingList(t *testing.T) { + a := newRing() + b := newRing() + shouldBeEqual(t, deep.Equal(a, b)) +} + +type oroborous struct { + Any interface{} +} + +func newOroborous() *oroborous { + o := new(oroborous) + o.Any = o + return o +} + +func TestOroborous(t *testing.T) { + a := newOroborous() + b := newOroborous() + shouldBeEqual(t, deep.Equal(a, b)) +} + +func TestUnexportedErrorField(t *testing.T) { + restoreCompareUnexportedFields := deep.CompareUnexportedFields + t.Cleanup(func() { deep.CompareUnexportedFields = restoreCompareUnexportedFields }) + + deep.CompareUnexportedFields = true + + type S struct { + err error } - diff = deep.Equal(mark, someNilThing) - if diff == nil { - t.Error("Nil value to comparison should not be equal") + + a := &S{err: errors.New("it broke")} + b := &S{err: errors.New("it broke")} + shouldBeEqual(t, deep.Equal(a, b)) +} + +func TestIPAddresses(t *testing.T) { + a := net.ParseIP("1.2.3.4") + b := net.ParseIP("1.2.3.4") + shouldBeEqual(t, deep.Equal(a, b)) +} + +type hasStringer int + +func (hs hasStringer) String() string { + switch hs { + case 0: + return "VALUE_ZERO" + case 1: + return "VALUE_ONE" } - diff = deep.Equal(someNilThing, someNilThing) - if diff != nil { - t.Error("Nil value to comparison should not be equal") + + return fmt.Sprintf("VALUE(%d)", hs) +} + +func TestHasStringer(t *testing.T) { + a := hasStringer(0) + b := hasStringer(1) + shouldBeDiffs(t, deep.Equal(a, b), "VALUE_ZERO != VALUE_ONE") +} + +func TestTimePrecision(t *testing.T) { + restoreTimePrecision := deep.TimePrecision + t.Cleanup(func() { deep.TimePrecision = restoreTimePrecision }) + + deep.TimePrecision = 1 * time.Microsecond + + now := time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC) + later := now.Add(123 * time.Nanosecond) + + shouldBeEqual(t, deep.Equal(now, later)) + + d1 := 1 * time.Microsecond + d2 := d1 + 123*time.Nanosecond + + shouldBeEqual(t, deep.Equal(d1, d2)) + + restoreCompareUnexportedFields := deep.CompareUnexportedFields + t.Cleanup(func() { deep.CompareUnexportedFields = restoreCompareUnexportedFields }) + + deep.CompareUnexportedFields = true + + type S struct { + t time.Time + d time.Duration } + + s1 := &S{t: now, d: d1} + s2 := &S{t: later, d: d2} + + // Since we cannot call `Truncate` on the unexported fields, + // we will show differences here. + shouldBeDiffs(t, deep.Equal(s1, s2), + "t.wall: 0 != 123", + "d: 1000 != 1123", + ) +} + +func TestCompareFuncs(t *testing.T) { + restoreCompareFunctions := deep.CompareFunctions + t.Cleanup(func() { deep.CompareFunctions = restoreCompareFunctions }) + + deep.CompareFunctions = true + + var f1, f2 func() + + shouldBeEqual(t, deep.Equal(f1, f2)) + + f2 = func() {} + shouldBeDiffs(t, deep.Equal(f1, f2), " != ") + shouldBeDiffs(t, deep.Equal(f2, f1), " != ") + shouldBeDiffs(t, deep.Equal(f2, f2), " != ") }