diff --git a/gh.go b/gh.go index 4597400..43dd6e7 100644 --- a/gh.go +++ b/gh.go @@ -127,7 +127,14 @@ func CurrentRepository() (repo.Repository, error) { } translator := ssh.NewTranslator() - translateRemotes(remotes, translator) + for _, r := range remotes { + if r.FetchURL != nil { + r.FetchURL = translator.Translate(r.FetchURL) + } + if r.PushURL != nil { + r.PushURL = translator.Translate(r.PushURL) + } + } hosts := auth.KnownHosts() @@ -169,14 +176,3 @@ func resolveOptions(opts *api.ClientOptions) error { } return nil } - -func translateRemotes(remotes git.RemoteSet, translator ssh.Translator) { - for _, r := range remotes { - if r.FetchURL != nil { - r.FetchURL = translator.Translate(r.FetchURL) - } - if r.PushURL != nil { - r.PushURL = translator.Translate(r.PushURL) - } - } -} diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index 4ec651e..0ef9137 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -1,79 +1,39 @@ -// Package ssh is a set of types and functions for parsing and -// applying a user's SSH hostname aliases. +// Package ssh resolves local SSH hostname aliases. package ssh import ( "bufio" - "io" "net/url" - "os" - "path/filepath" - "regexp" + "os/exec" "strings" -) + "sync" -var ( - configLineRE = regexp.MustCompile(`\A\s*(?P[A-Za-z][A-Za-z0-9]*)(?:\s+|\s*=\s*)(?P.+)`) - tokenRE = regexp.MustCompile(`%[%h]`) + "github.com/cli/safeexec" ) -// Translator is the interface that encapsulates the SSH hostname alias translate method. -type Translator interface { - Translate(*url.URL) *url.URL -} - -type config struct { - aliases map[string]string -} - -type parser struct { - dir string - cfg config - hosts []string - open func(string) (io.Reader, error) - glob func(string) ([]string, error) -} - -// NewTranslator constructs a map of SSH hostname aliases based on user and system configuration files. -// It returns a Translator to apply these mappings. -func NewTranslator() Translator { - configFiles := []string{ - "/etc/ssh_config", - "/etc/ssh/ssh_config", - } - - p := parser{} +type Translator struct { + cacheMap map[string]string + cacheMu sync.RWMutex + sshPath string + sshPathErr error + sshPathMu sync.Mutex - if sshDir, err := homeDirPath(".ssh"); err == nil { - userConfig := filepath.Join(sshDir, "config") - configFiles = append([]string{userConfig}, configFiles...) - p.dir = filepath.Dir(sshDir) - } - - for _, file := range configFiles { - _ = p.read(file) - } - - return p.cfg + lookPath func(string) (string, error) + newCommand func(string, ...string) *exec.Cmd } -func homeDirPath(subdir string) (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", err - } - - newPath := filepath.Join(homeDir, subdir) - return newPath, nil +// NewTranslator initializes a new Translator instance. +func NewTranslator() *Translator { + return &Translator{} } // Translate applies applicable SSH hostname aliases to the specified URL and returns the resulting URL. -func (c config) Translate(u *url.URL) *url.URL { +func (t *Translator) Translate(u *url.URL) *url.URL { if u.Scheme != "ssh" { return u } - resolvedHost, ok := c.aliases[u.Hostname()] - if !ok { + resolvedHost, err := t.resolve(u.Hostname()) + if err != nil { return u } if strings.EqualFold(resolvedHost, "ssh.github.com") { @@ -84,101 +44,62 @@ func (c config) Translate(u *url.URL) *url.URL { return newURL } -func (p *parser) read(fileName string) error { - var file io.Reader - if p.open == nil { - f, err := os.Open(fileName) - if err != nil { - return err - } - defer f.Close() - file = f - } else { - var err error - file, err = p.open(fileName) - if err != nil { - return err - } +func (t *Translator) resolve(hostname string) (string, error) { + t.cacheMu.RLock() + cached, cacheFound := t.cacheMap[strings.ToLower(hostname)] + t.cacheMu.RUnlock() + if cacheFound { + return cached, nil } - if len(p.hosts) == 0 { - p.hosts = []string{"*"} - } - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - m := configLineRE.FindStringSubmatch(scanner.Text()) - if len(m) < 3 { - continue - } - - keyword, arguments := strings.ToLower(m[1]), m[2] - switch keyword { - case "host": - p.hosts = strings.Fields(arguments) - case "hostname": - for _, host := range p.hosts { - for _, name := range strings.Fields(arguments) { - if p.cfg.aliases == nil { - p.cfg.aliases = make(map[string]string) - } - p.cfg.aliases[host] = expandTokens(name, host) - } - } - case "include": - for _, arg := range strings.Fields(arguments) { - path := p.absolutePath(fileName, arg) - - var fileNames []string - if p.glob == nil { - paths, _ := filepath.Glob(path) - for _, p := range paths { - if s, err := os.Stat(p); err == nil && !s.IsDir() { - fileNames = append(fileNames, p) - } - } - } else { - var err error - fileNames, err = p.glob(path) - if err != nil { - continue - } - } - - for _, fileName := range fileNames { - _ = p.read(fileName) - } - } + var sshPath string + t.sshPathMu.Lock() + if t.sshPath == "" && t.sshPathErr == nil { + lookPath := t.lookPath + if lookPath == nil { + lookPath = safeexec.LookPath } + t.sshPath, t.sshPathErr = lookPath("ssh") + } + if t.sshPathErr != nil { + defer t.sshPathMu.Unlock() + return t.sshPath, t.sshPathErr } + sshPath = t.sshPath + t.sshPathMu.Unlock() - return scanner.Err() -} + t.cacheMu.Lock() + defer t.cacheMu.Unlock() -func (p *parser) absolutePath(parentFile, path string) string { - if filepath.IsAbs(path) || strings.HasPrefix(filepath.ToSlash(path), "/") { - return path + newCommand := t.newCommand + if newCommand == nil { + newCommand = exec.Command + } + sshCmd := newCommand(sshPath, "-G", hostname) + stdout, err := sshCmd.StdoutPipe() + if err != nil { + return "", err } - if strings.HasPrefix(path, "~") { - return filepath.Join(p.dir, strings.TrimPrefix(path, "~")) + if err := sshCmd.Start(); err != nil { + return "", err } - if strings.HasPrefix(filepath.ToSlash(parentFile), "/etc/ssh") { - return filepath.Join("/etc/ssh", path) + var resolvedHost string + s := bufio.NewScanner(stdout) + for s.Scan() { + line := s.Text() + parts := strings.SplitN(line, " ", 2) + if len(parts) == 2 && parts[0] == "hostname" { + resolvedHost = parts[1] + } } - return filepath.Join(p.dir, ".ssh", path) -} + _ = sshCmd.Wait() -func expandTokens(text, host string) string { - return tokenRE.ReplaceAllStringFunc(text, func(match string) string { - switch match { - case "%h": - return host - case "%%": - return "%" - } - return "" - }) + if t.cacheMap == nil { + t.cacheMap = map[string]string{} + } + t.cacheMap[strings.ToLower(hostname)] = resolvedHost + return resolvedHost, nil } diff --git a/pkg/ssh/ssh_test.go b/pkg/ssh/ssh_test.go index e24df60..46ea3d4 100644 --- a/pkg/ssh/ssh_test.go +++ b/pkg/ssh/ssh_test.go @@ -1,149 +1,142 @@ package ssh import ( - "bytes" "fmt" - "io" "net/url" - "path/filepath" + "os" + "os/exec" "testing" "github.com/MakeNowJust/heredoc" + "github.com/cli/safeexec" ) -func Test_sshParser_read(t *testing.T) { - testFiles := map[string]string{ - "/etc/ssh/config": heredoc.Doc(` - Include sites/* - `), - "/etc/ssh/sites/cfg1": heredoc.Doc(` - Host s1 - Hostname=site1.net - `), - "/etc/ssh/sites/cfg2": heredoc.Doc(` - Host s2 - Hostname = site2.net - `), - "HOME/.ssh/config": heredoc.Doc(` - Host * - Host gh gittyhubby - Hostname github.com - #Hostname example.com - Host ex - Include ex_config/* - `), - "HOME/.ssh/ex_config/ex_cfg": heredoc.Doc(` - Hostname example.com - `), - } - globResults := map[string][]string{ - "/etc/ssh/sites/*": {"/etc/ssh/sites/cfg1", "/etc/ssh/sites/cfg2"}, - "HOME/.ssh/ex_config/*": {"HOME/.ssh/ex_config/ex_cfg"}, +func TestTranslator(t *testing.T) { + if _, err := safeexec.LookPath("ssh"); err != nil { + t.Skip("no ssh found on system") } - p := &parser{ - dir: "HOME", - open: func(s string) (io.Reader, error) { - if contents, ok := testFiles[filepath.ToSlash(s)]; ok { - return bytes.NewBufferString(contents), nil - } else { - return nil, fmt.Errorf("no test file stub found: %q", s) - } + tests := []struct { + name string + sshConfig string + arg string + want string + }{ + { + name: "translate SSH URL", + sshConfig: heredoc.Doc(` + Host github-* + Hostname github.com + `), + arg: "ssh://git@github-foo/owner/repo.git", + want: "ssh://git@github.com/owner/repo.git", }, - glob: func(p string) ([]string, error) { - if results, ok := globResults[filepath.ToSlash(p)]; ok { - return results, nil - } else { - return nil, fmt.Errorf("no glob stubs found: %q", p) - } + { + name: "does not translate HTTPS URL", + sshConfig: heredoc.Doc(` + Host github-* + Hostname github.com + `), + arg: "https://github-foo/owner/repo.git", + want: "https://github-foo/owner/repo.git", + }, + { + name: "treats ssh.github.com as github.com", + sshConfig: heredoc.Doc(` + Host github.com + Hostname ssh.github.com + `), + arg: "ssh://git@github.com/owner/repo.git", + want: "ssh://git@github.com/owner/repo.git", }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, err := os.CreateTemp("", "ssh-config.*") + if err != nil { + t.Fatalf("error creating file: %v", err) + } + _, err = f.WriteString(tt.sshConfig) + _ = f.Close() + if err != nil { + t.Fatalf("error writing ssh config: %v", err) + } - if err := p.read("/etc/ssh/config"); err != nil { - t.Fatalf("read(global config) = %v", err) - } - if err := p.read("HOME/.ssh/config"); err != nil { - t.Fatalf("read(user config) = %v", err) + tr := &Translator{ + newCommand: func(exe string, args ...string) *exec.Cmd { + args = append([]string{"-F", f.Name()}, args...) + return exec.Command(exe, args...) + }, + } + u, err := url.Parse(tt.arg) + if err != nil { + t.Fatalf("error parsing URL: %v", err) + } + res := tr.Translate(u) + if got := res.String(); got != tt.want { + t.Errorf("expected %q, got %q", tt.want, got) + } + }) } +} - if got := p.cfg.aliases["gh"]; got != "github.com" { - t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got) - } - if got := p.cfg.aliases["gittyhubby"]; got != "github.com" { - t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got) - } - if got := p.cfg.aliases["example.com"]; got != "" { - t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got) +func TestHelperProcess(t *testing.T) { + if os.Getenv("GH_WANT_HELPER_PROCESS") != "1" { + return } - if got := p.cfg.aliases["ex"]; got != "example.com" { - t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got) - } - if got := p.cfg.aliases["s1"]; got != "site1.net" { - t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got) + if err := func(args []string) error { + fmt.Fprint(os.Stdout, "hostname github.com\n") + return nil + }(os.Args[3:]); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) } + os.Exit(0) } -func Test_sshParser_absolutePath(t *testing.T) { - dir := "HOME" - p := &parser{dir: dir} - - tests := map[string]struct { - parentFile string - arg string - want string - }{ - "absolute path": { - parentFile: "/etc/ssh/ssh_config", - arg: "/etc/ssh/config", - want: "/etc/ssh/config", - }, - "system relative path": { - parentFile: "/etc/ssh/config", - arg: "configs/*.conf", - want: filepath.Join("/etc", "ssh", "configs", "*.conf"), +func TestTranslator_caching(t *testing.T) { + countLookPath := 0 + countNewCommand := 0 + tr := &Translator{ + lookPath: func(s string) (string, error) { + countLookPath++ + return "/path/to/ssh", nil }, - "user relative path": { - parentFile: filepath.Join(dir, ".ssh", "ssh_config"), - arg: "configs/*.conf", - want: filepath.Join(dir, ".ssh", "configs/*.conf"), - }, - "shell-like ~ rerefence": { - parentFile: filepath.Join(dir, ".ssh", "ssh_config"), - arg: "~/.ssh/*.conf", - want: filepath.Join(dir, ".ssh", "*.conf"), + newCommand: func(exe string, args ...string) *exec.Cmd { + args = append([]string{"-test.run=TestHelperProcess", "--", exe}, args...) + c := exec.Command(os.Args[0], args...) + c.Env = []string{"GH_WANT_HELPER_PROCESS=1"} + countNewCommand++ + return c }, } - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - if got := p.absolutePath(tt.parentFile, tt.arg); got != tt.want { - t.Errorf("absolutePath(): %q, wants %q", got, tt.want) - } - }) + u1, err := url.Parse("ssh://github1.com/owner/repo.git") + if err != nil { + t.Fatalf("error parsing URL: %v", err) } -} - -func Test_Translate(t *testing.T) { - m := config{ - aliases: map[string]string{ - "gh": "github.com", - "github.com": "ssh.github.com", - "my.gh.com": "ssh.github.com", - }, + if res := tr.Translate(u1); res.Host != "github.com" { + t.Errorf("expected github.com, got: %q", res.Host) + } + if res := tr.Translate(u1); res.Host != "github.com" { + t.Errorf("expected github.com, got: %q", res.Host) } - cases := [][]string{ - {"ssh://gh/o/r", "ssh://github.com/o/r"}, - {"ssh://github.com/o/r", "ssh://github.com/o/r"}, - {"ssh://my.gh.com", "ssh://github.com"}, - {"https://gh/o/r", "https://gh/o/r"}, + u2, err := url.Parse("ssh://github2.com/owner/repo.git") + if err != nil { + t.Fatalf("error parsing URL: %v", err) + } + if res := tr.Translate(u2); res.Host != "github.com" { + t.Errorf("expected github.com, got: %q", res.Host) + } + if res := tr.Translate(u2); res.Host != "github.com" { + t.Errorf("expected github.com, got: %q", res.Host) } - for _, c := range cases { - u, _ := url.Parse(c[0]) - got := m.Translate(u) - if got.String() != c[1] { - t.Errorf("%q: expected %q, got %q", c[0], c[1], got) - } + if countLookPath != 1 { + t.Errorf("expected lookPath to happen 1 time; actual: %d", countLookPath) + } + if countNewCommand != 2 { + t.Errorf("expected ssh command to shell out 2 times; actual: %d", countNewCommand) } }