Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Swap and CompareAndSwap for Value wrappers #130

Merged
merged 3 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion bool.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @generated Code generated by gen-atomicwrapper.

// Copyright (c) 2020-2022 Uber Technologies, Inc.
// Copyright (c) 2020-2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
64 changes: 64 additions & 0 deletions bool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,67 @@ func TestBool(t *testing.T) {
})
})
}

func TestBool_InitializeDefaults(t *testing.T) {
tests := []struct {
msg string
newBool func() *Bool
}{
{
msg: "Uninitialized",
newBool: func() *Bool {
var b Bool
return &b
},
},
{
msg: "NewBool with default",
newBool: func() *Bool {
return NewBool(false)
},
},
{
msg: "Bool swapped with default",
newBool: func() *Bool {
b := NewBool(true)
b.Swap(false)
return b
},
},
{
msg: "Bool CAS'd with default",
newBool: func() *Bool {
b := NewBool(true)
b.CompareAndSwap(true, false)
return b
},
},
}

for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
t.Run("MarshalJSON", func(t *testing.T) {
b := tt.newBool()
marshalled, err := b.MarshalJSON()
require.NoError(t, err)
assert.Equal(t, "false", string(marshalled))
})

t.Run("String", func(t *testing.T) {
b := tt.newBool()
assert.Equal(t, "false", b.String())
})

t.Run("CompareAndSwap", func(t *testing.T) {
b := tt.newBool()
require.True(t, b.CompareAndSwap(false, true))
assert.Equal(t, true, b.Load())
})

t.Run("Swap", func(t *testing.T) {
b := tt.newBool()
assert.Equal(t, false, b.Swap(true))
})
})
}
}
2 changes: 1 addition & 1 deletion duration.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @generated Code generated by gen-atomicwrapper.

// Copyright (c) 2020-2022 Uber Technologies, Inc.
// Copyright (c) 2020-2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
14 changes: 12 additions & 2 deletions error.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @generated Code generated by gen-atomicwrapper.

// Copyright (c) 2020-2022 Uber Technologies, Inc.
// Copyright (c) 2020-2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -52,7 +52,17 @@ func (x *Error) Store(val error) {

// CompareAndSwap is an atomic compare-and-swap for error values.
func (x *Error) CompareAndSwap(old, new error) (swapped bool) {
return x.v.CompareAndSwap(packError(old), packError(new))
if x.v.CompareAndSwap(packError(old), packError(new)) {
return true
}

if old == _zeroError {
// If the old value is the empty value, then it's possible the
// underlying Value hasn't been set and is nil, so retry with nil.
return x.v.CompareAndSwap(nil, packError(new))
}

return false
}

// Swap atomically stores the given error and returns the old
Expand Down
53 changes: 53 additions & 0 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -81,3 +82,55 @@ func TestErrorCompareAndSwap(t *testing.T) {
require.True(t, swapped, "Expected swapped to be true")
require.Equal(t, err2, atom.Load(), "Expected Load to return overridden value")
}

func TestError_InitializeDefaults(t *testing.T) {
tests := []struct {
msg string
newError func() *Error
}{
{
msg: "Uninitialized",
newError: func() *Error {
var e Error
return &e
},
},
{
msg: "NewError with default",
newError: func() *Error {
return NewError(nil)
},
},
{
msg: "Error swapped with default",
newError: func() *Error {
e := NewError(assert.AnError)
e.Swap(nil)
return e
},
},
{
msg: "Error CAS'd with default",
newError: func() *Error {
e := NewError(assert.AnError)
e.CompareAndSwap(assert.AnError, nil)
return e
},
},
}

for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
t.Run("CompareAndSwap", func(t *testing.T) {
e := tt.newError()
require.True(t, e.CompareAndSwap(nil, assert.AnError))
assert.Equal(t, assert.AnError, e.Load())
})

t.Run("Swap", func(t *testing.T) {
e := tt.newError()
assert.Equal(t, nil, e.Swap(assert.AnError))
})
})
}
}
2 changes: 1 addition & 1 deletion float32.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @generated Code generated by gen-atomicwrapper.

// Copyright (c) 2020-2022 Uber Technologies, Inc.
// Copyright (c) 2020-2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion float64.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @generated Code generated by gen-atomicwrapper.

// Copyright (c) 2020-2022 Uber Technologies, Inc.
// Copyright (c) 2020-2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion int32.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @generated Code generated by gen-atomicint.

// Copyright (c) 2020-2022 Uber Technologies, Inc.
// Copyright (c) 2020-2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion int64.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @generated Code generated by gen-atomicint.

// Copyright (c) 2020-2022 Uber Technologies, Inc.
// Copyright (c) 2020-2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
15 changes: 6 additions & 9 deletions internal/gen-atomicwrapper/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
//
// The packing/unpacking logic allows the stored value to be different from
// the user-facing value.
//
// Without -pack and -unpack, the output will be cast to the target type,
// defaulting to the zero value.
package main

import (
Expand Down Expand Up @@ -143,12 +140,12 @@ func run(args []string) error {
return err
}

if len(opts.Name) == 0 || len(opts.Wrapped) == 0 || len(opts.Type) == 0 {
return errors.New("flags -name, -wrapped, and -type are required")
}

if (len(opts.Pack) == 0) != (len(opts.Unpack) == 0) {
return errors.New("either both, or neither of -pack and -unpack must be specified")
if len(opts.Name) == 0 ||
len(opts.Wrapped) == 0 ||
len(opts.Type) == 0 ||
len(opts.Pack) == 0 ||
len(opts.Unpack) == 0 {
return errors.New("flags -name, -wrapped, -pack, -unpack and -type are required")
}

if opts.CAS {
Expand Down
28 changes: 15 additions & 13 deletions internal/gen-atomicwrapper/wrapper.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ func (x *{{ .Name }}) Load() {{ .Type }} {

// Store atomically stores the passed {{ .Type }}.
func (x *{{ .Name }}) Store(val {{ .Type }}) {
{{ if .Pack -}}
x.v.Store({{ .Pack }}(val))
{{- else -}}
x.v.Store(val)
{{- end }}
x.v.Store({{ .Pack }}(val))
}

{{ if .CAS -}}
Expand All @@ -80,10 +76,20 @@ func (x *{{ .Name }}) Store(val {{ .Type }}) {
{{ if .CompareAndSwap -}}
// CompareAndSwap is an atomic compare-and-swap for {{ .Type }} values.
func (x *{{ .Name }}) CompareAndSwap(old, new {{ .Type }}) (swapped bool) {
{{ if .Pack -}}
{{ if eq .Wrapped "Value" -}}
if x.v.CompareAndSwap({{ .Pack }}(old), {{ .Pack }}(new)) {
return true
}

if old == _zero{{ .Name }} {
// If the old value is the empty value, then it's possible the
// underlying Value hasn't been set and is nil, so retry with nil.
return x.v.CompareAndSwap(nil, {{ .Pack }}(new))
}

return false
{{- else -}}
return x.v.CompareAndSwap({{ .Pack }}(old), {{ .Pack }}(new))
{{- else -}}{{- /* assume go.uber.org/atomic.Value */ -}}
return x.v.CompareAndSwap(old, new)
{{- end }}
}
{{- end }}
Expand All @@ -92,11 +98,7 @@ func (x *{{ .Name }}) Store(val {{ .Type }}) {
// Swap atomically stores the given {{ .Type }} and returns the old
// value.
func (x *{{ .Name }}) Swap(val {{ .Type }}) (old {{ .Type }}) {
{{ if .Pack -}}
return {{ .Unpack }}(x.v.Swap({{ .Pack }}(val)))
{{- else -}}{{- /* assume go.uber.org/atomic.Value */ -}}
return x.v.Swap(val).({{ .Type }})
{{- end }}
return {{ .Unpack }}(x.v.Swap({{ .Pack }}(val)))
}
{{- end }}

Expand Down
23 changes: 15 additions & 8 deletions string.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @generated Code generated by gen-atomicwrapper.

// Copyright (c) 2020-2022 Uber Technologies, Inc.
// Copyright (c) 2020-2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -42,24 +42,31 @@ func NewString(val string) *String {

// Load atomically loads the wrapped string.
func (x *String) Load() string {
if v := x.v.Load(); v != nil {
return v.(string)
}
return _zeroString
return unpackString(x.v.Load())
}

// Store atomically stores the passed string.
func (x *String) Store(val string) {
x.v.Store(val)
x.v.Store(packString(val))
}

// CompareAndSwap is an atomic compare-and-swap for string values.
func (x *String) CompareAndSwap(old, new string) (swapped bool) {
return x.v.CompareAndSwap(old, new)
if x.v.CompareAndSwap(packString(old), packString(new)) {
return true
}

if old == _zeroString {
// If the old value is the empty value, then it's possible the
// underlying Value hasn't been set and is nil, so retry with nil.
return x.v.CompareAndSwap(nil, packString(new))
}

return false
}

// Swap atomically stores the given string and returns the old
// value.
func (x *String) Swap(val string) (old string) {
return x.v.Swap(val).(string)
return unpackString(x.v.Swap(packString(val)))
}
15 changes: 13 additions & 2 deletions string_ext.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022 Uber Technologies, Inc.
// Copyright (c) 2020-2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand All @@ -20,7 +20,18 @@

package atomic

//go:generate bin/gen-atomicwrapper -name=String -type=string -wrapped=Value -compareandswap -swap -file=string.go
//go:generate bin/gen-atomicwrapper -name=String -type=string -wrapped Value -pack packString -unpack unpackString -compareandswap -swap -file=string.go

func packString(s string) interface{} {
return s
}

func unpackString(v interface{}) string {
if s, ok := v.(string); ok {
return s
}
return ""
}

// String returns the wrapped value.
func (s *String) String() string {
Expand Down