Skip to content

Commit

Permalink
only check for availability for IMDS
Browse files Browse the repository at this point in the history
  • Loading branch information
scott-the-programmer committed Apr 22, 2022
1 parent 4722437 commit d253b6c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 33 deletions.
27 changes: 12 additions & 15 deletions autorest/adal/token.go
Expand Up @@ -1317,28 +1317,25 @@ func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTena

// MSIAvailable returns true if the MSI endpoint is available for authentication.
func MSIAvailable(ctx context.Context, s Sender) bool {
_, _, err := getMSIAvailable(ctx, s)
return err == nil
}
msiType, _, err := getMSIType()

// getMSIAvailable returns the response from the resolved MSI endpoint, the MSI endpoint used, and the error returned
func getMSIAvailable(ctx context.Context, s Sender) (*http.Response, string, error) {
if s == nil {
s = sender()
}
msiType, endpoint, err := getMSIType()
if err != nil {
return nil, endpoint, err
return false
}

var apiVersion string
if apiVersion = msiAPIVersion; msiType == msiTypeAppServiceV20170901 {
apiVersion = appServiceAPIVersion2017
if msiType != msiTypeIMDS {
return true
}

resp, err := getMSIEndpoint(ctx, s, endpoint, apiVersion)
if s == nil {
s = sender()
}

resp, err := getMSIEndpoint(ctx, s)

if err == nil {
resp.Body.Close()
}
return resp, endpoint, err

return err == nil
}
6 changes: 3 additions & 3 deletions autorest/adal/token_1.13.go
Expand Up @@ -24,13 +24,13 @@ import (
"time"
)

func getMSIEndpoint(ctx context.Context, sender Sender, endpoint string, apiVersion string) (*http.Response, error) {
func getMSIEndpoint(ctx context.Context, sender Sender) (*http.Response, error) {
tempCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
// http.NewRequestWithContext() was added in Go 1.13
req, _ := http.NewRequestWithContext(tempCtx, http.MethodGet, endpoint, nil)
req, _ := http.NewRequestWithContext(tempCtx, http.MethodGet, msiEndpoint, nil)
q := req.URL.Query()
q.Add("api-version", apiVersion)
q.Add("api-version", msiAPIVersion)
req.URL.RawQuery = q.Encode()
return sender.Do(req)
}
Expand Down
22 changes: 7 additions & 15 deletions autorest/adal/token_test.go
Expand Up @@ -1352,28 +1352,20 @@ func TestMSIAvailableAppService(t *testing.T) {
}()
c := mocks.NewSender()
c.AppendResponse(mocks.NewResponse())
_, endpoint, err := getMSIAvailable(context.Background(), c)
available := MSIAvailable(context.Background(), c)

if err != nil {
t.Fatal("unexpected error")
}

if endpoint != "http://localhost" {
t.Fatal("incorrect endpoint returned")
if !available {
t.Fatal("expected MSI to be available")
}
}

func TestMSIAvailableIMDS(t *testing.T) {
c := mocks.NewSender()
c.AppendResponse(mocks.NewResponse())
_, endpoint, err := getMSIAvailable(context.Background(), c)
available := MSIAvailable(context.Background(), c)

if err != nil {
t.Fatal("unexpected error")
}

if endpoint != msiEndpoint {
t.Fatal("incorrect endpoint returned")
if !available {
t.Fatal("expected MSI to be available")
}
}

Expand All @@ -1393,7 +1385,7 @@ func TestMSIAvailableFail(t *testing.T) {
if MSIAvailable(context.Background(), c) {
t.Fatal("unexpected true")
}
_, err := getMSIEndpoint(context.Background(), c, msiEndpoint, msiAPIVersion)
_, err := getMSIEndpoint(context.Background(), c)
if !strings.Contains(err.Error(), "") {
t.Fatalf("expected error: '%s', but got error '%s'", expectErr, err)
}
Expand Down

0 comments on commit d253b6c

Please sign in to comment.