diff --git a/assert/assertions.go b/assert/assertions.go index 0357b2231..5f194c0b4 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -815,7 +815,6 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok return true // we consider nil to be equal to the nil set } - subsetValue := reflect.ValueOf(subset) defer func() { if e := recover(); e != nil { ok = false @@ -825,14 +824,32 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok listKind := reflect.TypeOf(list).Kind() subsetKind := reflect.TypeOf(subset).Kind() - if listKind != reflect.Array && listKind != reflect.Slice { + if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) } - if subsetKind != reflect.Array && subsetKind != reflect.Slice { + if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) } + subsetValue := reflect.ValueOf(subset) + if subsetKind == reflect.Map && listKind == reflect.Map { + listValue := reflect.ValueOf(list) + subsetKeys := subsetValue.MapKeys() + + for i := 0; i < len(subsetKeys); i++ { + subsetKey := subsetKeys[i] + subsetElement := subsetValue.MapIndex(subsetKey).Interface() + listElement := listValue.MapIndex(subsetKey).Interface() + + if !ObjectsAreEqual(subsetElement, listElement) { + return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...) + } + } + + return true + } + for i := 0; i < subsetValue.Len(); i++ { element := subsetValue.Index(i).Interface() ok, found := containsElement(list, element) @@ -859,7 +876,6 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...) } - subsetValue := reflect.ValueOf(subset) defer func() { if e := recover(); e != nil { ok = false @@ -869,14 +885,32 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) listKind := reflect.TypeOf(list).Kind() subsetKind := reflect.TypeOf(subset).Kind() - if listKind != reflect.Array && listKind != reflect.Slice { + if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) } - if subsetKind != reflect.Array && subsetKind != reflect.Slice { + if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) } + subsetValue := reflect.ValueOf(subset) + if subsetKind == reflect.Map && listKind == reflect.Map { + listValue := reflect.ValueOf(list) + subsetKeys := subsetValue.MapKeys() + + for i := 0; i < len(subsetKeys); i++ { + subsetKey := subsetKeys[i] + subsetElement := subsetValue.MapIndex(subsetKey).Interface() + listElement := listValue.MapIndex(subsetKey).Interface() + + if !ObjectsAreEqual(subsetElement, listElement) { + return true + } + } + + return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...) + } + for i := 0; i < subsetValue.Len(); i++ { element := subsetValue.Index(i).Interface() ok, found := containsElement(list, element) diff --git a/assert/assertions_test.go b/assert/assertions_test.go index 0f5de92ff..d28c30050 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -685,11 +685,27 @@ func TestSubsetNotSubset(t *testing.T) { {[]int{1, 2, 3}, []int{1, 2}, true, "[1, 2, 3] contains [1, 2]"}, {[]int{1, 2, 3}, []int{1, 2, 3}, true, "[1, 2, 3] contains [1, 2, 3"}, {[]string{"hello", "world"}, []string{"hello"}, true, "[\"hello\", \"world\"] contains [\"hello\"]"}, + {map[string]string{ + "a": "x", + "c": "z", + "b": "y", + }, map[string]string{ + "a": "x", + "b": "y", + }, true, `{ "a": "x", "b": "y", "c": "z"} contains { "a": "x", "b": "y"}`}, // cases that are expected not to contain {[]string{"hello", "world"}, []string{"hello", "testify"}, false, "[\"hello\", \"world\"] does not contain [\"hello\", \"testify\"]"}, {[]int{1, 2, 3}, []int{4, 5}, false, "[1, 2, 3] does not contain [4, 5"}, {[]int{1, 2, 3}, []int{1, 5}, false, "[1, 2, 3] does not contain [1, 5]"}, + {map[string]string{ + "a": "x", + "c": "z", + "b": "y", + }, map[string]string{ + "a": "x", + "b": "z", + }, false, `{ "a": "x", "b": "y", "c": "z"} does not contain { "a": "x", "b": "z"}`}, } for _, c := range cases {