Skip to content

Commit

Permalink
Few performance improvements
Browse files Browse the repository at this point in the history
replaceEscapes() got called for every string, and key.String() gets
called a lot in the parser, so small improvements add up. Also figured
that calling replaceEscapes() for every string isn't really needed.

It's about 20 to 30% faster (depending on the TOML file).
  • Loading branch information
arp242 committed Oct 10, 2023
1 parent c0a26cb commit 6fb5266
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 1,286 deletions.
2 changes: 1 addition & 1 deletion bench_test.go
Expand Up @@ -52,7 +52,7 @@ func BenchmarkDecode(b *testing.B) {
}

b.Run("large-doc", func(b *testing.B) {
d, err := os.ReadFile("testdata/ja-JP.toml")
d, err := os.ReadFile("testdata/Cargo.toml")
if err != nil {
b.Fatal(err)
}
Expand Down
20 changes: 19 additions & 1 deletion decode_test.go
Expand Up @@ -1290,7 +1290,7 @@ func TestMetaKeys(t *testing.T) {
}

func TestDecodeParallel(t *testing.T) {
doc, err := os.ReadFile("testdata/ja-JP.toml")
doc, err := os.ReadFile("testdata/Cargo.toml")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1323,3 +1323,21 @@ func errorContains(have error, want string) bool {
}
return strings.Contains(have.Error(), want)
}

func BenchmarkEscapes(b *testing.B) {
p := new(parser)
it := item{}
str := strings.Repeat("hello, world!\n", 10)
b.ResetTimer()
for n := 0; n < b.N; n++ {
p.replaceEscapes(it, str)
}
}

func BenchmarkKey(b *testing.B) {
k := Key{"cargo-credential-macos-keychain", "version"}
b.ResetTimer()
for n := 0; n < b.N; n++ {
k.String()
}
}
13 changes: 11 additions & 2 deletions lex.go
Expand Up @@ -17,6 +17,7 @@ const (
itemEOF
itemText
itemString
itemStringEsc
itemRawString
itemMultilineString
itemRawMultilineString
Expand Down Expand Up @@ -53,6 +54,7 @@ type lexer struct {
state stateFn
items chan item
tomlNext bool
esc bool

// Allow for backing up up to 4 runes. This is necessary because TOML
// contains 3-rune tokens (""" and ''').
Expand Down Expand Up @@ -696,7 +698,12 @@ func lexString(lx *lexer) stateFn {
return lexStringEscape
case r == '"':
lx.backup()
lx.emit(itemString)
if lx.esc {
lx.esc = false
lx.emit(itemStringEsc)
} else {
lx.emit(itemString)
}
lx.next()
lx.ignore()
return lx.pop()
Expand Down Expand Up @@ -746,6 +753,7 @@ func lexMultilineString(lx *lexer) stateFn {
lx.backup() /// backup: don't include the """ in the item.
lx.backup()
lx.backup()
lx.esc = false
lx.emit(itemMultilineString)
lx.next() /// Read over ''' again and discard it.
lx.next()
Expand Down Expand Up @@ -835,6 +843,7 @@ func lexMultilineStringEscape(lx *lexer) stateFn {
}

func lexStringEscape(lx *lexer) stateFn {
lx.esc = true
r := lx.next()
switch r {
case 'e':
Expand Down Expand Up @@ -1199,7 +1208,7 @@ func (itype itemType) String() string {
return "EOF"
case itemText:
return "Text"
case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
case itemString, itemStringEsc, itemRawString, itemMultilineString, itemRawMultilineString:
return "String"
case itemBool:
return "Bool"
Expand Down
34 changes: 27 additions & 7 deletions meta.go
Expand Up @@ -94,21 +94,41 @@ func (md *MetaData) Undecoded() []Key {
type Key []string

func (k Key) String() string {
ss := make([]string, len(k))
for i := range k {
ss[i] = k.maybeQuoted(i)
// This is called quite often, so it's a bit funky to make it faster.
var b strings.Builder
b.Grow(len(k) * 25)
outer:
for i, kk := range k {
if i > 0 {
b.WriteByte('.')
}
if kk == "" {
b.WriteString(`""`)
} else {
for _, r := range kk {
// "Inline" isBareKeyChar
if !((r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-') {
b.WriteByte('"')
b.WriteString(dblQuotedReplacer.Replace(kk))
b.WriteByte('"')
continue outer
}
}
b.WriteString(kk)
}
}
return strings.Join(ss, ".")
return b.String()
}

func (k Key) maybeQuoted(i int) string {
if k[i] == "" {
return `""`
}
for _, c := range k[i] {
if !isBareKeyChar(c, false) {
return `"` + dblQuotedReplacer.Replace(k[i]) + `"`
for _, r := range k[i] {
if (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' {
continue
}
return `"` + dblQuotedReplacer.Replace(k[i]) + `"`
}
return k[i]
}
Expand Down
109 changes: 56 additions & 53 deletions parse.go
Expand Up @@ -224,7 +224,7 @@ func (p *parser) keyString(it item) string {
switch it.typ {
case itemText:
return it.val
case itemString, itemMultilineString,
case itemString, itemStringEsc, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it, false)
return s.(string)
Expand All @@ -244,6 +244,8 @@ var datetimeRepl = strings.NewReplacer(
func (p *parser) value(it item, parentIsArray bool) (any, tomlType) {
switch it.typ {
case itemString:
return it.val, p.typeOfPrimitive(it)
case itemStringEsc:
return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it)
case itemMultilineString:
return p.replaceEscapes(it, p.stripEscapedNewlines(stripFirstNewline(it.val))), p.typeOfPrimitive(it)
Expand Down Expand Up @@ -707,8 +709,11 @@ func stripFirstNewline(s string) string {
// the next newline. After a line-ending backslash, all whitespace is removed
// until the next non-whitespace character.
func (p *parser) stripEscapedNewlines(s string) string {
var b strings.Builder
var i int
var (
b strings.Builder
i int
)
b.Grow(len(s))
for {
ix := strings.Index(s[i:], `\`)
if ix < 0 {
Expand Down Expand Up @@ -738,9 +743,8 @@ func (p *parser) stripEscapedNewlines(s string) string {
continue
}
if !strings.Contains(s[i:j], "\n") {
// This is not a line-ending backslash.
// (It's a bad escape sequence, but we can let
// replaceEscapes catch it.)
// This is not a line-ending backslash. (It's a bad escape sequence,
// but we can let replaceEscapes catch it.)
i++
continue
}
Expand All @@ -751,79 +755,78 @@ func (p *parser) stripEscapedNewlines(s string) string {
}

func (p *parser) replaceEscapes(it item, str string) string {
replaced := make([]rune, 0, len(str))
s := []byte(str)
r := 0
for r < len(s) {
if s[r] != '\\' {
c, size := utf8.DecodeRune(s[r:])
r += size
replaced = append(replaced, c)
var (
b strings.Builder
skip = 0
)
b.Grow(len(str))
for i, c := range str {
if skip > 0 {
skip--
continue
}
if c != '\\' {
b.WriteRune(c)
continue
}
r += 1
if r >= len(s) {

if i >= len(str) {
p.bug("Escape sequence at end of string.")
return ""
}
switch s[r] {
switch str[i+1] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
p.bug("Expected valid escape code after \\, but got %q.", str[i+1])
case ' ', '\t':
p.panicItemf(it, "invalid escape: '\\%c'", s[r])
p.panicItemf(it, "invalid escape: '\\%c'", str[i+1])
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
b.WriteByte(0x08)
skip = 1
case 't':
replaced = append(replaced, rune(0x0009))
r += 1
b.WriteByte(0x09)
skip = 1
case 'n':
replaced = append(replaced, rune(0x000A))
r += 1
b.WriteByte(0x0a)
skip = 1
case 'f':
replaced = append(replaced, rune(0x000C))
r += 1
b.WriteByte(0x0c)
skip = 1
case 'r':
replaced = append(replaced, rune(0x000D))
r += 1
b.WriteByte(0x0d)
skip = 1
case 'e':
if p.tomlNext {
replaced = append(replaced, rune(0x001B))
r += 1
b.WriteByte(0x1b)
skip = 1
}
case '"':
replaced = append(replaced, rune(0x0022))
r += 1
b.WriteByte(0x22)
skip = 1
case '\\':
replaced = append(replaced, rune(0x005C))
r += 1
b.WriteByte(0x5c)
skip = 1
// The lexer guarantees the correct number of characters are present;
// don't need to check here.
case 'x':
if p.tomlNext {
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+3])
replaced = append(replaced, escaped)
r += 3
escaped := p.asciiEscapeToUnicode(it, str[i+2:i+4])
b.WriteRune(escaped)
skip = 3
}
case 'u':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+5])
replaced = append(replaced, escaped)
r += 5
escaped := p.asciiEscapeToUnicode(it, str[i+2:i+6])
b.WriteRune(escaped)
skip = 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+9])
replaced = append(replaced, escaped)
r += 9
escaped := p.asciiEscapeToUnicode(it, str[i+2:i+10])
b.WriteRune(escaped)
skip = 9
}
}
return string(replaced)
return b.String()
}

func (p *parser) asciiEscapeToUnicode(it item, bs []byte) rune {
s := string(bs)
func (p *parser) asciiEscapeToUnicode(it item, s string) rune {
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the lexer claims it's OK: %s", s, err)
Expand Down

0 comments on commit 6fb5266

Please sign in to comment.