Skip to content

Commit

Permalink
use MSI_ENDPOINT when required for availability
Browse files Browse the repository at this point in the history
  • Loading branch information
scott-murray-zip committed Feb 24, 2022
1 parent b3899c1 commit d0b73fb
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
20 changes: 18 additions & 2 deletions autorest/adal/token.go
Expand Up @@ -1317,12 +1317,28 @@ func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTena

// MSIAvailable returns true if the MSI endpoint is available for authentication.
func MSIAvailable(ctx context.Context, s Sender) bool {
_, _, err := getMSIAvailable(ctx, s)
return err == nil
}

// getMSIAvailable returns the response from the resolved MSI endpoint, the MSI endpoint used, and the error returned
func getMSIAvailable(ctx context.Context, s Sender) (*http.Response, string, error) {
if s == nil {
s = sender()
}
resp, err := getMSIEndpoint(ctx, s)
msiType, endpoint, err := getMSIType()
if err != nil {
return nil, endpoint, err
}

var apiVersion string
if apiVersion = msiAPIVersion; msiType == msiTypeAppServiceV20170901 {
apiVersion = appServiceAPIVersion2017
}

resp, err := getMSIEndpoint(ctx, s, endpoint, apiVersion)
if err == nil {
resp.Body.Close()
}
return err == nil
return resp, endpoint, err
}
6 changes: 3 additions & 3 deletions autorest/adal/token_1.13.go
Expand Up @@ -24,13 +24,13 @@ import (
"time"
)

func getMSIEndpoint(ctx context.Context, sender Sender) (*http.Response, error) {
func getMSIEndpoint(ctx context.Context, sender Sender, endpoint string, apiVersion string) (*http.Response, error) {
tempCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
// http.NewRequestWithContext() was added in Go 1.13
req, _ := http.NewRequestWithContext(tempCtx, http.MethodGet, msiEndpoint, nil)
req, _ := http.NewRequestWithContext(tempCtx, http.MethodGet, endpoint, nil)
q := req.URL.Query()
q.Add("api-version", msiAPIVersion)
q.Add("api-version", apiVersion)
req.URL.RawQuery = q.Encode()
return sender.Do(req)
}
Expand Down
36 changes: 35 additions & 1 deletion autorest/adal/token_test.go
Expand Up @@ -1343,6 +1343,40 @@ func TestMSIAvailableSuccess(t *testing.T) {
}
}

func TestMSIAvailableAppService(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "http://localhost")
os.Setenv("MSI_SECRET", "super")
defer func() {
os.Unsetenv("MSI_ENDPOINT")
os.Unsetenv("MSI_SECRET")
}()
c := mocks.NewSender()
c.AppendResponse(mocks.NewResponse())
_, endpoint, err := getMSIAvailable(context.Background(), c)

if err != nil {
t.Fatal("unexpected error")
}

if endpoint != "http://localhost" {
t.Fatal("incorrect endpoint returned")
}
}

func TestMSIAvailableIMDS(t *testing.T) {
c := mocks.NewSender()
c.AppendResponse(mocks.NewResponse())
_, endpoint, err := getMSIAvailable(context.Background(), c)

if err != nil {
t.Fatal("unexpected error")
}

if endpoint != msiEndpoint {
t.Fatal("incorrect endpoint returned")
}
}

func TestMSIAvailableSlow(t *testing.T) {
c := mocks.NewSender()
// introduce a long response delay to simulate the endpoint not being available
Expand All @@ -1359,7 +1393,7 @@ func TestMSIAvailableFail(t *testing.T) {
if MSIAvailable(context.Background(), c) {
t.Fatal("unexpected true")
}
_, err := getMSIEndpoint(context.Background(), c)
_, err := getMSIEndpoint(context.Background(), c, msiEndpoint, msiAPIVersion)
if !strings.Contains(err.Error(), "") {
t.Fatalf("expected error: '%s', but got error '%s'", expectErr, err)
}
Expand Down

0 comments on commit d0b73fb

Please sign in to comment.