From 583c3341c5747bc021cd223fc98a4045c8a7b139 Mon Sep 17 00:00:00 2001 From: Jacob Ryan McCollum Date: Mon, 13 Jun 2022 02:11:08 -0400 Subject: [PATCH] Allow custom & union productions (#233) Implements #229 --- _examples/expr3/main.go | 134 ++++++++++++++++++++++++++++++++++ _examples/expr3/main_test.go | 49 +++++++++++++ _examples/expr4/main.go | 136 +++++++++++++++++++++++++++++++++++ _examples/expr4/main_test.go | 61 ++++++++++++++++ context.go | 4 +- ebnf.go | 22 ++++++ ebnf_test.go | 38 ++++++++++ grammar.go | 32 +++++++++ nodes.go | 43 ++++++++++- options.go | 67 +++++++++++++++++ parser.go | 19 +++++ parser_test.go | 121 +++++++++++++++++++++++++++++++ trace.go | 5 ++ visit.go | 9 +++ 14 files changed, 738 insertions(+), 2 deletions(-) create mode 100644 _examples/expr3/main.go create mode 100644 _examples/expr3/main_test.go create mode 100644 _examples/expr4/main.go create mode 100644 _examples/expr4/main_test.go diff --git a/_examples/expr3/main.go b/_examples/expr3/main.go new file mode 100644 index 00000000..f49cc7b2 --- /dev/null +++ b/_examples/expr3/main.go @@ -0,0 +1,134 @@ +package main + +import ( + "strings" + + "github.com/alecthomas/kong" + "github.com/alecthomas/participle/v2" + "github.com/alecthomas/repr" +) + +type ( + ExprString struct { + Value string `@String` + } + + ExprNumber struct { + Value float64 `@Int | @Float` + } + + ExprIdent struct { + Name string `@Ident` + } + + ExprParens struct { + Inner ExprPrecAll `"(" @@ ")"` + } + + ExprUnary struct { + Op string `@("-" | "!")` + Expr ExprOperand `@@` + } + + ExprAddSub struct { + Head ExprPrec2 `@@` + Tail []ExprAddSubExt `@@+` + } + + ExprAddSubExt struct { + Op string `@("+" | "-")` + Expr ExprPrec2 `@@` + } + + ExprMulDiv struct { + Head ExprPrec3 `@@` + Tail []ExprMulDivExt `@@+` + } + + ExprMulDivExt struct { + Op string `@("*" | "/")` + Expr ExprPrec3 `@@` + } + + ExprRem struct { + Head ExprOperand `@@` + Tail []ExprRemExt `@@+` + } + + ExprRemExt struct { + Op string `@"%"` + Expr ExprOperand `@@` + } + + ExprPrecAll interface{ exprPrecAll() } + ExprPrec2 interface{ exprPrec2() } + ExprPrec3 interface{ exprPrec3() } + ExprOperand interface{ exprOperand() } +) + +// These expression types can be matches as individual operands +func (ExprIdent) exprOperand() {} +func (ExprNumber) exprOperand() {} +func (ExprString) exprOperand() {} +func (ExprParens) exprOperand() {} +func (ExprUnary) exprOperand() {} + +// These expression types can be matched at precedence level 3 +func (ExprIdent) exprPrec3() {} +func (ExprNumber) exprPrec3() {} +func (ExprString) exprPrec3() {} +func (ExprParens) exprPrec3() {} +func (ExprUnary) exprPrec3() {} +func (ExprRem) exprPrec3() {} + +// These expression types can be matched at precedence level 2 +func (ExprIdent) exprPrec2() {} +func (ExprNumber) exprPrec2() {} +func (ExprString) exprPrec2() {} +func (ExprParens) exprPrec2() {} +func (ExprUnary) exprPrec2() {} +func (ExprRem) exprPrec2() {} +func (ExprMulDiv) exprPrec2() {} + +// These expression types can be matched at the minimum precedence level +func (ExprIdent) exprPrecAll() {} +func (ExprNumber) exprPrecAll() {} +func (ExprString) exprPrecAll() {} +func (ExprParens) exprPrecAll() {} +func (ExprUnary) exprPrecAll() {} +func (ExprRem) exprPrecAll() {} +func (ExprMulDiv) exprPrecAll() {} +func (ExprAddSub) exprPrecAll() {} + +type Expression struct { + X ExprPrecAll `@@` +} + +var parser = participle.MustBuild(&Expression{}, + // This grammar requires enough lookahead to see the entire expression before + // it can select the proper binary expression type - in other words, we only + // know that `1 * 2 * 3 * 4` isn't the left-hand side of an addition or subtraction + // expression until we know for sure that no `+` or `-` operator follows it + participle.UseLookahead(99999), + // Register the ExprOperand union so we can parse individual operands + participle.ParseUnion[ExprOperand](ExprUnary{}, ExprIdent{}, ExprNumber{}, ExprString{}, ExprParens{}), + // Register the ExprPrec3 union so we can parse expressions at precedence level 3 + participle.ParseUnion[ExprPrec3](ExprRem{}, ExprUnary{}, ExprIdent{}, ExprNumber{}, ExprString{}, ExprParens{}), + // Register the ExprPrec2 union so we can parse expressions at precedence level 2 + participle.ParseUnion[ExprPrec2](ExprMulDiv{}, ExprRem{}, ExprUnary{}, ExprIdent{}, ExprNumber{}, ExprString{}, ExprParens{}), + // Register the ExprPrecAll union so we can parse expressions at the minimum precedence level + participle.ParseUnion[ExprPrecAll](ExprAddSub{}, ExprMulDiv{}, ExprRem{}, ExprUnary{}, ExprIdent{}, ExprNumber{}, ExprString{}, ExprParens{}), +) + +func main() { + var cli struct { + Expr []string `arg required help:"Expression to parse."` + } + ctx := kong.Parse(&cli) + + expr := &Expression{} + err := parser.ParseString("", strings.Join(cli.Expr, " "), expr) + ctx.FatalIfErrorf(err) + + repr.Println(expr) +} diff --git a/_examples/expr3/main_test.go b/_examples/expr3/main_test.go new file mode 100644 index 00000000..8e2e2e97 --- /dev/null +++ b/_examples/expr3/main_test.go @@ -0,0 +1,49 @@ +package main + +import ( + "testing" + + require "github.com/alecthomas/assert/v2" +) + +func TestExpressionParser(t *testing.T) { + type testCase struct { + src string + expected ExprPrecAll + } + + for _, c := range []testCase{ + {`1`, ExprNumber{1}}, + {`1.5`, ExprNumber{1.5}}, + {`"a"`, ExprString{`"a"`}}, + {`(1)`, ExprParens{ExprNumber{1}}}, + {`1 + 1`, ExprAddSub{ExprNumber{1}, []ExprAddSubExt{{"+", ExprNumber{1}}}}}, + {`1 - 1`, ExprAddSub{ExprNumber{1}, []ExprAddSubExt{{"-", ExprNumber{1}}}}}, + {`1 * 1`, ExprMulDiv{ExprNumber{1}, []ExprMulDivExt{{"*", ExprNumber{1}}}}}, + {`1 / 1`, ExprMulDiv{ExprNumber{1}, []ExprMulDivExt{{"/", ExprNumber{1}}}}}, + {`1 % 1`, ExprRem{ExprNumber{1}, []ExprRemExt{{"%", ExprNumber{1}}}}}, + { + `a + b - c * d / e % f`, + ExprAddSub{ + ExprIdent{"a"}, + []ExprAddSubExt{ + {"+", ExprIdent{"b"}}, + {"-", ExprMulDiv{ + ExprIdent{"c"}, + []ExprMulDivExt{ + {"*", ExprIdent{Name: "d"}}, + {"/", ExprRem{ + ExprIdent{"e"}, + []ExprRemExt{{"%", ExprIdent{"f"}}}, + }}, + }, + }}, + }, + }, + }, + } { + var actual Expression + require.NoError(t, parser.ParseString("", c.src, &actual)) + require.Equal(t, c.expected, actual.X) + } +} diff --git a/_examples/expr4/main.go b/_examples/expr4/main.go new file mode 100644 index 00000000..450f1652 --- /dev/null +++ b/_examples/expr4/main.go @@ -0,0 +1,136 @@ +package main + +import ( + "fmt" + "strconv" + "strings" + "text/scanner" + + "github.com/alecthomas/kong" + "github.com/alecthomas/participle/v2" + "github.com/alecthomas/participle/v2/lexer" + "github.com/alecthomas/repr" +) + +type operatorPrec struct{ Left, Right int } + +var operatorPrecs = map[string]operatorPrec{ + "+": {1, 1}, + "-": {1, 1}, + "*": {3, 2}, + "/": {5, 4}, + "%": {7, 6}, +} + +type ( + Expr interface{ expr() } + + ExprIdent struct{ Name string } + ExprString struct{ Value string } + ExprNumber struct{ Value float64 } + ExprParens struct{ Sub Expr } + + ExprUnary struct { + Op string + Sub Expr + } + + ExprBinary struct { + Lhs Expr + Op string + Rhs Expr + } +) + +func (ExprIdent) expr() {} +func (ExprString) expr() {} +func (ExprNumber) expr() {} +func (ExprParens) expr() {} +func (ExprUnary) expr() {} +func (ExprBinary) expr() {} + +func parseExprAny(lex *lexer.PeekingLexer) (Expr, error) { return parseExprPrec(lex, 0) } + +func parseExprAtom(lex *lexer.PeekingLexer) (Expr, error) { + switch peek := lex.Peek(); { + case peek.Type == scanner.Ident: + return ExprIdent{lex.Next().Value}, nil + case peek.Type == scanner.String: + val, err := strconv.Unquote(lex.Next().Value) + if err != nil { + return nil, err + } + return ExprString{val}, nil + case peek.Type == scanner.Int || peek.Type == scanner.Float: + val, err := strconv.ParseFloat(lex.Next().Value, 64) + if err != nil { + return nil, err + } + return ExprNumber{val}, nil + case peek.Value == "(": + _ = lex.Next() + inner, err := parseExprAny(lex) + if err != nil { + return nil, err + } + if lex.Peek().Value != ")" { + return nil, fmt.Errorf("expected closing ')'") + } + _ = lex.Next() + return ExprParens{inner}, nil + default: + return nil, participle.NextMatch + } +} + +func parseExprPrec(lex *lexer.PeekingLexer, minPrec int) (Expr, error) { + var lhs Expr + if peeked := lex.Peek(); peeked.Value == "-" || peeked.Value == "!" { + op := lex.Next().Value + atom, err := parseExprAtom(lex) + if err != nil { + return nil, err + } + lhs = ExprUnary{op, atom} + } else { + atom, err := parseExprAtom(lex) + if err != nil { + return nil, err + } + lhs = atom + } + + for { + peek := lex.Peek() + prec, isOp := operatorPrecs[peek.Value] + if !isOp || prec.Left < minPrec { + break + } + op := lex.Next().Value + rhs, err := parseExprPrec(lex, prec.Right) + if err != nil { + return nil, err + } + lhs = ExprBinary{lhs, op, rhs} + } + return lhs, nil +} + +type Expression struct { + X Expr `@@` +} + +var parser = participle.MustBuild(&Expression{}, participle.ParseTypeWith(parseExprAny)) + +func main() { + var cli struct { + Expr []string `arg required help:"Expression to parse."` + } + ctx := kong.Parse(&cli) + + expr := &Expression{} + err := parser.ParseString("", strings.Join(cli.Expr, " "), expr) + ctx.FatalIfErrorf(err) + + repr.Println(expr) +} diff --git a/_examples/expr4/main_test.go b/_examples/expr4/main_test.go new file mode 100644 index 00000000..e2444920 --- /dev/null +++ b/_examples/expr4/main_test.go @@ -0,0 +1,61 @@ +package main + +import ( + "testing" + + require "github.com/alecthomas/assert/v2" +) + +func TestCustomExprParser(t *testing.T) { + type testCase struct { + src string + expected Expr + } + + for _, c := range []testCase{ + {`1`, ExprNumber{1}}, + {`1.5`, ExprNumber{1.5}}, + {`"a"`, ExprString{"a"}}, + {`(1)`, ExprParens{ExprNumber{1}}}, + {`1+1`, ExprBinary{ExprNumber{1}, "+", ExprNumber{1}}}, + {`1-1`, ExprBinary{ExprNumber{1}, "-", ExprNumber{1}}}, + {`1*1`, ExprBinary{ExprNumber{1}, "*", ExprNumber{1}}}, + {`1/1`, ExprBinary{ExprNumber{1}, "/", ExprNumber{1}}}, + {`1%1`, ExprBinary{ExprNumber{1}, "%", ExprNumber{1}}}, + {`a - -b`, ExprBinary{ExprIdent{"a"}, "-", ExprUnary{"-", ExprIdent{"b"}}}}, + { + `a + b - c * d / e % f`, + ExprBinary{ + ExprIdent{"a"}, "+", ExprBinary{ + ExprIdent{"b"}, "-", ExprBinary{ + ExprIdent{"c"}, "*", ExprBinary{ + ExprIdent{"d"}, "/", ExprBinary{ + ExprIdent{"e"}, "%", ExprIdent{"f"}, + }, + }, + }, + }, + }, + }, + { + `a * b + c * d`, + ExprBinary{ + ExprBinary{ExprIdent{"a"}, "*", ExprIdent{"b"}}, + "+", + ExprBinary{ExprIdent{"c"}, "*", ExprIdent{"d"}}, + }, + }, + { + `(a + b) * (c + d)`, + ExprBinary{ + ExprParens{ExprBinary{ExprIdent{"a"}, "+", ExprIdent{"b"}}}, + "*", + ExprParens{ExprBinary{ExprIdent{"c"}, "+", ExprIdent{"d"}}}, + }, + }, + } { + var actual Expression + require.NoError(t, parser.ParseString("", c.src, &actual)) + require.Equal(t, c.expected, actual.X) + } +} diff --git a/context.go b/context.go index e31650fb..17484dd6 100644 --- a/context.go +++ b/context.go @@ -94,13 +94,15 @@ func (p *parseContext) Stop(err error, branch *parseContext) bool { p.deepestError = err p.deepestErrorDepth = maxInt(branch.PeekingLexer.Cursor(), branch.deepestErrorDepth) } - if branch.PeekingLexer.Cursor() > p.PeekingLexer.Cursor()+p.lookahead { + if !p.hasInfiniteLookahead() && branch.PeekingLexer.Cursor() > p.PeekingLexer.Cursor()+p.lookahead { p.Accept(branch) return true } return false } +func (p *parseContext) hasInfiniteLookahead() bool { return p.lookahead < 0 } + func maxInt(a, b int) int { if a > b { return a diff --git a/ebnf.go b/ebnf.go index 049a1690..2df523e0 100644 --- a/ebnf.go +++ b/ebnf.go @@ -51,6 +51,28 @@ func buildEBNF(root bool, n node, seen map[node]bool, p *ebnfp, outp *[]*ebnfp) p.out += ")" } + case *union: + name := strings.ToUpper(n.typ.Name()[:1]) + n.typ.Name()[1:] + if p != nil { + p.out += name + } + if seen[n] { + return + } + p = &ebnfp{name: name} + *outp = append(*outp, p) + seen[n] = true + for i, next := range n.members { + if i > 0 { + p.out += " | " + } + buildEBNF(false, next, seen, p, outp) + } + + case *custom: + name := strings.ToUpper(n.typ.Name()[:1]) + n.typ.Name()[1:] + p.out += name + case *strct: name := strings.ToUpper(n.typ.Name()[:1]) + n.typ.Name()[1:] if p != nil { diff --git a/ebnf_test.go b/ebnf_test.go index ae5ce439..0054a8cb 100644 --- a/ebnf_test.go +++ b/ebnf_test.go @@ -5,6 +5,7 @@ import ( "testing" require "github.com/alecthomas/assert/v2" + "github.com/alecthomas/participle/v2" ) func TestEBNF(t *testing.T) { @@ -37,3 +38,40 @@ func TestEBNF_Other(t *testing.T) { expected := `Grammar = ((?= "good") ) | ((?! "bad" | "worse") ) | ~("anything" | "but") .` require.Equal(t, expected, parser.String()) } + +type ( + EBNFUnion interface{ ebnfUnion() } + + EBNFUnionA struct { + A string `@Ident` + } + + EBNFUnionB struct { + B string `@String` + } + + EBNFUnionC struct { + C string `@Float` + } +) + +func (EBNFUnionA) ebnfUnion() {} +func (EBNFUnionB) ebnfUnion() {} +func (EBNFUnionC) ebnfUnion() {} + +func TestEBNF_Union(t *testing.T) { + type Grammar struct { + TheUnion EBNFUnion `@@` + } + + parser := mustTestParser(t, &Grammar{}, participle.ParseUnion[EBNFUnion](EBNFUnionA{}, EBNFUnionB{}, EBNFUnionC{})) + require.Equal(t, + strings.TrimSpace(` +Grammar = EBNFUnion . +EBNFUnion = EBNFUnionA | EBNFUnionB | EBNFUnionC . +EBNFUnionA = . +EBNFUnionB = . +EBNFUnionC = . +`), + parser.String()) +} diff --git a/grammar.go b/grammar.go index a9944e78..43f6122c 100644 --- a/grammar.go +++ b/grammar.go @@ -22,6 +22,38 @@ func newGeneratorContext(lex lexer.Definition) *generatorContext { } } +func (g *generatorContext) addUnionDefs(defs []unionDef) error { + unionNodes := make([]*union, len(defs)) + for i, def := range defs { + if _, exists := g.typeNodes[def.typ]; exists { + return fmt.Errorf("duplicate definition for interface or union type %s", def.typ) + } + unionNode := &union{def.typ, make([]node, 0, len(def.members))} + g.typeNodes[def.typ], unionNodes[i] = unionNode, unionNode + } + for i, def := range defs { + unionNode := unionNodes[i] + for _, memberType := range def.members { + memberNode, err := g.parseType(memberType) + if err != nil { + return err + } + unionNode.members = append(unionNode.members, memberNode) + } + } + return nil +} + +func (g *generatorContext) addCustomDefs(defs []customDef) error { + for _, def := range defs { + if _, exists := g.typeNodes[def.typ]; exists { + return fmt.Errorf("duplicate definition for interface or union type %s", def.typ) + } + g.typeNodes[def.typ] = &custom{typ: def.typ, parseFn: def.parseFn} + } + return nil +} + // Takes a type and builds a tree of nodes out of it. func (g *generatorContext) parseType(t reflect.Type) (_ node, returnedError error) { t = indirectType(t) diff --git a/nodes.go b/nodes.go index 60635c61..a7f1a25b 100644 --- a/nodes.go +++ b/nodes.go @@ -72,6 +72,47 @@ func (p *parseable) Parse(ctx *parseContext, parent reflect.Value) (out []reflec return []reflect.Value{rv.Elem()}, nil } +// @@ (but for a custom production) +type custom struct { + typ reflect.Type + parseFn reflect.Value +} + +func (c *custom) String() string { return ebnf(c) } +func (c *custom) GoString() string { return c.typ.Name() } + +func (c *custom) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) { + results := c.parseFn.Call([]reflect.Value{reflect.ValueOf(ctx.PeekingLexer)}) + if err, _ := results[1].Interface().(error); err != nil { + if err == NextMatch { + return nil, nil + } + return nil, err + } + return []reflect.Value{results[0]}, nil +} + +// @@ (for a union) +type union struct { + typ reflect.Type + members []node +} + +func (u *union) String() string { return ebnf(u) } +func (u *union) GoString() string { return u.typ.Name() } + +func (u *union) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) { + temp := disjunction{u.members} + vals, err := temp.Parse(ctx, parent) + if err != nil { + return nil, err + } + for i := range vals { + vals[i] = vals[i].Convert(u.typ) + } + return vals, nil +} + // @@ type strct struct { typ reflect.Type @@ -710,7 +751,7 @@ func setField(tokens []lexer.Token, strct reflect.Value, field structLexerField, f.Set(fv) } - case reflect.Bool, reflect.Struct: + case reflect.Bool, reflect.Struct, reflect.Interface: if f.Kind() == reflect.Bool && fv.Kind() == reflect.Bool { f.SetBool(fv.Bool()) break diff --git a/options.go b/options.go index 4ac0a30f..8a4e23cc 100644 --- a/options.go +++ b/options.go @@ -1,9 +1,17 @@ package participle import ( + "fmt" + "reflect" + "github.com/alecthomas/participle/v2/lexer" ) +// MaxLookahead can be used with UseLookahead to get pseudo-infinite +// lookahead without the risk of pathological cases causing a stack +// overflow. +const MaxLookahead = 99999 + // An Option to modify the behaviour of the Parser. type Option func(p *Parser) error @@ -21,6 +29,14 @@ func Lexer(def lexer.Definition) Option { // // Note that increasing lookahead has a minor performance impact, but also // reduces the accuracy of error reporting. +// +// If "n" is negative, it will be treated as "infinite" lookahead. +// This can have a large impact on performance, and does not provide any +// protection against stack overflow during parsing. +// In most cases, using MaxLookahead will achieve the same results in practice, +// but with a concrete upper bound to prevent pathological behavior in the parser. +// Using infinite lookahead can be useful for testing, or for parsing especially +// ambiguous grammars. Use at your own risk! func UseLookahead(n int) Option { return func(p *Parser) error { p.useLookahead = n @@ -41,6 +57,57 @@ func CaseInsensitive(tokens ...string) Option { } } +// ParseTypeWith associates a custom parsing function with some interface type T. +// When the parser encounters a value of type T, it will use the given parse function to +// parse a value from the input. +// +// The parse function may return anything it wishes as long as that value satisfies the interface T. +// However, only a single function can be defined for any type T. +// If you want to have multiple parse functions returning types that satisfy the same interface, you'll +// need to define new wrapper types for each one. +// +// This can be useful if you want to parse a DSL within the larger grammar, or if you want +// to implement an optimized parsing scheme for some portion of the grammar. +func ParseTypeWith[T any](parseFn func(*lexer.PeekingLexer) (T, error)) Option { + return func(p *Parser) error { + parseFnVal := reflect.ValueOf(parseFn) + parseFnType := parseFnVal.Type() + if parseFnType.Out(0).Kind() != reflect.Interface { + return fmt.Errorf("ParseTypeWith: T must be an interface type (got %s)", parseFnType.Out(0)) + } + prodType := parseFnType.Out(0) + p.customDefs = append(p.customDefs, customDef{prodType, parseFnVal}) + return nil + } +} + +// ParseUnion associates several member productions with some interface type T. +// Given members X, Y, Z, and W for a union type U, the the EBNF rule is: +// U = X | Y | Z | W . +// When the parser encounters a field of type T, it will attempt to parse each member +// in sequence and take the first matche. Because of this, the order in which the +// members are defined is important. You must be careful to order your members appropriately. +// +// An example of a bad parse that can happen if members are out of order: +// +// If the first member matches A, and the second member matches A B, +// and he source string is "AB", then the parser will only match A, and will not +// try to parse the second member at all. +func ParseUnion[T any](members ...T) Option { + return func(p *Parser) error { + unionType := reflect.TypeOf((*T)(nil)).Elem() + if unionType.Kind() != reflect.Interface { + return fmt.Errorf("ParseUnion: union type must be an interface (got %s)", unionType) + } + memberTypes := make([]reflect.Type, 0, len(members)) + for _, m := range members { + memberTypes = append(memberTypes, reflect.TypeOf(m)) + } + p.unionDefs = append(p.unionDefs, unionDef{unionType, memberTypes}) + return nil + } +} + // ParseOption modifies how an individual parse is applied. type ParseOption func(p *parseContext) diff --git a/parser.go b/parser.go index 3d74d67c..0a1a71dd 100644 --- a/parser.go +++ b/parser.go @@ -10,6 +10,16 @@ import ( "github.com/alecthomas/participle/v2/lexer" ) +type unionDef struct { + typ reflect.Type + members []reflect.Type +} + +type customDef struct { + typ reflect.Type + parseFn reflect.Value +} + // A Parser for a particular grammar and lexer. type Parser struct { root node @@ -19,6 +29,8 @@ type Parser struct { useLookahead int caseInsensitive map[string]bool mappers []mapperByToken + unionDefs []unionDef + customDefs []customDef elide []string } @@ -83,6 +95,13 @@ func Build(grammar interface{}, options ...Option) (parser *Parser, err error) { } context := newGeneratorContext(p.lex) + if err := context.addCustomDefs(p.customDefs); err != nil { + return nil, err + } + if err := context.addUnionDefs(p.unionDefs); err != nil { + return nil, err + } + v := reflect.ValueOf(grammar) if v.Kind() == reflect.Interface { v = v.Elem() diff --git a/parser_test.go b/parser_test.go index 8168720d..d06d2eba 100644 --- a/parser_test.go +++ b/parser_test.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" "testing" + "text/scanner" require "github.com/alecthomas/assert/v2" "github.com/alecthomas/participle/v2" @@ -1727,3 +1728,123 @@ func TestRootParseableFail(t *testing.T) { err := p.ParseString("", "blah", &RootParseableFail{}) require.EqualError(t, err, ":1:1: always fail immediately") } + +type ( + TestCustom interface{ isTestCustom() } + + CustomIdent string + CustomNumber float64 + CustomBoolean bool +) + +func (CustomIdent) isTestCustom() {} +func (CustomNumber) isTestCustom() {} +func (CustomBoolean) isTestCustom() {} + +func TestParserWithCustomProduction(t *testing.T) { + type grammar struct { + Custom TestCustom `@@` + } + + p := mustTestParser(t, &grammar{}, participle.ParseTypeWith(func(lex *lexer.PeekingLexer) (TestCustom, error) { + switch peek := lex.Peek(); { + case peek.Type == scanner.Int || peek.Type == scanner.Float: + v, err := strconv.ParseFloat(lex.Next().Value, 64) + if err != nil { + return nil, err + } + return CustomNumber(v), nil + case peek.Type == scanner.Ident: + name := lex.Next().Value + if name == "true" || name == "false" { + return CustomBoolean(name == "true"), nil + } + return CustomIdent(name), nil + default: + return nil, participle.NextMatch + } + })) + + type testCase struct { + src string + expected TestCustom + } + + for _, c := range []testCase{ + {"a", CustomIdent("a")}, + {"12.5", CustomNumber(12.5)}, + {"true", CustomBoolean(true)}, + {"false", CustomBoolean(false)}, + } { + var actual grammar + require.NoError(t, p.ParseString("", c.src, &actual)) + require.Equal(t, c.expected, actual.Custom) + } + + require.Equal(t, `Grammar = TestCustom .`, p.String()) +} + +type ( + TestUnionA interface{ isTestUnionA() } + TestUnionB interface{ isTestUnionB() } + + AMember1 struct { + V string `@Ident` + } + + AMember2 struct { + V TestUnionB `"[" @@ "]"` + } + + BMember1 struct { + V float64 `@Int | @Float` + } + + BMember2 struct { + V TestUnionA `"{" @@ "}"` + } +) + +func (AMember1) isTestUnionA() {} +func (AMember2) isTestUnionA() {} + +func (BMember1) isTestUnionB() {} +func (BMember2) isTestUnionB() {} + +func TestParserWithUnion(t *testing.T) { + type grammar struct { + A TestUnionA `@@` + B TestUnionB `| @@` + } + + parser := mustTestParser(t, &grammar{}, participle.UseLookahead(10), + participle.ParseUnion[TestUnionA](AMember1{}, AMember2{}), + participle.ParseUnion[TestUnionB](BMember1{}, BMember2{})) + + type testCase struct { + src string + expected grammar + } + + for _, c := range []testCase{ + {`a`, grammar{A: AMember1{"a"}}}, + {`1.5`, grammar{B: BMember1{1.5}}}, + {`[2.5]`, grammar{A: AMember2{BMember1{2.5}}}}, + {`{x}`, grammar{B: BMember2{AMember1{"x"}}}}, + {`{ [ { [12] } ] }`, grammar{B: BMember2{AMember2{BMember2{AMember2{BMember1{12}}}}}}}, + } { + var actual grammar + require.NoError(t, parser.ParseString("", c.src, &actual)) + require.Equal(t, c.expected, actual) + } + + require.Equal(t, strings.TrimSpace(` +Grammar = TestUnionA | TestUnionB . +TestUnionA = AMember1 | AMember2 . +AMember1 = . +AMember2 = "[" TestUnionB "]" . +TestUnionB = BMember1 | BMember2 . +BMember1 = | . +BMember2 = "{" TestUnionA "}" . + `), parser.String()) +} diff --git a/trace.go b/trace.go index 959a3995..4f1d239e 100644 --- a/trace.go +++ b/trace.go @@ -34,12 +34,17 @@ func injectTrace(w io.Writer, indent int, n node) node { for i, child := range n.nodes { n.nodes[i] = injectTrace(w, indent+2, child) } + case *union: + for i, child := range n.members { + n.members[i] = injectTrace(w, indent+2, child) + } case *strct: n.expr = injectTrace(w, indent+2, n.expr) case *sequence: n.node = injectTrace(w, indent+2, n.node) // injectTrace(w, indent, n.next) case *parseable: + case *custom: case *capture: n.node = injectTrace(w, indent+2, n.node) case *reference: diff --git a/visit.go b/visit.go index e3254d62..9371d0d0 100644 --- a/visit.go +++ b/visit.go @@ -17,6 +17,15 @@ func visit(n node, visitor func(n node, next func() error) error) error { return nil case *strct: return visit(n.expr, visitor) + case *custom: + return nil + case *union: + for _, member := range n.members { + if err := visit(member, visitor); err != nil { + return err + } + } + return nil case *sequence: if err := visit(n.node, visitor); err != nil { return err