From 65c855784a290345ec3452c33d2ecc08ca267010 Mon Sep 17 00:00:00 2001 From: Georges Varouchas Date: Fri, 11 Nov 2022 11:37:38 +0400 Subject: [PATCH] more consistent line splitting change the scanner to split lines according to expected behavior, e.g: split lines on LF, CR or CRLF remove need for extra function to guess that after the facts Co-authored-by: Luis Davim --- gotenv.go | 50 +++++++++++++++++++------------ scanner_test.go | 78 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 18 deletions(-) create mode 100644 scanner_test.go diff --git a/gotenv.go b/gotenv.go index 7b1186e..769d820 100644 --- a/gotenv.go +++ b/gotenv.go @@ -3,6 +3,7 @@ package gotenv import ( "bufio" + "bytes" "fmt" "io" "os" @@ -174,9 +175,38 @@ func Write(env Env, filename string) error { return file.Sync() } -func strictParse(r io.Reader, override bool) (Env, error) { - env := make(Env) +// splitLines is a valid SplitFunc for a bufio.Scanner. It will split lines on CR ('\r'), LF ('\n') or CRLF (any of the three sequences). +// If a CR is immediately followed by a LF, it is treated as a CRLF (one single line break). +func splitLines(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, bufio.ErrFinalToken + } + + idx := bytes.IndexAny(data, "\r\n") + switch { + case atEOF && idx < 0: + return len(data), data, bufio.ErrFinalToken + + case idx < 0: + return 0, nil, nil + } + + // consume CR or LF + eol := idx + 1 + // detect CRLF + if len(data) > eol && data[eol-1] == '\r' && data[eol] == '\n' { + eol++ + } + + return eol, data[:idx], nil +} + +func strictParse(r io.Reader, env Env, override bool) (Env, error) { + if env == nil { + env = make(Env) + } scanner := bufio.NewScanner(r) + scanner.Split(splitLines) firstLine := true @@ -283,7 +313,6 @@ func parseLine(s string, env Env, override bool) error { return varReplacement(s, hsq, env, override) } val = varRgx.ReplaceAllStringFunc(val, fv) - val = parseVal(val, env, hdq, override) } env[key] = val @@ -352,18 +381,3 @@ func checkFormat(s string, env Env) error { return fmt.Errorf("line `%s` doesn't match format", s) } - -func parseVal(val string, env Env, ignoreNewlines bool, override bool) string { - if strings.Contains(val, "=") && !ignoreNewlines { - kv := strings.Split(val, "\r") - - if len(kv) > 1 { - val = kv[0] - for _, l := range kv[1:] { - _ = parseLine(l, env, override) - } - } - } - - return val -} diff --git a/scanner_test.go b/scanner_test.go new file mode 100644 index 0000000..977dc6c --- /dev/null +++ b/scanner_test.go @@ -0,0 +1,78 @@ +package gotenv + +import ( + "bufio" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestScanner(t *testing.T) { + + type testCase struct { + name string + in string + exp []string + } + + testCases := []testCase{ + { + "regular LF split with trailing LF", + "aa\nbb\ncc\n", + []string{"aa", "bb", "cc", ""}, + }, + { + "regular LF split with no trailing LF", + "aa\nbb\ncc", + []string{"aa", "bb", "cc"}, + }, + + { + "regular CR split with trailing CR", + "aa\rbb\rcc\r", + []string{"aa", "bb", "cc", ""}, + }, + { + "regular CR split with no trailing CR", + "aa\rbb\rcc", + []string{"aa", "bb", "cc"}, + }, + + { + "regular CRLF split with trailing CRLF", + "aa\r\nbb\r\ncc\r\n", + []string{"aa", "bb", "cc", ""}, + }, + { + "regular CRLF split with no trailing CRLF", + "aa\r\nbb\r\ncc", + []string{"aa", "bb", "cc"}, + }, + + { + "mix of possible line endings", + "aa\r\nbb\ncc\rdd", + []string{"aa", "bb", "cc", "dd"}, + }, + } + + for _, tc := range testCases { + s := bufio.NewScanner(strings.NewReader(tc.in)) + s.Split(splitLines) + + i := 0 + for s.Scan() { + if i >= len(tc.exp) { + assert.Fail(t, "unexpected line", "testCase: %s - got extra line: %q", tc.name, s.Text()) + } else { + got := s.Text() + assert.Equal(t, tc.exp[i], got, "testCase: %s - line %d", tc.name, i) + } + i++ + } + + assert.NoError(t, s.Err(), "testCase: %s", tc.name) + assert.Equal(t, len(tc.exp), i, "testCase: %s - expected to have the correct line count", tc.name) + } +}