From c31ea0312f8a96ca55801db5ebdf62119800fb70 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 21 Jun 2022 12:54:58 +0800 Subject: [PATCH] Support comparing byte slice (#1202) * support comparing byte slice Signed-off-by: Ryan Leung * address the comment Signed-off-by: Ryan Leung --- assert/assertion_compare.go | 24 ++++- assert/assertion_compare_go1.17_test.go | 128 ++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 1 deletion(-) diff --git a/assert/assertion_compare.go b/assert/assertion_compare.go index 3bb22a971..95d8e59da 100644 --- a/assert/assertion_compare.go +++ b/assert/assertion_compare.go @@ -1,6 +1,7 @@ package assert import ( + "bytes" "fmt" "reflect" "time" @@ -32,7 +33,8 @@ var ( stringType = reflect.TypeOf("") - timeType = reflect.TypeOf(time.Time{}) + timeType = reflect.TypeOf(time.Time{}) + bytesType = reflect.TypeOf([]byte{}) ) func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { @@ -323,6 +325,26 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64) } + case reflect.Slice: + { + // We only care about the []byte type. + if !canConvert(obj1Value, bytesType) { + break + } + + // []byte can be compared! + bytesObj1, ok := obj1.([]byte) + if !ok { + bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte) + + } + bytesObj2, ok := obj2.([]byte) + if !ok { + bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte) + } + + return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true + } } return compareEqual, false diff --git a/assert/assertion_compare_go1.17_test.go b/assert/assertion_compare_go1.17_test.go index d3ea935fa..53e01ed46 100644 --- a/assert/assertion_compare_go1.17_test.go +++ b/assert/assertion_compare_go1.17_test.go @@ -8,6 +8,7 @@ package assert import ( + "bytes" "reflect" "testing" "time" @@ -15,6 +16,7 @@ import ( func TestCompare17(t *testing.T) { type customTime time.Time + type customBytes []byte for _, currCase := range []struct { less interface{} greater interface{} @@ -22,6 +24,8 @@ func TestCompare17(t *testing.T) { }{ {less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"}, {less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"}, + {less: []byte{1, 1}, greater: []byte{1, 2}, cType: "[]byte"}, + {less: customBytes([]byte{1, 1}), greater: customBytes([]byte{1, 2}), cType: "[]byte"}, } { resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind()) if !isComparable { @@ -52,3 +56,127 @@ func TestCompare17(t *testing.T) { } } } + +func TestGreater17(t *testing.T) { + mockT := new(testing.T) + + if !Greater(mockT, 2, 1) { + t.Error("Greater should return true") + } + + if Greater(mockT, 1, 1) { + t.Error("Greater should return false") + } + + if Greater(mockT, 1, 2) { + t.Error("Greater should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than "[1 2]"`}, + {less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than "0001-01-01 01:00:00 +0000 UTC"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, Greater(out, currCase.less, currCase.greater)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "github.com/stretchr/testify/assert.Greater") + } +} + +func TestGreaterOrEqual17(t *testing.T) { + mockT := new(testing.T) + + if !GreaterOrEqual(mockT, 2, 1) { + t.Error("GreaterOrEqual should return true") + } + + if !GreaterOrEqual(mockT, 1, 1) { + t.Error("GreaterOrEqual should return true") + } + + if GreaterOrEqual(mockT, 1, 2) { + t.Error("GreaterOrEqual should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than or equal to "[1 2]"`}, + {less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than or equal to "0001-01-01 01:00:00 +0000 UTC"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, GreaterOrEqual(out, currCase.less, currCase.greater)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "github.com/stretchr/testify/assert.GreaterOrEqual") + } +} + +func TestLess17(t *testing.T) { + mockT := new(testing.T) + + if !Less(mockT, 1, 2) { + t.Error("Less should return true") + } + + if Less(mockT, 1, 1) { + t.Error("Less should return false") + } + + if Less(mockT, 2, 1) { + t.Error("Less should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than "[1 1]"`}, + {less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than "0001-01-01 00:00:00 +0000 UTC"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, Less(out, currCase.greater, currCase.less)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "github.com/stretchr/testify/assert.Less") + } +} + +func TestLessOrEqual17(t *testing.T) { + mockT := new(testing.T) + + if !LessOrEqual(mockT, 1, 2) { + t.Error("LessOrEqual should return true") + } + + if !LessOrEqual(mockT, 1, 1) { + t.Error("LessOrEqual should return true") + } + + if LessOrEqual(mockT, 2, 1) { + t.Error("LessOrEqual should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than or equal to "[1 1]"`}, + {less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than or equal to "0001-01-01 00:00:00 +0000 UTC"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, LessOrEqual(out, currCase.greater, currCase.less)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "github.com/stretchr/testify/assert.LessOrEqual") + } +}