Skip to content

Commit

Permalink
[program-gen] Fix enum resolution from types of the form Union[string…
Browse files Browse the repository at this point in the history
…, Enum] and emit fully qualified enum cases (#15696)

# Description

This PR improves enum type resolution from strings. When we try to
resolve `Union[string, Enum]` for a string expression, we choose
`string` because it is the more general type since not every string is
assignable to `Enum`. However, here we spacial case strings that are
actually part of that `Enum`.

The result is that `pcl.LowerConversion` will choose `Enum` from
`Union[string, Enum]` when the value of the input string is compatible
with the enum. This greatly improves program-gen for all of typescript,
python, csharp and go which now will emit the fully qualified enum cases
instead of emitting strings.

Closes pulumi/pulumi-dotnet#41 which is
supposed to be a duplicate of
pulumi/pulumi-azure-native#2616 but that is
not the case (the former is about unions of objects, the latter is
unions of enums and strings)

## Checklist

- [ ] I have run `make tidy` to update any new dependencies
- [x] I have run `make lint` to verify my code passes the lint check
  - [x] I have formatted my code using `gofumpt`

<!--- Please provide details if the checkbox below is to be left
unchecked. -->
- [x] I have added tests that prove my fix is effective or that my
feature works
<!--- 
User-facing changes require a CHANGELOG entry.
-->
- [x] I have run `make changelog` and committed the
`changelog/pending/<file>` documenting my change
<!--
If the change(s) in this PR is a modification of an existing call to the
Pulumi Cloud,
then the service should honor older versions of the CLI where this
change would not exist.
You must then bump the API version in
/pkg/backend/httpstate/client/api.go, as well as add
it to the service.
-->
- [ ] Yes, there are changes in this PR that warrants bumping the Pulumi
Cloud API version
<!-- @pulumi employees: If yes, you must submit corresponding changes in
the service repo. -->
  • Loading branch information
Zaid-Ajaj committed Mar 15, 2024
1 parent 84b2dd7 commit 3bdc65c
Show file tree
Hide file tree
Showing 35 changed files with 210 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
changes:
- type: fix
scope: programgen/dotnet,go,nodejs,python
description: Fix enum resolution from types of the form union[string, enum]
6 changes: 6 additions & 0 deletions pkg/codegen/dotnet/gen_program_expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,12 @@ func (g *generator) genSafeEnum(w io.Writer, to *model.EnumType) func(member *sc
func enumName(enum *model.EnumType) (string, string) {
components := strings.Split(enum.Token, ":")
contract.Assertf(len(components) == 3, "malformed token %v", enum.Token)
modParts := strings.Split(components[1], "/")
// if the token has the format {pkg}:{mod}/{name}:{Name}
// then we simplify into {pkg}:{mod}:{Name}
if len(modParts) == 2 && strings.EqualFold(modParts[1], components[2]) {
components[1] = modParts[0]
}
enumName := tokenToName(enum.Token)
e, ok := pcl.GetSchemaForType(enum)
if !ok {
Expand Down
12 changes: 10 additions & 2 deletions pkg/codegen/go/gen_program_expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (g *generator) GenForExpression(w io.Writer, expr *model.ForExpression) {
g.genNYI(w, "For expression")
}

func (g *generator) genSafeEnum(w io.Writer, to *model.EnumType) func(member *schema.Enum) {
func (g *generator) genSafeEnum(w io.Writer, to *model.EnumType, dest model.Type) func(member *schema.Enum) {
return func(member *schema.Enum) {
// We know the enum value at the call site, so we can directly stamp in a
// valid enum instance. We don't need to convert.
Expand All @@ -173,6 +173,14 @@ func (g *generator) genSafeEnum(w io.Writer, to *model.EnumType) func(member *sc
pkg, mod, _, _ := pcl.DecomposeToken(to.Token, to.SyntaxNode().Range())
mod = g.getModOrAlias(pkg, mod, mod)

if union, isUnion := dest.(*model.UnionType); isUnion && len(union.Annotations) > 0 {
if input, ok := union.Annotations[0].(schema.Type); ok {
if _, ok := codegen.ResolvedType(input).(*schema.UnionType); ok {
g.Fgenf(w, "pulumi.String(%s.%s)", mod, memberTag)
return
}
}
}
g.Fgenf(w, "%s.%s", mod, memberTag)
}
}
Expand Down Expand Up @@ -206,7 +214,7 @@ func (g *generator) GenFunctionCallExpression(w io.Writer, expr *model.FunctionC
from, enumTag, underlyingType)
return
}
diag := pcl.GenEnum(to, from, g.genSafeEnum(w, to), func(from model.Expression) {
diag := pcl.GenEnum(to, from, g.genSafeEnum(w, to, expr.Signature.ReturnType), func(from model.Expression) {
g.Fgenf(w, "%s(%v)", enumTag, from)
})
if diag != nil {
Expand Down
6 changes: 6 additions & 0 deletions pkg/codegen/nodejs/gen_program_expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,12 @@ func enumNameWithPackage(enumToken string, pkgRef schema.PackageReference) (stri
name := tokenToName(enumToken)
pkg := makeValidIdentifier(components[0])
if mod := components[1]; mod != "" && mod != "index" {
// if the token has the format {pkg}:{mod}/{name}:{Name}
// then we simplify into {pkg}:{mod}:{Name}
modParts := strings.Split(mod, "/")
if len(modParts) == 2 && strings.EqualFold(modParts[1], components[2]) {
mod = modParts[0]
}
if pkgRef != nil {
mod = moduleName(mod, pkgRef)
}
Expand Down
27 changes: 27 additions & 0 deletions pkg/codegen/pcl/rewrite_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,20 @@ func convertLiteralToString(from model.Expression) (string, bool) {
return "", false
}

func literalExprValue(expr model.Expression) (cty.Value, bool) {
if lit, ok := expr.(*model.LiteralValueExpression); ok {
return lit.Value, true
}

if templateExpr, ok := expr.(*model.TemplateExpression); ok {
if len(templateExpr.Parts) == 1 {
return literalExprValue(templateExpr.Parts[0])
}
}

return cty.NilVal, false
}

// lowerConversion performs the main logic of LowerConversion. nil, false is
// returned if there is no conversion (safe or unsafe) between `from` and `to`.
// This can occur when a loosely typed program is converted, or if an other
Expand All @@ -332,6 +346,19 @@ func lowerConversion(from model.Expression, to model.Type) (model.Type, bool) {
case *model.UnionType:
// Assignment: it just works
for _, to := range to.ElementTypes {
// in general, strings are not assignable to enums, but we allow it here
// if the enum has an element that matches the `from` expression
switch enumType := to.(type) {
case *model.EnumType:
if literal, ok := literalExprValue(from); ok {
for _, enumCase := range enumType.Elements {
if enumCase.RawEquals(literal) {
return to, true
}
}
}
}

if to.AssignableFrom(from.Type()) {
return to, true
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/codegen/python/gen_program_expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,14 @@ func (g *generator) GenFunctionCallExpression(w io.Writer, expr *model.FunctionC
moduleNameOverrides = pkg.Language["python"].(PackageInfo).ModuleNameOverrides
}
pkg := strings.ReplaceAll(components[0], "-", "_")
enumName := tokenToName(to.Token)
if m := tokenToModule(to.Token, nil, moduleNameOverrides); m != "" {
pkg += "." + m
modParts := strings.Split(m, "/")
if len(modParts) == 2 && strings.EqualFold(modParts[1], enumName) {
m = modParts[0]
}
pkg += "." + strings.ReplaceAll(m, "/", ".")
}
enumName := tokenToName(to.Token)

if isOutput {
g.Fgenf(w, "%.v.apply(lambda x: %s.%s(x))", from, pkg, enumName)
Expand Down
7 changes: 7 additions & 0 deletions pkg/codegen/testing/test/program_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ var PulumiPulumiProgramTests = []ProgramTest{
Directory: "azure-sa",
Description: "Azure SA",
},
{
Directory: "string-enum-union-list",
Description: "Contains resource which has a property of type List<Union<String, Enum>>",
// skipping compiling on Go because it doesn't know to handle unions in lists
// and instead generates pulumi.StringArray
SkipCompile: codegen.NewStringSet("go"),
},
{
Directory: "kubernetes-operator",
Description: "K8s Operator",
Expand Down
2 changes: 1 addition & 1 deletion tests/testdata/codegen/aws-lambda-pp/python/aws-lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
code=pulumi.FileArchive("lambda_function_payload.zip"),
role=iam_for_lambda.arn,
handler="index.test",
runtime="nodejs12.x",
runtime=aws.lambda_.Runtime.NODE_JS12D_X,
environment=aws.lambda_.FunctionEnvironmentArgs(
variables={
"foo": "bar",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
{
{ "Name", "web-server-www" },
},
InstanceType = "t2.micro",
InstanceType = Aws.Ec2.InstanceType.T2_Micro,
SecurityGroups = new[]
{
securityGroup.Name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func main() {
Tags: pulumi.StringMap{
"Name": pulumi.String("web-server-www"),
},
InstanceType: pulumi.String("t2.micro"),
InstanceType: pulumi.String(ec2.InstanceType_T2_Micro),
SecurityGroups: pulumi.StringArray{
securityGroup.Name,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const server = new aws.ec2.Instance("server", {
tags: {
Name: "web-server-www",
},
instanceType: "t2.micro",
instanceType: aws.ec2.InstanceType.T2_Micro,
securityGroups: [securityGroup.name],
ami: ami.then(ami => ami.id),
userData: `#!/bin/bash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
tags={
"Name": "web-server-www",
},
instance_type="t2.micro",
instance_type=aws.ec2.InstanceType.T2_MICRO,
security_groups=[security_group.name],
ami=ami.id,
user_data="""#!/bin/bash
Expand Down
9 changes: 8 additions & 1 deletion tests/testdata/codegen/azure-native-1.56.0.json
Original file line number Diff line number Diff line change
Expand Up @@ -426994,7 +426994,14 @@
"rights": {
"type": "array",
"items": {
"$ref": "#/types/azure-native:servicebus:AccessRights"
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/types/azure-native:servicebus:AccessRights"
}
]
},
"description": "The rights associated with the rule."
}
Expand Down
10 changes: 5 additions & 5 deletions tests/testdata/codegen/azure-native-pp/dotnet/azure-native.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
Name = "CacheExpiration",
Parameters = new AzureNative.Cdn.Inputs.CacheExpirationActionParametersArgs
{
CacheBehavior = "Override",
CacheBehavior = AzureNative.Cdn.CacheBehavior.Override,
CacheDuration = "10:10:09",
CacheType = "All",
CacheType = AzureNative.Cdn.CacheType.All,
OdataType = "#Microsoft.Azure.Cdn.Models.DeliveryRuleCacheExpirationActionParameters",
},
},
Expand All @@ -51,7 +51,7 @@
Name = "ModifyResponseHeader",
Parameters = new AzureNative.Cdn.Inputs.HeaderActionParametersArgs
{
HeaderAction = "Overwrite",
HeaderAction = AzureNative.Cdn.HeaderAction.Overwrite,
HeaderName = "Access-Control-Allow-Origin",
OdataType = "#Microsoft.Azure.Cdn.Models.DeliveryRuleHeaderActionParameters",
Value = "*",
Expand All @@ -62,7 +62,7 @@
Name = "ModifyRequestHeader",
Parameters = new AzureNative.Cdn.Inputs.HeaderActionParametersArgs
{
HeaderAction = "Overwrite",
HeaderAction = AzureNative.Cdn.HeaderAction.Overwrite,
HeaderName = "Accept-Encoding",
OdataType = "#Microsoft.Azure.Cdn.Models.DeliveryRuleHeaderActionParameters",
Value = "gzip",
Expand All @@ -83,7 +83,7 @@
},
NegateCondition = true,
OdataType = "#Microsoft.Azure.Cdn.Models.DeliveryRuleRemoteAddressConditionParameters",
Operator = "IPMatch",
Operator = AzureNative.Cdn.RemoteAddressOperator.IPMatch,
},
},
},
Expand Down
10 changes: 5 additions & 5 deletions tests/testdata/codegen/azure-native-pp/go/azure-native.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ func main() {
{
Name: "CacheExpiration",
Parameters: {
CacheBehavior: "Override",
CacheBehavior: cdn.CacheBehaviorOverride,
CacheDuration: "10:10:09",
CacheType: "All",
CacheType: cdn.CacheTypeAll,
OdataType: "#Microsoft.Azure.Cdn.Models.DeliveryRuleCacheExpirationActionParameters",
},
},
{
Name: "ModifyResponseHeader",
Parameters: {
HeaderAction: "Overwrite",
HeaderAction: cdn.HeaderActionOverwrite,
HeaderName: "Access-Control-Allow-Origin",
OdataType: "#Microsoft.Azure.Cdn.Models.DeliveryRuleHeaderActionParameters",
Value: "*",
Expand All @@ -51,7 +51,7 @@ func main() {
{
Name: "ModifyRequestHeader",
Parameters: {
HeaderAction: "Overwrite",
HeaderAction: cdn.HeaderActionOverwrite,
HeaderName: "Accept-Encoding",
OdataType: "#Microsoft.Azure.Cdn.Models.DeliveryRuleHeaderActionParameters",
Value: "gzip",
Expand All @@ -68,7 +68,7 @@ func main() {
},
NegateCondition: true,
OdataType: "#Microsoft.Azure.Cdn.Models.DeliveryRuleRemoteAddressConditionParameters",
Operator: "IPMatch",
Operator: cdn.RemoteAddressOperatorIPMatch,
},
},
},
Expand Down
10 changes: 5 additions & 5 deletions tests/testdata/codegen/azure-native-pp/nodejs/azure-native.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ const endpoint = new azure_native.cdn.Endpoint("endpoint", {
{
name: "CacheExpiration",
parameters: {
cacheBehavior: "Override",
cacheBehavior: azure_native.cdn.CacheBehavior.Override,
cacheDuration: "10:10:09",
cacheType: "All",
cacheType: azure_native.cdn.CacheType.All,
odataType: "#Microsoft.Azure.Cdn.Models.DeliveryRuleCacheExpirationActionParameters",
},
},
{
name: "ModifyResponseHeader",
parameters: {
headerAction: "Overwrite",
headerAction: azure_native.cdn.HeaderAction.Overwrite,
headerName: "Access-Control-Allow-Origin",
odataType: "#Microsoft.Azure.Cdn.Models.DeliveryRuleHeaderActionParameters",
value: "*",
Expand All @@ -38,7 +38,7 @@ const endpoint = new azure_native.cdn.Endpoint("endpoint", {
{
name: "ModifyRequestHeader",
parameters: {
headerAction: "Overwrite",
headerAction: azure_native.cdn.HeaderAction.Overwrite,
headerName: "Accept-Encoding",
odataType: "#Microsoft.Azure.Cdn.Models.DeliveryRuleHeaderActionParameters",
value: "gzip",
Expand All @@ -54,7 +54,7 @@ const endpoint = new azure_native.cdn.Endpoint("endpoint", {
],
negateCondition: true,
odataType: "#Microsoft.Azure.Cdn.Models.DeliveryRuleRemoteAddressConditionParameters",
operator: "IPMatch",
operator: azure_native.cdn.RemoteAddressOperator.IPMatch,
},
}],
name: "rule1",
Expand Down
10 changes: 5 additions & 5 deletions tests/testdata/codegen/azure-native-pp/python/azure-native.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
azure_native.cdn.DeliveryRuleCacheExpirationActionArgs(
name="CacheExpiration",
parameters=azure_native.cdn.CacheExpirationActionParametersArgs(
cache_behavior="Override",
cache_behavior=azure_native.cdn.CacheBehavior.OVERRIDE,
cache_duration="10:10:09",
cache_type="All",
cache_type=azure_native.cdn.CacheType.ALL,
odata_type="#Microsoft.Azure.Cdn.Models.DeliveryRuleCacheExpirationActionParameters",
),
),
azure_native.cdn.DeliveryRuleResponseHeaderActionArgs(
name="ModifyResponseHeader",
parameters=azure_native.cdn.HeaderActionParametersArgs(
header_action="Overwrite",
header_action=azure_native.cdn.HeaderAction.OVERWRITE,
header_name="Access-Control-Allow-Origin",
odata_type="#Microsoft.Azure.Cdn.Models.DeliveryRuleHeaderActionParameters",
value="*",
Expand All @@ -37,7 +37,7 @@
azure_native.cdn.DeliveryRuleRequestHeaderActionArgs(
name="ModifyRequestHeader",
parameters=azure_native.cdn.HeaderActionParametersArgs(
header_action="Overwrite",
header_action=azure_native.cdn.HeaderAction.OVERWRITE,
header_name="Accept-Encoding",
odata_type="#Microsoft.Azure.Cdn.Models.DeliveryRuleHeaderActionParameters",
value="gzip",
Expand All @@ -53,7 +53,7 @@
],
negate_condition=True,
odata_type="#Microsoft.Azure.Cdn.Models.DeliveryRuleRemoteAddressConditionParameters",
operator="IPMatch",
operator=azure_native.cdn.RemoteAddressOperator.IP_MATCH,
),
)],
name="rule1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Capacity = 2,
Family = "Gen5",
Name = "B_Gen5_2",
Tier = "Basic",
Tier = AzureNative.DBforPostgreSQL.SkuTier.Basic,
},
Tags =
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const server = new azure_native.dbforpostgresql.Server("server", {
capacity: 2,
family: "Gen5",
name: "B_Gen5_2",
tier: "Basic",
tier: azure_native.dbforpostgresql.SkuTier.Basic,
},
tags: {
ElasticServer: "1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
capacity=2,
family="Gen5",
name="B_Gen5_2",
tier="Basic",
tier=azure_native.dbforpostgresql.SkuTier.BASIC,
),
tags={
"ElasticServer": "1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
var storageAccounts = new AzureNative.Storage.StorageAccount("storageAccounts", new()
{
AccountName = "sto4445",
Kind = "BlockBlobStorage",
Kind = AzureNative.Storage.Kind.BlockBlobStorage,
Location = "eastus",
ResourceGroupName = "res9101",
Sku = new AzureNative.Storage.Inputs.SkuArgs
{
Name = "Premium_LRS",
Name = AzureNative.Storage.SkuName.Premium_LRS,
},
NetworkRuleSet = new AzureNative.Storage.Inputs.NetworkRuleSetArgs
{
Expand Down

0 comments on commit 3bdc65c

Please sign in to comment.