Skip to content

Commit

Permalink
Handle expires_on in int format (#698)
Browse files Browse the repository at this point in the history
* Handle expires_on in int format

Unmarshal the value into an interface{} and perform the proper
conversion depending on the underlying type.

* remove the cruft

* convert date-time expires_on to Unix time

* fix padding for test

* remove comment
  • Loading branch information
jhendrixMSFT committed May 19, 2022
1 parent e10c7aa commit 7dd32b6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 21 deletions.
31 changes: 20 additions & 11 deletions autorest/adal/token.go
Expand Up @@ -1104,8 +1104,8 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource

// AAD returns expires_in as a string, ADFS returns it as an int
ExpiresIn json.Number `json:"expires_in"`
// expires_on can be in two formats, a UTC time stamp or the number of seconds.
ExpiresOn string `json:"expires_on"`
// expires_on can be in three formats, a UTC time stamp, or the number of seconds as a string *or* int.
ExpiresOn interface{} `json:"expires_on"`
NotBefore json.Number `json:"not_before"`

Resource string `json:"resource"`
Expand All @@ -1118,7 +1118,7 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource
}
expiresOn := json.Number("")
// ADFS doesn't include the expires_on field
if token.ExpiresOn != "" {
if token.ExpiresOn != nil {
if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil {
return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp)
}
Expand All @@ -1135,18 +1135,27 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource
}

// converts expires_on to the number of seconds
func parseExpiresOn(s string) (json.Number, error) {
// convert the expiration date to the number of seconds from now
func parseExpiresOn(s interface{}) (json.Number, error) {
// the JSON unmarshaler treats JSON numbers unmarshaled into an interface{} as float64
asFloat64, ok := s.(float64)
if ok {
// this is the number of seconds as int case
return json.Number(strconv.FormatInt(int64(asFloat64), 10)), nil
}
asStr, ok := s.(string)
if !ok {
return "", fmt.Errorf("unexpected expires_on type %T", s)
}
// convert the expiration date to the number of seconds from the unix epoch
timeToDuration := func(t time.Time) json.Number {
dur := t.Sub(time.Now().UTC())
return json.Number(strconv.FormatInt(int64(dur.Round(time.Second).Seconds()), 10))
return json.Number(strconv.FormatInt(t.UTC().Unix(), 10))
}
if _, err := strconv.ParseInt(s, 10, 64); err == nil {
if _, err := json.Number(asStr).Int64(); err == nil {
// this is the number of seconds case, no conversion required
return json.Number(s), nil
} else if eo, err := time.Parse(expiresOnDateFormatPM, s); err == nil {
return json.Number(asStr), nil
} else if eo, err := time.Parse(expiresOnDateFormatPM, asStr); err == nil {
return timeToDuration(eo), nil
} else if eo, err := time.Parse(expiresOnDateFormat, s); err == nil {
} else if eo, err := time.Parse(expiresOnDateFormat, asStr); err == nil {
return timeToDuration(eo), nil
} else {
// unknown format
Expand Down
59 changes: 49 additions & 10 deletions autorest/adal/token_test.go
Expand Up @@ -88,8 +88,7 @@ func TestTokenWillExpireIn(t *testing.T) {
}

func TestParseExpiresOn(t *testing.T) {
// get current time, round to nearest second, and add one hour
n := time.Now().UTC().Round(time.Second).Add(time.Hour)
n := time.Now().UTC()
amPM := "AM"
if n.Hour() >= 12 {
amPM = "PM"
Expand All @@ -107,12 +106,12 @@ func TestParseExpiresOn(t *testing.T) {
{
Name: "timestamp with AM/PM",
String: fmt.Sprintf("%d/%d/%d %d:%02d:%02d %s +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second(), amPM),
Value: 3600,
Value: n.Unix(),
},
{
Name: "timestamp without AM/PM",
String: fmt.Sprintf("%d/%d/%d %d:%02d:%02d +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second()),
Value: 3600,
String: fmt.Sprintf("%02d/%02d/%02d %02d:%02d:%02d +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second()),
Value: n.Unix(),
},
}
for _, tc := range testcases {
Expand Down Expand Up @@ -368,7 +367,8 @@ func TestServicePrincipalTokenFromASE(t *testing.T) {
}
spt.MaxMSIRefreshAttempts = 1
// expires_on is sent in UTC
expiresOn := time.Now().UTC().Add(time.Hour)
nowTime := time.Now()
expiresOn := nowTime.UTC().Add(time.Hour)
// use int format for expires_in
body := mocks.NewBody(newTokenJSON("3600", expiresOn.Format(expiresOnDateFormat), "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
Expand Down Expand Up @@ -407,10 +407,8 @@ func TestServicePrincipalTokenFromASE(t *testing.T) {
if err != nil {
t.Fatalf("adal: failed to get ExpiresOn %v", err)
}
// depending on elapsed time it might be slightly less that one hour
const hourInSeconds = int64(time.Hour / time.Second)
if v > hourInSeconds || v < hourInSeconds-1 {
t.Fatalf("adal: expected %v, got %v", int64(time.Hour/time.Second), v)
if nowAsUnix := nowTime.Add(time.Hour).Unix(); v != nowAsUnix {
t.Fatalf("adal: expected %v, got %v", nowAsUnix, v)
}
if body.IsOpen() {
t.Fatalf("the response was not closed!")
Expand Down Expand Up @@ -891,6 +889,34 @@ func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
}
}

func TestServicePrincipalTokenEnsureFreshWithIntExpiresOn(t *testing.T) {
spt := newServicePrincipalToken()
expireToken(&spt.inner.Token)

body := mocks.NewBody(newTokenJSONIntExpiresOn(`"3600"`, 12345, "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")

f := false
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.EnsureFresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
}
if !f {
t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
}
}

func TestServicePrincipalTokenEnsureFreshFails1(t *testing.T) {
spt := newServicePrincipalToken()
expireToken(&spt.inner.Token)
Expand Down Expand Up @@ -1461,6 +1487,19 @@ func newTokenJSON(expiresIn, expiresOn, resource string) string {
expiresIn, expiresOn, nb, resource)
}

func newTokenJSONIntExpiresOn(expiresIn string, expiresOn int, resource string) string {
return fmt.Sprintf(`{
"access_token" : "accessToken",
"expires_in" : %s,
"expires_on" : %d,
"not_before" : "%d",
"resource" : "%s",
"token_type" : "Bearer",
"refresh_token": "ABC123"
}`,
expiresIn, expiresOn, expiresOn, resource)
}

func newADFSTokenJSON(expiresIn int) string {
return fmt.Sprintf(`{
"access_token" : "accessToken",
Expand Down

0 comments on commit 7dd32b6

Please sign in to comment.