diff --git a/plumbing/transport/common.go b/plumbing/transport/common.go index b993c4e9f..a9ee2caee 100644 --- a/plumbing/transport/common.go +++ b/plumbing/transport/common.go @@ -58,6 +58,11 @@ type Session interface { // If the repository does not exist, returns ErrRepositoryNotFound. // If the repository exists, but is empty, returns ErrEmptyRemoteRepository. AdvertisedReferences() (*packp.AdvRefs, error) + // AdvertisedReferencesContext retrieves the advertised references for a + // repository. + // If the repository does not exist, returns ErrRepositoryNotFound. + // If the repository exists, but is empty, returns ErrEmptyRemoteRepository. + AdvertisedReferencesContext(context.Context) (*packp.AdvRefs, error) io.Closer } diff --git a/plumbing/transport/http/common.go b/plumbing/transport/http/common.go index aeedc5bb5..d57c0feef 100644 --- a/plumbing/transport/http/common.go +++ b/plumbing/transport/http/common.go @@ -3,6 +3,7 @@ package http import ( "bytes" + "context" "fmt" "net" "net/http" @@ -32,7 +33,7 @@ func applyHeadersToRequest(req *http.Request, content *bytes.Buffer, host string const infoRefsPath = "/info/refs" -func advertisedReferences(s *session, serviceName string) (ref *packp.AdvRefs, err error) { +func advertisedReferences(ctx context.Context, s *session, serviceName string) (ref *packp.AdvRefs, err error) { url := fmt.Sprintf( "%s%s?service=%s", s.endpoint.String(), infoRefsPath, serviceName, @@ -45,7 +46,7 @@ func advertisedReferences(s *session, serviceName string) (ref *packp.AdvRefs, e s.ApplyAuthToRequest(req) applyHeadersToRequest(req, nil, s.endpoint.Host, serviceName) - res, err := s.client.Do(req) + res, err := s.client.Do(req.WithContext(ctx)) if err != nil { return nil, err } diff --git a/plumbing/transport/http/receive_pack.go b/plumbing/transport/http/receive_pack.go index 433dfcfda..4d14ff21e 100644 --- a/plumbing/transport/http/receive_pack.go +++ b/plumbing/transport/http/receive_pack.go @@ -25,7 +25,11 @@ func newReceivePackSession(c *http.Client, ep *transport.Endpoint, auth transpor } func (s *rpSession) AdvertisedReferences() (*packp.AdvRefs, error) { - return advertisedReferences(s.session, transport.ReceivePackServiceName) + return advertisedReferences(context.TODO(), s.session, transport.ReceivePackServiceName) +} + +func (s *rpSession) AdvertisedReferencesContext(ctx context.Context) (*packp.AdvRefs, error) { + return advertisedReferences(ctx, s.session, transport.ReceivePackServiceName) } func (s *rpSession) ReceivePack(ctx context.Context, req *packp.ReferenceUpdateRequest) ( diff --git a/plumbing/transport/http/upload_pack.go b/plumbing/transport/http/upload_pack.go index db3708940..e735b3d7c 100644 --- a/plumbing/transport/http/upload_pack.go +++ b/plumbing/transport/http/upload_pack.go @@ -25,7 +25,11 @@ func newUploadPackSession(c *http.Client, ep *transport.Endpoint, auth transport } func (s *upSession) AdvertisedReferences() (*packp.AdvRefs, error) { - return advertisedReferences(s.session, transport.UploadPackServiceName) + return advertisedReferences(context.TODO(), s.session, transport.UploadPackServiceName) +} + +func (s *upSession) AdvertisedReferencesContext(ctx context.Context) (*packp.AdvRefs, error) { + return advertisedReferences(ctx, s.session, transport.UploadPackServiceName) } func (s *upSession) UploadPack( diff --git a/plumbing/transport/internal/common/common.go b/plumbing/transport/internal/common/common.go index 89432e34c..75405c72f 100644 --- a/plumbing/transport/internal/common/common.go +++ b/plumbing/transport/internal/common/common.go @@ -162,14 +162,18 @@ func (c *client) listenFirstError(r io.Reader) chan string { return errLine } -// AdvertisedReferences retrieves the advertised references from the server. func (s *session) AdvertisedReferences() (*packp.AdvRefs, error) { + return s.AdvertisedReferencesContext(context.TODO()) +} + +// AdvertisedReferences retrieves the advertised references from the server. +func (s *session) AdvertisedReferencesContext(ctx context.Context) (*packp.AdvRefs, error) { if s.advRefs != nil { return s.advRefs, nil } ar := packp.NewAdvRefs() - if err := ar.Decode(s.Stdout); err != nil { + if err := ar.Decode(s.StdoutContext(ctx)); err != nil { if err := s.handleAdvRefDecodeError(err); err != nil { return nil, err } @@ -237,7 +241,7 @@ func (s *session) UploadPack(ctx context.Context, req *packp.UploadPackRequest) return nil, err } - if _, err := s.AdvertisedReferences(); err != nil { + if _, err := s.AdvertisedReferencesContext(ctx); err != nil { return nil, err } diff --git a/plumbing/transport/server/server.go b/plumbing/transport/server/server.go index 727f90215..6f89ec397 100644 --- a/plumbing/transport/server/server.go +++ b/plumbing/transport/server/server.go @@ -108,6 +108,10 @@ type upSession struct { } func (s *upSession) AdvertisedReferences() (*packp.AdvRefs, error) { + return s.AdvertisedReferencesContext(context.TODO()) +} + +func (s *upSession) AdvertisedReferencesContext(ctx context.Context) (*packp.AdvRefs, error) { ar := packp.NewAdvRefs() if err := s.setSupportedCapabilities(ar.Capabilities); err != nil { @@ -204,6 +208,10 @@ type rpSession struct { } func (s *rpSession) AdvertisedReferences() (*packp.AdvRefs, error) { + return s.AdvertisedReferencesContext(context.TODO()) +} + +func (s *rpSession) AdvertisedReferencesContext(ctx context.Context) (*packp.AdvRefs, error) { ar := packp.NewAdvRefs() if err := s.setSupportedCapabilities(ar.Capabilities); err != nil { diff --git a/remote.go b/remote.go index 66ba71edc..382545b0f 100644 --- a/remote.go +++ b/remote.go @@ -109,7 +109,7 @@ func (r *Remote) PushContext(ctx context.Context, o *PushOptions) (err error) { defer ioutil.CheckClose(s, &err) - ar, err := s.AdvertisedReferences() + ar, err := s.AdvertisedReferencesContext(ctx) if err != nil { return err } @@ -316,7 +316,7 @@ func (r *Remote) fetch(ctx context.Context, o *FetchOptions) (sto storer.Referen defer ioutil.CheckClose(s, &err) - ar, err := s.AdvertisedReferences() + ar, err := s.AdvertisedReferencesContext(ctx) if err != nil { return nil, err } @@ -1034,7 +1034,7 @@ func (r *Remote) List(o *ListOptions) (rfs []*plumbing.Reference, err error) { defer ioutil.CheckClose(s, &err) - ar, err := s.AdvertisedReferences() + ar, err := s.AdvertisedReferencesContext(context.TODO()) if err != nil { return nil, err } diff --git a/remote_test.go b/remote_test.go index 3446f1a86..420a29784 100644 --- a/remote_test.go +++ b/remote_test.go @@ -162,7 +162,7 @@ func (s *RemoteSuite) TestFetchContext(c *C) { config.RefSpec("+refs/heads/master:refs/remotes/origin/master"), }, }) - c.Assert(err, NotNil) + c.Assert(err, Equals, context.Canceled) } func (s *RemoteSuite) TestFetchWithAllTags(c *C) { @@ -486,7 +486,7 @@ func (s *RemoteSuite) TestPushContext(c *C) { err = r.PushContext(ctx, &PushOptions{ RefSpecs: []config.RefSpec{"refs/tags/*:refs/tags/*"}, }) - c.Assert(err, NotNil) + c.Assert(err, Equals, context.Canceled) // let the goroutine from pushHashes finish and check that the number of // goroutines is the same as before diff --git a/repository_test.go b/repository_test.go index 7d4ddce74..0239a2bb1 100644 --- a/repository_test.go +++ b/repository_test.go @@ -187,7 +187,7 @@ func (s *RepositorySuite) TestCloneContext(c *C) { }) c.Assert(r, NotNil) - c.Assert(err, ErrorMatches, ".* context canceled") + c.Assert(err, Equals, context.Canceled) } func (s *RepositorySuite) TestCloneWithTags(c *C) { @@ -655,12 +655,12 @@ func (s *RepositorySuite) TestPlainCloneContextCancel(c *C) { }) c.Assert(r, NotNil) - c.Assert(err, ErrorMatches, ".* context canceled") + c.Assert(err, Equals, context.Canceled) } func (s *RepositorySuite) TestPlainCloneContextNonExistentWithExistentDir(c *C) { ctx, cancel := context.WithCancel(context.Background()) - cancel() + defer cancel() tmpDir := c.MkDir() repoDir := tmpDir @@ -681,7 +681,7 @@ func (s *RepositorySuite) TestPlainCloneContextNonExistentWithExistentDir(c *C) func (s *RepositorySuite) TestPlainCloneContextNonExistentWithNonExistentDir(c *C) { ctx, cancel := context.WithCancel(context.Background()) - cancel() + defer cancel() tmpDir := c.MkDir() repoDir := filepath.Join(tmpDir, "repoDir") @@ -719,7 +719,7 @@ func (s *RepositorySuite) TestPlainCloneContextNonExistentWithNotDir(c *C) { func (s *RepositorySuite) TestPlainCloneContextNonExistentWithNotEmptyDir(c *C) { ctx, cancel := context.WithCancel(context.Background()) - cancel() + defer cancel() tmpDir := c.MkDir() repoDirPath := filepath.Join(tmpDir, "repoDir") @@ -743,7 +743,7 @@ func (s *RepositorySuite) TestPlainCloneContextNonExistentWithNotEmptyDir(c *C) func (s *RepositorySuite) TestPlainCloneContextNonExistingOverExistingGitDirectory(c *C) { ctx, cancel := context.WithCancel(context.Background()) - cancel() + defer cancel() tmpDir := c.MkDir() r, err := PlainInit(tmpDir, false)