Skip to content

Commit

Permalink
fix(mongodb): Fix resource update errors
Browse files Browse the repository at this point in the history
  • Loading branch information
youngmn committed Apr 2, 2024
1 parent cbd4f0a commit 734b5e5
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 23 deletions.
9 changes: 9 additions & 0 deletions internal/common/convert_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,18 @@ func ExpandStringList(configured []interface{}) []*string {
return vs
}

// value is nil. set null.
func Int64ValueFromInt32(value *int32) basetypes.Int64Value {
if value == nil {
return basetypes.NewInt64Null()
}
return basetypes.NewInt64Value(int64(*value))
}

// value is nil. set 0.
func Int64ZeroFromInt32(value *int32) basetypes.Int64Value {
if value == nil {
return basetypes.NewInt64Value(0)
}
return basetypes.NewInt64Value(int64(*value))
}
116 changes: 93 additions & 23 deletions internal/service/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,33 @@ func (m *mongodbResource) Schema(_ context.Context, _ resource.SchemaRequest, re
},
"member_product_code": schema.StringAttribute{
Optional: true,
Computed: true,
PlanModifiers: []planmodifier.String{
stringplanmodifier.UseStateForUnknown(),
stringplanmodifier.RequiresReplace(),
},
},
"arbiter_product_code": schema.StringAttribute{
Optional: true,
Computed: true,
PlanModifiers: []planmodifier.String{
stringplanmodifier.RequiresReplace(),
stringplanmodifier.RequiresReplaceIfConfigured(),
},
},
"mongos_product_code": schema.StringAttribute{
Optional: true,
Computed: true,
PlanModifiers: []planmodifier.String{
stringplanmodifier.RequiresReplace(),
stringplanmodifier.UseStateForUnknown(),
stringplanmodifier.RequiresReplaceIfConfigured(),
},
},
"config_product_code": schema.StringAttribute{
Optional: true,
Computed: true,
PlanModifiers: []planmodifier.String{
stringplanmodifier.RequiresReplace(),
stringplanmodifier.UseStateForUnknown(),
stringplanmodifier.RequiresReplaceIfConfigured(),
},
},
"shard_count": schema.Int64Attribute{
Expand All @@ -177,24 +184,40 @@ func (m *mongodbResource) Schema(_ context.Context, _ resource.SchemaRequest, re
},
"member_server_count": schema.Int64Attribute{
Optional: true,
Computed: true,
PlanModifiers: []planmodifier.Int64{
int64planmodifier.UseStateForUnknown(),
},
Validators: []validator.Int64{
int64validator.Between(2, 7),
},
},
"arbiter_server_count": schema.Int64Attribute{
Optional: true,
Computed: true,
PlanModifiers: []planmodifier.Int64{
int64planmodifier.UseStateForUnknown(),
},
Validators: []validator.Int64{
int64validator.Between(0, 1),
},
},
"mongos_server_count": schema.Int64Attribute{
Optional: true,
Computed: true,
PlanModifiers: []planmodifier.Int64{
int64planmodifier.UseStateForUnknown(),
},
Validators: []validator.Int64{
int64validator.Between(2, 5),
},
},
"config_server_count": schema.Int64Attribute{
Optional: true,
Computed: true,
PlanModifiers: []planmodifier.Int64{
int64planmodifier.UseStateForUnknown(),
},
Validators: []validator.Int64{
int64validator.Between(3, 7),
},
Expand Down Expand Up @@ -237,7 +260,7 @@ func (m *mongodbResource) Schema(_ context.Context, _ resource.SchemaRequest, re
Computed: true,
PlanModifiers: []planmodifier.Int64{
int64planmodifier.UseStateForUnknown(),
int64planmodifier.RequiresReplace(),
int64planmodifier.RequiresReplaceIfConfigured(),
},
Validators: []validator.Int64{
int64validator.Any(
Expand All @@ -247,16 +270,13 @@ func (m *mongodbResource) Schema(_ context.Context, _ resource.SchemaRequest, re
},
"arbiter_port": schema.Int64Attribute{
Computed: true,
PlanModifiers: []planmodifier.Int64{
int64planmodifier.UseStateForUnknown(),
},
},
"mongos_port": schema.Int64Attribute{
Optional: true,
Computed: true,
PlanModifiers: []planmodifier.Int64{
int64planmodifier.UseStateForUnknown(),
int64planmodifier.RequiresReplace(),
int64planmodifier.RequiresReplaceIfConfigured(),
},
Validators: []validator.Int64{
int64validator.Any(
Expand All @@ -269,7 +289,7 @@ func (m *mongodbResource) Schema(_ context.Context, _ resource.SchemaRequest, re
Computed: true,
PlanModifiers: []planmodifier.Int64{
int64planmodifier.UseStateForUnknown(),
int64planmodifier.RequiresReplace(),
int64planmodifier.RequiresReplaceIfConfigured(),
},
Validators: []validator.Int64{
int64validator.Any(
Expand Down Expand Up @@ -601,6 +621,7 @@ func (m *mongodbResource) Read(ctx context.Context, req resource.ReadRequest, re
func (m *mongodbResource) Update(ctx context.Context, req resource.UpdateRequest, resp *resource.UpdateResponse) {
var plan, state mongodbResourceModel

// plan is NEW, state is OLD
resp.Diagnostics.Append(req.Plan.Get(ctx, &plan)...)
resp.Diagnostics.Append(req.State.Get(ctx, &state)...)

Expand Down Expand Up @@ -636,7 +657,7 @@ func (m *mongodbResource) Update(ctx context.Context, req resource.UpdateRequest
return
}

plan.refreshFromOutput(ctx, output)
state.refreshFromOutput(ctx, output)
}

if !plan.MongosServerCount.Equal(state.MongosServerCount) {
Expand Down Expand Up @@ -667,7 +688,7 @@ func (m *mongodbResource) Update(ctx context.Context, req resource.UpdateRequest
return
}

plan.refreshFromOutput(ctx, output)
state.refreshFromOutput(ctx, output)
}

if !plan.MemberServerCount.Equal(state.MemberServerCount) ||
Expand Down Expand Up @@ -700,7 +721,7 @@ func (m *mongodbResource) Update(ctx context.Context, req resource.UpdateRequest
return
}

plan.refreshFromOutput(ctx, output)
state.refreshFromOutput(ctx, output)
}

if !plan.ShardCount.Equal(state.ShardCount) {
Expand Down Expand Up @@ -731,10 +752,10 @@ func (m *mongodbResource) Update(ctx context.Context, req resource.UpdateRequest
return
}

plan.refreshFromOutput(ctx, output)
state.refreshFromOutput(ctx, output)
}

resp.Diagnostics.Append(resp.State.Set(ctx, plan)...)
resp.Diagnostics.Append(resp.State.Set(ctx, state)...)
}

func (m *mongodbResource) Delete(ctx context.Context, req resource.DeleteRequest, resp *resource.DeleteResponse) {
Expand Down Expand Up @@ -984,31 +1005,37 @@ func (m *mongodbResourceModel) refreshFromOutput(ctx context.Context, output *vm
m.ServiceName = types.StringPointerValue(output.CloudMongoDbServiceName)
m.VpcNo = types.StringPointerValue(output.CloudMongoDbServerInstanceList[0].VpcNo)
m.SubnetNo = types.StringPointerValue(output.CloudMongoDbServerInstanceList[0].SubnetNo)
m.ClusterTypeCode = types.StringPointerValue(output.ClusterType.Code)
m.ImageProductCode = types.StringPointerValue(output.CloudMongoDbImageProductCode)
m.ShardCount = common.Int64ValueFromInt32(output.ShardCount)
m.BackupFileRetentionPeriod = common.Int64ValueFromInt32(output.BackupFileRetentionPeriod)
m.BackupTime = types.StringPointerValue(output.BackupTime)
m.ArbiterPort = common.Int64ValueFromInt32(output.ArbiterPort)
m.MemberPort = common.Int64ValueFromInt32(output.MemberPort)
m.MongosPort = common.Int64ValueFromInt32(output.MongosPort)
m.ConfigPort = common.Int64ValueFromInt32(output.ConfigPort)
m.DataStorageType = types.StringPointerValue(output.CloudMongoDbServerInstanceList[0].DataStorageType.Code)
m.CompressCode = types.StringPointerValue(output.Compress.Code)
m.ArbiterPort = common.Int64ZeroFromInt32(output.ArbiterPort)
m.MemberPort = common.Int64ZeroFromInt32(output.MemberPort)
m.MongosPort = common.Int64ZeroFromInt32(output.MongosPort)
m.ConfigPort = common.Int64ZeroFromInt32(output.ConfigPort)
m.EngineVersion = types.StringPointerValue(output.EngineVersion)
m.RegionCode = types.StringPointerValue(output.CloudMongoDbServerInstanceList[0].RegionCode)
m.ZoneCode = types.StringPointerValue(output.CloudMongoDbServerInstanceList[0].ZoneCode)

if output.CloudMongoDbServerInstanceList[0].DataStorageType != nil {
m.DataStorageType = types.StringPointerValue(output.CloudMongoDbServerInstanceList[0].DataStorageType.Code)
}
if output.Compress != nil {
m.CompressCode = types.StringPointerValue(output.Compress.Code)
}

acgList, _ := types.ListValueFrom(ctx, types.StringType, output.AccessControlGroupNoList)
m.AccessControlGroupNoList = acgList

var memberCount int64
var arbiterCount int64
var mongosCount int64
var configCount int64
var serverList []mongoServer
for _, server := range output.CloudMongoDbServerInstanceList {
mongoServerInstance := mongoServer{
ServerNo: types.StringPointerValue(server.CloudMongoDbServerInstanceNo),
ServerName: types.StringPointerValue(server.CloudMongoDbServerName),
ServerRole: types.StringPointerValue(server.CloudMongoDbServerRole.CodeName),
ClusterRole: types.StringPointerValue(server.ClusterRole.Code),
ProductCode: types.StringPointerValue(server.CloudMongoDbProductCode),
PrivateDomain: types.StringPointerValue(server.PrivateDomain),
PublicDomain: types.StringPointerValue(server.PublicDomain),
Expand All @@ -1019,9 +1046,52 @@ func (m *mongodbResourceModel) refreshFromOutput(ctx context.Context, output *vm
Uptime: types.StringPointerValue(server.Uptime),
CreateDate: types.StringPointerValue(server.CreateDate),
}

if server.CloudMongoDbServerRole != nil {
mongoServerInstance.ServerRole = types.StringPointerValue(server.CloudMongoDbServerRole.CodeName)
if *server.CloudMongoDbServerRole.Code == "A" || *server.CloudMongoDbServerRole.Code == "MB" {
m.MemberProductCode = types.StringPointerValue(server.CloudMongoDbProductCode)
memberCount++
} else if *server.CloudMongoDbServerRole.Code == "AB" {
m.ArbiterProductCode = types.StringPointerValue(server.CloudMongoDbProductCode)
arbiterCount++
} else if *server.CloudMongoDbServerRole.Code == "RT" {
m.MongosProductCode = types.StringPointerValue(server.CloudMongoDbProductCode)
mongosCount++
} else if *server.CloudMongoDbServerRole.Code == "C" {
m.ConfigProductCode = types.StringPointerValue(server.CloudMongoDbProductCode)
configCount++
}
}
if server.ClusterRole != nil {
mongoServerInstance.ClusterRole = types.StringPointerValue(server.ClusterRole.Code)
}
serverList = append(serverList, mongoServerInstance)
}

if output.ClusterType != nil {
m.ClusterTypeCode = types.StringPointerValue(output.ClusterType.Code)
if *output.ClusterType.Code == "SHARDED_CLUSTER" {
memberCount = memberCount / int64(*output.ShardCount)
if arbiterCount > 0 {
arbiterCount = arbiterCount / int64(*output.ShardCount)
}
}
}
m.MemberServerCount = types.Int64Value(memberCount)
m.ArbiterServerCount = types.Int64Value(arbiterCount)
m.MongosServerCount = types.Int64Value(mongosCount)
m.ConfigServerCount = types.Int64Value(configCount)
if arbiterCount == 0 {
m.ArbiterProductCode = types.StringValue("not allocated")
}
if mongosCount == 0 {
m.MongosProductCode = types.StringValue("not allocated")
}
if configCount == 0 {
m.ConfigProductCode = types.StringValue("not allocated")
}

mongoServers, _ := types.ListValueFrom(ctx, types.ObjectType{AttrTypes: mongoServer{}.attrTypes()}, serverList)

m.MongoDbServerList = mongoServers
Expand Down

0 comments on commit 734b5e5

Please sign in to comment.