Skip to content

Commit

Permalink
fix: assert.MapSubset (or just support maps in assert.Subset) (#1178)
Browse files Browse the repository at this point in the history
* WIP: added map key value check in subset

* upgraded subset & notsubset to check handle maps
  • Loading branch information
OladapoAjala committed Jun 28, 2022
1 parent 2fab6df commit 66eef0e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
46 changes: 40 additions & 6 deletions assert/assertions.go
Expand Up @@ -816,7 +816,6 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true // we consider nil to be equal to the nil set
}

subsetValue := reflect.ValueOf(subset)
defer func() {
if e := recover(); e != nil {
ok = false
Expand All @@ -826,14 +825,32 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind()

if listKind != reflect.Array && listKind != reflect.Slice {
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
}

if subsetKind != reflect.Array && subsetKind != reflect.Slice {
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}

subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()

for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()

if !ObjectsAreEqual(subsetElement, listElement) {
return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...)
}
}

return true
}

for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := containsElement(list, element)
Expand All @@ -860,7 +877,6 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
}

subsetValue := reflect.ValueOf(subset)
defer func() {
if e := recover(); e != nil {
ok = false
Expand All @@ -870,14 +886,32 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind()

if listKind != reflect.Array && listKind != reflect.Slice {
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
}

if subsetKind != reflect.Array && subsetKind != reflect.Slice {
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}

subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()

for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()

if !ObjectsAreEqual(subsetElement, listElement) {
return true
}
}

return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...)
}

for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := containsElement(list, element)
Expand Down
16 changes: 16 additions & 0 deletions assert/assertions_test.go
Expand Up @@ -685,11 +685,27 @@ func TestSubsetNotSubset(t *testing.T) {
{[]int{1, 2, 3}, []int{1, 2}, true, "[1, 2, 3] contains [1, 2]"},
{[]int{1, 2, 3}, []int{1, 2, 3}, true, "[1, 2, 3] contains [1, 2, 3"},
{[]string{"hello", "world"}, []string{"hello"}, true, "[\"hello\", \"world\"] contains [\"hello\"]"},
{map[string]string{
"a": "x",
"c": "z",
"b": "y",
}, map[string]string{
"a": "x",
"b": "y",
}, true, `{ "a": "x", "b": "y", "c": "z"} contains { "a": "x", "b": "y"}`},

// cases that are expected not to contain
{[]string{"hello", "world"}, []string{"hello", "testify"}, false, "[\"hello\", \"world\"] does not contain [\"hello\", \"testify\"]"},
{[]int{1, 2, 3}, []int{4, 5}, false, "[1, 2, 3] does not contain [4, 5"},
{[]int{1, 2, 3}, []int{1, 5}, false, "[1, 2, 3] does not contain [1, 5]"},
{map[string]string{
"a": "x",
"c": "z",
"b": "y",
}, map[string]string{
"a": "x",
"b": "z",
}, false, `{ "a": "x", "b": "y", "c": "z"} does not contain { "a": "x", "b": "z"}`},
}

for _, c := range cases {
Expand Down

0 comments on commit 66eef0e

Please sign in to comment.