Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go): add UnsafeCast function #3316

Merged
merged 8 commits into from Jan 12, 2022
53 changes: 53 additions & 0 deletions packages/@jsii/go-runtime/jsii-runtime-go/cast.go
@@ -0,0 +1,53 @@
package jsii

import (
"fmt"
"reflect"

"github.com/aws/jsii-runtime-go/internal/kernel"
)

// UnsafeCast converts the given interface value to the desired target interface
// pointer. Panics if the from value is not a jsii proxy object, or if the to
// value is not a pointer to an interface type.
func UnsafeCast(from interface{}, into interface{}) {
rinto := reflect.ValueOf(into)
if rinto.Kind() != reflect.Ptr {
panic(fmt.Errorf("Second argument to UnsafeCast must be a pointer to an interface. Received %s", rinto.Type().String()))
}
rinto = rinto.Elem()
if rinto.Kind() != reflect.Interface {
panic(fmt.Errorf("Second argument to UnsafeCast must be a pointer to an interface. Received pointer to %s", rinto.Type().String()))
}

rfrom := reflect.ValueOf(from)

// If rfrom is essentially nil, set into to nil and return.
if !rfrom.IsValid() || rfrom.IsZero() {
null := reflect.Zero(rinto.Type())
rinto.Set(null)
return
}
// Interfaces may present as a pointer to an implementing struct, and that's fine...
if rfrom.Kind() != reflect.Interface && rfrom.Kind() != reflect.Ptr {
panic(fmt.Errorf("First argument to UnsafeCast must be an interface value. Received %s", rfrom.Type().String()))
}

// If rfrom can be directly converted to rinto, just do it.
if rfrom.CanConvert(rinto.Type()) {
RomainMuller marked this conversation as resolved.
Show resolved Hide resolved
rfrom = rfrom.Convert(rinto.Type())
rinto.Set(rfrom)
return
}

client := kernel.GetClient()
if objID, found := client.FindObjectRef(rfrom); found {
// Ensures the value is initialized properly. Panics if the target value is not a jsii interface type.
client.Types().InitJsiiProxy(rinto)
// Make the new value an alias to the old value.
client.RegisterInstance(rinto, objID)
return
}

panic(fmt.Errorf("First argument to UnsafeCast must be a jsii proxy value. Received %s", rfrom.String()))
}
101 changes: 101 additions & 0 deletions packages/@jsii/go-runtime/jsii-runtime-go/cast_test.go
@@ -0,0 +1,101 @@
package jsii

import (
"reflect"
"testing"

"github.com/aws/jsii-runtime-go/internal/api"
"github.com/aws/jsii-runtime-go/internal/kernel"
)

type MockInterfaceABase interface {
MockMethodABase(_ float64)
}

type mockABase struct {
_ int // padding
}

func (m *mockABase) MockMethodABase(_ float64) {}

type MockInterfaceA interface {
MockInterfaceABase
MockMethodA(_ string)
}

func NewMockInterfaceA() MockInterfaceA {
return &mockA{mockABase{}}
}

type mockA struct {
mockABase
}

func (m *mockA) MockMethodA(_ string) {}

type MockInterfaceB interface {
MockMethodB(_ int)
}

func NewMockInterfaceB() MockInterfaceB {
return &mockB{}
}

type mockB struct {
_ int // Padding
}

func (m *mockB) MockMethodB(_ int) {}

func TestNilSource(t *testing.T) {
// Make "into" not nil to ensure the cast function overwrites it.
into := NewMockInterfaceB()
UnsafeCast(nil, &into)

if into != nil {
t.Fail()
}
}

func TestSourceAndTargetAreTheSame(t *testing.T) {
into := NewMockInterfaceB()
original := into
UnsafeCast(into, &into)

if into != original {
t.Fail()
}
}

func TestTargetIsSubclassOfSource(t *testing.T) {
from := NewMockInterfaceA()
var into MockInterfaceABase
UnsafeCast(from, &into)

if into != from {
t.Fail()
}
}

func TestRegistersAlias(t *testing.T) {
client := kernel.GetClient()

objid := "Object@1337#42"
from := NewMockInterfaceA()
client.RegisterInstance(reflect.ValueOf(from), objid)

var into MockInterfaceB
client.Types().RegisterInterface(api.FQN("mock.InterfaceB"), reflect.TypeOf(&into).Elem(), []api.Override{}, func() interface{} { return NewMockInterfaceB() })

UnsafeCast(from, &into)

if into == nil {
t.Fail()
}

if refid, found := client.FindObjectRef(reflect.ValueOf(into)); !found {
t.Fail()
} else if refid != objid {
t.Fail()
}
}