Skip to content

Commit

Permalink
simplify oauth2 by removing session controller & converting a session…
Browse files Browse the repository at this point in the history
… to a struct (#241)
  • Loading branch information
topi314 committed Feb 20, 2023
1 parent 76fb9f0 commit faa7947
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 175 deletions.
8 changes: 5 additions & 3 deletions _examples/oauth2/example.go
Expand Up @@ -25,6 +25,7 @@ var (
logger = log.Default()
httpClient = http.DefaultClient
client oauth2.Client
sessions map[string]oauth2.Session
)

func init() {
Expand All @@ -49,8 +50,8 @@ func handleRoot(w http.ResponseWriter, r *http.Request) {
var body string
cookie, err := r.Cookie("token")
if err == nil {
session := client.SessionController().GetSession(cookie.Value)
if session != nil {
session, ok := sessions[cookie.Value]
if ok {
var user *discord.OAuth2User
user, err = client.GetUser(session)
if err != nil {
Expand Down Expand Up @@ -100,11 +101,12 @@ func handleTryLogin(w http.ResponseWriter, r *http.Request) {
)
if code != "" && state != "" {
identifier := randStr(32)
_, err := client.StartSession(code, state, identifier)
session, _, err := client.StartSession(code, state)
if err != nil {
writeError(w, "error while starting session", err)
return
}
sessions[identifier] = session
http.SetCookie(w, &http.Cookie{Name: "token", Value: identifier})
}
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
Expand Down
61 changes: 42 additions & 19 deletions oauth2/client.go
Expand Up @@ -3,6 +3,7 @@ package oauth2
import (
"errors"
"fmt"
"time"

"github.com/disgoorg/snowflake/v2"

Expand All @@ -14,49 +15,71 @@ var (
// ErrStateNotFound is returned when the state is not found in the SessionController.
ErrStateNotFound = errors.New("state could not be found")

// ErrAccessTokenExpired is returned when the access token has expired.
ErrAccessTokenExpired = errors.New("access token expired. refresh the session")
// ErrSessionExpired is returned when the Session has expired.
ErrSessionExpired = errors.New("access token expired. refresh the session")

// ErrMissingOAuth2Scope is returned when a specific OAuth2 scope is missing.
ErrMissingOAuth2Scope = func(scope discord.OAuth2Scope) error {
return fmt.Errorf("missing '%s' scope", scope)
}
)

// Session represents a discord access token response (https://discord.com/developers/docs/topics/oauth2#authorization-code-grant-access-token-response)
type Session struct {
// AccessToken allows requesting user information
AccessToken string `json:"access_token"`

// RefreshToken allows refreshing the AccessToken
RefreshToken string `json:"refresh_token"`

// Scopes returns the discord.OAuth2Scope(s) of the Session
Scopes []discord.OAuth2Scope `json:"scope"`

// TokenType returns the discord.TokenType of the AccessToken
TokenType discord.TokenType `json:"token_type"`

// Expiration returns the time.Time when the AccessToken expires and needs to be refreshed
Expiration time.Time `json:"expiration"`
}

func (s Session) Expired() bool {
return s.Expiration.Before(time.Now())
}

// Client is a high level wrapper around Discord's OAuth2 API.
type Client interface {
// ID returns the configured client ID
// ID returns the configured client ID.
ID() snowflake.ID
// Secret returns the configured client secret
// Secret returns the configured client secret.
Secret() string
// Rest returns the underlying rest.OAuth2
// Rest returns the underlying rest.OAuth2.
Rest() rest.OAuth2

// SessionController returns the configured SessionController
SessionController() SessionController
// StateController returns the configured StateController
// StateController returns the configured StateController.
StateController() StateController

// GenerateAuthorizationURL generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes. State is automatically generated
// GenerateAuthorizationURL generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes. State is automatically generated.
GenerateAuthorizationURL(redirectURI string, permissions discord.Permissions, guildID snowflake.ID, disableGuildSelect bool, scopes ...discord.OAuth2Scope) string
// GenerateAuthorizationURLState generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes. State is automatically generated & returned
// GenerateAuthorizationURLState generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes. State is automatically generated & returned.
GenerateAuthorizationURLState(redirectURI string, permissions discord.Permissions, guildID snowflake.ID, disableGuildSelect bool, scopes ...discord.OAuth2Scope) (string, string)

// StartSession starts a new Session with the given authorization code & state
StartSession(code string, state string, identifier string, opts ...rest.RequestOpt) (Session, error)
// RefreshSession refreshes the given Session with the refresh token
RefreshSession(identifier string, session Session, opts ...rest.RequestOpt) (Session, error)
// StartSession starts a new Session with the given authorization code & state.
StartSession(code string, state string, opts ...rest.RequestOpt) (Session, *discord.IncomingWebhook, error)
// RefreshSession refreshes the given Session with the refresh token.
RefreshSession(session Session, opts ...rest.RequestOpt) (Session, error)
// VerifySession verifies the given Session & refreshes it if needed.
VerifySession(session Session, opts ...rest.RequestOpt) (Session, error)

// GetUser returns the discord.OAuth2User associated with the given Session. Fields filled in the struct depend on the Session.Scopes
// GetUser returns the discord.OAuth2User associated with the given Session. Fields filled in the struct depend on the Session.Scopes.
GetUser(session Session, opts ...rest.RequestOpt) (*discord.OAuth2User, error)
// GetMember returns the discord.Member associated with the given Session in a specific guild.
GetMember(session Session, guildID snowflake.ID, opts ...rest.RequestOpt) (*discord.Member, error)
// GetGuilds returns the discord.OAuth2Guild(s) the user is a member of. This requires the discord.OAuth2ScopeGuilds scope in the Session
// GetGuilds returns the discord.OAuth2Guild(s) the user is a member of. This requires the discord.OAuth2ScopeGuilds scope in the Session.
GetGuilds(session Session, opts ...rest.RequestOpt) ([]discord.OAuth2Guild, error)
// GetConnections returns the discord.Connection(s) the user has connected. This requires the discord.OAuth2ScopeConnections scope in the Session
// GetConnections returns the discord.Connection(s) the user has connected. This requires the discord.OAuth2ScopeConnections scope in the Session.
GetConnections(session Session, opts ...rest.RequestOpt) ([]discord.Connection, error)
// GetApplicationRoleConnection returns the discord.ApplicationRoleConnection for the given application. This requires the discord.OAuth2ScopeRoleConnectionsWrite scope in the Session
// GetApplicationRoleConnection returns the discord.ApplicationRoleConnection for the given application. This requires the discord.OAuth2ScopeRoleConnectionsWrite scope in the Session.
GetApplicationRoleConnection(session Session, applicationID snowflake.ID, opts ...rest.RequestOpt) (*discord.ApplicationRoleConnection, error)
// UpdateApplicationRoleConnection updates the discord.ApplicationRoleConnection for the given application. This requires the discord.OAuth2ScopeRoleConnectionsWrite scope in the Session
// UpdateApplicationRoleConnection updates the discord.ApplicationRoleConnection for the given application. This requires the discord.OAuth2ScopeRoleConnectionsWrite scope in the Session.
UpdateApplicationRoleConnection(session Session, applicationID snowflake.ID, update discord.ApplicationRoleConnectionUpdate, opts ...rest.RequestOpt) (*discord.ApplicationRoleConnection, error)
}
58 changes: 36 additions & 22 deletions oauth2/client_impl.go
Expand Up @@ -35,10 +35,6 @@ func (c *clientImpl) Rest() rest.OAuth2 {
return c.config.OAuth2
}

func (c *clientImpl) SessionController() SessionController {
return c.config.SessionController
}

func (c *clientImpl) StateController() StateController {
return c.config.StateController
}
Expand Down Expand Up @@ -70,74 +66,92 @@ func (c *clientImpl) GenerateAuthorizationURLState(redirectURI string, permissio
return discord.AuthorizeURL(values), state
}

func (c *clientImpl) StartSession(code string, state string, identifier string, opts ...rest.RequestOpt) (Session, error) {
func (c *clientImpl) StartSession(code string, state string, opts ...rest.RequestOpt) (Session, *discord.IncomingWebhook, error) {
redirectURI := c.StateController().ConsumeState(state)
if redirectURI == "" {
return nil, ErrStateNotFound
return Session{}, nil, ErrStateNotFound
}
exchange, err := c.Rest().GetAccessToken(c.id, c.secret, code, redirectURI, opts...)
accessToken, err := c.Rest().GetAccessToken(c.id, c.secret, code, redirectURI, opts...)
if err != nil {
return nil, err
return Session{}, nil, err
}
return c.SessionController().CreateSessionFromResponse(identifier, *exchange), nil

return newSession(*accessToken), accessToken.Webhook, nil
}

func (c *clientImpl) RefreshSession(identifier string, session Session, opts ...rest.RequestOpt) (Session, error) {
exchange, err := c.Rest().RefreshAccessToken(c.id, c.secret, session.RefreshToken(), opts...)
func (c *clientImpl) RefreshSession(session Session, opts ...rest.RequestOpt) (Session, error) {
accessToken, err := c.Rest().RefreshAccessToken(c.id, c.secret, session.RefreshToken, opts...)
if err != nil {
return nil, err
return Session{}, err
}
return c.SessionController().CreateSessionFromResponse(identifier, *exchange), nil
return newSession(*accessToken), nil
}

func (c *clientImpl) VerifySession(session Session, opts ...rest.RequestOpt) (Session, error) {
if session.Expired() {
return c.RefreshSession(session, opts...)
}
return session, nil
}

func (c *clientImpl) GetUser(session Session, opts ...rest.RequestOpt) (*discord.OAuth2User, error) {
if err := checkSession(session, discord.OAuth2ScopeIdentify); err != nil {
return nil, err
}
return c.Rest().GetCurrentUser(session.AccessToken(), opts...)
return c.Rest().GetCurrentUser(session.AccessToken, opts...)
}

func (c *clientImpl) GetMember(session Session, guildID snowflake.ID, opts ...rest.RequestOpt) (*discord.Member, error) {
if err := checkSession(session, discord.OAuth2ScopeGuildsMembersRead); err != nil {
return nil, err
}
return c.Rest().GetCurrentMember(session.AccessToken(), guildID, opts...)
return c.Rest().GetCurrentMember(session.AccessToken, guildID, opts...)
}

func (c *clientImpl) GetGuilds(session Session, opts ...rest.RequestOpt) ([]discord.OAuth2Guild, error) {
if err := checkSession(session, discord.OAuth2ScopeGuilds); err != nil {
return nil, err
}
return c.Rest().GetCurrentUserGuilds(session.AccessToken(), 0, 0, 0, opts...)
return c.Rest().GetCurrentUserGuilds(session.AccessToken, 0, 0, 0, opts...)
}

func (c *clientImpl) GetConnections(session Session, opts ...rest.RequestOpt) ([]discord.Connection, error) {
if err := checkSession(session, discord.OAuth2ScopeConnections); err != nil {
return nil, err
}
return c.Rest().GetCurrentUserConnections(session.AccessToken(), opts...)
return c.Rest().GetCurrentUserConnections(session.AccessToken, opts...)
}

func (c *clientImpl) GetApplicationRoleConnection(session Session, applicationID snowflake.ID, opts ...rest.RequestOpt) (*discord.ApplicationRoleConnection, error) {
if err := checkSession(session, discord.OAuth2ScopeRoleConnectionsWrite); err != nil {
return nil, err
}
return c.Rest().GetCurrentUserApplicationRoleConnection(session.AccessToken(), applicationID, opts...)
return c.Rest().GetCurrentUserApplicationRoleConnection(session.AccessToken, applicationID, opts...)
}

func (c *clientImpl) UpdateApplicationRoleConnection(session Session, applicationID snowflake.ID, update discord.ApplicationRoleConnectionUpdate, opts ...rest.RequestOpt) (*discord.ApplicationRoleConnection, error) {
if err := checkSession(session, discord.OAuth2ScopeRoleConnectionsWrite); err != nil {
return nil, err
}
return c.Rest().UpdateCurrentUserApplicationRoleConnection(session.AccessToken(), applicationID, update, opts...)
return c.Rest().UpdateCurrentUserApplicationRoleConnection(session.AccessToken, applicationID, update, opts...)
}

func checkSession(session Session, scope discord.OAuth2Scope) error {
if session.Expiration().Before(time.Now()) {
return ErrAccessTokenExpired
if session.Expired() {
return ErrSessionExpired
}
if !discord.HasScope(scope, session.Scopes()...) {
if !discord.HasScope(scope, session.Scopes...) {
return ErrMissingOAuth2Scope(scope)
}
return nil
}

func newSession(accessToken discord.AccessTokenResponse) Session {
return Session{
AccessToken: accessToken.AccessToken,
RefreshToken: accessToken.RefreshToken,
Scopes: accessToken.Scope,
TokenType: accessToken.TokenType,
Expiration: time.Now().Add(accessToken.ExpiresIn * time.Second),
}
}
11 changes: 1 addition & 10 deletions oauth2/config.go
Expand Up @@ -9,8 +9,7 @@ import (
// DefaultConfig is the configuration which is used by default
func DefaultConfig() *Config {
return &Config{
Logger: log.Default(),
SessionController: NewSessionController(),
Logger: log.Default(),
}
}

Expand All @@ -20,7 +19,6 @@ type Config struct {
RestClient rest.Client
RestClientConfigOpts []rest.ConfigOpt
OAuth2 rest.OAuth2
SessionController SessionController
StateController StateController
StateControllerConfigOpts []StateControllerConfigOpt
}
Expand Down Expand Up @@ -72,13 +70,6 @@ func WithOAuth2(oauth2 rest.OAuth2) ConfigOpt {
}
}

// WithSessionController applies a custom SessionController to the OAuth2 client
func WithSessionController(sessionController SessionController) ConfigOpt {
return func(config *Config) {
config.SessionController = sessionController
}
}

// WithStateController applies a custom StateController to the OAuth2 client
func WithStateController(stateController StateController) ConfigOpt {
return func(config *Config) {
Expand Down
63 changes: 0 additions & 63 deletions oauth2/session.go

This file was deleted.

0 comments on commit faa7947

Please sign in to comment.