From 19223bd8e68685cfeec1a87ff75655ba8c6a50fb Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Mon, 13 Jun 2022 17:33:41 +0800 Subject: [PATCH] support comparing byte slice Signed-off-by: Ryan Leung --- assert/assertion_compare.go | 24 +++++++++++++++++++++++- assert/assertion_compare_test.go | 7 +++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/assert/assertion_compare.go b/assert/assertion_compare.go index 3bb22a971..79d87418b 100644 --- a/assert/assertion_compare.go +++ b/assert/assertion_compare.go @@ -1,6 +1,7 @@ package assert import ( + "bytes" "fmt" "reflect" "time" @@ -32,7 +33,8 @@ var ( stringType = reflect.TypeOf("") - timeType = reflect.TypeOf(time.Time{}) + timeType = reflect.TypeOf(time.Time{}) + bytesType = reflect.TypeOf([]byte{}) ) func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { @@ -323,6 +325,26 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64) } + case reflect.Slice: + { + // We only care about the []byte type. + if !canConvert(obj1Value, bytesType) { + break + } + + // []byte can be compared! + byteObj1 := obj1Value.Bytes() + byteObj2 := obj2Value.Bytes() + if bytes.Compare(byteObj1, byteObj2) > 0 { + return compareGreater, true + } + if bytes.Equal(byteObj1, byteObj2) { + return compareEqual, true + } + if bytes.Compare(byteObj1, byteObj2) < 0 { + return compareLess, true + } + } } return compareEqual, false diff --git a/assert/assertion_compare_test.go b/assert/assertion_compare_test.go index a38d88060..2e8701483 100644 --- a/assert/assertion_compare_test.go +++ b/assert/assertion_compare_test.go @@ -22,6 +22,7 @@ func TestCompare(t *testing.T) { type customFloat32 float32 type customFloat64 float64 type customString string + type customBytes []byte for _, currCase := range []struct { less interface{} greater interface{} @@ -52,6 +53,8 @@ func TestCompare(t *testing.T) { {less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"}, {less: float64(1.23), greater: float64(2.34), cType: "float64"}, {less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"}, + {less: []byte{1, 1}, greater: []byte{1, 2}, cType: "[]byte"}, + {less: customBytes([]byte{1, 1}), greater: customBytes([]byte{1, 2}), cType: "[]byte"}, } { resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind()) if !isComparable { @@ -148,6 +151,7 @@ func TestGreater(t *testing.T) { {less: uint64(1), greater: uint64(2), msg: `"1" is not greater than "2"`}, {less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than "2.34"`}, {less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than "2.34"`}, + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than "[1 2]"`}, } { out := &outputT{buf: bytes.NewBuffer(nil)} False(t, Greater(out, currCase.less, currCase.greater)) @@ -189,6 +193,7 @@ func TestGreaterOrEqual(t *testing.T) { {less: uint64(1), greater: uint64(2), msg: `"1" is not greater than or equal to "2"`}, {less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than or equal to "2.34"`}, {less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than or equal to "2.34"`}, + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than or equal to "[1 2]"`}, } { out := &outputT{buf: bytes.NewBuffer(nil)} False(t, GreaterOrEqual(out, currCase.less, currCase.greater)) @@ -230,6 +235,7 @@ func TestLess(t *testing.T) { {less: uint64(1), greater: uint64(2), msg: `"2" is not less than "1"`}, {less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than "1.23"`}, {less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than "1.23"`}, + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than "[1 1]"`}, } { out := &outputT{buf: bytes.NewBuffer(nil)} False(t, Less(out, currCase.greater, currCase.less)) @@ -271,6 +277,7 @@ func TestLessOrEqual(t *testing.T) { {less: uint64(1), greater: uint64(2), msg: `"2" is not less than or equal to "1"`}, {less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than or equal to "1.23"`}, {less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than or equal to "1.23"`}, + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than or equal to "[1 1]"`}, } { out := &outputT{buf: bytes.NewBuffer(nil)} False(t, LessOrEqual(out, currCase.greater, currCase.less))