diff --git a/decode_test.go b/decode_test.go index fb13b9a..c929195 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2144,6 +2144,29 @@ b: *a t.Fatal("failed to unmarshal") } }) + t.Run("quoted map keys", func(t *testing.T) { + t.Parallel() + yml := ` +a: + "b": 2 + 'c': true +` + var v struct { + A struct { + B int + C bool + } + } + if err := yaml.Unmarshal([]byte(yml), &v); err != nil { + t.Fatalf("failed to unmarshal %v", err) + } + if v.A.B != 2 { + t.Fatalf("expected a.b to equal 2 but was %d", v.A.B) + } + if !v.A.C { + t.Fatal("expected a.c to be true but was false") + } + }) } type unmarshalablePtrStringContainer struct { diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index bbe4e2f..03b741f 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -56,6 +56,10 @@ func TestTokenize(t *testing.T) { "a: 'Hello #comment'\n", "a: 100.5\n", "a: bogus\n", + "\"a\": double quoted map key", + "'a': single quoted map key", + "a: \"double quoted\"\nb: \"value map\"", + "a: 'single quoted'\nb: 'value map'", } for _, src := range sources { lexer.Tokenize(src).Dump() @@ -231,6 +235,60 @@ func TestSingleLineToken_ValueLineColumnPosition(t *testing.T) { 15: "]", }, }, + { + name: "double quote key", + src: `"a": b`, + expect: map[int]string{ + 1: "a", + 4: ":", + 6: "b", + }, + }, + { + name: "single quote key", + src: `'a': b`, + expect: map[int]string{ + 1: "a", + 4: ":", + 6: "b", + }, + }, + { + name: "double quote key and value", + src: `"a": "b"`, + expect: map[int]string{ + 1: "a", + 4: ":", + 6: "b", + }, + }, + { + name: "single quote key and value", + src: `'a': 'b'`, + expect: map[int]string{ + 1: "a", + 4: ":", + 6: "b", + }, + }, + { + name: "double quote key, single quote value", + src: `"a": 'b'`, + expect: map[int]string{ + 1: "a", + 4: ":", + 6: "b", + }, + }, + { + name: "single quote key, double quote value", + src: `'a': "b"`, + expect: map[int]string{ + 1: "a", + 4: ":", + 6: "b", + }, + }, } for _, tc := range tests { @@ -432,6 +490,59 @@ foo2: 'bar2'`, }, }, }, + { + name: "single and double quote map keys", + src: `"a": test +'b': 1 +c: true`, + expect: []testToken{ + { + line: 1, + column: 1, + value: "a", + }, + { + line: 1, + column: 4, + value: ":", + }, + { + line: 1, + column: 6, + value: "test", + }, + { + line: 2, + column: 1, + value: "b", + }, + { + line: 2, + column: 4, + value: ":", + }, + { + line: 2, + column: 6, + value: "1", + }, + { + line: 3, + column: 1, + value: "c", + }, + { + line: 3, + column: 2, + value: ":", + }, + { + line: 3, + column: 4, + value: "true", + }, + }, + }, } for _, tc := range tests { diff --git a/parser/parser_test.go b/parser/parser_test.go index 595e0a3..f284f45 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -80,6 +80,8 @@ func TestParser(t *testing.T) { ? !!str "implicit" : !!str "entry", ? !!null "" : !!null "", }`, + "\"a\": a\n\"b\": b", + "'a': a\n'b': b", } for _, src := range sources { if _, err := parser.Parse(lexer.Tokenize(src), 0); err != nil { @@ -562,6 +564,22 @@ b: c ` - key1: val key2: ( foo + bar ) +`, + }, + { + ` +"a": b +'c': d +"e": "f" +g: "h" +i: 'j' +`, + ` +"a": b +'c': d +"e": "f" +g: "h" +i: 'j' `, }, } diff --git a/scanner/scanner.go b/scanner/scanner.go index 1e09190..08bdffe 100644 --- a/scanner/scanner.go +++ b/scanner/scanner.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/goccy/go-yaml/token" + "golang.org/x/xerrors" ) @@ -110,6 +111,15 @@ func (s *Scanner) isNewLineChar(c rune) bool { return false } +func (s *Scanner) isWhitespaceBuffer(ctx *Context) bool { + allWhitespace := true + bufContent := ctx.bufferedSrc() + for _, r := range bufContent { + allWhitespace = allWhitespace && r == ' ' + } + return allWhitespace && len(bufContent) > 0 +} + func (s *Scanner) newLineCount(src []rune) int { size := len(src) cnt := 0 @@ -221,9 +231,11 @@ func (s *Scanner) scanSingleQuote(ctx *Context) (tk *token.Token, pos int) { value := []rune{} isFirstLineChar := false isNewLine := false + columnSkip := 0 for idx := startIndex; idx < size; idx++ { if !isNewLine { - s.progressColumn(ctx, 1) + s.progressColumn(ctx, 1+columnSkip) + columnSkip = 0 } else { isNewLine = false } @@ -248,6 +260,7 @@ func (s *Scanner) scanSingleQuote(ctx *Context) (tk *token.Token, pos int) { value = append(value, c) ctx.addOriginBuf(c) idx++ + columnSkip = 1 continue } s.progressColumn(ctx, 1) @@ -285,9 +298,11 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { value := []rune{} isFirstLineChar := false isNewLine := false + columnSkip := 0 for idx := startIndex; idx < size; idx++ { if !isNewLine { - s.progressColumn(ctx, 1) + s.progressColumn(ctx, 1+columnSkip) + columnSkip = 0 } else { isNewLine = false } @@ -311,51 +326,61 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { ctx.addOriginBuf(nextChar) value = append(value, '\b') idx++ + columnSkip = 1 continue case 'e': ctx.addOriginBuf(nextChar) value = append(value, '\x1B') idx++ + columnSkip = 1 continue case 'f': ctx.addOriginBuf(nextChar) value = append(value, '\f') idx++ + columnSkip = 1 continue case 'n': ctx.addOriginBuf(nextChar) value = append(value, '\n') idx++ + columnSkip = 1 continue case 'v': ctx.addOriginBuf(nextChar) value = append(value, '\v') idx++ + columnSkip = 1 continue case 'L': // LS (#x2028) ctx.addOriginBuf(nextChar) value = append(value, []rune{'\xE2', '\x80', '\xA8'}...) idx++ + columnSkip = 1 continue case 'N': // NEL (#x85) ctx.addOriginBuf(nextChar) value = append(value, []rune{'\xC2', '\x85'}...) idx++ + columnSkip = 1 continue case 'P': // PS (#x2029) ctx.addOriginBuf(nextChar) value = append(value, []rune{'\xE2', '\x80', '\xA9'}...) idx++ + columnSkip = 1 continue case '_': // #xA0 ctx.addOriginBuf(nextChar) value = append(value, []rune{'\xC2', '\xA0'}...) idx++ + columnSkip = 1 continue case '"': ctx.addOriginBuf(nextChar) value = append(value, nextChar) idx++ + columnSkip = 1 continue case 'x': if idx+3 >= size { @@ -366,6 +391,7 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { codeNum := hexRunesToInt(src[idx+2 : idx+4]) value = append(value, rune(codeNum)) idx += 3 + columnSkip = 3 continue case 'u': if idx+5 >= size { @@ -376,6 +402,7 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { codeNum := hexRunesToInt(src[idx+2 : idx+6]) value = append(value, rune(codeNum)) idx += 5 + columnSkip = 5 continue case 'U': if idx+9 >= size { @@ -386,10 +413,12 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { codeNum := hexRunesToInt(src[idx+2 : idx+10]) value = append(value, rune(codeNum)) idx += 9 + columnSkip = 9 continue case '\\': ctx.addOriginBuf(nextChar) idx++ + columnSkip = 1 } } value = append(value, c) @@ -409,9 +438,16 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { func (s *Scanner) scanQuote(ctx *Context, ch rune) (tk *token.Token, pos int) { if ch == '\'' { - return s.scanSingleQuote(ctx) + tk, pos = s.scanSingleQuote(ctx) + } else if ch == '"' { + tk, pos = s.scanDoubleQuote(ctx) + } else { + // TODO return an error object here when scan supports returning errors + return } - return s.scanDoubleQuote(ctx) + // The origin buffer needs to be reset so it's back in sync with the main buffer. + ctx.resetBuffer() + return } func (s *Scanner) isMergeKey(ctx *Context) bool { @@ -729,11 +765,28 @@ func (s *Scanner) scan(ctx *Context) (pos int) { nc := ctx.nextChar() if s.startedFlowMapNum > 0 || nc == ' ' || s.isNewLineChar(nc) || ctx.isNextEOS() { // mapping value - tk := s.bufferedToken(ctx) - if tk != nil { + + // If there's a token in the context, we need to check if it's a quote token. + if len(ctx.tokens) > 0 { + tk := ctx.tokens[len(ctx.tokens)-1] + if tk.Type == token.SingleQuoteType || tk.Type == token.DoubleQuoteType { + if len(ctx.buf) > 0 && s.isWhitespaceBuffer(ctx) { + // Spaces after quote map keys are valid, add the whitespace characters + // to the token and reset the buffer; we consider them part of the map key. + tk.Value += string(ctx.bufferedSrc()) + ctx.resetBuffer() + } + + // Set the previous indent column to the beginning of the quote token. + s.prevIndentColumn = tk.Position.Column + } + } + if tk := s.bufferedToken(ctx); tk != nil { + // If there's anything in the buffer at this point, we'll treat that as the map key. s.prevIndentColumn = tk.Position.Column ctx.addToken(tk) } + ctx.addToken(token.MappingValue(s.pos())) s.progressColumn(ctx, 1) return @@ -805,7 +858,7 @@ func (s *Scanner) scan(ctx *Context) (pos int) { token, progress := s.scanQuote(ctx, c) ctx.addToken(token) pos += progress - return + continue } case '\r', '\n': // There is no problem that we ignore CR which followed by LF and normalize it to LF, because of following YAML1.2 spec.