Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sane conversions for As*Map* and As*Array* conversions #11351

Merged
merged 1 commit into from Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,4 @@
changes:
- type: feat
scope: sdk/go
description: Allow sane conversions for `As*Map*` and `As*Array*` conversions.
4 changes: 1 addition & 3 deletions sdk/go/pulumi/generate/main.go
Expand Up @@ -213,6 +213,7 @@ func makeBuiltins(primitives []*builtin) []*builtin {
// Augment primitives with array and map types.
var builtins []*builtin
for _, p := range primitives {
p.Strategy = "primitive"
name := ""
if p.Name != "Input" {
builtins = append(builtins, p)
Expand Down Expand Up @@ -242,9 +243,6 @@ func makeBuiltins(primitives []*builtin) []*builtin {

InnerElementType: p.Type,
}
if p.Type != "interface{}" {
arrType.Strategy = "array-contravariance"
}
builtins = append(builtins, arrType)
mapType := &builtin{
Name: name + "Map",
Expand Down
37 changes: 13 additions & 24 deletions sdk/go/pulumi/generate/templates/types_builtins.go.template
Expand Up @@ -19,7 +19,6 @@ package pulumi

import (
"context"
"fmt"
"reflect"
)

Expand Down Expand Up @@ -207,36 +206,26 @@ func getResolvedValue(input Input) (reflect.Value, bool) {
}

{{range .Builtins}}
{{ if eq .Strategy "array-contravariance" }}
{{ if eq .Strategy "primitive" }}
// As{{.Name}}Output asserts that the type of the AnyOutput's underlying interface{} value is
// {{.ElementType}} or []interface{} and returns a `{{.Name}}Output` with that value.
// As{{.Name}}Output panics if the value was not the expected type or a compatible array.
// {{.ElementType}} and returns a `{{.Name}}Output` with that value. As{{.Name}}Output panics if the value
// was not the expected type.
func (a AnyOutput) As{{.Name}}Output() {{.Name}}Output {
return a.ApplyT(func(i interface{}) ({{.ElementType}}, error) {
if array, ok := i.([]interface{}); ok {
if len(array) == 0 {
return nil, nil
}
out := make([]{{.InnerElementType}}, len(array))
for i, v := range array {
value, ok := v.({{.InnerElementType}})
if !ok {
return nil, fmt.Errorf("[%d]: expected value of type {{.InnerElementType}}, got %T", i, v)
}
out[i] = value
}
return out, nil
}
return i.({{.ElementType}}), nil
return a.ApplyT(func(i interface{}) {{.ElementType}} {
return i.({{.ElementType}})
}).({{.Name}}Output)
}
{{else}}
// As{{.Name}}Output asserts that the type of the AnyOutput's underlying interface{} value is
// {{.ElementType}} and returns a `{{.Name}}Output` with that value. As{{.Name}}Output panics if the value
// was not the expected type.
// {{.ElementType}} or a compatible type and returns a `{{.Name}}Output` with that value.
// As{{.Name}}Output panics if the value was not the expected type or a compatible type.
func (a AnyOutput) As{{.Name}}Output() {{.Name}}Output {
return a.ApplyT(func(i interface{}) {{.ElementType}} {
return i.({{.ElementType}})
return a.ApplyT(func(i interface{}) ({{.ElementType}}, error) {
v, err := coerceTypeConversion(i, reflect.TypeOf((*{{.ElementType}})(nil)).Elem())
if err != nil {
return nil, err
}
return v.({{.ElementType}}), nil
}).({{.Name}}Output)
}
{{end}}
Expand Down
63 changes: 62 additions & 1 deletion sdk/go/pulumi/types.go
@@ -1,4 +1,4 @@
// Copyright 2016-2020, Pulumi Corporation.
// Copyright 2016-2022, Pulumi Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1248,3 +1248,64 @@ func init() {
RegisterOutputType(ResourceOutput{})
RegisterOutputType(ResourceArrayOutput{})
}

// coerceTypeConversion assigns src to dst, performing deep type coercion as necessary.
func coerceTypeConversion(src interface{}, dst reflect.Type) (interface{}, error) {
makeError := func(src, dst reflect.Value) error {
return fmt.Errorf("expected value of type %s, not %s", dst.Type(), src.Type())
}
var coerce func(reflect.Value, reflect.Value) error
coerce = func(src, dst reflect.Value) error {
Comment on lines +1252 to +1258
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good recursive algorithm 👍

Contravariance implemented in ApplyT is extremely general.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a query, can any not be used in place of interface{}? Looks cleaner, but I'm guessing it is for backward compatibility.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're just using go's built in reflect.Type fmt.Stringer implementation. I was going for an error message that was similar to what we had before (i.(newType)). I would assume that the go dev team prefers interface {} over any for backwards compatibility, but you'd need to ask them to be sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Graham-Beer we could use any here if we updated go.mod to require 1.18 and above. I believe we'll be doing that soon anyhow as 1.17 is no longer supported.

if src.Type().Kind() == reflect.Interface && !src.IsNil() {
src = src.Elem()
}
if src.Type().AssignableTo(dst.Type()) {
dst.Set(src)
return nil
}
switch dst.Type().Kind() {
case reflect.Map:
if src.Kind() != reflect.Map {
return makeError(src, dst)
}

dst.Set(reflect.MakeMapWithSize(dst.Type(), src.Len()))

for iter := src.MapRange(); iter.Next(); {
dstKey := reflect.New(dst.Type().Key()).Elem()
dstVal := reflect.New(dst.Type().Elem()).Elem()
if err := coerce(iter.Key(), dstKey); err != nil {
return fmt.Errorf("invalid key: %w", err)
}
if err := coerce(iter.Value(), dstVal); err != nil {
return fmt.Errorf("[%#v]: %w", dstKey.Interface(), err)
}
dst.SetMapIndex(dstKey, dstVal)
}

return nil
case reflect.Slice:
if src.Kind() != reflect.Slice {
return makeError(src, dst)
}
dst.Set(reflect.MakeSlice(dst.Type(), src.Len(), src.Cap()))
for i := 0; i < src.Len(); i++ {
dstVal := reflect.New(dst.Type().Elem()).Elem()
if err := coerce(src.Index(i), dstVal); err != nil {
return fmt.Errorf("[%d]: %w", i, err)
}
dst.Index(i).Set(dstVal)
}
return nil
default:
return makeError(src, dst)
}
}

srcV, dstV := reflect.ValueOf(src), reflect.New(dst).Elem()

if err := coerce(srcV, dstV); err != nil {
return nil, err
}
return dstV.Interface(), nil
}