From 904160207285530185fea39b1b106e817420952b Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Tue, 17 May 2022 14:14:02 +0200 Subject: [PATCH 1/5] Refactor config package and add functionality --- gh.go | 35 +- gh_test.go | 7 +- internal/config/config.go | 345 ------------- internal/config/config_map.go | 131 ----- internal/config/config_map_test.go | 212 -------- internal/yamlmap/yaml_map.go | 200 ++++++++ internal/yamlmap/yaml_map_test.go | 225 +++++++++ pkg/api/client.go | 2 +- pkg/auth/auth.go | 80 +++ pkg/auth/auth_test.go | 255 ++++++++++ pkg/config/config.go | 269 ++++++++++ {internal => pkg}/config/config_test.go | 619 +++++++++++------------- pkg/config/errors.go | 27 ++ pkg/repository/repository.go | 7 +- pkg/repository/repository_test.go | 5 +- 15 files changed, 1364 insertions(+), 1055 deletions(-) delete mode 100644 internal/config/config.go delete mode 100644 internal/config/config_map.go delete mode 100644 internal/config/config_map_test.go create mode 100644 internal/yamlmap/yaml_map.go create mode 100644 internal/yamlmap/yaml_map_test.go create mode 100644 pkg/auth/auth.go create mode 100644 pkg/auth/auth_test.go create mode 100644 pkg/config/config.go rename {internal => pkg}/config/config_test.go (53%) create mode 100644 pkg/config/errors.go diff --git a/gh.go b/gh.go index 05dc749..ba4a042 100644 --- a/gh.go +++ b/gh.go @@ -14,11 +14,11 @@ import ( "os/exec" iapi "github.com/cli/go-gh/internal/api" - "github.com/cli/go-gh/internal/config" - iconfig "github.com/cli/go-gh/internal/config" "github.com/cli/go-gh/internal/git" irepo "github.com/cli/go-gh/internal/repository" "github.com/cli/go-gh/pkg/api" + "github.com/cli/go-gh/pkg/auth" + "github.com/cli/go-gh/pkg/config" repo "github.com/cli/go-gh/pkg/repository" "github.com/cli/go-gh/pkg/ssh" "github.com/cli/safeexec" @@ -62,7 +62,7 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) { opts = &api.ClientOptions{} } if optionsNeedResolution(opts) { - cfg, err := config.Load() + cfg, err := config.Read() if err != nil { return nil, err } @@ -83,7 +83,7 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) { opts = &api.ClientOptions{} } if optionsNeedResolution(opts) { - cfg, err := config.Load() + cfg, err := config.Read() if err != nil { return nil, err } @@ -109,7 +109,7 @@ func HTTPClient(opts *api.ClientOptions) (*http.Client, error) { opts = &api.ClientOptions{} } if optionsNeedResolution(opts) { - cfg, err := config.Load() + cfg, err := config.Read() if err != nil { return nil, err } @@ -141,12 +141,12 @@ func CurrentRepository() (repo.Repository, error) { translator := ssh.NewTranslator() translateRemotes(remotes, translator) - cfg, err := config.Load() + cfg, err := config.Read() if err != nil { return nil, err } - hosts := cfg.Hosts() + hosts := auth.KnownHosts(cfg) filteredRemotes := remotes.FilterByHosts(hosts) if len(filteredRemotes) == 0 { @@ -170,27 +170,18 @@ func optionsNeedResolution(opts *api.ClientOptions) bool { return false } -func resolveOptions(opts *api.ClientOptions, cfg config.Config) error { - var token string - var err error +func resolveOptions(opts *api.ClientOptions, cfg *config.Config) error { if opts.Host == "" { - opts.Host = cfg.Host() + opts.Host, _ = auth.DefaultHost(cfg) } if opts.AuthToken == "" { - token, err = cfg.AuthToken(opts.Host) - if err != nil { - var notFoundError iconfig.NotFoundError - if errors.As(err, ¬FoundError) { - return fmt.Errorf("authentication token not found for host %s", opts.Host) - } else { - return err - } + opts.AuthToken, _ = auth.TokenForHost(cfg, opts.Host) + if opts.AuthToken == "" { + return fmt.Errorf("authentication token not found for host %s", opts.Host) } - opts.AuthToken = token } if opts.UnixDomainSocket == "" { - unixSocket, _ := cfg.Get("http_unix_socket") - opts.UnixDomainSocket = unixSocket + opts.UnixDomainSocket, _ = config.Get(cfg, []string{"http_unix_socket"}) } return nil } diff --git a/gh_test.go b/gh_test.go index 9e6f62d..7338bfe 100644 --- a/gh_test.go +++ b/gh_test.go @@ -7,8 +7,8 @@ import ( "strings" "testing" - "github.com/cli/go-gh/internal/config" "github.com/cli/go-gh/pkg/api" + "github.com/cli/go-gh/pkg/config" "github.com/stretchr/testify/assert" "gopkg.in/h2non/gock.v1" ) @@ -303,7 +303,7 @@ func TestOptionsNeedResolution(t *testing.T) { } } -func testConfig() config.Config { +func testConfig() *config.Config { var data = ` hosts: github.com: @@ -312,8 +312,7 @@ hosts: git_protocol: ssh http_unix_socket: socket ` - cfg, _ := config.FromString(data) - return cfg + return config.ReadFromString(data) } func printPendingMocks(mocks []gock.Mock) string { diff --git a/internal/config/config.go b/internal/config/config.go deleted file mode 100644 index e595c23..0000000 --- a/internal/config/config.go +++ /dev/null @@ -1,345 +0,0 @@ -package config - -import ( - "errors" - "fmt" - "io" - "io/fs" - "os" - "path/filepath" - "runtime" - "strings" - - "github.com/cli/go-gh/internal/set" - "gopkg.in/yaml.v3" -) - -const ( - appData = "AppData" - defaultHost = "github.com" - ghConfigDir = "GH_CONFIG_DIR" - ghEnterpriseToken = "GH_ENTERPRISE_TOKEN" - ghHost = "GH_HOST" - ghToken = "GH_TOKEN" - githubEnterpriseToken = "GITHUB_ENTERPRISE_TOKEN" - githubToken = "GITHUB_TOKEN" - localAppData = "LocalAppData" - oauthToken = "oauth_token" - xdgConfigHome = "XDG_CONFIG_HOME" - xdgDataHome = "XDG_DATA_HOME" - xdgStateHome = "XDG_STATE_HOME" -) - -type Config interface { - Get(key string) (string, error) - GetForHost(host string, key string) (string, error) - Host() string - Hosts() []string - AuthToken(host string) (string, error) -} - -type config struct { - global configMap - hosts configMap -} - -func (c config) Get(key string) (string, error) { - return c.global.getStringValue(key) -} - -func (c config) GetForHost(host, key string) (string, error) { - hostEntry, err := c.hosts.findEntry(host) - if err != nil { - return "", err - } - hostMap := configMap{Root: hostEntry.ValueNode} - return hostMap.getStringValue(key) -} - -func (c config) Host() string { - if host := os.Getenv(ghHost); host != "" { - return host - } - entries := c.hosts.keys() - if len(entries) == 1 { - return entries[0] - } - return defaultHost -} - -func (c config) Hosts() []string { - hosts := set.NewStringSet() - if host := os.Getenv(ghHost); host != "" { - hosts.Add(host) - } - if token, _ := c.AuthToken(defaultHost); token != "" { - hosts.Add(defaultHost) - } - entries := c.hosts.keys() - hosts.AddValues(entries) - return hosts.ToSlice() -} - -func (c config) AuthToken(host string) (string, error) { - hostname := normalizeHostname(host) - if isEnterprise(hostname) { - if token := os.Getenv(ghEnterpriseToken); token != "" { - return token, nil - } - if token := os.Getenv(githubEnterpriseToken); token != "" { - return token, nil - } - if token, err := c.GetForHost(hostname, oauthToken); err == nil { - return token, nil - } - return "", NotFoundError{errors.New("not found")} - } - - if token := os.Getenv(ghToken); token != "" { - return token, nil - } - if token := os.Getenv(githubToken); token != "" { - return token, nil - } - if token, err := c.GetForHost(hostname, oauthToken); err == nil { - return token, nil - } - return "", NotFoundError{errors.New("not found")} -} - -func isEnterprise(host string) bool { - return host != defaultHost -} - -func normalizeHostname(host string) string { - hostname := strings.ToLower(host) - if strings.HasSuffix(hostname, "."+defaultHost) { - return defaultHost - } - return hostname -} - -func FromString(str string) (Config, error) { - root, err := parseData([]byte(str)) - if err != nil { - return nil, err - } - cfg := config{} - globalMap := configMap{Root: root} - cfg.global = globalMap - hostsEntry, err := globalMap.findEntry("hosts") - if err == nil { - cfg.hosts = configMap{Root: hostsEntry.ValueNode} - } - return cfg, nil -} - -func defaultConfig() Config { - return config{global: configMap{Root: defaultGlobal().Content[0]}} -} - -//TODO: Add caching so as not to load config multiple times. -func Load() (Config, error) { - return load(configFile(), hostsConfigFile()) -} - -func load(globalFilePath, hostsFilePath string) (Config, error) { - var readErr error - var parseErr error - globalData, readErr := readFile(globalFilePath) - if readErr != nil && !errors.Is(readErr, fs.ErrNotExist) { - return nil, readErr - } - - // Use defaultGlobal node if globalFile does not exist or is empty. - global := defaultGlobal().Content[0] - if len(globalData) > 0 { - global, parseErr = parseData(globalData) - } - if parseErr != nil { - return nil, parseErr - } - - hostsData, readErr := readFile(hostsFilePath) - if readErr != nil && !os.IsNotExist(readErr) { - return nil, readErr - } - - // Use nil if hostsFile does not exist or is empty. - var hosts *yaml.Node - if len(hostsData) > 0 { - hosts, parseErr = parseData(hostsData) - } - if parseErr != nil { - return nil, parseErr - } - - cfg := config{ - global: configMap{Root: global}, - hosts: configMap{Root: hosts}, - } - - return cfg, nil -} - -// Config path precedence: GH_CONFIG_DIR, XDG_CONFIG_HOME, AppData (windows only), HOME. -func configDir() string { - var path string - if a := os.Getenv(ghConfigDir); a != "" { - path = a - } else if b := os.Getenv(xdgConfigHome); b != "" { - path = filepath.Join(b, "gh") - } else if c := os.Getenv(appData); runtime.GOOS == "windows" && c != "" { - path = filepath.Join(c, "GitHub CLI") - } else { - d, _ := os.UserHomeDir() - path = filepath.Join(d, ".config", "gh") - } - return path -} - -// State path precedence: XDG_STATE_HOME, LocalAppData (windows only), HOME. -func stateDir() string { - var path string - if a := os.Getenv(xdgStateHome); a != "" { - path = filepath.Join(a, "gh") - } else if b := os.Getenv(localAppData); runtime.GOOS == "windows" && b != "" { - path = filepath.Join(b, "GitHub CLI") - } else { - c, _ := os.UserHomeDir() - path = filepath.Join(c, ".local", "state", "gh") - } - return path -} - -// Data path precedence: XDG_DATA_HOME, LocalAppData (windows only), HOME. -func dataDir() string { - var path string - if a := os.Getenv(xdgDataHome); a != "" { - path = filepath.Join(a, "gh") - } else if b := os.Getenv(localAppData); runtime.GOOS == "windows" && b != "" { - path = filepath.Join(b, "GitHub CLI") - } else { - c, _ := os.UserHomeDir() - path = filepath.Join(c, ".local", "share", "gh") - } - return path -} - -func configFile() string { - return filepath.Join(configDir(), "config.yml") -} - -func hostsConfigFile() string { - return filepath.Join(configDir(), "hosts.yml") -} - -func readFile(filename string) ([]byte, error) { - f, err := os.Open(filename) - if err != nil { - return nil, err - } - defer f.Close() - data, err := io.ReadAll(f) - if err != nil { - return nil, err - } - return data, nil -} - -func parseData(data []byte) (*yaml.Node, error) { - var root yaml.Node - err := yaml.Unmarshal(data, &root) - if err != nil { - return nil, fmt.Errorf("invalid config file: %w", err) - } - if len(root.Content) == 0 || root.Content[0].Kind != yaml.MappingNode { - return nil, fmt.Errorf("invalid config file") - } - return root.Content[0], nil -} - -func defaultGlobal() *yaml.Node { - return &yaml.Node{ - Kind: yaml.DocumentNode, - Content: []*yaml.Node{ - { - Kind: yaml.MappingNode, - Content: []*yaml.Node{ - { - HeadComment: "What protocol to use when performing git operations. Supported values: ssh, https", - Kind: yaml.ScalarNode, - Value: "git_protocol", - }, - { - Kind: yaml.ScalarNode, - Value: "https", - }, - { - HeadComment: "What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment.", - Kind: yaml.ScalarNode, - Value: "editor", - }, - { - Kind: yaml.ScalarNode, - Value: "", - }, - { - HeadComment: "When to interactively prompt. This is a global config that cannot be overridden by hostname. Supported values: enabled, disabled", - Kind: yaml.ScalarNode, - Value: "prompt", - }, - { - Kind: yaml.ScalarNode, - Value: "enabled", - }, - { - HeadComment: "A pager program to send command output to, e.g. \"less\". Set the value to \"cat\" to disable the pager.", - Kind: yaml.ScalarNode, - Value: "pager", - }, - { - Kind: yaml.ScalarNode, - Value: "", - }, - { - HeadComment: "Aliases allow you to create nicknames for gh commands", - Kind: yaml.ScalarNode, - Value: "aliases", - }, - { - Kind: yaml.MappingNode, - Content: []*yaml.Node{ - { - Kind: yaml.ScalarNode, - Value: "co", - }, - { - Kind: yaml.ScalarNode, - Value: "pr checkout", - }, - }, - }, - { - HeadComment: "The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport.", - Kind: yaml.ScalarNode, - Value: "http_unix_socket", - }, - { - Kind: yaml.ScalarNode, - Value: "", - }, - { - HeadComment: "What web browser gh should use when opening URLs. If blank, will refer to environment.", - Kind: yaml.ScalarNode, - Value: "browser", - }, - { - Kind: yaml.ScalarNode, - Value: "", - }, - }, - }, - }, - } -} diff --git a/internal/config/config_map.go b/internal/config/config_map.go deleted file mode 100644 index caa373a..0000000 --- a/internal/config/config_map.go +++ /dev/null @@ -1,131 +0,0 @@ -package config - -import ( - "errors" - - "gopkg.in/yaml.v3" -) - -// This type implements a low-level get/set config that is backed by an in-memory tree of yaml -// nodes. It allows us to interact with a yaml-based config programmatically, preserving any -// comments that were present when the yaml was parsed. -type configMap struct { - Root *yaml.Node -} - -type configEntry struct { - KeyNode *yaml.Node - ValueNode *yaml.Node - Index int -} - -type NotFoundError struct { - error -} - -func (cm *configMap) empty() bool { - return cm.Root == nil || len(cm.Root.Content) == 0 -} - -func (cm *configMap) getStringValue(key string) (string, error) { - entry, err := cm.findEntry(key) - if err != nil { - return "", err - } - return entry.ValueNode.Value, nil -} - -func (cm *configMap) setStringValue(key, value string) error { - entry, err := cm.findEntry(key) - if err == nil { - entry.ValueNode.Value = value - return nil - } - - var notFound *NotFoundError - if err != nil && !errors.As(err, ¬Found) { - return err - } - - keyNode := &yaml.Node{ - Kind: yaml.ScalarNode, - Value: key, - } - valueNode := &yaml.Node{ - Kind: yaml.ScalarNode, - Tag: "!!str", - Value: value, - } - - cm.Root.Content = append(cm.Root.Content, keyNode, valueNode) - return nil -} - -func (cm *configMap) findEntry(key string) (*configEntry, error) { - if cm.empty() { - return nil, &NotFoundError{errors.New("not found")} - } - - ce := &configEntry{} - - // Content slice goes [key1, value1, key2, value2, ...]. - topLevelPairs := cm.Root.Content - for i, v := range topLevelPairs { - // Skip every other slice item since we only want to check against keys. - if i%2 != 0 { - continue - } - if v.Value == key { - ce.KeyNode = v - ce.Index = i - if i+1 < len(topLevelPairs) { - ce.ValueNode = topLevelPairs[i+1] - } - return ce, nil - } - } - - return nil, &NotFoundError{errors.New("not found")} -} - -func (cm *configMap) removeEntry(key string) { - if cm.empty() { - return - } - - newContent := []*yaml.Node{} - - var skipNext bool - for i, v := range cm.Root.Content { - if skipNext { - skipNext = false - continue - } - if i%2 != 0 || v.Value != key { - newContent = append(newContent, v) - } else { - // Don't append current node and skip the next which is this key's value. - skipNext = true - } - } - - cm.Root.Content = newContent -} - -func (cm *configMap) keys() []string { - keys := []string{} - if cm.empty() { - return keys - } - - // Content slice goes [key1, value1, key2, value2, ...]. - for i, v := range cm.Root.Content { - // Skip every other slice item since we only want keys. - if i%2 != 0 { - continue - } - keys = append(keys, v.Value) - } - - return keys -} diff --git a/internal/config/config_map_test.go b/internal/config/config_map_test.go deleted file mode 100644 index aae1da6..0000000 --- a/internal/config/config_map_test.go +++ /dev/null @@ -1,212 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v3" -) - -func TestFindEntry(t *testing.T) { - tests := []struct { - name string - key string - output string - wantErr bool - }{ - { - name: "find key", - key: "valid", - output: "present", - }, - { - name: "find key that is not present", - key: "invalid", - wantErr: true, - }, - { - name: "find key with blank value", - key: "blank", - output: "", - }, - { - name: "find key that has same content as a value", - key: "same", - output: "logical", - }, - } - - for _, tt := range tests { - cm := configMap{Root: testYaml()} - t.Run(tt.name, func(t *testing.T) { - out, err := cm.findEntry(tt.key) - if tt.wantErr { - assert.EqualError(t, err, "not found") - return - } - assert.NoError(t, err) - assert.Equal(t, tt.output, out.ValueNode.Value) - }) - } -} - -func TestEmpty(t *testing.T) { - cm := configMap{} - assert.Equal(t, true, cm.empty()) - cm.Root = &yaml.Node{ - Content: []*yaml.Node{ - { - Value: "test", - }, - }, - } - assert.Equal(t, false, cm.empty()) -} - -func TestGetStringValue(t *testing.T) { - tests := []struct { - name string - key string - wantValue string - wantErr bool - }{ - { - name: "get key", - key: "valid", - wantValue: "present", - }, - { - name: "get key that is not present", - key: "invalid", - wantErr: true, - }, - { - name: "get key that has same content as a value", - key: "same", - wantValue: "logical", - }, - } - - for _, tt := range tests { - cm := configMap{Root: testYaml()} - t.Run(tt.name, func(t *testing.T) { - val, err := cm.getStringValue(tt.key) - if tt.wantErr { - assert.EqualError(t, err, "not found") - return - } - assert.Equal(t, tt.wantValue, val) - }) - } -} - -func TestSetStringValue(t *testing.T) { - tests := []struct { - name string - key string - value string - }{ - { - name: "set key that is not present", - key: "notPresent", - value: "test1", - }, - { - name: "set key that is present", - key: "erroneous", - value: "test2", - }, - { - name: "set key that is blank", - key: "blank", - value: "test3", - }, - { - name: "set key that has same content as a value", - key: "present", - value: "test4", - }, - } - - for _, tt := range tests { - cm := configMap{Root: testYaml()} - t.Run(tt.name, func(t *testing.T) { - err := cm.setStringValue(tt.key, tt.value) - assert.NoError(t, err) - val, err := cm.getStringValue(tt.key) - assert.NoError(t, err) - assert.Equal(t, tt.value, val) - }) - } -} - -func TestRemoveEntry(t *testing.T) { - tests := []struct { - name string - key string - wantLength int - }{ - { - name: "remove key", - key: "erroneous", - wantLength: 6, - }, - { - name: "remove key that is not present", - key: "invalid", - wantLength: 8, - }, - { - name: "remove key that has same content as a value", - key: "same", - wantLength: 6, - }, - } - - for _, tt := range tests { - cm := configMap{Root: testYaml()} - t.Run(tt.name, func(t *testing.T) { - cm.removeEntry(tt.key) - assert.Equal(t, tt.wantLength, len(cm.Root.Content)) - _, err := cm.findEntry(tt.key) - assert.EqualError(t, err, "not found") - }) - } -} - -func TestKeys(t *testing.T) { - tests := []struct { - name string - cm configMap - wantKeys []string - }{ - { - name: "keys for full map", - cm: configMap{Root: testYaml()}, - wantKeys: []string{"valid", "erroneous", "blank", "same"}, - }, - { - name: "keys for empty map", - cm: configMap{Root: nil}, - wantKeys: []string{}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - keys := tt.cm.keys() - assert.Equal(t, tt.wantKeys, keys) - }) - } -} - -func testYaml() *yaml.Node { - var root yaml.Node - var data = ` -valid: present -erroneous: same -blank: -same: logical -` - _ = yaml.Unmarshal([]byte(data), &root) - return root.Content[0] -} diff --git a/internal/yamlmap/yaml_map.go b/internal/yamlmap/yaml_map.go new file mode 100644 index 0000000..ac2d54d --- /dev/null +++ b/internal/yamlmap/yaml_map.go @@ -0,0 +1,200 @@ +// Package yamlmap is a wrapper of gopkg.in/yaml.v3 for interacting +// with yaml data as if it were a map. +package yamlmap + +import ( + "errors" + + "gopkg.in/yaml.v3" +) + +const ( + modified = "modifed" +) + +type Map struct { + *yaml.Node +} + +var ErrNotFound = errors.New("not found") +var ErrInvalidYaml = errors.New("invalid yaml") +var ErrInvalidFormat = errors.New("invalid format") + +func StringValue(value string) *Map { + return &Map{&yaml.Node{ + Kind: yaml.ScalarNode, + Tag: "!!str", + Value: value, + }} +} + +func MapValue() *Map { + return &Map{&yaml.Node{ + Kind: yaml.MappingNode, + Tag: "!!map", + }} +} + +func Unmarshal(data []byte) (*Map, error) { + var root yaml.Node + err := yaml.Unmarshal(data, &root) + if err != nil { + return nil, ErrInvalidYaml + } + if len(root.Content) == 0 || root.Content[0].Kind != yaml.MappingNode { + return nil, ErrInvalidFormat + } + return &Map{root.Content[0]}, nil +} + +func Marshal(m *Map) ([]byte, error) { + return yaml.Marshal(m.Node) +} + +func (m *Map) AddEntry(key string, value *Map) { + keyNode := &yaml.Node{ + Kind: yaml.ScalarNode, + Tag: "!!str", + Value: key, + } + m.Content = append(m.Content, keyNode, value.Node) + m.SetModified() +} + +func (m *Map) Empty() bool { + return m.Content == nil || len(m.Content) == 0 +} + +func (m *Map) FindEntry(key string) (*Map, error) { + // Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...]. + // When iterating over the content slice we only want to compare the keys of the yamlMap. + for i, v := range m.Content { + if i%2 != 0 { + continue + } + if v.Value == key { + if i+1 < len(m.Content) { + return &Map{m.Content[i+1]}, nil + } + } + } + return nil, ErrNotFound +} + +func (m *Map) Keys() []string { + // Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...]. + // When iterating over the content slice we only want to select the keys of the yamlMap. + keys := []string{} + for i, v := range m.Content { + if i%2 != 0 { + continue + } + keys = append(keys, v.Value) + } + return keys +} + +func (m *Map) RemoveEntry(key string) error { + // Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...]. + // When iterating over the content slice we only want to compare the keys of the yamlMap. + // If we find they key to remove, remove the key and its value from the content slice. + found, skipNext := false, false + newContent := []*yaml.Node{} + for i, v := range m.Content { + if skipNext { + skipNext = false + continue + } + if i%2 != 0 || v.Value != key { + newContent = append(newContent, v) + } else { + found = true + skipNext = true + m.SetModified() + } + } + if !found { + return ErrNotFound + } + m.Content = newContent + return nil +} + +func (m *Map) SetEntry(key string, value *Map) { + // Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...]. + // When iterating over the content slice we only want to compare the keys of the yamlMap. + // If we find they key to set, set the next item in the content slice to the new value. + m.SetModified() + for i, v := range m.Content { + if i%2 != 0 || v.Value != key { + continue + } + if v.Value == key { + if i+1 < len(m.Content) { + m.Content[i+1] = value.Node + return + } + } + } + m.AddEntry(key, value) +} + +// Note: This is a hack to introduce the concept of modified/unmodified +// on top of gopkg.in/yaml.v3. This works by setting the Value property +// of a MappingNode to a specific value and then later checking if the +// node's Value property is that specific value. When a MappingNode gets +// output as a string the Value property is not used, thus changing it +// has no impact for our purposes. +func (m *Map) SetModified() { + // Can not mark a non-mapping node as modified + if m.Node.Kind == yaml.MappingNode { + m.Node.Value = modified + } +} + +// Traverse map using BFS to set all nodes as unmodified. +func (m *Map) SetUnmodified() { + i := 0 + queue := []*yaml.Node{m.Node} + for { + if i > (len(queue) - 1) { + break + } + q := queue[i] + i = i + 1 + if q.Kind != yaml.MappingNode { + continue + } + q.Value = "" + queue = append(queue, q.Content...) + } +} + +// Traverse map using BFS to searach for any nodes that have been modified. +func (m *Map) IsModified() bool { + i := 0 + queue := []*yaml.Node{m.Node} + for { + if i > (len(queue) - 1) { + break + } + q := queue[i] + i = i + 1 + if q.Kind != yaml.MappingNode { + continue + } + if q.Value == modified { + return true + } + queue = append(queue, q.Content...) + } + return false +} + +func (m *Map) String() string { + data, err := Marshal(m) + if err != nil { + return "" + } + return string(data) +} diff --git a/internal/yamlmap/yaml_map_test.go b/internal/yamlmap/yaml_map_test.go new file mode 100644 index 0000000..89a312f --- /dev/null +++ b/internal/yamlmap/yaml_map_test.go @@ -0,0 +1,225 @@ +package yamlmap + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMapAddEntry(t *testing.T) { + tests := []struct { + name string + key string + value string + wantValue string + wantLength int + }{ + { + name: "add entry with key that is not present", + key: "notPresent", + value: "test1", + wantValue: "test1", + wantLength: 10, + }, + { + name: "add entry with key that is already present", + key: "erroneous", + value: "test2", + wantValue: "same", + wantLength: 10, + }, + } + + for _, tt := range tests { + m := testMap() + t.Run(tt.name, func(t *testing.T) { + m.AddEntry(tt.key, StringValue(tt.value)) + entry, err := m.FindEntry(tt.key) + assert.NoError(t, err) + assert.Equal(t, tt.wantValue, entry.Value) + assert.Equal(t, tt.wantLength, len(m.Content)) + assert.True(t, m.IsModified()) + }) + } +} + +func TestMapEmpty(t *testing.T) { + m := blankMap() + assert.Equal(t, true, m.Empty()) + m.AddEntry("test", StringValue("test")) + assert.Equal(t, false, m.Empty()) +} + +func TestMapFindEntry(t *testing.T) { + tests := []struct { + name string + key string + output string + wantErr bool + }{ + { + name: "find key", + key: "valid", + output: "present", + }, + { + name: "find key that is not present", + key: "invalid", + wantErr: true, + }, + { + name: "find key with blank value", + key: "blank", + output: "", + }, + { + name: "find key that has same content as a value", + key: "same", + output: "logical", + }, + } + + for _, tt := range tests { + m := testMap() + t.Run(tt.name, func(t *testing.T) { + out, err := m.FindEntry(tt.key) + if tt.wantErr { + assert.EqualError(t, err, "not found") + assert.False(t, m.IsModified()) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.output, out.Value) + assert.False(t, m.IsModified()) + }) + } +} + +func TestMapFindEntryModified(t *testing.T) { + m := testMap() + entry, err := m.FindEntry("valid") + assert.NoError(t, err) + assert.Equal(t, "present", entry.Value) + entry.Value = "test" + assert.Equal(t, "test", entry.Value) + entry2, err := m.FindEntry("valid") + assert.NoError(t, err) + assert.Equal(t, "test", entry2.Value) +} + +func TestMapKeys(t *testing.T) { + tests := []struct { + name string + m *Map + wantKeys []string + }{ + { + name: "keys for full map", + m: testMap(), + wantKeys: []string{"valid", "erroneous", "blank", "same"}, + }, + { + name: "keys for empty map", + m: blankMap(), + wantKeys: []string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keys := tt.m.Keys() + assert.Equal(t, tt.wantKeys, keys) + assert.False(t, tt.m.IsModified()) + }) + } +} + +func TestMapRemoveEntry(t *testing.T) { + tests := []struct { + name string + key string + wantLength int + wantErr bool + }{ + { + name: "remove key", + key: "erroneous", + wantLength: 6, + }, + { + name: "remove key that is not present", + key: "invalid", + wantLength: 8, + wantErr: true, + }, + { + name: "remove key that has same content as a value", + key: "same", + wantLength: 6, + }, + } + + for _, tt := range tests { + m := testMap() + t.Run(tt.name, func(t *testing.T) { + err := m.RemoveEntry(tt.key) + if tt.wantErr { + assert.EqualError(t, err, "not found") + assert.False(t, m.IsModified()) + } else { + assert.NoError(t, err) + assert.True(t, m.IsModified()) + } + assert.Equal(t, tt.wantLength, len(m.Content)) + _, err = m.FindEntry(tt.key) + assert.EqualError(t, err, "not found") + }) + } +} + +func TestMapSetEntry(t *testing.T) { + tests := []struct { + name string + key string + value *Map + wantLength int + }{ + { + name: "sets key that is not present", + key: "not", + value: StringValue("present"), + wantLength: 10, + }, + { + name: "sets key that is present", + key: "erroneous", + value: StringValue("not same"), + wantLength: 8, + }, + } + for _, tt := range tests { + m := testMap() + t.Run(tt.name, func(t *testing.T) { + m.SetEntry(tt.key, tt.value) + assert.True(t, m.IsModified()) + assert.Equal(t, tt.wantLength, len(m.Content)) + e, err := m.FindEntry(tt.key) + assert.NoError(t, err) + assert.Equal(t, tt.value.Value, e.Value) + }) + } +} + +func testMap() *Map { + var data = ` +valid: present +erroneous: same +blank: +same: logical +` + m, _ := Unmarshal([]byte(data)) + return m +} + +func blankMap() *Map { + return MapValue() +} diff --git a/pkg/api/client.go b/pkg/api/client.go index f4feb0f..a57feba 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -1,4 +1,4 @@ -// Package api is a set of types for GitHub API. +// Package api is a set of types for interacting with the GitHub API. package api import ( diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go new file mode 100644 index 0000000..fc1b10b --- /dev/null +++ b/pkg/auth/auth.go @@ -0,0 +1,80 @@ +package auth + +import ( + "os" + "strings" + + "github.com/cli/go-gh/internal/set" + "github.com/cli/go-gh/pkg/config" +) + +const ( + defaultHost = "github.com" + ghEnterpriseToken = "GH_ENTERPRISE_TOKEN" + ghHost = "GH_HOST" + ghToken = "GH_TOKEN" + githubEnterpriseToken = "GITHUB_ENTERPRISE_TOKEN" + githubToken = "GITHUB_TOKEN" + oauthToken = "oauth_token" + hostsKey = "hosts" +) + +func TokenForHost(cfg *config.Config, host string) (string, string) { + host = normalizeHostname(host) + if isEnterprise(host) { + if token := os.Getenv(ghEnterpriseToken); token != "" { + return token, ghEnterpriseToken + } + if token := os.Getenv(githubEnterpriseToken); token != "" { + return token, githubEnterpriseToken + } + token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken}) + return token, oauthToken + } + if token := os.Getenv(ghToken); token != "" { + return token, ghToken + } + if token := os.Getenv(githubToken); token != "" { + return token, githubToken + } + token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken}) + return token, oauthToken +} + +func KnownHosts(cfg *config.Config) []string { + hosts := set.NewStringSet() + if host := os.Getenv(ghHost); host != "" { + hosts.Add(host) + } + if token, _ := TokenForHost(cfg, defaultHost); token != "" { + hosts.Add(defaultHost) + } + keys, err := config.Keys(cfg, []string{hostsKey}) + if err == nil { + hosts.AddValues(keys) + } + return hosts.ToSlice() +} + +func DefaultHost(cfg *config.Config) (string, string) { + if host := os.Getenv(ghHost); host != "" { + return host, ghHost + } + keys, err := config.Keys(cfg, []string{hostsKey}) + if err == nil && len(keys) == 1 { + return keys[0], hostsKey + } + return defaultHost, "default" +} + +func isEnterprise(host string) bool { + return host != defaultHost +} + +func normalizeHostname(host string) string { + hostname := strings.ToLower(host) + if strings.HasSuffix(hostname, "."+defaultHost) { + return defaultHost + } + return hostname +} diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go new file mode 100644 index 0000000..42a80a7 --- /dev/null +++ b/pkg/auth/auth_test.go @@ -0,0 +1,255 @@ +package auth + +import ( + "os" + "testing" + + "github.com/cli/go-gh/pkg/config" + "github.com/stretchr/testify/assert" +) + +func TestTokenForHost(t *testing.T) { + orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN") + orig_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN") + orig_GH_TOKEN := os.Getenv("GH_TOKEN") + orig_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN") + t.Cleanup(func() { + os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN) + os.Setenv("GITHUB_ENTERPRISE_TOKEN", orig_GITHUB_ENTERPRISE_TOKEN) + os.Setenv("GH_TOKEN", orig_GH_TOKEN) + os.Setenv("GH_ENTERPRISE_TOKEN", orig_GH_ENTERPRISE_TOKEN) + }) + + tests := []struct { + name string + host string + githubToken string + githubEnterpriseToken string + ghToken string + ghEnterpriseToken string + config *config.Config + wantToken string + wantSource string + wantNotFound bool + }{ + { + name: "token for github.com with no env tokens and no config token", + host: "github.com", + config: testNoHostsConfig(), + wantToken: "", + wantSource: "oauth_token", + wantNotFound: true, + }, + { + name: "token for enterprise.com with no env tokens and no config token", + host: "enterprise.com", + config: testNoHostsConfig(), + wantToken: "", + wantSource: "oauth_token", + wantNotFound: true, + }, + { + name: "token for github.com with GH_TOKEN, GITHUB_TOKEN, and config token", + host: "github.com", + ghToken: "GH_TOKEN", + githubToken: "GITHUB_TOKEN", + config: testHostsConfig(), + wantToken: "GH_TOKEN", + wantSource: "GH_TOKEN", + }, + { + name: "token for github.com with GITHUB_TOKEN, and config token", + host: "github.com", + githubToken: "GITHUB_TOKEN", + config: testHostsConfig(), + wantToken: "GITHUB_TOKEN", + wantSource: "GITHUB_TOKEN", + }, + { + name: "token for github.com with config token", + host: "github.com", + config: testHostsConfig(), + wantToken: "xxxxxxxxxxxxxxxxxxxx", + wantSource: "oauth_token", + }, + { + name: "token for enterprise.com with GH_ENTERPRISE_TOKEN, GITHUB_ENTERPRISE_TOKEN, and config token", + host: "enterprise.com", + ghEnterpriseToken: "GH_ENTERPRISE_TOKEN", + githubEnterpriseToken: "GITHUB_ENTERPRISE_TOKEN", + config: testHostsConfig(), + wantToken: "GH_ENTERPRISE_TOKEN", + wantSource: "GH_ENTERPRISE_TOKEN", + }, + { + name: "token for enterprise.com with GITHUB_ENTERPRISE_TOKEN, and config token", + host: "enterprise.com", + githubEnterpriseToken: "GITHUB_ENTERPRISE_TOKEN", + config: testHostsConfig(), + wantToken: "GITHUB_ENTERPRISE_TOKEN", + wantSource: "GITHUB_ENTERPRISE_TOKEN", + }, + { + name: "token for enterprise.com with config token", + host: "enterprise.com", + config: testHostsConfig(), + wantToken: "yyyyyyyyyyyyyyyyyyyy", + wantSource: "oauth_token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Setenv("GITHUB_TOKEN", tt.githubToken) + os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.githubEnterpriseToken) + os.Setenv("GH_TOKEN", tt.ghToken) + os.Setenv("GH_ENTERPRISE_TOKEN", tt.ghEnterpriseToken) + token, source := TokenForHost(tt.config, tt.host) + assert.Equal(t, tt.wantToken, token) + assert.Equal(t, tt.wantSource, source) + }) + } +} + +func TestDefaultHost(t *testing.T) { + tests := []struct { + name string + config *config.Config + ghHost string + wantHost string + wantSource string + wantNotFound bool + }{ + { + name: "GH_HOST if set", + config: testHostsConfig(), + ghHost: "test.com", + wantHost: "test.com", + wantSource: "GH_HOST", + }, + { + name: "authenticated host if only one", + config: testSingleHostConfig(), + wantHost: "enterprise.com", + wantSource: "hosts", + }, + { + name: "default host if more than one authenticated host", + config: testHostsConfig(), + wantHost: "github.com", + wantSource: "default", + wantNotFound: true, + }, + { + name: "default host if no authenticated host", + config: testNoHostsConfig(), + wantHost: "github.com", + wantSource: "default", + wantNotFound: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.ghHost != "" { + k := "GH_HOST" + old := os.Getenv(k) + os.Setenv(k, tt.ghHost) + defer os.Setenv(k, old) + } + host, source := DefaultHost(tt.config) + assert.Equal(t, tt.wantHost, host) + assert.Equal(t, tt.wantSource, source) + }) + } +} + +func TestKnownHosts(t *testing.T) { + tests := []struct { + name string + config *config.Config + ghHost string + ghToken string + wantHosts []string + }{ + { + name: "no known hosts", + config: testNoHostsConfig(), + wantHosts: []string{}, + }, + { + name: "includes GH_HOST", + config: testNoHostsConfig(), + ghHost: "test.com", + wantHosts: []string{"test.com"}, + }, + { + name: "includes authenticated hosts", + config: testHostsConfig(), + wantHosts: []string{"github.com", "enterprise.com"}, + }, + { + name: "includes default host if environment auth token", + config: testNoHostsConfig(), + ghToken: "TOKEN", + wantHosts: []string{"github.com"}, + }, + { + name: "deduplicates hosts", + config: testHostsConfig(), + ghHost: "test.com", + ghToken: "TOKEN", + wantHosts: []string{"test.com", "github.com", "enterprise.com"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.ghHost != "" { + k := "GH_HOST" + old := os.Getenv(k) + os.Setenv(k, tt.ghHost) + defer os.Setenv(k, old) + } + if tt.ghToken != "" { + k := "GH_TOKEN" + old := os.Getenv(k) + os.Setenv(k, tt.ghToken) + defer os.Setenv(k, old) + } + hosts := KnownHosts(tt.config) + assert.Equal(t, tt.wantHosts, hosts) + }) + } +} + +func testNoHostsConfig() *config.Config { + var data = `` + return config.ReadFromString(data) +} + +func testSingleHostConfig() *config.Config { + var data = ` +hosts: + enterprise.com: + user: user2 + oauth_token: yyyyyyyyyyyyyyyyyyyy + git_protocol: https +` + return config.ReadFromString(data) +} + +func testHostsConfig() *config.Config { + var data = ` +hosts: + github.com: + user: user1 + oauth_token: xxxxxxxxxxxxxxxxxxxx + git_protocol: ssh + enterprise.com: + user: user2 + oauth_token: yyyyyyyyyyyyyyyyyyyy + git_protocol: https +` + return config.ReadFromString(data) +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..56962ad --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,269 @@ +// Package config is a set of types for interacting with the gh configuration files. +// Note: This package is intended for use only in `gh`, any other use cases are subject +// to breakage and non-backwards compatible updates. +package config + +import ( + "errors" + "io" + "os" + "path/filepath" + "runtime" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/go-gh/internal/yamlmap" +) + +const ( + appData = "AppData" + ghConfigDir = "GH_CONFIG_DIR" + localAppData = "LocalAppData" + xdgConfigHome = "XDG_CONFIG_HOME" + xdgDataHome = "XDG_DATA_HOME" + xdgStateHome = "XDG_STATE_HOME" +) + +type Config struct { + entries *yamlmap.Map +} + +func Get(c *Config, keys []string) (string, error) { + m := c.entries + for _, key := range keys { + var err error + m, err = m.FindEntry(key) + if err != nil { + return "", NotFoundError{key} + } + } + return m.Value, nil +} + +func Keys(c *Config, keys []string) ([]string, error) { + m := c.entries + for _, key := range keys { + var err error + m, err = m.FindEntry(key) + if err != nil { + return nil, NotFoundError{key} + } + } + return m.Keys(), nil +} + +func Remove(c *Config, keys []string) error { + m := c.entries + for i := 0; i < len(keys)-1; i++ { + var err error + key := keys[i] + m, err = m.FindEntry(key) + if err != nil { + return NotFoundError{key} + } + } + err := m.RemoveEntry(keys[len(keys)-1]) + if err != nil { + return NotFoundError{keys[len(keys)-1]} + } + return nil +} + +func Set(c *Config, keys []string, value string) { + m := c.entries + for i := 0; i < len(keys)-1; i++ { + key := keys[i] + entry, err := m.FindEntry(key) + if err != nil { + entry = yamlmap.MapValue() + m.AddEntry(key, entry) + } + m = entry + } + m.SetEntry(keys[len(keys)-1], yamlmap.StringValue(value)) +} + +func Read() (*Config, error) { + // TODO: Make global config singleton using sync.Once + // so as not to read from file every time. + return load(generalConfigFile(), hostsConfigFile()) +} + +// ReadFromString takes a yaml string and returns a Config. +// Note: This is only used for testing, and should not be +// relied upon in production. +func ReadFromString(str string) *Config { + m, _ := mapFromString(str) + if m == nil { + m = yamlmap.MapValue() + } + return &Config{m} +} + +func Write(config *Config) error { + hosts, err := config.entries.FindEntry("hosts") + if err == nil && hosts.IsModified() { + err := writeFile(hostsConfigFile(), []byte(hosts.String())) + if err != nil { + return err + } + hosts.SetUnmodified() + } + + if config.entries.IsModified() { + // Hosts gets written to a different file above so remove it + // before writing and add it back in after writing. + hostsMap, hostsErr := config.entries.FindEntry("hosts") + if hostsErr == nil { + _ = config.entries.RemoveEntry("hosts") + } + err := writeFile(generalConfigFile(), []byte(config.entries.String())) + if err != nil { + return err + } + config.entries.SetUnmodified() + if hostsErr == nil { + config.entries.AddEntry("hosts", hostsMap) + } + } + + return nil +} + +func load(generalFilePath, hostsFilePath string) (*Config, error) { + generalMap, err := mapFromFile(generalFilePath) + if err != nil && !os.IsNotExist(err) { + if errors.Is(err, yamlmap.ErrInvalidYaml) || + errors.Is(err, yamlmap.ErrInvalidFormat) { + return nil, InvalidConfigFileError{Path: generalFilePath, Err: err} + } + return nil, err + } + + if generalMap == nil || generalMap.Empty() { + generalMap, _ = mapFromString(defaultGeneralEntries) + } + + hostsMap, err := mapFromFile(hostsFilePath) + if err != nil && !os.IsNotExist(err) { + if errors.Is(err, yamlmap.ErrInvalidYaml) || + errors.Is(err, yamlmap.ErrInvalidFormat) { + return nil, InvalidConfigFileError{Path: hostsFilePath, Err: err} + } + return nil, err + } + + if hostsMap != nil && !hostsMap.Empty() { + generalMap.AddEntry("hosts", hostsMap) + } + + return &Config{generalMap}, nil +} + +func generalConfigFile() string { + return filepath.Join(configDir(), "config.yml") +} + +func hostsConfigFile() string { + return filepath.Join(configDir(), "hosts.yml") +} + +func mapFromFile(filename string) (*yamlmap.Map, error) { + data, err := readFile(filename) + if err != nil { + return nil, err + } + return yamlmap.Unmarshal(data) +} + +func mapFromString(str string) (*yamlmap.Map, error) { + return yamlmap.Unmarshal([]byte(str)) +} + +// Config path precedence: GH_CONFIG_DIR, XDG_CONFIG_HOME, AppData (windows only), HOME. +func configDir() string { + var path string + if a := os.Getenv(ghConfigDir); a != "" { + path = a + } else if b := os.Getenv(xdgConfigHome); b != "" { + path = filepath.Join(b, "gh") + } else if c := os.Getenv(appData); runtime.GOOS == "windows" && c != "" { + path = filepath.Join(c, "GitHub CLI") + } else { + d, _ := os.UserHomeDir() + path = filepath.Join(d, ".config", "gh") + } + return path +} + +// State path precedence: XDG_STATE_HOME, LocalAppData (windows only), HOME. +func stateDir() string { + var path string + if a := os.Getenv(xdgStateHome); a != "" { + path = filepath.Join(a, "gh") + } else if b := os.Getenv(localAppData); runtime.GOOS == "windows" && b != "" { + path = filepath.Join(b, "GitHub CLI") + } else { + c, _ := os.UserHomeDir() + path = filepath.Join(c, ".local", "state", "gh") + } + return path +} + +// Data path precedence: XDG_DATA_HOME, LocalAppData (windows only), HOME. +func dataDir() string { + var path string + if a := os.Getenv(xdgDataHome); a != "" { + path = filepath.Join(a, "gh") + } else if b := os.Getenv(localAppData); runtime.GOOS == "windows" && b != "" { + path = filepath.Join(b, "GitHub CLI") + } else { + c, _ := os.UserHomeDir() + path = filepath.Join(c, ".local", "share", "gh") + } + return path +} + +func readFile(filename string) ([]byte, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + data, err := io.ReadAll(f) + if err != nil { + return nil, err + } + return data, nil +} + +func writeFile(filename string, data []byte) error { + err := os.MkdirAll(filepath.Dir(filename), 0771) + if err != nil { + return err + } + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer file.Close() + _, err = file.Write(data) + return err +} + +var defaultGeneralEntries = heredoc.Doc(` +# What protocol to use when performing git operations. Supported values: ssh, https +git_protocol: https +# What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment. +editor: +# When to interactively prompt. This is a global config that cannot be overridden by hostname. Supported values: enabled, disabled +prompt: enabled +# A pager program to send command output to, e.g. "less". Set the value to "cat" to disable the pager. +pager: +# Aliases allow you to create nicknames for gh commands +aliases: + co: pr checkout +# The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport. +http_unix_socket: +# What web browser gh should use when opening URLs. If blank, will refer to environment. +browser: +`) diff --git a/internal/config/config_test.go b/pkg/config/config_test.go similarity index 53% rename from internal/config/config_test.go rename to pkg/config/config_test.go index c2477cc..130d45e 100644 --- a/internal/config/config_test.go +++ b/pkg/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "os" "path/filepath" "runtime" @@ -91,7 +92,6 @@ func TestConfigDir(t *testing.T) { defer os.Setenv(k, old) } } - assert.Equal(t, tt.output, configDir()) }) } @@ -156,7 +156,6 @@ func TestStateDir(t *testing.T) { defer os.Setenv(k, old) } } - assert.Equal(t, tt.output, stateDir()) }) } @@ -221,464 +220,434 @@ func TestDataDir(t *testing.T) { defer os.Setenv(k, old) } } - assert.Equal(t, tt.output, dataDir()) }) } } -func TestConfigGet(t *testing.T) { - cfg := testLoadedConfig() +func TestLoad(t *testing.T) { + tempDir := t.TempDir() + globalFilePath := filepath.Join(tempDir, "config.yml") + invalidGlobalFilePath := filepath.Join(tempDir, "invalid_config.yml") + hostsFilePath := filepath.Join(tempDir, "hosts.yml") + invalidHostsFilePath := filepath.Join(tempDir, "invalid_hosts.yml") + err := os.WriteFile(globalFilePath, []byte(testGlobalData()), 0755) + assert.NoError(t, err) + err = os.WriteFile(invalidGlobalFilePath, []byte("invalid"), 0755) + assert.NoError(t, err) + err = os.WriteFile(hostsFilePath, []byte(testHostsData()), 0755) + assert.NoError(t, err) + err = os.WriteFile(invalidHostsFilePath, []byte("invalid"), 0755) + assert.NoError(t, err) tests := []struct { - name string - key string - wantValue string - wantErr bool - wantErrMsg string + name string + globalConfigPath string + hostsConfigPath string + wantGitProtocol string + wantToken string + wantErr bool + wantErrMsg string + wantGetErr bool }{ { - name: "get git_protocol value", - key: "git_protocol", - wantValue: "ssh", + name: "global and hosts files exist", + globalConfigPath: globalFilePath, + hostsConfigPath: hostsFilePath, + wantGitProtocol: "ssh", + wantToken: "yyyyyyyyyyyyyyyyyyyy", }, { - name: "get editor value", - key: "editor", - wantValue: "", + name: "invalid global file", + globalConfigPath: invalidGlobalFilePath, + wantErr: true, + wantErrMsg: fmt.Sprintf("invalid config file %s: invalid format", filepath.Join(tempDir, "invalid_config.yml")), }, { - name: "get prompt value", - key: "prompt", - wantValue: "enabled", + name: "invalid hosts file", + globalConfigPath: globalFilePath, + hostsConfigPath: invalidHostsFilePath, + wantErr: true, + wantErrMsg: fmt.Sprintf("invalid config file %s: invalid format", filepath.Join(tempDir, "invalid_hosts.yml")), }, { - name: "get pager value", - key: "pager", - wantValue: "less", + name: "global file does not exist and hosts file exist", + globalConfigPath: "", + hostsConfigPath: hostsFilePath, + wantGitProtocol: "https", + wantToken: "yyyyyyyyyyyyyyyyyyyy", }, { - name: "unknown key", - key: "unknown", - wantErr: true, - wantErrMsg: "not found", + name: "global file exist and hosts file does not exist", + globalConfigPath: globalFilePath, + hostsConfigPath: "", + wantGitProtocol: "ssh", + wantGetErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - value, err := cfg.Get(tt.key) + cfg, err := load(tt.globalConfigPath, tt.hostsConfigPath) if tt.wantErr { assert.EqualError(t, err, tt.wantErrMsg) return } assert.NoError(t, err) - assert.Equal(t, tt.wantValue, value) + protocol, err := Get(cfg, []string{"git_protocol"}) + assert.NoError(t, err) + assert.Equal(t, tt.wantGitProtocol, protocol) + token, err := Get(cfg, []string{"hosts", "enterprise.com", "oauth_token"}) + if tt.wantGetErr { + assert.EqualError(t, err, `could not find key "hosts"`) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.wantToken, token) }) } } -func TestConfigGetForHost(t *testing.T) { - cfg := testLoadedConfig() - +func TestWrite(t *testing.T) { tests := []struct { - name string - host string - key string - wantValue string - wantErr bool - wantErrMsg string + name string + createConfig func() *Config + wantConfig func() *Config + wantErr bool + wantErrMsg string }{ { - name: "get github user value", - host: "github.com", - key: "user", - wantValue: "user1", - }, - { - name: "get github oauth_token value", - host: "github.com", - key: "oauth_token", - wantValue: "xxxxxxxxxxxxxxxxxxxx", - }, - { - name: "get github git_protocol value", - host: "github.com", - key: "git_protocol", - wantValue: "ssh", - }, - { - name: "get enterprise user value", - host: "enterprise.com", - key: "user", - wantValue: "user2", - }, - { - name: "get enterprise oauth_token value", - host: "enterprise.com", - key: "oauth_token", - wantValue: "yyyyyyyyyyyyyyyyyyyy", + name: "writes config and hosts files", + createConfig: func() *Config { + cfg := ReadFromString(testFullConfig()) + Set(cfg, []string{"editor"}, "vim") + Set(cfg, []string{"hosts", "github.com", "git_protocol"}, "https") + return cfg + }, }, { - name: "get enterprise git_protocol value", - host: "enterprise.com", - key: "git_protocol", - wantValue: "https", + name: "only writes hosts file", + createConfig: func() *Config { + cfg := ReadFromString(testFullConfig()) + Set(cfg, []string{"hosts", "enterprise.com", "git_protocol"}, "ssh") + return cfg + }, + wantConfig: func() *Config { + // The hosts file is writen but not the general config file. + // When we use Read in the test the defaultGeneralEntries are used. + cfg := ReadFromString(defaultGeneralEntries) + Set(cfg, []string{"hosts", "github.com", "user"}, "user1") + Set(cfg, []string{"hosts", "github.com", "oauth_token"}, "xxxxxxxxxxxxxxxxxxxx") + Set(cfg, []string{"hosts", "github.com", "git_protocol"}, "ssh") + Set(cfg, []string{"hosts", "enterprise.com", "user"}, "user2") + Set(cfg, []string{"hosts", "enterprise.com", "oauth_token"}, "yyyyyyyyyyyyyyyyyyyy") + Set(cfg, []string{"hosts", "enterprise.com", "git_protocol"}, "ssh") + return cfg + }, }, { - name: "unknown host", - host: "unknown", - key: "user", - wantErr: true, - wantErrMsg: "not found", + name: "only writes config file", + createConfig: func() *Config { + cfg := ReadFromString(testFullConfig()) + Set(cfg, []string{"editor"}, "vim") + return cfg + }, + wantConfig: func() *Config { + // The general config file is written but not the hosts config file. + // When we use Read in the test there will not be any hosts entries. + cfg := ReadFromString(testFullConfig()) + Set(cfg, []string{"editor"}, "vim") + _ = Remove(cfg, []string{"hosts"}) + return cfg + }, }, { - name: "unknown key", - host: "github.com", - key: "unknown", - wantErr: true, - wantErrMsg: "not found", + name: "write default config file keeps comments", + createConfig: func() *Config { + cfg := ReadFromString(defaultGeneralEntries) + cfg.entries.SetModified() + return cfg + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - value, err := cfg.GetForHost(tt.host, tt.key) - if tt.wantErr { - assert.EqualError(t, err, tt.wantErrMsg) - return - } + tempDir := t.TempDir() + old := os.Getenv("GH_CONFIG_DIR") + os.Setenv("GH_CONFIG_DIR", tempDir) + defer os.Setenv("GH_CONFIG_DIR", old) + cfg := tt.createConfig() + err := Write(cfg) assert.NoError(t, err) - assert.Equal(t, tt.wantValue, value) + loadedCfg, err := Read() + assert.NoError(t, err) + wantCfg := cfg + if tt.wantConfig != nil { + wantCfg = tt.wantConfig() + } + assert.Equal(t, wantCfg.entries.String(), loadedCfg.entries.String()) }) } } -func TestConfigHost(t *testing.T) { +func TestGet(t *testing.T) { tests := []struct { - name string - cfg Config - ghHost string - wantHost string + name string + keys []string + wantValue string + wantErr bool + wantErrMsg string }{ { - name: "GH_HOST if set", - cfg: testLoadedNoHostConfig(), - ghHost: "test.com", - wantHost: "test.com", + name: "get git_protocol value", + keys: []string{"git_protocol"}, + wantValue: "ssh", + }, + { + name: "get editor value", + keys: []string{"editor"}, + wantValue: "", + }, + { + name: "get prompt value", + keys: []string{"prompt"}, + wantValue: "enabled", + }, + { + name: "get pager value", + keys: []string{"pager"}, + wantValue: "less", }, { - name: "authenticated host if only one", - cfg: testLoadedSingleHostConfig(), - wantHost: "enterprise.com", + name: "non-existant key", + keys: []string{"unknown"}, + wantErr: true, + wantErrMsg: `could not find key "unknown"`, + wantValue: "", }, { - name: "default host if more than one authenticated host", - cfg: testLoadedConfig(), - wantHost: "github.com", + name: "nested key", + keys: []string{"nested", "key"}, + wantValue: "value", }, { - name: "default host if no authenticated host", - cfg: testLoadedNoHostConfig(), - wantHost: "github.com", + name: "nested key with same name", + keys: []string{"nested", "pager"}, + wantValue: "more", + }, + { + name: "nested non-existant key", + keys: []string{"nested", "invalid"}, + wantErr: true, + wantErrMsg: `could not find key "invalid"`, + wantValue: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.ghHost != "" { - k := "GH_HOST" - old := os.Getenv(k) - os.Setenv(k, tt.ghHost) - defer os.Setenv(k, old) + cfg := testConfig() + value, err := Get(cfg, tt.keys) + if tt.wantErr { + assert.EqualError(t, err, tt.wantErrMsg) + } else { + assert.NoError(t, err) } - host := tt.cfg.Host() - assert.Equal(t, tt.wantHost, host) + assert.Equal(t, tt.wantValue, value) + assert.False(t, cfg.entries.IsModified()) }) } } -func TestConfigHosts(t *testing.T) { +func TestKeys(t *testing.T) { tests := []struct { - name string - cfg Config - ghHost string - ghToken string - wantHosts []string + name string + findKeys []string + wantKeys []string + wantErr bool + wantErrMsg string }{ { - name: "no known hosts", - cfg: testLoadedNoHostConfig(), - wantHosts: []string{}, - }, - { - name: "includes GH_HOST", - cfg: testLoadedNoHostConfig(), - ghHost: "test.com", - wantHosts: []string{"test.com"}, + name: "top level keys", + findKeys: nil, + wantKeys: []string{"git_protocol", "editor", "prompt", "pager", "nested"}, }, { - name: "includes authenticated hosts", - cfg: testLoadedConfig(), - wantHosts: []string{"github.com", "enterprise.com"}, + name: "nested keys", + findKeys: []string{"nested"}, + wantKeys: []string{"key", "pager"}, }, { - name: "includes default host if environment auth token", - cfg: testLoadedNoHostConfig(), - ghToken: "TOKEN", - wantHosts: []string{"github.com"}, - }, - { - name: "deduplicates hosts", - cfg: testLoadedConfig(), - ghHost: "test.com", - ghToken: "TOKEN", - wantHosts: []string{"test.com", "github.com", "enterprise.com"}, + name: "keys for non-existant nested key", + findKeys: []string{"unknown"}, + wantKeys: nil, + wantErr: true, + wantErrMsg: `could not find key "unknown"`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.ghHost != "" { - k := "GH_HOST" - old := os.Getenv(k) - os.Setenv(k, tt.ghHost) - defer os.Setenv(k, old) - } - if tt.ghToken != "" { - k := "GH_TOKEN" - old := os.Getenv(k) - os.Setenv(k, tt.ghToken) - defer os.Setenv(k, old) + cfg := testConfig() + ks, err := Keys(cfg, tt.findKeys) + if tt.wantErr { + assert.EqualError(t, err, tt.wantErrMsg) + } else { + assert.NoError(t, err) } - hosts := tt.cfg.Hosts() - assert.Equal(t, tt.wantHosts, hosts) + assert.Equal(t, tt.wantKeys, ks) + assert.False(t, cfg.entries.IsModified()) }) } } -func TestConfigAuthToken(t *testing.T) { - orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN") - orig_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN") - orig_GH_TOKEN := os.Getenv("GH_TOKEN") - orig_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN") - t.Cleanup(func() { - os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN) - os.Setenv("GITHUB_ENTERPRISE_TOKEN", orig_GITHUB_ENTERPRISE_TOKEN) - os.Setenv("GH_TOKEN", orig_GH_TOKEN) - os.Setenv("GH_ENTERPRISE_TOKEN", orig_GH_ENTERPRISE_TOKEN) - }) - +func TestRemove(t *testing.T) { tests := []struct { - name string - host string - GITHUB_TOKEN string - GITHUB_ENTERPRISE_TOKEN string - GH_TOKEN string - GH_ENTERPRISE_TOKEN string - cfg Config - wantToken string - wantErr bool - wantErrMsg string + name string + keys []string + wantErr bool + wantErrMsg string }{ { - name: "token for github.com with no env tokens and no config token", - host: "github.com", - cfg: testLoadedNoHostConfig(), - wantErr: true, - wantErrMsg: "not found", - }, - { - name: "token for enterprise.com with no env tokens and no config token", - host: "enterprise.com", - cfg: testLoadedNoHostConfig(), - wantErr: true, - wantErrMsg: "not found", - }, - { - name: "token for github.com with GH_TOKEN, GITHUB_TOKEN, and config token", - host: "github.com", - GH_TOKEN: "GH_TOKEN", - GITHUB_TOKEN: "GITHUB_TOKEN", - cfg: testLoadedConfig(), - wantToken: "GH_TOKEN", - }, - { - name: "token for github.com with GITHUB_TOKEN, and config token", - host: "github.com", - GITHUB_TOKEN: "GITHUB_TOKEN", - cfg: testLoadedConfig(), - wantToken: "GITHUB_TOKEN", + name: "remove top level key", + keys: []string{"pager"}, }, { - name: "token for github.com with config token", - host: "github.com", - cfg: testLoadedConfig(), - wantToken: "xxxxxxxxxxxxxxxxxxxx", + name: "remove nested key", + keys: []string{"nested", "pager"}, }, { - name: "token for enterprise.com with GH_ENTERPRISE_TOKEN, GITHUB_ENTERPRISE_TOKEN, and config token", - host: "enterprise.com", - GH_ENTERPRISE_TOKEN: "GH_ENTERPRISE_TOKEN", - GITHUB_ENTERPRISE_TOKEN: "GITHUB_ENTERPRISE_TOKEN", - cfg: testLoadedConfig(), - wantToken: "GH_ENTERPRISE_TOKEN", + name: "remove top level map", + keys: []string{"nested"}, }, { - name: "token for enterprise.com with GITHUB_ENTERPRISE_TOKEN, and config token", - host: "enterprise.com", - GITHUB_ENTERPRISE_TOKEN: "GITHUB_ENTERPRISE_TOKEN", - cfg: testLoadedConfig(), - wantToken: "GITHUB_ENTERPRISE_TOKEN", + name: "remove non-existant top level key", + keys: []string{"unknown"}, + wantErr: true, + wantErrMsg: `could not find key "unknown"`, }, { - name: "token for enterprise.com with config token", - host: "enterprise.com", - cfg: testLoadedConfig(), - wantToken: "yyyyyyyyyyyyyyyyyyyy", + name: "remove non-existant nested key", + keys: []string{"nested", "invalid"}, + wantErr: true, + wantErrMsg: `could not find key "invalid"`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - os.Setenv("GITHUB_TOKEN", tt.GITHUB_TOKEN) - os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.GITHUB_ENTERPRISE_TOKEN) - os.Setenv("GH_TOKEN", tt.GH_TOKEN) - os.Setenv("GH_ENTERPRISE_TOKEN", tt.GH_ENTERPRISE_TOKEN) - token, err := tt.cfg.AuthToken(tt.host) + cfg := testConfig() + err := Remove(cfg, tt.keys) if tt.wantErr { assert.EqualError(t, err, tt.wantErrMsg) - return + assert.False(t, cfg.entries.IsModified()) + } else { + assert.NoError(t, err) + assert.True(t, cfg.entries.IsModified()) } - assert.NoError(t, err) - assert.Equal(t, tt.wantToken, token) + _, getErr := Get(cfg, tt.keys) + assert.Error(t, getErr) }) } } -func TestLoad(t *testing.T) { - tempDir := t.TempDir() - oldWd, _ := os.Getwd() - assert.NoError(t, os.Chdir(tempDir)) - t.Cleanup(func() { _ = os.Chdir(oldWd) }) - - globalFilePath := filepath.Join(tempDir, "config.yml") - invalidGlobalFilePath := filepath.Join(tempDir, "invalid_config.yml") - hostsFilePath := filepath.Join(tempDir, "hosts.yml") - invalidHostsFilePath := filepath.Join(tempDir, "invalid_hosts.yml") - err := os.WriteFile(globalFilePath, []byte(testGlobalConfig()), 0755) - assert.NoError(t, err) - err = os.WriteFile(invalidGlobalFilePath, []byte("invalid"), 0755) - assert.NoError(t, err) - err = os.WriteFile(hostsFilePath, []byte(testHostsConfig()), 0755) - assert.NoError(t, err) - err = os.WriteFile(invalidHostsFilePath, []byte("invalid"), 0755) - assert.NoError(t, err) - +func TestSet(t *testing.T) { tests := []struct { - name string - globalConfigPath string - hostsConfigPath string - wantGitProtocol string - wantToken string - wantErr bool - wantErrMsg string - wantGetErr bool - wantGetErrMsg string + name string + keys []string + value string }{ { - name: "global and hosts files exist", - globalConfigPath: globalFilePath, - hostsConfigPath: hostsFilePath, - wantGitProtocol: "ssh", - wantToken: "yyyyyyyyyyyyyyyyyyyy", + name: "set top level existing key", + keys: []string{"pager"}, + value: "test pager", }, { - name: "invalid global file", - globalConfigPath: invalidGlobalFilePath, - wantErr: true, - wantErrMsg: "invalid config file", + name: "set nested existing key", + keys: []string{"nested", "pager"}, + value: "new pager", }, { - name: "invalid hosts file", - globalConfigPath: globalFilePath, - hostsConfigPath: invalidHostsFilePath, - wantErr: true, - wantErrMsg: "invalid config file", + name: "set top level map", + keys: []string{"nested"}, + value: "override", }, { - name: "global file does not exist and hosts file exist", - globalConfigPath: "", - hostsConfigPath: hostsFilePath, - wantGitProtocol: "https", - wantToken: "yyyyyyyyyyyyyyyyyyyy", + name: "set non-existant top level key", + keys: []string{"unknown"}, + value: "why not", }, { - name: "global file exist and hosts file does not exist", - globalConfigPath: globalFilePath, - hostsConfigPath: "", - wantGitProtocol: "ssh", - wantGetErr: true, - wantGetErrMsg: "not found", + name: "set non-existant nested key", + keys: []string{"nested", "invalid"}, + value: "sure", + }, + { + name: "set non-existant nest", + keys: []string{"johnny", "test"}, + value: "dukey", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfg, err := load(tt.globalConfigPath, tt.hostsConfigPath) - if tt.wantErr { - assert.EqualError(t, err, tt.wantErrMsg) - return - } + cfg := testConfig() + Set(cfg, tt.keys, tt.value) + assert.True(t, cfg.entries.IsModified()) + value, err := Get(cfg, tt.keys) assert.NoError(t, err) - - git_protocol, err := cfg.Get("git_protocol") - assert.NoError(t, err) - assert.Equal(t, tt.wantGitProtocol, git_protocol) - - token, err := cfg.GetForHost("enterprise.com", "oauth_token") - if tt.wantGetErr { - assert.EqualError(t, err, tt.wantGetErrMsg) - return - } - assert.NoError(t, err) - assert.Equal(t, tt.wantToken, token) + assert.Equal(t, tt.value, value) }) } } -func TestDefaultConfig(t *testing.T) { - cfg := defaultConfig() +func TestDefaultGeneralEntries(t *testing.T) { + cfg := ReadFromString(defaultGeneralEntries) - git_protocol, err := cfg.Get("git_protocol") + protocol, err := Get(cfg, []string{"git_protocol"}) assert.NoError(t, err) - assert.Equal(t, "https", git_protocol) + assert.Equal(t, "https", protocol) - editor, err := cfg.Get("editor") + editor, err := Get(cfg, []string{"editor"}) assert.NoError(t, err) assert.Equal(t, "", editor) - prompt, err := cfg.Get("prompt") + prompt, err := Get(cfg, []string{"prompt"}) assert.NoError(t, err) assert.Equal(t, "enabled", prompt) - pager, err := cfg.Get("pager") + pager, err := Get(cfg, []string{"pager"}) assert.NoError(t, err) assert.Equal(t, "", pager) - unix_socket, err := cfg.Get("http_unix_socket") + socket, err := Get(cfg, []string{"http_unix_socket"}) assert.NoError(t, err) - assert.Equal(t, "", unix_socket) + assert.Equal(t, "", socket) - browser, err := cfg.Get("browser") + browser, err := Get(cfg, []string{"browser"}) assert.NoError(t, err) assert.Equal(t, "", browser) - _, err = cfg.Get("unknown") - assert.EqualError(t, err, "not found") + unknown, err := Get(cfg, []string{"unknown"}) + assert.EqualError(t, err, `could not find key "unknown"`) + assert.Equal(t, "", unknown) +} + +func testConfig() *Config { + var data = ` +git_protocol: ssh +editor: +prompt: enabled +pager: less +nested: + key: value + pager: more +` + return ReadFromString(data) } -func testGlobalConfig() string { +func testGlobalData() string { var data = ` git_protocol: ssh editor: @@ -688,7 +657,7 @@ pager: less return data } -func testHostsConfig() string { +func testHostsData() string { var data = ` github.com: user: user1 @@ -702,7 +671,7 @@ enterprise.com: return data } -func testLoadedConfig() Config { +func testFullConfig() string { var data = ` git_protocol: ssh editor: @@ -718,27 +687,5 @@ hosts: oauth_token: yyyyyyyyyyyyyyyyyyyy git_protocol: https ` - cfg, _ := FromString(data) - return cfg -} - -func testLoadedSingleHostConfig() Config { - var data = ` -git_protocol: ssh -editor: -prompt: enabled -pager: less -hosts: - enterprise.com: - user: user2 - oauth_token: yyyyyyyyyyyyyyyyyyyy - git_protocol: https -` - cfg, _ := FromString(data) - return cfg -} - -func testLoadedNoHostConfig() Config { - cfg, _ := FromString(testGlobalConfig()) - return cfg + return data } diff --git a/pkg/config/errors.go b/pkg/config/errors.go new file mode 100644 index 0000000..77de5ef --- /dev/null +++ b/pkg/config/errors.go @@ -0,0 +1,27 @@ +package config + +import ( + "fmt" +) + +// InvalidConfigFileError represents an error when trying to read a config file. +type InvalidConfigFileError struct { + Path string + Err error +} + +// Allow InvalidConfigFileError to satisfy error interface. +func (e InvalidConfigFileError) Error() string { + return fmt.Sprintf("invalid config file %s: %s", e.Path, e.Err) +} + +// NotFoundError represents an error when trying to find a config key +// that does not exist. +type NotFoundError struct { + Key string +} + +// Allow NotFoundError to satisfy error interface. +func (e NotFoundError) Error() string { + return fmt.Sprintf("could not find key %q", e.Key) +} diff --git a/pkg/repository/repository.go b/pkg/repository/repository.go index ea2c170..f5dfa9a 100644 --- a/pkg/repository/repository.go +++ b/pkg/repository/repository.go @@ -6,9 +6,10 @@ import ( "fmt" "strings" - "github.com/cli/go-gh/internal/config" "github.com/cli/go-gh/internal/git" irepo "github.com/cli/go-gh/internal/repository" + "github.com/cli/go-gh/pkg/auth" + "github.com/cli/go-gh/pkg/config" ) // Repository is the interface that wraps repository information methods. @@ -48,9 +49,9 @@ func Parse(s string) (Repository, error) { return irepo.New(parts[0], parts[1], parts[2]), nil case 2: host := "github.com" - cfg, err := config.Load() + cfg, err := config.Read() if err == nil { - host = cfg.Host() + host, _ = auth.DefaultHost(cfg) } return irepo.New(host, parts[0], parts[1]), nil default: diff --git a/pkg/repository/repository_test.go b/pkg/repository/repository_test.go index 3af334f..2784226 100644 --- a/pkg/repository/repository_test.go +++ b/pkg/repository/repository_test.go @@ -102,7 +102,10 @@ func TestParse_hostFromConfig(t *testing.T) { tempDir := t.TempDir() old := os.Getenv("GH_CONFIG_DIR") os.Setenv("GH_CONFIG_DIR", tempDir) - defer os.Setenv("GH_CONFIG_DIR", old) + t.Cleanup(func() { + os.Setenv("GH_CONFIG_DIR", old) + }) + var configData = ` git_protocol: ssh editor: From faad1a1134d27bef44ce2e98f82084e3864339e6 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Thu, 26 May 2022 10:34:51 +0200 Subject: [PATCH 2/5] Address PR comments --- gh.go | 34 +++++-------------- gh_test.go | 36 ++++++++++++++------ pkg/auth/auth.go | 64 +++++++++++++++++++++++++----------- pkg/auth/auth_test.go | 6 ++-- pkg/config/config.go | 13 ++++---- pkg/config/errors.go | 13 +++++--- pkg/repository/repository.go | 7 +--- 7 files changed, 98 insertions(+), 75 deletions(-) diff --git a/gh.go b/gh.go index ba4a042..d9b3402 100644 --- a/gh.go +++ b/gh.go @@ -62,11 +62,7 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) { opts = &api.ClientOptions{} } if optionsNeedResolution(opts) { - cfg, err := config.Read() - if err != nil { - return nil, err - } - err = resolveOptions(opts, cfg) + err := resolveOptions(opts) if err != nil { return nil, err } @@ -83,11 +79,7 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) { opts = &api.ClientOptions{} } if optionsNeedResolution(opts) { - cfg, err := config.Read() - if err != nil { - return nil, err - } - err = resolveOptions(opts, cfg) + err := resolveOptions(opts) if err != nil { return nil, err } @@ -109,11 +101,7 @@ func HTTPClient(opts *api.ClientOptions) (*http.Client, error) { opts = &api.ClientOptions{} } if optionsNeedResolution(opts) { - cfg, err := config.Read() - if err != nil { - return nil, err - } - err = resolveOptions(opts, cfg) + err := resolveOptions(opts) if err != nil { return nil, err } @@ -141,12 +129,7 @@ func CurrentRepository() (repo.Repository, error) { translator := ssh.NewTranslator() translateRemotes(remotes, translator) - cfg, err := config.Read() - if err != nil { - return nil, err - } - - hosts := auth.KnownHosts(cfg) + hosts := auth.KnownHosts() filteredRemotes := remotes.FilterByHosts(hosts) if len(filteredRemotes) == 0 { @@ -170,17 +153,18 @@ func optionsNeedResolution(opts *api.ClientOptions) bool { return false } -func resolveOptions(opts *api.ClientOptions, cfg *config.Config) error { +func resolveOptions(opts *api.ClientOptions) error { + cfg, _ := config.Read() if opts.Host == "" { - opts.Host, _ = auth.DefaultHost(cfg) + opts.Host, _ = auth.DefaultHost() } if opts.AuthToken == "" { - opts.AuthToken, _ = auth.TokenForHost(cfg, opts.Host) + opts.AuthToken, _ = auth.TokenForHost(opts.Host) if opts.AuthToken == "" { return fmt.Errorf("authentication token not found for host %s", opts.Host) } } - if opts.UnixDomainSocket == "" { + if opts.UnixDomainSocket == "" && cfg != nil { opts.UnixDomainSocket, _ = config.Get(cfg, []string{"http_unix_socket"}) } return nil diff --git a/gh_test.go b/gh_test.go index 7338bfe..524ea79 100644 --- a/gh_test.go +++ b/gh_test.go @@ -4,11 +4,11 @@ import ( "fmt" "net/http" "os" + "path/filepath" "strings" "testing" "github.com/cli/go-gh/pkg/api" - "github.com/cli/go-gh/pkg/config" "github.com/stretchr/testify/assert" "gopkg.in/h2non/gock.v1" ) @@ -162,7 +162,18 @@ func TestHTTPClient(t *testing.T) { } func TestResolveOptions(t *testing.T) { - cfg := testConfig() + tempDir := t.TempDir() + orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR") + t.Cleanup(func() { + os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR) + }) + os.Setenv("GH_CONFIG_DIR", tempDir) + globalFilePath := filepath.Join(tempDir, "config.yml") + hostsFilePath := filepath.Join(tempDir, "hosts.yml") + err := os.WriteFile(globalFilePath, []byte(testGlobalData()), 0755) + assert.NoError(t, err) + err = os.WriteFile(hostsFilePath, []byte(testHostsData()), 0755) + assert.NoError(t, err) tests := []struct { name string @@ -193,7 +204,7 @@ func TestResolveOptions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := resolveOptions(tt.opts, cfg) + err := resolveOptions(tt.opts) assert.NoError(t, err) assert.Equal(t, tt.wantHost, tt.opts.Host) assert.Equal(t, tt.wantAuthToken, tt.opts.AuthToken) @@ -303,16 +314,21 @@ func TestOptionsNeedResolution(t *testing.T) { } } -func testConfig() *config.Config { +func testGlobalData() string { var data = ` -hosts: - github.com: - user: user1 - oauth_token: token - git_protocol: ssh http_unix_socket: socket ` - return config.ReadFromString(data) + return data +} + +func testHostsData() string { + var data = ` +github.com: + user: user1 + oauth_token: token + git_protocol: ssh +` + return data } func printPendingMocks(mocks []gock.Mock) string { diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index fc1b10b..6939539 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -9,7 +9,7 @@ import ( ) const ( - defaultHost = "github.com" + github = "github.com" ghEnterpriseToken = "GH_ENTERPRISE_TOKEN" ghHost = "GH_HOST" ghToken = "GH_TOKEN" @@ -19,7 +19,12 @@ const ( hostsKey = "hosts" ) -func TokenForHost(cfg *config.Config, host string) (string, string) { +func TokenForHost(host string) (string, string) { + cfg, _ := config.Read() + return tokenForHost(cfg, host) +} + +func tokenForHost(cfg *config.Config, host string) (string, string) { host = normalizeHostname(host) if isEnterprise(host) { if token := os.Getenv(ghEnterpriseToken); token != "" { @@ -28,8 +33,10 @@ func TokenForHost(cfg *config.Config, host string) (string, string) { if token := os.Getenv(githubEnterpriseToken); token != "" { return token, githubEnterpriseToken } - token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken}) - return token, oauthToken + if cfg != nil { + token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken}) + return token, oauthToken + } } if token := os.Getenv(ghToken); token != "" { return token, ghToken @@ -37,44 +44,61 @@ func TokenForHost(cfg *config.Config, host string) (string, string) { if token := os.Getenv(githubToken); token != "" { return token, githubToken } - token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken}) - return token, oauthToken + if cfg != nil { + token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken}) + return token, oauthToken + } + return "", "" +} + +func KnownHosts() []string { + cfg, _ := config.Read() + return knownHosts(cfg) } -func KnownHosts(cfg *config.Config) []string { +func knownHosts(cfg *config.Config) []string { hosts := set.NewStringSet() if host := os.Getenv(ghHost); host != "" { hosts.Add(host) } - if token, _ := TokenForHost(cfg, defaultHost); token != "" { - hosts.Add(defaultHost) + if token, _ := tokenForHost(cfg, github); token != "" { + hosts.Add(github) } - keys, err := config.Keys(cfg, []string{hostsKey}) - if err == nil { - hosts.AddValues(keys) + if cfg != nil { + keys, err := config.Keys(cfg, []string{hostsKey}) + if err == nil { + hosts.AddValues(keys) + } } return hosts.ToSlice() } -func DefaultHost(cfg *config.Config) (string, string) { +func DefaultHost() (string, string) { + cfg, _ := config.Read() + return defaultHost(cfg) +} + +func defaultHost(cfg *config.Config) (string, string) { if host := os.Getenv(ghHost); host != "" { return host, ghHost } - keys, err := config.Keys(cfg, []string{hostsKey}) - if err == nil && len(keys) == 1 { - return keys[0], hostsKey + if cfg != nil { + keys, err := config.Keys(cfg, []string{hostsKey}) + if err == nil && len(keys) == 1 { + return keys[0], hostsKey + } } - return defaultHost, "default" + return github, "default" } func isEnterprise(host string) bool { - return host != defaultHost + return host != github } func normalizeHostname(host string) string { hostname := strings.ToLower(host) - if strings.HasSuffix(hostname, "."+defaultHost) { - return defaultHost + if strings.HasSuffix(hostname, "."+github) { + return github } return hostname } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 42a80a7..a07cdbf 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -104,7 +104,7 @@ func TestTokenForHost(t *testing.T) { os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.githubEnterpriseToken) os.Setenv("GH_TOKEN", tt.ghToken) os.Setenv("GH_ENTERPRISE_TOKEN", tt.ghEnterpriseToken) - token, source := TokenForHost(tt.config, tt.host) + token, source := tokenForHost(tt.config, tt.host) assert.Equal(t, tt.wantToken, token) assert.Equal(t, tt.wantSource, source) }) @@ -157,7 +157,7 @@ func TestDefaultHost(t *testing.T) { os.Setenv(k, tt.ghHost) defer os.Setenv(k, old) } - host, source := DefaultHost(tt.config) + host, source := defaultHost(tt.config) assert.Equal(t, tt.wantHost, host) assert.Equal(t, tt.wantSource, source) }) @@ -217,7 +217,7 @@ func TestKnownHosts(t *testing.T) { os.Setenv(k, tt.ghToken) defer os.Setenv(k, old) } - hosts := KnownHosts(tt.config) + hosts := knownHosts(tt.config) assert.Equal(t, tt.wantHosts, hosts) }) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 56962ad..e89c60f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -10,7 +10,6 @@ import ( "path/filepath" "runtime" - "github.com/MakeNowJust/heredoc" "github.com/cli/go-gh/internal/yamlmap" ) @@ -33,7 +32,7 @@ func Get(c *Config, keys []string) (string, error) { var err error m, err = m.FindEntry(key) if err != nil { - return "", NotFoundError{key} + return "", KeyNotFoundError{key} } } return m.Value, nil @@ -45,7 +44,7 @@ func Keys(c *Config, keys []string) ([]string, error) { var err error m, err = m.FindEntry(key) if err != nil { - return nil, NotFoundError{key} + return nil, KeyNotFoundError{key} } } return m.Keys(), nil @@ -58,12 +57,12 @@ func Remove(c *Config, keys []string) error { key := keys[i] m, err = m.FindEntry(key) if err != nil { - return NotFoundError{key} + return KeyNotFoundError{key} } } err := m.RemoveEntry(keys[len(keys)-1]) if err != nil { - return NotFoundError{keys[len(keys)-1]} + return KeyNotFoundError{keys[len(keys)-1]} } return nil } @@ -250,7 +249,7 @@ func writeFile(filename string, data []byte) error { return err } -var defaultGeneralEntries = heredoc.Doc(` +var defaultGeneralEntries = ` # What protocol to use when performing git operations. Supported values: ssh, https git_protocol: https # What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment. @@ -266,4 +265,4 @@ aliases: http_unix_socket: # What web browser gh should use when opening URLs. If blank, will refer to environment. browser: -`) +` diff --git a/pkg/config/errors.go b/pkg/config/errors.go index 77de5ef..28120ac 100644 --- a/pkg/config/errors.go +++ b/pkg/config/errors.go @@ -15,13 +15,18 @@ func (e InvalidConfigFileError) Error() string { return fmt.Sprintf("invalid config file %s: %s", e.Path, e.Err) } -// NotFoundError represents an error when trying to find a config key +// Allow InvalidConfigFileError to be unwrapped. +func (e InvalidConfigFileError) Unwrap() error { + return e.Err +} + +// KeyNotFoundError represents an error when trying to find a config key // that does not exist. -type NotFoundError struct { +type KeyNotFoundError struct { Key string } -// Allow NotFoundError to satisfy error interface. -func (e NotFoundError) Error() string { +// Allow KeyNotFoundError to satisfy error interface. +func (e KeyNotFoundError) Error() string { return fmt.Sprintf("could not find key %q", e.Key) } diff --git a/pkg/repository/repository.go b/pkg/repository/repository.go index f5dfa9a..3e86c9c 100644 --- a/pkg/repository/repository.go +++ b/pkg/repository/repository.go @@ -9,7 +9,6 @@ import ( "github.com/cli/go-gh/internal/git" irepo "github.com/cli/go-gh/internal/repository" "github.com/cli/go-gh/pkg/auth" - "github.com/cli/go-gh/pkg/config" ) // Repository is the interface that wraps repository information methods. @@ -48,11 +47,7 @@ func Parse(s string) (Repository, error) { case 3: return irepo.New(parts[0], parts[1], parts[2]), nil case 2: - host := "github.com" - cfg, err := config.Read() - if err == nil { - host, _ = auth.DefaultHost(cfg) - } + host, _ := auth.DefaultHost() return irepo.New(host, parts[0], parts[1]), nil default: return nil, fmt.Errorf(`expected the "[HOST/]OWNER/REPO" format, got %q`, s) From ae1157f1caa8ca93bdbe1c60f6e5054015e1f009 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Thu, 26 May 2022 14:46:43 +0200 Subject: [PATCH 3/5] Add comments --- pkg/auth/auth.go | 14 ++++++++++++++ pkg/config/config.go | 29 ++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 6939539..f4fc7d7 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -1,3 +1,5 @@ +// Package auth is a set of functions for retrieving authentication tokens +// and authenticated hosts. package auth import ( @@ -19,6 +21,10 @@ const ( hostsKey = "hosts" ) +// TokenForHost retrieves an authentication token and the source of +// that token for the specified host. The source can be either an +// environment variable or the configuration file. +// Returns blank strings if no applicable token is found. func TokenForHost(host string) (string, string) { cfg, _ := config.Read() return tokenForHost(cfg, host) @@ -51,6 +57,10 @@ func tokenForHost(cfg *config.Config, host string) (string, string) { return "", "" } +// KnownHosts retrieves a list of hosts that have corresponding +// authentication tokens, either from environment variables +// or the configuration file. +// Returns an empty string slice if no hosts are found. func KnownHosts() []string { cfg, _ := config.Read() return knownHosts(cfg) @@ -73,6 +83,10 @@ func knownHosts(cfg *config.Config) []string { return hosts.ToSlice() } +// DefaultHost retrieves an authenticated host and the source of host. +// The source can be either an environment variable or the +// configuration file. +// Returns "github.com", "default" if no viable host is found. func DefaultHost() (string, string) { cfg, _ := config.Read() return defaultHost(cfg) diff --git a/pkg/config/config.go b/pkg/config/config.go index e89c60f..f63a2e7 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,5 +1,5 @@ // Package config is a set of types for interacting with the gh configuration files. -// Note: This package is intended for use only in `gh`, any other use cases are subject +// Note: This package is intended for use only in gh, any other use cases are subject // to breakage and non-backwards compatible updates. package config @@ -22,10 +22,19 @@ const ( xdgStateHome = "XDG_STATE_HOME" ) +// Config is a in memory representation of the gh configuration files. +// It can be thought of as map where entries consist of a key that +// correspond to either a string value or a map value, allowing for +// multi-level maps. type Config struct { entries *yamlmap.Map } +// Get a string value from a Config. +// The keys argument is a sequence of key values so that nested +// entries can be retrieved. A undefined string will be returned +// if trying to retrieve a key that corresponds to a map value. +// Returns "", KeyNotFoundError if any of the keys can not be found. func Get(c *Config, keys []string) (string, error) { m := c.entries for _, key := range keys { @@ -38,6 +47,10 @@ func Get(c *Config, keys []string) (string, error) { return m.Value, nil } +// Keys enumerates a Configs keys. +// The keys argument is a sequence of key values so that nested +// map values can be have their keys enumerated. +// Returns nil, KeyNotFoundError if any of the keys can not be found. func Keys(c *Config, keys []string) ([]string, error) { m := c.entries for _, key := range keys { @@ -50,6 +63,11 @@ func Keys(c *Config, keys []string) ([]string, error) { return m.Keys(), nil } +// Remove an entry from a Config. +// The keys argument is a sequence of key values so that nested +// entries can be removed. Removing an entry that has nested +// entries removes those also. +// Returns "", KeyNotFoundError if any of the keys can not be found. func Remove(c *Config, keys []string) error { m := c.entries for i := 0; i < len(keys)-1; i++ { @@ -67,6 +85,10 @@ func Remove(c *Config, keys []string) error { return nil } +// Set an string value in a Config. +// The keys argument is a sequence of key values so that nested +// entries can be set. If any of the keys do not exist they will +// be created. func Set(c *Config, keys []string, value string) { m := c.entries for i := 0; i < len(keys)-1; i++ { @@ -81,6 +103,8 @@ func Set(c *Config, keys []string, value string) { m.SetEntry(keys[len(keys)-1], yamlmap.StringValue(value)) } +// Read gh configuration files from the local file system and +// return a Config. func Read() (*Config, error) { // TODO: Make global config singleton using sync.Once // so as not to read from file every time. @@ -98,6 +122,9 @@ func ReadFromString(str string) *Config { return &Config{m} } +// Write gh configuration files to the local file system. +// It will only write gh configuration files that have been modified +// since last being read. func Write(config *Config) error { hosts, err := config.entries.FindEntry("hosts") if err == nil && hosts.IsModified() { From 9edf724f176ec43ce9aff33221efd630e8d1b7a6 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Mon, 30 May 2022 10:56:18 +0200 Subject: [PATCH 4/5] Change functions to methods --- gh.go | 2 +- pkg/auth/auth.go | 8 +++--- pkg/config/config.go | 12 ++++----- pkg/config/config_test.go | 54 +++++++++++++++++++-------------------- 4 files changed, 38 insertions(+), 38 deletions(-) diff --git a/gh.go b/gh.go index d9b3402..4597400 100644 --- a/gh.go +++ b/gh.go @@ -165,7 +165,7 @@ func resolveOptions(opts *api.ClientOptions) error { } } if opts.UnixDomainSocket == "" && cfg != nil { - opts.UnixDomainSocket, _ = config.Get(cfg, []string{"http_unix_socket"}) + opts.UnixDomainSocket, _ = cfg.Get([]string{"http_unix_socket"}) } return nil } diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index f4fc7d7..7399dc2 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -40,7 +40,7 @@ func tokenForHost(cfg *config.Config, host string) (string, string) { return token, githubEnterpriseToken } if cfg != nil { - token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken}) + token, _ := cfg.Get([]string{hostsKey, host, oauthToken}) return token, oauthToken } } @@ -51,7 +51,7 @@ func tokenForHost(cfg *config.Config, host string) (string, string) { return token, githubToken } if cfg != nil { - token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken}) + token, _ := cfg.Get([]string{hostsKey, host, oauthToken}) return token, oauthToken } return "", "" @@ -75,7 +75,7 @@ func knownHosts(cfg *config.Config) []string { hosts.Add(github) } if cfg != nil { - keys, err := config.Keys(cfg, []string{hostsKey}) + keys, err := cfg.Keys([]string{hostsKey}) if err == nil { hosts.AddValues(keys) } @@ -97,7 +97,7 @@ func defaultHost(cfg *config.Config) (string, string) { return host, ghHost } if cfg != nil { - keys, err := config.Keys(cfg, []string{hostsKey}) + keys, err := cfg.Keys([]string{hostsKey}) if err == nil && len(keys) == 1 { return keys[0], hostsKey } diff --git a/pkg/config/config.go b/pkg/config/config.go index f63a2e7..6685275 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -35,7 +35,7 @@ type Config struct { // entries can be retrieved. A undefined string will be returned // if trying to retrieve a key that corresponds to a map value. // Returns "", KeyNotFoundError if any of the keys can not be found. -func Get(c *Config, keys []string) (string, error) { +func (c *Config) Get(keys []string) (string, error) { m := c.entries for _, key := range keys { var err error @@ -47,11 +47,11 @@ func Get(c *Config, keys []string) (string, error) { return m.Value, nil } -// Keys enumerates a Configs keys. +// Keys enumerates a Config's keys. // The keys argument is a sequence of key values so that nested // map values can be have their keys enumerated. // Returns nil, KeyNotFoundError if any of the keys can not be found. -func Keys(c *Config, keys []string) ([]string, error) { +func (c *Config) Keys(keys []string) ([]string, error) { m := c.entries for _, key := range keys { var err error @@ -68,7 +68,7 @@ func Keys(c *Config, keys []string) ([]string, error) { // entries can be removed. Removing an entry that has nested // entries removes those also. // Returns "", KeyNotFoundError if any of the keys can not be found. -func Remove(c *Config, keys []string) error { +func (c *Config) Remove(keys []string) error { m := c.entries for i := 0; i < len(keys)-1; i++ { var err error @@ -85,11 +85,11 @@ func Remove(c *Config, keys []string) error { return nil } -// Set an string value in a Config. +// Set a string value in a Config. // The keys argument is a sequence of key values so that nested // entries can be set. If any of the keys do not exist they will // be created. -func Set(c *Config, keys []string, value string) { +func (c *Config) Set(keys []string, value string) { m := c.entries for i := 0; i < len(keys)-1; i++ { key := keys[i] diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 130d45e..f926fea 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -294,10 +294,10 @@ func TestLoad(t *testing.T) { return } assert.NoError(t, err) - protocol, err := Get(cfg, []string{"git_protocol"}) + protocol, err := cfg.Get([]string{"git_protocol"}) assert.NoError(t, err) assert.Equal(t, tt.wantGitProtocol, protocol) - token, err := Get(cfg, []string{"hosts", "enterprise.com", "oauth_token"}) + token, err := cfg.Get([]string{"hosts", "enterprise.com", "oauth_token"}) if tt.wantGetErr { assert.EqualError(t, err, `could not find key "hosts"`) } else { @@ -320,8 +320,8 @@ func TestWrite(t *testing.T) { name: "writes config and hosts files", createConfig: func() *Config { cfg := ReadFromString(testFullConfig()) - Set(cfg, []string{"editor"}, "vim") - Set(cfg, []string{"hosts", "github.com", "git_protocol"}, "https") + cfg.Set([]string{"editor"}, "vim") + cfg.Set([]string{"hosts", "github.com", "git_protocol"}, "https") return cfg }, }, @@ -329,19 +329,19 @@ func TestWrite(t *testing.T) { name: "only writes hosts file", createConfig: func() *Config { cfg := ReadFromString(testFullConfig()) - Set(cfg, []string{"hosts", "enterprise.com", "git_protocol"}, "ssh") + cfg.Set([]string{"hosts", "enterprise.com", "git_protocol"}, "ssh") return cfg }, wantConfig: func() *Config { // The hosts file is writen but not the general config file. // When we use Read in the test the defaultGeneralEntries are used. cfg := ReadFromString(defaultGeneralEntries) - Set(cfg, []string{"hosts", "github.com", "user"}, "user1") - Set(cfg, []string{"hosts", "github.com", "oauth_token"}, "xxxxxxxxxxxxxxxxxxxx") - Set(cfg, []string{"hosts", "github.com", "git_protocol"}, "ssh") - Set(cfg, []string{"hosts", "enterprise.com", "user"}, "user2") - Set(cfg, []string{"hosts", "enterprise.com", "oauth_token"}, "yyyyyyyyyyyyyyyyyyyy") - Set(cfg, []string{"hosts", "enterprise.com", "git_protocol"}, "ssh") + cfg.Set([]string{"hosts", "github.com", "user"}, "user1") + cfg.Set([]string{"hosts", "github.com", "oauth_token"}, "xxxxxxxxxxxxxxxxxxxx") + cfg.Set([]string{"hosts", "github.com", "git_protocol"}, "ssh") + cfg.Set([]string{"hosts", "enterprise.com", "user"}, "user2") + cfg.Set([]string{"hosts", "enterprise.com", "oauth_token"}, "yyyyyyyyyyyyyyyyyyyy") + cfg.Set([]string{"hosts", "enterprise.com", "git_protocol"}, "ssh") return cfg }, }, @@ -349,15 +349,15 @@ func TestWrite(t *testing.T) { name: "only writes config file", createConfig: func() *Config { cfg := ReadFromString(testFullConfig()) - Set(cfg, []string{"editor"}, "vim") + cfg.Set([]string{"editor"}, "vim") return cfg }, wantConfig: func() *Config { // The general config file is written but not the hosts config file. // When we use Read in the test there will not be any hosts entries. cfg := ReadFromString(testFullConfig()) - Set(cfg, []string{"editor"}, "vim") - _ = Remove(cfg, []string{"hosts"}) + cfg.Set([]string{"editor"}, "vim") + _ = cfg.Remove([]string{"hosts"}) return cfg }, }, @@ -448,7 +448,7 @@ func TestGet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := testConfig() - value, err := Get(cfg, tt.keys) + value, err := cfg.Get(tt.keys) if tt.wantErr { assert.EqualError(t, err, tt.wantErrMsg) } else { @@ -490,7 +490,7 @@ func TestKeys(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := testConfig() - ks, err := Keys(cfg, tt.findKeys) + ks, err := cfg.Keys(tt.findKeys) if tt.wantErr { assert.EqualError(t, err, tt.wantErrMsg) } else { @@ -538,7 +538,7 @@ func TestRemove(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := testConfig() - err := Remove(cfg, tt.keys) + err := cfg.Remove(tt.keys) if tt.wantErr { assert.EqualError(t, err, tt.wantErrMsg) assert.False(t, cfg.entries.IsModified()) @@ -546,7 +546,7 @@ func TestRemove(t *testing.T) { assert.NoError(t, err) assert.True(t, cfg.entries.IsModified()) } - _, getErr := Get(cfg, tt.keys) + _, getErr := cfg.Get(tt.keys) assert.Error(t, getErr) }) } @@ -593,9 +593,9 @@ func TestSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := testConfig() - Set(cfg, tt.keys, tt.value) + cfg.Set(tt.keys, tt.value) assert.True(t, cfg.entries.IsModified()) - value, err := Get(cfg, tt.keys) + value, err := cfg.Get(tt.keys) assert.NoError(t, err) assert.Equal(t, tt.value, value) }) @@ -605,31 +605,31 @@ func TestSet(t *testing.T) { func TestDefaultGeneralEntries(t *testing.T) { cfg := ReadFromString(defaultGeneralEntries) - protocol, err := Get(cfg, []string{"git_protocol"}) + protocol, err := cfg.Get([]string{"git_protocol"}) assert.NoError(t, err) assert.Equal(t, "https", protocol) - editor, err := Get(cfg, []string{"editor"}) + editor, err := cfg.Get([]string{"editor"}) assert.NoError(t, err) assert.Equal(t, "", editor) - prompt, err := Get(cfg, []string{"prompt"}) + prompt, err := cfg.Get([]string{"prompt"}) assert.NoError(t, err) assert.Equal(t, "enabled", prompt) - pager, err := Get(cfg, []string{"pager"}) + pager, err := cfg.Get([]string{"pager"}) assert.NoError(t, err) assert.Equal(t, "", pager) - socket, err := Get(cfg, []string{"http_unix_socket"}) + socket, err := cfg.Get([]string{"http_unix_socket"}) assert.NoError(t, err) assert.Equal(t, "", socket) - browser, err := Get(cfg, []string{"browser"}) + browser, err := cfg.Get([]string{"browser"}) assert.NoError(t, err) assert.Equal(t, "", browser) - unknown, err := Get(cfg, []string{"unknown"}) + unknown, err := cfg.Get([]string{"unknown"}) assert.EqualError(t, err, `could not find key "unknown"`) assert.Equal(t, "", unknown) } From bb3e52bc6da80e0dd96d7f7949917a21983a3176 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Mon, 30 May 2022 11:09:16 +0200 Subject: [PATCH 5/5] Small polish --- pkg/auth/auth.go | 17 +++++++++-------- pkg/config/config.go | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 7399dc2..6c663ee 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -11,20 +11,21 @@ import ( ) const ( - github = "github.com" + defaultSource = "default" ghEnterpriseToken = "GH_ENTERPRISE_TOKEN" ghHost = "GH_HOST" ghToken = "GH_TOKEN" + github = "github.com" githubEnterpriseToken = "GITHUB_ENTERPRISE_TOKEN" githubToken = "GITHUB_TOKEN" - oauthToken = "oauth_token" hostsKey = "hosts" + oauthToken = "oauth_token" ) // TokenForHost retrieves an authentication token and the source of // that token for the specified host. The source can be either an -// environment variable or the configuration file. -// Returns blank strings if no applicable token is found. +// environment variable or from the configuration file. +// Returns "", "default" if no applicable token is found. func TokenForHost(host string) (string, string) { cfg, _ := config.Read() return tokenForHost(cfg, host) @@ -54,12 +55,12 @@ func tokenForHost(cfg *config.Config, host string) (string, string) { token, _ := cfg.Get([]string{hostsKey, host, oauthToken}) return token, oauthToken } - return "", "" + return "", defaultSource } // KnownHosts retrieves a list of hosts that have corresponding // authentication tokens, either from environment variables -// or the configuration file. +// or from the configuration file. // Returns an empty string slice if no hosts are found. func KnownHosts() []string { cfg, _ := config.Read() @@ -84,7 +85,7 @@ func knownHosts(cfg *config.Config) []string { } // DefaultHost retrieves an authenticated host and the source of host. -// The source can be either an environment variable or the +// The source can be either an environment variable or from the // configuration file. // Returns "github.com", "default" if no viable host is found. func DefaultHost() (string, string) { @@ -102,7 +103,7 @@ func defaultHost(cfg *config.Config) (string, string) { return keys[0], hostsKey } } - return github, "default" + return github, defaultSource } func isEnterprise(host string) bool { diff --git a/pkg/config/config.go b/pkg/config/config.go index 6685275..307434d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -67,7 +67,7 @@ func (c *Config) Keys(keys []string) ([]string, error) { // The keys argument is a sequence of key values so that nested // entries can be removed. Removing an entry that has nested // entries removes those also. -// Returns "", KeyNotFoundError if any of the keys can not be found. +// Returns KeyNotFoundError if any of the keys can not be found. func (c *Config) Remove(keys []string) error { m := c.entries for i := 0; i < len(keys)-1; i++ {