Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhancement for enums #1400

Merged
merged 5 commits into from Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
551 changes: 551 additions & 0 deletions const.go

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions enums_test.go
Expand Up @@ -20,4 +20,13 @@ func TestParseGlobalEnums(t *testing.T) {
b, err := json.MarshalIndent(p.swagger, "", " ")
assert.NoError(t, err)
assert.Equal(t, string(expected), string(b))
constsPath := "github.com/swaggo/swag/testdata/enums/consts"
assert.Equal(t, 64, p.packages.packages[constsPath].ConstTable["uintSize"].Value)
assert.Equal(t, int32(62), p.packages.packages[constsPath].ConstTable["maxBase"].Value)
assert.Equal(t, 8, p.packages.packages[constsPath].ConstTable["shlByLen"].Value)
assert.Equal(t, 255, p.packages.packages[constsPath].ConstTable["hexnum"].Value)
assert.Equal(t, 15, p.packages.packages[constsPath].ConstTable["octnum"].Value)
assert.Equal(t, `aa\nbb\u8888cc`, p.packages.packages[constsPath].ConstTable["nonescapestr"].Value)
assert.Equal(t, "aa\nbb\u8888cc", p.packages.packages[constsPath].ConstTable["escapestr"].Value)
assert.Equal(t, '\u8888', p.packages.packages[constsPath].ConstTable["escapechar"].Value)
}
115 changes: 69 additions & 46 deletions package.go
Expand Up @@ -3,6 +3,7 @@ package swag
import (
"go/ast"
"go/token"
"reflect"
"strconv"
)

Expand Down Expand Up @@ -31,6 +32,7 @@ type PackageDefinitions struct {
type ConstVariableGlobalEvaluator interface {
EvaluateConstValue(pkg *PackageDefinitions, cv *ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr)
EvaluateConstValueByName(file *ast.File, pkgPath, constVariableName string, recursiveStack map[string]struct{}) (interface{}, ast.Expr)
FindTypeSpec(typeName string, file *ast.File) *TypeSpecDef
}

// NewPackageDefinitions new a PackageDefinitions object
Expand Down Expand Up @@ -92,68 +94,89 @@ func (pkg *PackageDefinitions) evaluateConstValue(file *ast.File, iota int, expr
case *ast.BasicLit:
switch valueExpr.Kind {
case token.INT:
x, err := strconv.ParseInt(valueExpr.Value, 10, 64)
if err != nil {
return nil, nil
// hexadecimal
if len(valueExpr.Value) > 2 && valueExpr.Value[0] == '0' && valueExpr.Value[1] == 'x' {
if x, err := strconv.ParseInt(valueExpr.Value[2:], 16, 64); err == nil {
return int(x), nil
} else if x, err := strconv.ParseUint(valueExpr.Value[2:], 16, 64); err == nil {
return x, nil
} else {
panic(err)
}
}

//octet
if len(valueExpr.Value) > 1 && valueExpr.Value[0] == '0' {
if x, err := strconv.ParseInt(valueExpr.Value[1:], 8, 64); err == nil {
return int(x), nil
} else if x, err := strconv.ParseUint(valueExpr.Value[1:], 8, 64); err == nil {
return x, nil
} else {
panic(err)
}
}

//a basic literal integer is int type in default, or must have an explicit converting type in front
if x, err := strconv.ParseInt(valueExpr.Value, 10, 64); err == nil {
return int(x), nil
} else if x, err := strconv.ParseUint(valueExpr.Value, 10, 64); err == nil {
return x, nil
} else {
panic(err)
}
return int(x), nil
case token.STRING, token.CHAR:
return valueExpr.Value[1 : len(valueExpr.Value)-1], nil
case token.STRING:
if valueExpr.Value[0] == '`' {
return valueExpr.Value[1 : len(valueExpr.Value)-1], nil
}
return EvaluateEscapedString(valueExpr.Value[1 : len(valueExpr.Value)-1]), nil
case token.CHAR:
return EvaluateEscapedChar(valueExpr.Value[1 : len(valueExpr.Value)-1]), nil
}
case *ast.UnaryExpr:
x, evalType := pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack)
if x == nil {
return nil, nil
}
switch valueExpr.Op {
case token.SUB:
return -x.(int), evalType
case token.XOR:
return ^(x.(int)), evalType
return x, evalType
}
return EvaluateUnary(x, valueExpr.Op, evalType)
case *ast.BinaryExpr:
x, evalTypex := pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack)
y, evalTypey := pkg.evaluateConstValue(file, iota, valueExpr.Y, globalEvaluator, recursiveStack)
if x == nil || y == nil {
return nil, nil
}
evalType := evalTypex
if evalType == nil {
evalType = evalTypey
}
switch valueExpr.Op {
case token.ADD:
if ix, ok := x.(int); ok {
return ix + y.(int), evalType
} else if sx, ok := x.(string); ok {
return sx + y.(string), evalType
}
case token.SUB:
return x.(int) - y.(int), evalType
case token.MUL:
return x.(int) * y.(int), evalType
case token.QUO:
return x.(int) / y.(int), evalType
case token.REM:
return x.(int) % y.(int), evalType
case token.AND:
return x.(int) & y.(int), evalType
case token.OR:
return x.(int) | y.(int), evalType
case token.XOR:
return x.(int) ^ y.(int), evalType
case token.SHL:
return x.(int) << y.(int), evalType
case token.SHR:
return x.(int) >> y.(int), evalType
}
return EvaluateBinary(x, y, valueExpr.Op, evalTypex, evalTypey)
case *ast.ParenExpr:
return pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack)
case *ast.CallExpr:
//data conversion
if ident, ok := valueExpr.Fun.(*ast.Ident); ok && len(valueExpr.Args) == 1 && IsGolangPrimitiveType(ident.Name) {
arg, _ := pkg.evaluateConstValue(file, iota, valueExpr.Args[0], globalEvaluator, recursiveStack)
return arg, nil
if len(valueExpr.Args) != 1 {
return nil, nil
}
arg := valueExpr.Args[0]
if ident, ok := valueExpr.Fun.(*ast.Ident); ok {
name := ident.Name
if name == "uintptr" {
name = "uint"
}
if IsGolangPrimitiveType(name) {
value, _ := pkg.evaluateConstValue(file, iota, arg, globalEvaluator, recursiveStack)
value = EvaluateDataConversion(value, name)
return value, nil
} else if name == "len" {
value, _ := pkg.evaluateConstValue(file, iota, arg, globalEvaluator, recursiveStack)
return reflect.ValueOf(value).Len(), nil
}
typeDef := globalEvaluator.FindTypeSpec(name, file)
if typeDef == nil {
return nil, nil
}
return arg, valueExpr.Fun
} else if selector, ok := valueExpr.Fun.(*ast.SelectorExpr); ok {
typeDef := globalEvaluator.FindTypeSpec(fullTypeName(selector.X.(*ast.Ident).Name, selector.Sel.Name), file)
if typeDef == nil {
return nil, nil
}
return arg, typeDef.TypeSpec.Type
}
}
return nil, nil
Expand Down
9 changes: 9 additions & 0 deletions testdata/enums/consts/const.go
@@ -1,3 +1,12 @@
package consts

const Base = 1

const uintSize = 32 << (^uint(uintptr(0)) >> 63)
const maxBase = 10 + ('z' - 'a' + 1) + ('Z' - 'A' + 1)
const shlByLen = 1 << len("aaa")
const hexnum = 0xFF
const octnum = 017
const nonescapestr = `aa\nbb\u8888cc`
const escapestr = "aa\nbb\u8888cc"
const escapechar = '\u8888'
1 change: 0 additions & 1 deletion testdata/enums/main.go
Expand Up @@ -14,5 +14,4 @@ package main

// @BasePath /v2
func main() {

}
21 changes: 15 additions & 6 deletions testdata/enums/types/model.go
Expand Up @@ -11,8 +11,8 @@ const (
A Class = consts.Base + (iota+1-1)*2/2%100 - (1&1 | 1) + (2 ^ 2) // AAA
B /* BBB */
C
D
F = D + 1
D = C + 1
F = Class(5)
//G is not enum
G = H + 10
//H is not enum
Expand All @@ -21,13 +21,15 @@ const (
I = int(F + 2)
)

const J = 1 << uint16(I)

type Mask int

const (
Mask1 Mask = 2 << iota >> 1 // Mask1
Mask2 /* Mask2 */
Mask3 // Mask3
Mask4 // Mask4
Mask1 Mask = 0x02 << iota >> 1 // Mask1
Mask2 /* Mask2 */
Mask3 // Mask3
Mask4 // Mask4
)

type Type string
Expand All @@ -40,6 +42,13 @@ const (
OtherUnknown = string(Other + Unknown)
)

type Sex rune

const (
Male Sex = 'M'
Female = 'F'
)

type Person struct {
Name string
Class Class
Expand Down
31 changes: 31 additions & 0 deletions utils_go18.go
@@ -0,0 +1,31 @@
//go:build go1.18
// +build go1.18

package swag

import (
"reflect"
"unicode/utf8"
)

// AppendUtf8Rune appends the UTF-8 encoding of r to the end of p and
// returns the extended buffer. If the rune is out of range,
// it appends the encoding of RuneError.
func AppendUtf8Rune(p []byte, r rune) []byte {
return utf8.AppendRune(p, r)
}

// CanIntegerValue a wrapper of reflect.Value
type CanIntegerValue struct {
reflect.Value
}

// CanInt reports whether Uint can be used without panicking.
func (v CanIntegerValue) CanInt() bool {
return v.Value.CanInt()
}

// CanUint reports whether Uint can be used without panicking.
func (v CanIntegerValue) CanUint() bool {
return v.Value.CanUint()
}
47 changes: 47 additions & 0 deletions utils_other.go
@@ -0,0 +1,47 @@
//go:build !go1.18
// +build !go1.18

package swag

import (
"reflect"
"unicode/utf8"
)

// AppendUtf8Rune appends the UTF-8 encoding of r to the end of p and
// returns the extended buffer. If the rune is out of range,
// it appends the encoding of RuneError.
func AppendUtf8Rune(p []byte, r rune) []byte {
length := utf8.RuneLen(rune(r))
if length > 0 {
utf8Slice := make([]byte, length)
utf8.EncodeRune(utf8Slice, rune(r))
p = append(p, utf8Slice...)
}
return p
}

// CanIntegerValue a wrapper of reflect.Value
type CanIntegerValue struct {
reflect.Value
}

// CanInt reports whether Uint can be used without panicking.
func (v CanIntegerValue) CanInt() bool {
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
default:
return false
}
}

// CanUint reports whether Uint can be used without panicking.
func (v CanIntegerValue) CanUint() bool {
switch v.Kind() {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return true
default:
return false
}
}