diff --git a/sdk/go/pulumi/types.go b/sdk/go/pulumi/types.go index 913c164daabd..2b8ce0a10234 100644 --- a/sdk/go/pulumi/types.go +++ b/sdk/go/pulumi/types.go @@ -1,4 +1,4 @@ -// Copyright 2016-2020, Pulumi Corporation. +// Copyright 2016-2022, Pulumi Corporation. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -1248,3 +1248,64 @@ func init() { RegisterOutputType(ResourceOutput{}) RegisterOutputType(ResourceArrayOutput{}) } + +// coerceTypeConversion assigns src to dst, performing deep type coercion as necessary. +func coerceTypeConversion(src interface{}, dst reflect.Type) (interface{}, error) { + makeError := func(src, dst reflect.Value) error { + return fmt.Errorf("expected value of type %s, not %s", dst.Type(), src.Type()) + } + var coerce func(reflect.Value, reflect.Value) error + coerce = func(src, dst reflect.Value) error { + if src.Type().Kind() == reflect.Interface && !src.IsNil() { + src = src.Elem() + } + if src.Type().AssignableTo(dst.Type()) { + dst.Set(src) + return nil + } + switch dst.Type().Kind() { + case reflect.Map: + if src.Kind() != reflect.Map { + return makeError(src, dst) + } + + dst.Set(reflect.MakeMapWithSize(dst.Type(), src.Len())) + + for iter := src.MapRange(); iter.Next(); { + dstKey := reflect.New(dst.Type().Key()).Elem() + dstVal := reflect.New(dst.Type().Elem()).Elem() + if err := coerce(iter.Key(), dstKey); err != nil { + return fmt.Errorf("invalid key: %w", err) + } + if err := coerce(iter.Value(), dstVal); err != nil { + return fmt.Errorf("[%#v]: %w", dstKey.Interface(), err) + } + dst.SetMapIndex(dstKey, dstVal) + } + + return nil + case reflect.Slice: + if src.Kind() != reflect.Slice { + return makeError(src, dst) + } + dst.Set(reflect.MakeSlice(dst.Type(), src.Len(), src.Cap())) + for i := 0; i < src.Len(); i++ { + dstVal := reflect.New(dst.Type().Elem()).Elem() + if err := coerce(src.Index(i), dstVal); err != nil { + return fmt.Errorf("[%d]: %w", i, err) + } + dst.Index(i).Set(dstVal) + } + return nil + default: + return makeError(src, dst) + } + } + + srcV, dstV := reflect.ValueOf(src), reflect.New(dst).Elem() + + if err := coerce(srcV, dstV); err != nil { + return nil, err + } + return dstV.Interface(), nil +} diff --git a/sdk/go/pulumi/types_test.go b/sdk/go/pulumi/types_test.go index 68db440891f9..89f4491f3b34 100644 --- a/sdk/go/pulumi/types_test.go +++ b/sdk/go/pulumi/types_test.go @@ -1,4 +1,4 @@ -// Copyright 2016-2018, Pulumi Corporation. +// Copyright 2016-2022, Pulumi Corporation. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func await(out Output) (interface{}, bool, bool, []Resource, error) { @@ -988,3 +989,99 @@ func TestApplyTOutputJoin(t *testing.T) { assertResult(t, out31, 2, true, true, r3, r1) assertResult(t, out312, nil, false, true, r3, r1, r2) /* out2 is unknown, hiding the output */ } + +func TestTypeCoersion(t *testing.T) { + + tests := []struct { + input interface{} + expected interface{} + err string + }{ + {"foo", "foo", ""}, + {"foo", 0, "expected value of type int, not string"}, + { + map[string]interface{}{ + "foo": "bar", + "fizz": "buzz", + }, + map[string]string{ + "foo": "bar", + "fizz": "buzz", + }, + "", + }, + { + map[string]interface{}{ + "foo": "bar", + "fizz": 8, + }, + map[string]string{ + "foo": "bar", + "fizz": "buzz", + }, + `["fizz"]: expected value of type string, not int`, + }, + { + []interface{}{1, 2, 3}, + []int{1, 2, 3}, + "", + }, + { + []interface{}{1, "two", 3}, + []int{1, 2, 3}, + `[1]: expected value of type int, not string`, + }, + { + []interface{}{ + map[string]interface{}{ + "fizz": []interface{}{3, 15}, + "buzz": []interface{}{5, 15}, + "fizzbuzz": []interface{}{15}, + }, + map[string]interface{}{}, + }, + []map[string][]int{ + { + "fizz": {3, 15}, + "buzz": {5, 15}, + "fizzbuzz": {15}, + }, + {}, + }, + "", + }, + { + []interface{}{ + map[string]interface{}{ + "fizz": []interface{}{3, 15}, + "buzz": []interface{}{"5", 15}, + "fizzbuzz": []interface{}{15}, + }, + map[string]interface{}{}, + }, + []map[string][]int{ + { + "fizz": {3, 15}, + "buzz": {5, 15}, + "fizzbuzz": {15}, + }, + {}, + }, + `[0]: ["buzz"]: [0]: expected value of type int, not string`, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("%v->%v", tt.input, tt.expected), func(t *testing.T) { + dstT := reflect.TypeOf(tt.expected) + val, err := coerceTypeConversion(tt.input, dstT) + if tt.err == "" { + require.NoError(t, err) + assert.Equal(t, tt.expected, val) + } else { + assert.EqualError(t, err, tt.err) + } + }) + } +}