diff --git a/integration_tests/samples/oauth/oauth_v2_async.py b/integration_tests/samples/oauth/oauth_v2_async.py index 1f6b9b7e6..21a2a1d38 100644 --- a/integration_tests/samples/oauth/oauth_v2_async.py +++ b/integration_tests/samples/oauth/oauth_v2_async.py @@ -116,7 +116,7 @@ async def oauth_callback(req: Request): body=html, ) - error = req.args["error"] if "error" in req.args else "" + error = req.args.get("error") if "error" in req.args else "" return HTTPResponse( status=400, body=f"Something is wrong with the installation (error: {error})" ) @@ -143,10 +143,10 @@ async def slack_app(req: Request): ): return HTTPResponse(status=403, body="invalid request") - if "command" in req.form and req.form["command"] == "/open-modal": + if "command" in req.form and req.form.get("command") == "/open-modal": try: enterprise_id = req.form.get("enterprise_id") - team_id = req.form["team_id"] + team_id = req.form.get("team_id") bot = installation_store.find_bot( enterprise_id=enterprise_id, team_id=team_id, @@ -157,7 +157,7 @@ async def slack_app(req: Request): client = AsyncWebClient(token=bot_token) await client.views_open( - trigger_id=req.form["trigger_id"], + trigger_id=req.form.get("trigger_id"), view={ "type": "modal", "callback_id": "modal-id", @@ -188,12 +188,12 @@ async def slack_app(req: Request): ) elif "payload" in req.form: - payload = json.loads(req.form["payload"]) + payload = json.loads(req.form.get("payload")) if ( - payload["type"] == "view_submission" - and payload["view"]["callback_id"] == "modal-id" + payload.get("type") == "view_submission" + and payload.get("view").get("callback_id") == "modal-id" ): - submitted_data = payload["view"]["state"]["values"] + submitted_data = payload.get("view").get("state").get("values") print( submitted_data ) # {'b-id': {'a-id': {'type': 'plain_text_input', 'value': 'your input'}}} @@ -203,9 +203,8 @@ async def slack_app(req: Request): if __name__ == "__main__": - # export SLACK_TEST_CLIENT_ID=123.123 - # export SLACK_TEST_CLIENT_SECRET=xxx - # export SLACK_TEST_REDIRECT_URI=https://{yours}.ngrok.io/slack/oauth/callback + # export SLACK_CLIENT_ID=123.123 + # export SLACK_CLIENT_SECRET=xxx # export SLACK_SIGNING_SECRET=*** app.run(host="0.0.0.0", port=3000) diff --git a/integration_tests/samples/token_rotation/.gitignore b/integration_tests/samples/token_rotation/.gitignore new file mode 100644 index 000000000..e6905a239 --- /dev/null +++ b/integration_tests/samples/token_rotation/.gitignore @@ -0,0 +1 @@ +.env* \ No newline at end of file diff --git a/integration_tests/samples/token_rotation/oauth.py b/integration_tests/samples/token_rotation/oauth.py new file mode 100644 index 000000000..df216bc50 --- /dev/null +++ b/integration_tests/samples/token_rotation/oauth.py @@ -0,0 +1,274 @@ +# --------------------- +# Flask App for Slack OAuth flow +# --------------------- + +# pip3 install flask +from typing import Optional + +from integration_tests.samples.token_rotation.util import ( + parse_body, + extract_enterprise_id, + extract_user_id, + extract_team_id, + extract_is_enterprise_install, + extract_content_type, +) + +import logging +import os +from slack_sdk.web import WebClient +from slack_sdk.oauth.token_rotation import TokenRotator +from slack_sdk.oauth import AuthorizeUrlGenerator, RedirectUriPageRenderer +from slack_sdk.oauth.installation_store import FileInstallationStore, Installation +from slack_sdk.oauth.state_store import FileOAuthStateStore + +client_id = os.environ["SLACK_CLIENT_ID"] +client_secret = os.environ["SLACK_CLIENT_SECRET"] +scopes = ["app_mentions:read", "chat:write", "commands"] +user_scopes = ["search:read"] + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + +state_store = FileOAuthStateStore(expiration_seconds=300) +installation_store = FileInstallationStore() +token_rotator = TokenRotator( + client_id=client_id, + client_secret=client_secret, +) + +# --------------------- +# Flask App for Slack events +# --------------------- + +import json +from slack_sdk.errors import SlackApiError +from slack_sdk.signature import SignatureVerifier + +signing_secret = os.environ["SLACK_SIGNING_SECRET"] +signature_verifier = SignatureVerifier(signing_secret=signing_secret) + + +def rotate_tokens( + enterprise_id: Optional[str] = None, + team_id: Optional[str] = None, + user_id: Optional[str] = None, + is_enterprise_install: Optional[bool] = None, +): + installation = installation_store.find_installation( + enterprise_id=enterprise_id, + team_id=team_id, + user_id=user_id, + is_enterprise_install=is_enterprise_install, + ) + if installation is not None: + updated_installation = token_rotator.perform_token_rotation( + installation=installation, + minutes_before_expiration=60 * 24 * 365, # one year for testing + ) + if updated_installation is not None: + installation_store.save(updated_installation) + + +from flask import Flask, request, make_response + +app = Flask(__name__) +app.debug = True + + +@app.route("/slack/events", methods=["POST"]) +def slack_app(): + if not signature_verifier.is_valid( + body=request.get_data(), + timestamp=request.headers.get("X-Slack-Request-Timestamp"), + signature=request.headers.get("X-Slack-Signature"), + ): + return make_response("invalid request", 403) + + raw_body = request.data.decode("utf-8") + body = parse_body(body=raw_body, content_type=extract_content_type(request.headers)) + rotate_tokens( + enterprise_id=extract_enterprise_id(body), + team_id=extract_team_id(body), + user_id=extract_user_id(body), + is_enterprise_install=extract_is_enterprise_install(body), + ) + + if "command" in request.form and request.form["command"] == "/token-rotation-modal": + try: + enterprise_id = request.form.get("enterprise_id") + team_id = request.form["team_id"] + bot = installation_store.find_bot( + enterprise_id=enterprise_id, + team_id=team_id, + ) + bot_token = bot.bot_token if bot else None + if not bot_token: + return make_response("Please install this app first!", 200) + + client = WebClient(token=bot_token) + trigger_id = request.form["trigger_id"] + response = client.views_open( + trigger_id=trigger_id, + view={ + "type": "modal", + "callback_id": "modal-id", + "title": {"type": "plain_text", "text": "Awesome Modal"}, + "submit": {"type": "plain_text", "text": "Submit"}, + "close": {"type": "plain_text", "text": "Cancel"}, + "blocks": [ + { + "type": "input", + "block_id": "b-id", + "label": { + "type": "plain_text", + "text": "Input label", + }, + "element": { + "action_id": "a-id", + "type": "plain_text_input", + }, + } + ], + }, + ) + return make_response("", 200) + except SlackApiError as e: + code = e.response["error"] + return make_response(f"Failed to open a modal due to {code}", 200) + + elif "payload" in request.form: + payload = json.loads(request.form["payload"]) + if ( + payload["type"] == "view_submission" + and payload["view"]["callback_id"] == "modal-id" + ): + submitted_data = payload["view"]["state"]["values"] + print( + submitted_data + ) # {'b-id': {'a-id': {'type': 'plain_text_input', 'value': 'your input'}}} + return make_response("", 200) + + else: + if raw_body.startswith("{"): + event_payload = json.loads(raw_body) + logger.info(f"Events API payload: {event_payload}") + if event_payload.get("type") == "url_verification": + return make_response(event_payload.get("challenge"), 200) + return make_response("", 200) + + return make_response("", 404) + + +# --------------------- +# Flask App for Slack OAuth flow +# --------------------- + +authorization_url_generator = AuthorizeUrlGenerator( + client_id=client_id, + scopes=scopes, + user_scopes=user_scopes, +) +redirect_page_renderer = RedirectUriPageRenderer( + install_path="/slack/install", + redirect_uri_path="/slack/oauth_redirect", +) + + +@app.route("/slack/install", methods=["GET"]) +def oauth_start(): + state = state_store.issue() + url = authorization_url_generator.generate(state) + return ( + '' + f'' + f'' + "" + ) + + +@app.route("/slack/oauth_redirect", methods=["GET"]) +def oauth_callback(): + # Retrieve the auth code and state from the request params + if "code" in request.args: + state = request.args["state"] + if state_store.consume(state): + code = request.args["code"] + client = WebClient() # no prepared token needed for this app + oauth_response = client.oauth_v2_access( + client_id=client_id, client_secret=client_secret, code=code + ) + logger.info(f"oauth.v2.access response: {oauth_response}") + + installed_enterprise = oauth_response.get("enterprise", {}) + is_enterprise_install = oauth_response.get("is_enterprise_install") + installed_team = oauth_response.get("team", {}) + installer = oauth_response.get("authed_user", {}) + incoming_webhook = oauth_response.get("incoming_webhook", {}) + + bot_token = oauth_response.get("access_token") + # NOTE: oauth.v2.access doesn't include bot_id in response + bot_id = None + enterprise_url = None + if bot_token is not None: + auth_test = client.auth_test(token=bot_token) + bot_id = auth_test["bot_id"] + if is_enterprise_install is True: + enterprise_url = auth_test.get("url") + + installation = Installation( + app_id=oauth_response.get("app_id"), + enterprise_id=installed_enterprise.get("id"), + enterprise_name=installed_enterprise.get("name"), + enterprise_url=enterprise_url, + team_id=installed_team.get("id"), + team_name=installed_team.get("name"), + bot_token=bot_token, + bot_id=bot_id, + bot_user_id=oauth_response.get("bot_user_id"), + bot_scopes=oauth_response.get("scope"), # comma-separated string + bot_refresh_token=oauth_response.get("refresh_token"), + bot_token_expires_in=oauth_response.get("expires_in"), + user_id=installer.get("id"), + user_token=installer.get("access_token"), + user_scopes=installer.get("scope"), # comma-separated string + user_refresh_token=installer.get("refresh_token"), + user_token_expires_in=installer.get("expires_in"), + incoming_webhook_url=incoming_webhook.get("url"), + incoming_webhook_channel=incoming_webhook.get("channel"), + incoming_webhook_channel_id=incoming_webhook.get("channel_id"), + incoming_webhook_configuration_url=incoming_webhook.get( + "configuration_url" + ), + is_enterprise_install=is_enterprise_install, + token_type=oauth_response.get("token_type"), + ) + installation_store.save(installation) + return redirect_page_renderer.render_success_page( + app_id=installation.app_id, + team_id=installation.team_id, + is_enterprise_install=installation.is_enterprise_install, + enterprise_url=installation.enterprise_url, + ) + else: + return redirect_page_renderer.render_failure_page( + "the state value is already expired" + ) + + error = request.args["error"] if "error" in request.args else "" + return make_response( + f"Something is wrong with the installation (error: {error})", 400 + ) + + +if __name__ == "__main__": + # export SLACK_CLIENT_ID=123.123 + # export SLACK_CLIENT_SECRET=xxx + # export SLACK_SIGNING_SECRET=*** + # export FLASK_ENV=development + + app.run("localhost", 3000) + + # python3 integration_tests/samples/token_rotation/oauth.py + # ngrok http 3000 + # https://{yours}.ngrok.io/slack/oauth/start diff --git a/integration_tests/samples/token_rotation/oauth_async.py b/integration_tests/samples/token_rotation/oauth_async.py new file mode 100644 index 000000000..27661f906 --- /dev/null +++ b/integration_tests/samples/token_rotation/oauth_async.py @@ -0,0 +1,276 @@ +# --------------------- +# Sanic App for Slack OAuth flow +# --------------------- + +from typing import Optional + +from integration_tests.samples.token_rotation.util import ( + parse_body, + extract_enterprise_id, + extract_user_id, + extract_team_id, + extract_is_enterprise_install, + extract_content_type, +) + +import logging +import os +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.oauth.token_rotation.async_rotator import AsyncTokenRotator +from slack_sdk.oauth import AuthorizeUrlGenerator, RedirectUriPageRenderer +from slack_sdk.oauth.installation_store import FileInstallationStore, Installation +from slack_sdk.oauth.state_store import FileOAuthStateStore + +client_id = os.environ["SLACK_CLIENT_ID"] +client_secret = os.environ["SLACK_CLIENT_SECRET"] +scopes = ["app_mentions:read", "chat:write", "commands"] +user_scopes = ["search:read"] + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + +state_store = FileOAuthStateStore(expiration_seconds=300) +installation_store = FileInstallationStore() +token_rotator = AsyncTokenRotator( + client_id=client_id, + client_secret=client_secret, +) + +# --------------------- +# Sanic App for Slack events +# --------------------- + +import json +from slack_sdk.errors import SlackApiError +from slack_sdk.signature import SignatureVerifier + +signing_secret = os.environ["SLACK_SIGNING_SECRET"] +signature_verifier = SignatureVerifier(signing_secret=signing_secret) + + +async def rotate_tokens( + enterprise_id: Optional[str] = None, + team_id: Optional[str] = None, + user_id: Optional[str] = None, + is_enterprise_install: Optional[bool] = None, +): + installation = await installation_store.async_find_installation( + enterprise_id=enterprise_id, + team_id=team_id, + user_id=user_id, + is_enterprise_install=is_enterprise_install, + ) + if installation is not None: + updated_installation = await token_rotator.perform_token_rotation( + installation=installation, + minutes_before_expiration=60 * 24 * 365, # one year for testing + ) + if updated_installation is not None: + await installation_store.async_save(updated_installation) + + +# https://sanicframework.org/ +from sanic import Sanic +from sanic.request import Request +from sanic.response import HTTPResponse + +app = Sanic("my-awesome-slack-app") + + +@app.post("/slack/events") +async def slack_app(req: Request): + if not signature_verifier.is_valid( + body=req.body.decode("utf-8"), + timestamp=req.headers.get("X-Slack-Request-Timestamp"), + signature=req.headers.get("X-Slack-Signature"), + ): + return HTTPResponse(status=403, body="invalid request") + + raw_body = req.body.decode("utf-8") + body = parse_body(body=raw_body, content_type=extract_content_type(req.headers)) + await rotate_tokens( + enterprise_id=extract_enterprise_id(body), + team_id=extract_team_id(body), + user_id=extract_user_id(body), + is_enterprise_install=extract_is_enterprise_install(body), + ) + + if "command" in req.form and req.form.get("command") == "/token-rotation-modal": + try: + enterprise_id = req.form.get("enterprise_id") + team_id = req.form.get("team_id") + bot = await installation_store.async_find_bot( + enterprise_id=enterprise_id, + team_id=team_id, + ) + bot_token = bot.bot_token if bot else None + if not bot_token: + return HTTPResponse(status=200, body="Please install this app first!") + + client = AsyncWebClient(token=bot_token) + await client.views_open( + trigger_id=req.form.get("trigger_id"), + view={ + "type": "modal", + "callback_id": "modal-id", + "title": {"type": "plain_text", "text": "Awesome Modal"}, + "submit": {"type": "plain_text", "text": "Submit"}, + "close": {"type": "plain_text", "text": "Cancel"}, + "blocks": [ + { + "type": "input", + "block_id": "b-id", + "label": { + "type": "plain_text", + "text": "Input label", + }, + "element": { + "action_id": "a-id", + "type": "plain_text_input", + }, + } + ], + }, + ) + return HTTPResponse(status=200, body="") + except SlackApiError as e: + code = e.response["error"] + return HTTPResponse( + status=200, body=f"Failed to open a modal due to {code}" + ) + + elif "payload" in req.form: + payload = json.loads(req.form.get("payload")) + if ( + payload.get("type") == "view_submission" + and payload.get("view").get("callback_id") == "modal-id" + ): + submitted_data = payload.get("view").get("state").get("values") + print( + submitted_data + ) # {'b-id': {'a-id': {'type': 'plain_text_input', 'value': 'your input'}}} + return HTTPResponse(status=200, body="") + + else: + if raw_body.startswith("{"): + event_payload = json.loads(raw_body) + if event_payload.get("type") == "url_verification": + return HTTPResponse(status=200, body=event_payload.get("challenge")) + return HTTPResponse(status=200, body="") + + return HTTPResponse(status=404, body="Not found") + + +# --------------------- +# Sanic App for Slack OAuth flow +# --------------------- + +authorization_url_generator = AuthorizeUrlGenerator( + client_id=client_id, + scopes=scopes, + user_scopes=user_scopes, +) +redirect_page_renderer = RedirectUriPageRenderer( + install_path="/slack/install", + redirect_uri_path="/slack/oauth_redirect", +) + + +@app.get("/slack/install") +async def oauth_start(req: Request): + state = state_store.issue() + url = authorization_url_generator.generate(state) + response_body = ( + '' + f'' + f'' + "" + ) + return HTTPResponse( + status=200, + body=response_body, + ) + + +@app.get("/slack/oauth_redirect") +async def oauth_callback(req: Request): + # Retrieve the auth code and state from the request params + if "code" in req.args: + state = req.args.get("state") + if state_store.consume(state): + code = req.args.get("code") + client = AsyncWebClient() # no prepared token needed for this app + oauth_response = await client.oauth_v2_access( + client_id=client_id, client_secret=client_secret, code=code + ) + logger.info(f"oauth.v2.access response: {oauth_response}") + + installed_enterprise = oauth_response.get("enterprise") or {} + installed_team = oauth_response.get("team") or {} + installer = oauth_response.get("authed_user") or {} + incoming_webhook = oauth_response.get("incoming_webhook") or {} + bot_token = oauth_response.get("access_token") + # NOTE: oauth.v2.access doesn't include bot_id in response + bot_id = None + if bot_token is not None: + auth_test = await client.auth_test(token=bot_token) + bot_id = auth_test["bot_id"] + + installation = Installation( + app_id=oauth_response.get("app_id"), + enterprise_id=installed_enterprise.get("id"), + team_id=installed_team.get("id"), + bot_token=bot_token, + bot_id=bot_id, + bot_user_id=oauth_response.get("bot_user_id"), + bot_scopes=oauth_response.get("scope"), # comma-separated string + bot_refresh_token=oauth_response.get("refresh_token"), + bot_token_expires_in=oauth_response.get("expires_in"), + user_id=installer.get("id"), + user_token=installer.get("access_token"), + user_scopes=installer.get("scope"), # comma-separated string + user_refresh_token=installer.get("refresh_token"), + user_token_expires_in=installer.get("expires_in"), + incoming_webhook_url=incoming_webhook.get("url"), + incoming_webhook_channel_id=incoming_webhook.get("channel_id"), + incoming_webhook_configuration_url=incoming_webhook.get( + "configuration_url" + ), + ) + await installation_store.async_save(installation) + html = redirect_page_renderer.render_success_page( + app_id=installation.app_id, + team_id=installation.team_id, + is_enterprise_install=installation.is_enterprise_install, + enterprise_url=installation.enterprise_url, + ) + return HTTPResponse( + status=200, + headers={ + "Content-Type": "text/html; charset=utf-8", + }, + body=html, + ) + else: + html = redirect_page_renderer.render_failure_page( + "the state value is already expired" + ) + return HTTPResponse( + status=400, + headers={ + "Content-Type": "text/html; charset=utf-8", + }, + body=html, + ) + + error = req.args.get("error") if "error" in req.args else "" + return HTTPResponse( + status=400, body=f"Something is wrong with the installation (error: {error})" + ) + + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=3000) + # python3 integration_tests/samples/token_rotation/oauth_async.py + # ngrok http 3000 + # https://{yours}.ngrok.io/slack/install diff --git a/integration_tests/samples/token_rotation/oauth_sqlalchemy.py b/integration_tests/samples/token_rotation/oauth_sqlalchemy.py new file mode 100644 index 000000000..be95e475a --- /dev/null +++ b/integration_tests/samples/token_rotation/oauth_sqlalchemy.py @@ -0,0 +1,303 @@ +# --------------------- +# Flask App for Slack OAuth flow +# --------------------- + +# pip3 install flask +from typing import Optional + +from integration_tests.samples.token_rotation.util import ( + parse_body, + extract_enterprise_id, + extract_user_id, + extract_team_id, + extract_is_enterprise_install, + extract_content_type, +) + +import logging +import os + +from slack_sdk.oauth.installation_store.sqlalchemy import SQLAlchemyInstallationStore +from slack_sdk.web import WebClient +from slack_sdk.oauth.token_rotation import TokenRotator +from slack_sdk.oauth import AuthorizeUrlGenerator, RedirectUriPageRenderer +from slack_sdk.oauth.installation_store import Installation +from slack_sdk.oauth.state_store.sqlalchemy import SQLAlchemyOAuthStateStore + +client_id = os.environ["SLACK_CLIENT_ID"] +client_secret = os.environ["SLACK_CLIENT_SECRET"] +scopes = ["app_mentions:read", "chat:write", "commands"] +user_scopes = ["search:read"] + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) +logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + +import sqlalchemy +from sqlalchemy.engine import Engine + +database_url = "sqlite:///slackapp.db" +# database_url = "postgresql://localhost/slackapp" # pip install psycopg2 +engine: Engine = sqlalchemy.create_engine(database_url) + +installation_store = SQLAlchemyInstallationStore( + client_id=client_id, + engine=engine, + logger=logger, +) +token_rotator = TokenRotator( + client_id=client_id, + client_secret=client_secret, +) + +state_store = SQLAlchemyOAuthStateStore( + engine=engine, + logger=logger, + expiration_seconds=300, +) + +try: + engine.execute("select count(*) from slack_bots") +except Exception as e: + installation_store.metadata.create_all(engine) + +try: + engine.execute("select count(*) from slack_oauth_states") +except Exception as e: + state_store.metadata.create_all(engine) + +# --------------------- +# Flask App for Slack events +# --------------------- + +import json +from slack_sdk.errors import SlackApiError +from slack_sdk.signature import SignatureVerifier + +signing_secret = os.environ["SLACK_SIGNING_SECRET"] +signature_verifier = SignatureVerifier(signing_secret=signing_secret) + + +def rotate_tokens( + enterprise_id: Optional[str] = None, + team_id: Optional[str] = None, + user_id: Optional[str] = None, + is_enterprise_install: Optional[bool] = None, +): + installation = installation_store.find_installation( + enterprise_id=enterprise_id, + team_id=team_id, + user_id=user_id, + is_enterprise_install=is_enterprise_install, + ) + if installation is not None: + updated_installation = token_rotator.perform_token_rotation( + installation=installation, + minutes_before_expiration=60 * 24 * 365, # one year for testing + ) + if updated_installation is not None: + installation_store.save(updated_installation) + + +from flask import Flask, request, make_response + +app = Flask(__name__) +app.debug = True + + +@app.route("/slack/events", methods=["POST"]) +def slack_app(): + if not signature_verifier.is_valid( + body=request.get_data(), + timestamp=request.headers.get("X-Slack-Request-Timestamp"), + signature=request.headers.get("X-Slack-Signature"), + ): + return make_response("invalid request", 403) + + raw_body = request.data.decode("utf-8") + body = parse_body(body=raw_body, content_type=extract_content_type(request.headers)) + rotate_tokens( + enterprise_id=extract_enterprise_id(body), + team_id=extract_team_id(body), + user_id=extract_user_id(body), + is_enterprise_install=extract_is_enterprise_install(body), + ) + + if "command" in request.form and request.form["command"] == "/token-rotation-modal": + try: + enterprise_id = request.form.get("enterprise_id") + team_id = request.form["team_id"] + bot = installation_store.find_bot( + enterprise_id=enterprise_id, + team_id=team_id, + ) + bot_token = bot.bot_token if bot else None + if not bot_token: + return make_response("Please install this app first!", 200) + + client = WebClient(token=bot_token) + trigger_id = request.form["trigger_id"] + response = client.views_open( + trigger_id=trigger_id, + view={ + "type": "modal", + "callback_id": "modal-id", + "title": {"type": "plain_text", "text": "Awesome Modal"}, + "submit": {"type": "plain_text", "text": "Submit"}, + "close": {"type": "plain_text", "text": "Cancel"}, + "blocks": [ + { + "type": "input", + "block_id": "b-id", + "label": { + "type": "plain_text", + "text": "Input label", + }, + "element": { + "action_id": "a-id", + "type": "plain_text_input", + }, + } + ], + }, + ) + return make_response("", 200) + except SlackApiError as e: + code = e.response["error"] + return make_response(f"Failed to open a modal due to {code}", 200) + + elif "payload" in request.form: + payload = json.loads(request.form["payload"]) + if ( + payload["type"] == "view_submission" + and payload["view"]["callback_id"] == "modal-id" + ): + submitted_data = payload["view"]["state"]["values"] + print( + submitted_data + ) # {'b-id': {'a-id': {'type': 'plain_text_input', 'value': 'your input'}}} + return make_response("", 200) + + else: + if raw_body.startswith("{"): + event_payload = json.loads(raw_body) + logger.info(f"Events API payload: {event_payload}") + if event_payload.get("type") == "url_verification": + return make_response(event_payload.get("challenge"), 200) + return make_response("", 200) + + return make_response("", 404) + + +# --------------------- +# Flask App for Slack OAuth flow +# --------------------- + +authorization_url_generator = AuthorizeUrlGenerator( + client_id=client_id, + scopes=scopes, + user_scopes=user_scopes, +) +redirect_page_renderer = RedirectUriPageRenderer( + install_path="/slack/install", + redirect_uri_path="/slack/oauth_redirect", +) + + +@app.route("/slack/install", methods=["GET"]) +def oauth_start(): + state = state_store.issue() + url = authorization_url_generator.generate(state) + return ( + '' + f'' + f'' + "" + ) + + +@app.route("/slack/oauth_redirect", methods=["GET"]) +def oauth_callback(): + # Retrieve the auth code and state from the request params + if "code" in request.args: + state = request.args["state"] + if state_store.consume(state): + code = request.args["code"] + client = WebClient() # no prepared token needed for this app + oauth_response = client.oauth_v2_access( + client_id=client_id, client_secret=client_secret, code=code + ) + logger.info(f"oauth.v2.access response: {oauth_response}") + + installed_enterprise = oauth_response.get("enterprise", {}) + is_enterprise_install = oauth_response.get("is_enterprise_install") + installed_team = oauth_response.get("team", {}) + installer = oauth_response.get("authed_user", {}) + incoming_webhook = oauth_response.get("incoming_webhook", {}) + + bot_token = oauth_response.get("access_token") + # NOTE: oauth.v2.access doesn't include bot_id in response + bot_id = None + enterprise_url = None + if bot_token is not None: + auth_test = client.auth_test(token=bot_token) + bot_id = auth_test["bot_id"] + if is_enterprise_install is True: + enterprise_url = auth_test.get("url") + + installation = Installation( + app_id=oauth_response.get("app_id"), + enterprise_id=installed_enterprise.get("id"), + enterprise_name=installed_enterprise.get("name"), + enterprise_url=enterprise_url, + team_id=installed_team.get("id"), + team_name=installed_team.get("name"), + bot_token=bot_token, + bot_id=bot_id, + bot_user_id=oauth_response.get("bot_user_id"), + bot_scopes=oauth_response.get("scope"), # comma-separated string + bot_refresh_token=oauth_response.get("refresh_token"), + bot_token_expires_in=oauth_response.get("expires_in"), + user_id=installer.get("id"), + user_token=installer.get("access_token"), + user_scopes=installer.get("scope"), # comma-separated string + user_refresh_token=installer.get("refresh_token"), + user_token_expires_in=installer.get("expires_in"), + incoming_webhook_url=incoming_webhook.get("url"), + incoming_webhook_channel=incoming_webhook.get("channel"), + incoming_webhook_channel_id=incoming_webhook.get("channel_id"), + incoming_webhook_configuration_url=incoming_webhook.get( + "configuration_url" + ), + is_enterprise_install=is_enterprise_install, + token_type=oauth_response.get("token_type"), + ) + installation_store.save(installation) + return redirect_page_renderer.render_success_page( + app_id=installation.app_id, + team_id=installation.team_id, + is_enterprise_install=installation.is_enterprise_install, + enterprise_url=installation.enterprise_url, + ) + else: + return redirect_page_renderer.render_failure_page( + "the state value is already expired" + ) + + error = request.args["error"] if "error" in request.args else "" + return make_response( + f"Something is wrong with the installation (error: {error})", 400 + ) + + +if __name__ == "__main__": + # export SLACK_CLIENT_ID=123.123 + # export SLACK_CLIENT_SECRET=xxx + # export SLACK_SIGNING_SECRET=*** + # export FLASK_ENV=development + + app.run("localhost", 3000) + + # python3 integration_tests/samples/token_rotation/oauth.py + # ngrok http 3000 + # https://{yours}.ngrok.io/slack/oauth/start diff --git a/integration_tests/samples/token_rotation/oauth_sqlite3.py b/integration_tests/samples/token_rotation/oauth_sqlite3.py new file mode 100644 index 000000000..4c2adde7f --- /dev/null +++ b/integration_tests/samples/token_rotation/oauth_sqlite3.py @@ -0,0 +1,296 @@ +# --------------------- +# Flask App for Slack OAuth flow +# --------------------- + +# pip3 install flask +from typing import Optional + +from integration_tests.samples.token_rotation.util import ( + parse_body, + extract_enterprise_id, + extract_user_id, + extract_team_id, + extract_is_enterprise_install, + extract_content_type, +) + +import logging +import os + +from slack_sdk.oauth.installation_store.sqlite3 import SQLite3InstallationStore +from slack_sdk.web import WebClient +from slack_sdk.oauth.token_rotation import TokenRotator +from slack_sdk.oauth import AuthorizeUrlGenerator, RedirectUriPageRenderer +from slack_sdk.oauth.installation_store import Installation +from slack_sdk.oauth.state_store.sqlite3 import SQLite3OAuthStateStore + +client_id = os.environ["SLACK_CLIENT_ID"] +client_secret = os.environ["SLACK_CLIENT_SECRET"] +scopes = ["app_mentions:read", "chat:write", "commands"] +user_scopes = ["search:read"] + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) +logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + +import sqlalchemy +from sqlalchemy.engine import Engine + +database_url = "sqlite:///slackapp.db" +# database_url = "postgresql://localhost/slackapp" # pip install psycopg2 +engine: Engine = sqlalchemy.create_engine(database_url) + +installation_store = SQLite3InstallationStore( + database="test.db", + client_id=client_id, + logger=logger, +) +installation_store.init() + +token_rotator = TokenRotator( + client_id=client_id, + client_secret=client_secret, +) + +state_store = SQLite3OAuthStateStore( + database="test.db", + logger=logger, + expiration_seconds=300, +) +state_store.init() + +# --------------------- +# Flask App for Slack events +# --------------------- + +import json +from slack_sdk.errors import SlackApiError +from slack_sdk.signature import SignatureVerifier + +signing_secret = os.environ["SLACK_SIGNING_SECRET"] +signature_verifier = SignatureVerifier(signing_secret=signing_secret) + + +def rotate_tokens( + enterprise_id: Optional[str] = None, + team_id: Optional[str] = None, + user_id: Optional[str] = None, + is_enterprise_install: Optional[bool] = None, +): + installation = installation_store.find_installation( + enterprise_id=enterprise_id, + team_id=team_id, + user_id=user_id, + is_enterprise_install=is_enterprise_install, + ) + if installation is not None: + updated_installation = token_rotator.perform_token_rotation( + installation=installation, + minutes_before_expiration=60 * 24 * 365, # one year for testing + ) + if updated_installation is not None: + installation_store.save(updated_installation) + + +from flask import Flask, request, make_response + +app = Flask(__name__) +app.debug = True + + +@app.route("/slack/events", methods=["POST"]) +def slack_app(): + if not signature_verifier.is_valid( + body=request.get_data(), + timestamp=request.headers.get("X-Slack-Request-Timestamp"), + signature=request.headers.get("X-Slack-Signature"), + ): + return make_response("invalid request", 403) + + raw_body = request.data.decode("utf-8") + body = parse_body(body=raw_body, content_type=extract_content_type(request.headers)) + rotate_tokens( + enterprise_id=extract_enterprise_id(body), + team_id=extract_team_id(body), + user_id=extract_user_id(body), + is_enterprise_install=extract_is_enterprise_install(body), + ) + + if "command" in request.form and request.form["command"] == "/token-rotation-modal": + try: + enterprise_id = request.form.get("enterprise_id") + team_id = request.form["team_id"] + bot = installation_store.find_bot( + enterprise_id=enterprise_id, + team_id=team_id, + ) + bot_token = bot.bot_token if bot else None + if not bot_token: + return make_response("Please install this app first!", 200) + + client = WebClient(token=bot_token) + trigger_id = request.form["trigger_id"] + response = client.views_open( + trigger_id=trigger_id, + view={ + "type": "modal", + "callback_id": "modal-id", + "title": {"type": "plain_text", "text": "Awesome Modal"}, + "submit": {"type": "plain_text", "text": "Submit"}, + "close": {"type": "plain_text", "text": "Cancel"}, + "blocks": [ + { + "type": "input", + "block_id": "b-id", + "label": { + "type": "plain_text", + "text": "Input label", + }, + "element": { + "action_id": "a-id", + "type": "plain_text_input", + }, + } + ], + }, + ) + return make_response("", 200) + except SlackApiError as e: + code = e.response["error"] + return make_response(f"Failed to open a modal due to {code}", 200) + + elif "payload" in request.form: + payload = json.loads(request.form["payload"]) + if ( + payload["type"] == "view_submission" + and payload["view"]["callback_id"] == "modal-id" + ): + submitted_data = payload["view"]["state"]["values"] + print( + submitted_data + ) # {'b-id': {'a-id': {'type': 'plain_text_input', 'value': 'your input'}}} + return make_response("", 200) + + else: + if raw_body.startswith("{"): + event_payload = json.loads(raw_body) + logger.info(f"Events API payload: {event_payload}") + if event_payload.get("type") == "url_verification": + return make_response(event_payload.get("challenge"), 200) + return make_response("", 200) + + return make_response("", 404) + + +# --------------------- +# Flask App for Slack OAuth flow +# --------------------- + +authorization_url_generator = AuthorizeUrlGenerator( + client_id=client_id, + scopes=scopes, + user_scopes=user_scopes, +) +redirect_page_renderer = RedirectUriPageRenderer( + install_path="/slack/install", + redirect_uri_path="/slack/oauth_redirect", +) + + +@app.route("/slack/install", methods=["GET"]) +def oauth_start(): + state = state_store.issue() + url = authorization_url_generator.generate(state) + return ( + '' + f'' + f'' + "" + ) + + +@app.route("/slack/oauth_redirect", methods=["GET"]) +def oauth_callback(): + # Retrieve the auth code and state from the request params + if "code" in request.args: + state = request.args["state"] + if state_store.consume(state): + code = request.args["code"] + client = WebClient() # no prepared token needed for this app + oauth_response = client.oauth_v2_access( + client_id=client_id, client_secret=client_secret, code=code + ) + logger.info(f"oauth.v2.access response: {oauth_response}") + + installed_enterprise = oauth_response.get("enterprise", {}) + is_enterprise_install = oauth_response.get("is_enterprise_install") + installed_team = oauth_response.get("team", {}) + installer = oauth_response.get("authed_user", {}) + incoming_webhook = oauth_response.get("incoming_webhook", {}) + + bot_token = oauth_response.get("access_token") + # NOTE: oauth.v2.access doesn't include bot_id in response + bot_id = None + enterprise_url = None + if bot_token is not None: + auth_test = client.auth_test(token=bot_token) + bot_id = auth_test["bot_id"] + if is_enterprise_install is True: + enterprise_url = auth_test.get("url") + + installation = Installation( + app_id=oauth_response.get("app_id"), + enterprise_id=installed_enterprise.get("id"), + enterprise_name=installed_enterprise.get("name"), + enterprise_url=enterprise_url, + team_id=installed_team.get("id"), + team_name=installed_team.get("name"), + bot_token=bot_token, + bot_id=bot_id, + bot_user_id=oauth_response.get("bot_user_id"), + bot_scopes=oauth_response.get("scope"), # comma-separated string + bot_refresh_token=oauth_response.get("refresh_token"), + bot_token_expires_in=oauth_response.get("expires_in"), + user_id=installer.get("id"), + user_token=installer.get("access_token"), + user_scopes=installer.get("scope"), # comma-separated string + user_refresh_token=installer.get("refresh_token"), + user_token_expires_in=installer.get("expires_in"), + incoming_webhook_url=incoming_webhook.get("url"), + incoming_webhook_channel=incoming_webhook.get("channel"), + incoming_webhook_channel_id=incoming_webhook.get("channel_id"), + incoming_webhook_configuration_url=incoming_webhook.get( + "configuration_url" + ), + is_enterprise_install=is_enterprise_install, + token_type=oauth_response.get("token_type"), + ) + installation_store.save(installation) + return redirect_page_renderer.render_success_page( + app_id=installation.app_id, + team_id=installation.team_id, + is_enterprise_install=installation.is_enterprise_install, + enterprise_url=installation.enterprise_url, + ) + else: + return redirect_page_renderer.render_failure_page( + "the state value is already expired" + ) + + error = request.args["error"] if "error" in request.args else "" + return make_response( + f"Something is wrong with the installation (error: {error})", 400 + ) + + +if __name__ == "__main__": + # export SLACK_CLIENT_ID=123.123 + # export SLACK_CLIENT_SECRET=xxx + # export SLACK_SIGNING_SECRET=*** + # export FLASK_ENV=development + + app.run("localhost", 3000) + + # python3 integration_tests/samples/token_rotation/oauth_sqlite3.py + # ngrok http 3000 + # https://{yours}.ngrok.io/slack/oauth/start diff --git a/integration_tests/samples/token_rotation/util.py b/integration_tests/samples/token_rotation/util.py new file mode 100644 index 000000000..9948db5ba --- /dev/null +++ b/integration_tests/samples/token_rotation/util.py @@ -0,0 +1,92 @@ +import json +from typing import Optional, Dict, Any, Sequence +from urllib.parse import parse_qsl + + +def parse_body(body: str, content_type: Optional[str]) -> Dict[str, Any]: + if not body: + return {} + if ( + content_type is not None and content_type == "application/json" + ) or body.startswith("{"): + return json.loads(body) + else: + if "payload" in body: # This is not JSON format yet + params = dict(parse_qsl(body)) + if params.get("payload") is not None: + return json.loads(params.get("payload")) + else: + return {} + else: + return dict(parse_qsl(body)) + + +def extract_is_enterprise_install(payload: Dict[str, Any]) -> Optional[bool]: + if "is_enterprise_install" in payload: + is_enterprise_install = payload.get("is_enterprise_install") + return is_enterprise_install is not None and ( + is_enterprise_install is True or is_enterprise_install == "true" + ) + return False + + +def extract_enterprise_id(payload: Dict[str, Any]) -> Optional[str]: + if payload.get("enterprise") is not None: + org = payload.get("enterprise") + if isinstance(org, str): + return org + elif "id" in org: + return org.get("id") # type: ignore + if payload.get("authorizations") is not None and len(payload["authorizations"]) > 0: + # To make Events API handling functioning also for shared channels, + # we should use .authorizations[0].enterprise_id over .enterprise_id + return extract_enterprise_id(payload["authorizations"][0]) + if "enterprise_id" in payload: + return payload.get("enterprise_id") + if payload.get("team") is not None and "enterprise_id" in payload["team"]: + # In the case where the type is view_submission + return payload["team"].get("enterprise_id") + if payload.get("event") is not None: + return extract_enterprise_id(payload["event"]) + return None + + +def extract_team_id(payload: Dict[str, Any]) -> Optional[str]: + if payload.get("team") is not None: + team = payload.get("team") + if isinstance(team, str): + return team + elif team and "id" in team: + return team.get("id") + if payload.get("authorizations") is not None and len(payload["authorizations"]) > 0: + # To make Events API handling functioning also for shared channels, + # we should use .authorizations[0].team_id over .team_id + return extract_team_id(payload["authorizations"][0]) + if "team_id" in payload: + return payload.get("team_id") + if payload.get("event") is not None: + return extract_team_id(payload["event"]) + if payload.get("user") is not None: + return payload.get("user")["team_id"] + return None + + +def extract_user_id(payload: Dict[str, Any]) -> Optional[str]: + if payload.get("user") is not None: + user = payload.get("user") + if isinstance(user, str): + return user + elif "id" in user: + return user.get("id") # type: ignore + if "user_id" in payload: + return payload.get("user_id") + if payload.get("event") is not None: + return extract_user_id(payload["event"]) + return None + + +def extract_content_type(headers: Dict[str, Sequence[str]]) -> Optional[str]: + content_type: Optional[str] = headers.get("content-type", [None])[0] + if content_type: + return content_type.split(";")[0] + return None diff --git a/slack_sdk/audit_logs/v1/logs.py b/slack_sdk/audit_logs/v1/logs.py index 5b398bc62..a0137e73c 100644 --- a/slack_sdk/audit_logs/v1/logs.py +++ b/slack_sdk/audit_logs/v1/logs.py @@ -139,6 +139,7 @@ class Details: channel_id: Optional[str] added_team_id: Optional[str] unknown_fields: Dict[str, Any] + is_token_rotation_enabled_app: Optional[bool] def __init__( self, @@ -194,6 +195,7 @@ def __init__( external_user_email: Optional[str] = None, channel_id: Optional[str] = None, added_team_id: Optional[str] = None, + is_token_rotation_enabled_app: Optional[bool] = None, **kwargs, ) -> None: self.name = name @@ -248,6 +250,7 @@ def __init__( self.external_user_email = external_user_email self.channel_id = channel_id self.added_team_id = added_team_id + self.is_token_rotation_enabled_app = is_token_rotation_enabled_app class App: diff --git a/slack_sdk/errors/__init__.py b/slack_sdk/errors/__init__.py index ef3ef210b..7816be555 100644 --- a/slack_sdk/errors/__init__.py +++ b/slack_sdk/errors/__init__.py @@ -33,6 +33,15 @@ def __init__(self, message, response): super(SlackApiError, self).__init__(msg) +class SlackTokenRotationError(SlackClientError): + """Error raised when the oauth.v2.access call for token rotation fails""" + + api_error: SlackApiError + + def __init__(self, api_error: SlackApiError): + self.api_error = api_error + + class SlackClientNotConnectedError(SlackClientError): """Error raised when attempting to send messages over the websocket when the connection is closed.""" diff --git a/slack_sdk/oauth/installation_store/async_cacheable_installation_store.py b/slack_sdk/oauth/installation_store/async_cacheable_installation_store.py index 0fff80f84..072150c3d 100644 --- a/slack_sdk/oauth/installation_store/async_cacheable_installation_store.py +++ b/slack_sdk/oauth/installation_store/async_cacheable_installation_store.py @@ -27,6 +27,13 @@ def logger(self) -> Logger: return self.underlying.logger async def async_save(self, installation: Installation): + # Invalidate cache data for update operations + key = f"{installation.enterprise_id or ''}-{installation.team_id or ''}" + if key in self.cached_bots: + self.cached_bots.pop(key) + key = f"{installation.enterprise_id or ''}-{installation.team_id or ''}-{installation.user_id or ''}" + if key in self.cached_installations: + self.cached_installations.pop(key) return await self.underlying.async_save(installation) async def async_find_bot( @@ -98,8 +105,10 @@ async def async_delete_installation( team_id=team_id, user_id=user_id, ) - key = f"{enterprise_id or ''}-{team_id or ''}={user_id or ''}" - self.cached_installations.pop(key) + key_prefix = f"{enterprise_id or ''}-{team_id or ''}" + for key in list(self.cached_installations.keys()): + if key.startswith(key_prefix): + self.cached_installations.pop(key) async def async_delete_all( self, @@ -112,9 +121,9 @@ async def async_delete_all( team_id=team_id, ) key_prefix = f"{enterprise_id or ''}-{team_id or ''}" - for key in self.cached_bots.keys(): + for key in list(self.cached_bots.keys()): if key.startswith(key_prefix): self.cached_bots.pop(key) - for key in self.cached_installations.keys(): + for key in list(self.cached_installations.keys()): if key.startswith(key_prefix): self.cached_installations.pop(key) diff --git a/slack_sdk/oauth/installation_store/cacheable_installation_store.py b/slack_sdk/oauth/installation_store/cacheable_installation_store.py index aff4b7a27..b37e87e2f 100644 --- a/slack_sdk/oauth/installation_store/cacheable_installation_store.py +++ b/slack_sdk/oauth/installation_store/cacheable_installation_store.py @@ -25,6 +25,14 @@ def logger(self) -> Logger: return self.underlying.logger def save(self, installation: Installation): + # Invalidate cache data for update operations + key = f"{installation.enterprise_id or ''}-{installation.team_id or ''}" + if key in self.cached_bots: + self.cached_bots.pop(key) + key = f"{installation.enterprise_id or ''}-{installation.team_id or ''}-{installation.user_id or ''}" + if key in self.cached_installations: + self.cached_installations.pop(key) + return self.underlying.save(installation) def find_bot( @@ -58,7 +66,7 @@ def find_installation( ) -> Optional[Installation]: if is_enterprise_install or team_id is None: team_id = "" - key = f"{enterprise_id or ''}-{team_id or ''}={user_id or ''}" + key = f"{enterprise_id or ''}-{team_id or ''}-{user_id or ''}" if key in self.cached_installations: return self.cached_installations[key] installation = self.underlying.find_installation( @@ -82,7 +90,8 @@ def delete_bot( team_id=team_id, ) key = f"{enterprise_id or ''}-{team_id or ''}" - self.cached_bots.pop(key) + if key in self.cached_bots: + self.cached_bots.pop(key) def delete_installation( self, @@ -96,8 +105,10 @@ def delete_installation( team_id=team_id, user_id=user_id, ) - key = f"{enterprise_id or ''}-{team_id or ''}={user_id or ''}" - self.cached_installations.pop(key) + key_prefix = f"{enterprise_id or ''}-{team_id or ''}" + for key in list(self.cached_installations.keys()): + if key.startswith(key_prefix): + self.cached_installations.pop(key) def delete_all( self, @@ -110,9 +121,9 @@ def delete_all( team_id=team_id, ) key_prefix = f"{enterprise_id or ''}-{team_id or ''}" - for key in self.cached_bots.keys(): + for key in list(self.cached_bots.keys()): if key.startswith(key_prefix): self.cached_bots.pop(key) - for key in self.cached_installations.keys(): + for key in list(self.cached_installations.keys()): if key.startswith(key_prefix): self.cached_installations.pop(key) diff --git a/slack_sdk/oauth/installation_store/internals.py b/slack_sdk/oauth/installation_store/internals.py new file mode 100644 index 000000000..b5c0fe9dd --- /dev/null +++ b/slack_sdk/oauth/installation_store/internals.py @@ -0,0 +1,32 @@ +import platform +import datetime + +(major, minor, patch) = platform.python_version_tuple() +is_python_3_6: bool = int(major) == 3 and int(minor) >= 6 + +utc_timezone = datetime.timezone.utc + + +def _from_iso_format_to_datetime(iso_datetime_str: str) -> datetime: + if is_python_3_6: + elements = iso_datetime_str.split(" ") + ymd = elements[0].split("-") + hms = elements[1].split(":") + return datetime.datetime( + int(ymd[0]), + int(ymd[1]), + int(ymd[2]), + int(hms[0]), + int(hms[1]), + int(hms[2]), + 0, + utc_timezone, + ) + else: + if "+" not in iso_datetime_str: + iso_datetime_str += "+00:00" + return datetime.datetime.fromisoformat(iso_datetime_str) + + +def _from_iso_format_to_unix_timestamp(iso_datetime_str: str) -> float: + return _from_iso_format_to_datetime(iso_datetime_str).timestamp() diff --git a/slack_sdk/oauth/installation_store/models/bot.py b/slack_sdk/oauth/installation_store/models/bot.py index e2df9576d..0eaffbc80 100644 --- a/slack_sdk/oauth/installation_store/models/bot.py +++ b/slack_sdk/oauth/installation_store/models/bot.py @@ -1,6 +1,12 @@ +import re from datetime import datetime +from time import time from typing import Optional, Union, Dict, Any, Sequence +from slack_sdk.oauth.installation_store.internals import ( + _from_iso_format_to_unix_timestamp, +) + class Bot: app_id: Optional[str] @@ -12,6 +18,10 @@ class Bot: bot_id: str bot_user_id: str bot_scopes: Sequence[str] + # only when token rotation is enabled + bot_refresh_token: Optional[str] + # only when token rotation is enabled + bot_token_expires_at: Optional[int] is_enterprise_install: bool installed_at: float @@ -31,11 +41,20 @@ def __init__( bot_id: str, bot_user_id: str, bot_scopes: Union[str, Sequence[str]] = "", + # only when token rotation is enabled + bot_refresh_token: Optional[str] = None, + # only when token rotation is enabled + bot_token_expires_in: Optional[int] = None, + # only for duplicating this object + # only when token rotation is enabled + bot_token_expires_at: Optional[Union[int, datetime, str]] = None, is_enterprise_install: Optional[bool] = False, # timestamps - installed_at: float, + # The expected value type is float but the internals handle other types too + # for str values, we supports only ISO datetime format. + installed_at: Union[float, datetime, str], # custom values - custom_values: Optional[Dict[str, Any]] = None + custom_values: Optional[Dict[str, Any]] = None, ): self.app_id = app_id self.enterprise_id = enterprise_id @@ -50,8 +69,36 @@ def __init__( self.bot_scopes = bot_scopes.split(",") if len(bot_scopes) > 0 else [] else: self.bot_scopes = bot_scopes + self.bot_refresh_token = bot_refresh_token + if bot_token_expires_at is not None: + if type(bot_token_expires_at) == datetime: + self.bot_token_expires_at = int(bot_token_expires_at.timestamp()) + elif type(bot_token_expires_at) == str and not re.match( + "^\\d+$", bot_token_expires_at + ): + self.bot_token_expires_at = int( + _from_iso_format_to_unix_timestamp(bot_token_expires_at) + ) + else: + self.bot_token_expires_at = int(bot_token_expires_at) + elif bot_token_expires_in is not None: + self.bot_token_expires_at = int(time()) + bot_token_expires_in + else: + self.bot_token_expires_at = None self.is_enterprise_install = is_enterprise_install or False - self.installed_at = installed_at + + if type(installed_at) == float: + self.installed_at = installed_at + elif type(installed_at) == datetime: + self.installed_at = installed_at.timestamp() + elif type(installed_at) == str: + if re.match("^\\d+.\\d+$", installed_at): + self.installed_at = float(installed_at) + else: + self.installed_at = _from_iso_format_to_unix_timestamp(installed_at) + else: + raise ValueError(f"Unsupported data format for installed_at {installed_at}") + self.custom_values = custom_values if custom_values is not None else {} def set_custom_value(self, name: str, value: Any): @@ -71,6 +118,10 @@ def to_dict(self) -> Dict[str, Any]: "bot_id": self.bot_id, "bot_user_id": self.bot_user_id, "bot_scopes": ",".join(self.bot_scopes) if self.bot_scopes else None, + "bot_refresh_token": self.bot_refresh_token, + "bot_token_expires_at": datetime.utcfromtimestamp(self.bot_token_expires_at) + if self.bot_token_expires_at is not None + else None, "is_enterprise_install": self.is_enterprise_install, "installed_at": datetime.utcfromtimestamp(self.installed_at), } diff --git a/slack_sdk/oauth/installation_store/models/installation.py b/slack_sdk/oauth/installation_store/models/installation.py index 851ecb3b7..78fbc4a6c 100644 --- a/slack_sdk/oauth/installation_store/models/installation.py +++ b/slack_sdk/oauth/installation_store/models/installation.py @@ -1,7 +1,11 @@ +import re from datetime import datetime from time import time from typing import Optional, Union, Dict, Any, Sequence +from slack_sdk.oauth.installation_store.internals import ( + _from_iso_format_to_unix_timestamp, +) from slack_sdk.oauth.installation_store.models.bot import Bot @@ -16,9 +20,16 @@ class Installation: bot_id: Optional[str] bot_user_id: Optional[str] bot_scopes: Optional[Sequence[str]] + bot_refresh_token: Optional[str] # only when token rotation is enabled + # only when token rotation is enabled + # Unix time (seconds): only when token rotation is enabled + bot_token_expires_at: Optional[int] user_id: str user_token: Optional[str] user_scopes: Optional[Sequence[str]] + user_refresh_token: Optional[str] # only when token rotation is enabled + # Unix time (seconds): only when token rotation is enabled + user_token_expires_at: Optional[int] incoming_webhook_url: Optional[str] incoming_webhook_channel: Optional[str] incoming_webhook_channel_id: Optional[str] @@ -44,10 +55,22 @@ def __init__( bot_id: Optional[str] = None, bot_user_id: Optional[str] = None, bot_scopes: Union[str, Sequence[str]] = "", + bot_refresh_token: Optional[str] = None, # only when token rotation is enabled + # only when token rotation is enabled + bot_token_expires_in: Optional[int] = None, + # only for duplicating this object + # only when token rotation is enabled + bot_token_expires_at: Optional[Union[int, datetime, str]] = None, # installer user_id: str, user_token: Optional[str] = None, user_scopes: Union[str, Sequence[str]] = "", + user_refresh_token: Optional[str] = None, # only when token rotation is enabled + # only when token rotation is enabled + user_token_expires_in: Optional[int] = None, + # only for duplicating this object + # only when token rotation is enabled + user_token_expires_at: Optional[Union[int, datetime, str]] = None, # incoming webhook incoming_webhook_url: Optional[str] = None, incoming_webhook_channel: Optional[str] = None, @@ -57,9 +80,11 @@ def __init__( is_enterprise_install: Optional[bool] = False, token_type: Optional[str] = None, # timestamps - installed_at: Optional[float] = None, + # The expected value type is float but the internals handle other types too + # for str values, we supports only ISO datetime format. + installed_at: Optional[Union[float, datetime, str]] = None, # custom values - custom_values: Optional[Dict[str, Any]] = None + custom_values: Optional[Dict[str, Any]] = None, ): self.app_id = app_id self.enterprise_id = enterprise_id @@ -74,6 +99,22 @@ def __init__( self.bot_scopes = bot_scopes.split(",") if len(bot_scopes) > 0 else [] else: self.bot_scopes = bot_scopes + self.bot_refresh_token = bot_refresh_token + if bot_token_expires_at is not None: + if type(bot_token_expires_at) == datetime: + self.bot_token_expires_at = int(bot_token_expires_at.timestamp()) + elif type(bot_token_expires_at) == str and not re.match( + "^\\d+$", bot_token_expires_at + ): + self.bot_token_expires_at = int( + _from_iso_format_to_unix_timestamp(bot_token_expires_at) + ) + else: + self.bot_token_expires_at = bot_token_expires_at + elif bot_token_expires_in is not None: + self.bot_token_expires_at = int(time()) + bot_token_expires_in + else: + self.bot_token_expires_at = None self.user_id = user_id self.user_token = user_token @@ -81,6 +122,22 @@ def __init__( self.user_scopes = user_scopes.split(",") if len(user_scopes) > 0 else [] else: self.user_scopes = user_scopes + self.user_refresh_token = user_refresh_token + if user_token_expires_at is not None: + if type(user_token_expires_at) == datetime: + self.user_token_expires_at = int(user_token_expires_at.timestamp()) + elif type(user_token_expires_at) == str and not re.match( + "^\\d+$", user_token_expires_at + ): + self.user_token_expires_at = int( + _from_iso_format_to_unix_timestamp(user_token_expires_at) + ) + else: + self.user_token_expires_at = user_token_expires_at + elif user_token_expires_in is not None: + self.user_token_expires_at = int(time()) + user_token_expires_in + else: + self.user_token_expires_at = None self.incoming_webhook_url = incoming_webhook_url self.incoming_webhook_channel = incoming_webhook_channel @@ -90,7 +147,20 @@ def __init__( self.is_enterprise_install = is_enterprise_install or False self.token_type = token_type - self.installed_at = time() if installed_at is None else installed_at + if installed_at is None: + self.installed_at = datetime.now().timestamp() + elif type(installed_at) == float: + self.installed_at = installed_at + elif type(installed_at) == datetime: + self.installed_at = installed_at.timestamp() + elif type(installed_at) == str: + if re.match("^\\d+.\\d+$", installed_at): + self.installed_at = float(installed_at) + else: + self.installed_at = _from_iso_format_to_unix_timestamp(installed_at) + else: + raise ValueError(f"Unsupported data format for installed_at {installed_at}") + self.custom_values = custom_values if custom_values is not None else {} def to_bot(self) -> Bot: @@ -104,6 +174,8 @@ def to_bot(self) -> Bot: bot_id=self.bot_id, bot_user_id=self.bot_user_id, bot_scopes=self.bot_scopes, + bot_refresh_token=self.bot_refresh_token, + bot_token_expires_at=self.bot_token_expires_at, is_enterprise_install=self.is_enterprise_install, installed_at=self.installed_at, custom_values=self.custom_values, @@ -127,9 +199,19 @@ def to_dict(self) -> Dict[str, Any]: "bot_id": self.bot_id, "bot_user_id": self.bot_user_id, "bot_scopes": ",".join(self.bot_scopes) if self.bot_scopes else None, + "bot_refresh_token": self.bot_refresh_token, + "bot_token_expires_at": datetime.utcfromtimestamp(self.bot_token_expires_at) + if self.bot_token_expires_at is not None + else None, "user_id": self.user_id, "user_token": self.user_token, "user_scopes": ",".join(self.user_scopes) if self.user_scopes else None, + "user_refresh_token": self.user_refresh_token, + "user_token_expires_at": datetime.utcfromtimestamp( + self.user_token_expires_at + ) + if self.user_token_expires_at is not None + else None, "incoming_webhook_url": self.incoming_webhook_url, "incoming_webhook_channel": self.incoming_webhook_channel, "incoming_webhook_channel_id": self.incoming_webhook_channel_id, diff --git a/slack_sdk/oauth/installation_store/sqlalchemy/__init__.py b/slack_sdk/oauth/installation_store/sqlalchemy/__init__.py index 9685fced6..5b59e621b 100644 --- a/slack_sdk/oauth/installation_store/sqlalchemy/__init__.py +++ b/slack_sdk/oauth/installation_store/sqlalchemy/__init__.py @@ -48,9 +48,13 @@ def build_installations_table(cls, metadata: MetaData, table_name: str) -> Table Column("bot_id", String(32)), Column("bot_user_id", String(32)), Column("bot_scopes", String(1000)), + Column("bot_refresh_token", String(200)), # added in v3.8.0 + Column("bot_token_expires_at", DateTime), # added in v3.8.0 Column("user_id", String(32), nullable=False), Column("user_token", String(200)), Column("user_scopes", String(1000)), + Column("user_refresh_token", String(200)), # added in v3.8.0 + Column("user_token_expires_at", DateTime), # added in v3.8.0 Column("incoming_webhook_url", String(200)), Column("incoming_webhook_channel", String(200)), Column("incoming_webhook_channel_id", String(200)), @@ -89,6 +93,8 @@ def build_bots_table(cls, metadata: MetaData, table_name: str) -> Table: Column("bot_id", String(32)), Column("bot_user_id", String(32)), Column("bot_scopes", String(1000)), + Column("bot_refresh_token", String(200)), # added in v3.8.0 + Column("bot_token_expires_at", DateTime), # added in v3.8.0 Column("is_enterprise_install", Boolean, default=False, nullable=False), Column( "installed_at", @@ -135,10 +141,60 @@ def save(self, installation: Installation): with self.engine.begin() as conn: i = installation.to_dict() i["client_id"] = self.client_id - conn.execute(self.installations.insert(), i) + + i_column = self.installations.c + installations_rows = conn.execute( + sqlalchemy.select([i_column.id]) + .where( + and_( + i_column.client_id == self.client_id, + i_column.enterprise_id == installation.enterprise_id, + i_column.team_id == installation.team_id, + i_column.installed_at == i.get("installed_at"), + ) + ) + .limit(1) + ) + installations_row_id: Optional[str] = None + for row in installations_rows: + installations_row_id = row["id"] + if installations_row_id is None: + conn.execute(self.installations.insert(), i) + else: + update_statement = ( + self.installations.update() + .where(i_column.id == installations_row_id) + .values(**i) + ) + conn.execute(update_statement, i) + + # bots b = installation.to_bot().to_dict() b["client_id"] = self.client_id - conn.execute(self.bots.insert(), b) + + b_column = self.bots.c + bots_rows = conn.execute( + sqlalchemy.select([b_column.id]) + .where( + and_( + b_column.client_id == self.client_id, + b_column.enterprise_id == installation.enterprise_id, + b_column.team_id == installation.team_id, + b_column.installed_at == b.get("installed_at"), + ) + ) + .limit(1) + ) + bots_row_id: Optional[str] = None + for row in bots_rows: + bots_row_id = row["id"] + if bots_row_id is None: + conn.execute(self.bots.insert(), b) + else: + update_statement = ( + self.bots.update().where(b_column.id == bots_row_id).values(**b) + ) + conn.execute(update_statement, b) def find_bot( self, @@ -153,7 +209,13 @@ def find_bot( c = self.bots.c query = ( self.bots.select() - .where(and_(c.enterprise_id == enterprise_id, c.team_id == team_id)) + .where( + and_( + c.client_id == self.client_id, + c.enterprise_id == enterprise_id, + c.team_id == team_id, + ) + ) .order_by(desc(c.installed_at)) .limit(1) ) @@ -171,6 +233,8 @@ def find_bot( bot_id=row["bot_id"], bot_user_id=row["bot_user_id"], bot_scopes=row["bot_scopes"], + bot_refresh_token=row["bot_refresh_token"], + bot_token_expires_at=row["bot_token_expires_at"], is_enterprise_install=row["is_enterprise_install"], installed_at=row["installed_at"], ) @@ -191,6 +255,7 @@ def find_installation( where_clause = and_(c.enterprise_id == enterprise_id, c.team_id == team_id) if user_id is not None: where_clause = and_( + c.client_id == self.client_id, c.enterprise_id == enterprise_id, c.team_id == team_id, c.user_id == user_id, @@ -217,9 +282,13 @@ def find_installation( bot_id=row["bot_id"], bot_user_id=row["bot_user_id"], bot_scopes=row["bot_scopes"], + bot_refresh_token=row["bot_refresh_token"], + bot_token_expires_at=row["bot_token_expires_at"], user_id=row["user_id"], user_token=row["user_token"], user_scopes=row["user_scopes"], + user_refresh_token=row["user_refresh_token"], + user_token_expires_at=row["user_token_expires_at"], # Only the incoming webhook issued in the latest installation is set in this logic incoming_webhook_url=row["incoming_webhook_url"], incoming_webhook_channel=row["incoming_webhook_channel"], diff --git a/slack_sdk/oauth/installation_store/sqlite3/__init__.py b/slack_sdk/oauth/installation_store/sqlite3/__init__.py index 42f915528..eef059c61 100644 --- a/slack_sdk/oauth/installation_store/sqlite3/__init__.py +++ b/slack_sdk/oauth/installation_store/sqlite3/__init__.py @@ -65,9 +65,13 @@ def create_tables(self): bot_id text not null, bot_user_id text not null, bot_scopes text, + bot_refresh_token text, -- since v3.8 + bot_token_expires_at datetime, -- since v3.8 user_id text not null, user_token text, user_scopes text, + user_refresh_token text, -- since v3.8 + user_token_expires_at datetime, -- since v3.8 incoming_webhook_url text, incoming_webhook_channel text, incoming_webhook_channel_id text, @@ -103,6 +107,8 @@ def create_tables(self): bot_id text not null, bot_user_id text not null, bot_scopes text, + bot_refresh_token text, -- since v3.8 + bot_token_expires_at datetime, -- since v3.8 is_enterprise_install boolean not null default 0, installed_at datetime not null default current_timestamp ); @@ -139,6 +145,8 @@ def save(self, installation: Installation): bot_id, bot_user_id, bot_scopes, + bot_refresh_token, -- since v3.8 + bot_token_expires_at, -- since v3.8 is_enterprise_install ) values @@ -153,6 +161,8 @@ def save(self, installation: Installation): ?, ?, ?, + ?, + ?, ? ); """, @@ -167,6 +177,8 @@ def save(self, installation: Installation): installation.bot_id, installation.bot_user_id, ",".join(installation.bot_scopes), + installation.bot_refresh_token, + installation.bot_token_expires_at, installation.is_enterprise_install, ], ) @@ -184,9 +196,13 @@ def save(self, installation: Installation): bot_id, bot_user_id, bot_scopes, + bot_refresh_token, -- since v3.8 + bot_token_expires_at, -- since v3.8 user_id, user_token, user_scopes, + user_refresh_token, -- since v3.8 + user_token_expires_at, -- since v3.8 incoming_webhook_url, incoming_webhook_channel, incoming_webhook_channel_id, @@ -215,6 +231,10 @@ def save(self, installation: Installation): ?, ?, ?, + ?, + ?, + ?, + ?, ? ); """, @@ -230,11 +250,15 @@ def save(self, installation: Installation): installation.bot_id, installation.bot_user_id, ",".join(installation.bot_scopes), + installation.bot_refresh_token, + installation.bot_token_expires_at, installation.user_id, installation.user_token, ",".join(installation.user_scopes) if installation.user_scopes else None, + installation.user_refresh_token, + installation.user_token_expires_at, installation.incoming_webhook_url, installation.incoming_webhook_channel, installation.incoming_webhook_channel_id, @@ -285,6 +309,8 @@ def find_bot( bot_id, bot_user_id, bot_scopes, + bot_refresh_token, -- since v3.8 + bot_token_expires_at, -- since v3.8 is_enterprise_install, installed_at from @@ -316,15 +342,20 @@ def find_bot( bot_id=row[6], bot_user_id=row[7], bot_scopes=row[8], - is_enterprise_install=row[9], - installed_at=row[10], + bot_refresh_token=row[9], + bot_token_expires_at=row[10], + is_enterprise_install=row[11], + installed_at=row[12], ) return bot return None except Exception as e: # skipcq: PYL-W0703 message = f"Failed to find bot installation data for enterprise: {enterprise_id}, team: {team_id}: {e}" - self.logger.warning(message) + if self.logger.level <= logging.DEBUG: + self.logger.exception(message) + else: + self.logger.warning(message) return None async def async_find_installation( @@ -367,9 +398,13 @@ def find_installation( bot_id, bot_user_id, bot_scopes, + bot_refresh_token, -- since v3.8 + bot_token_expires_at, -- since v3.8 user_id, user_token, user_scopes, + user_refresh_token, -- since v3.8 + user_token_expires_at, -- since v3.8 incoming_webhook_url, incoming_webhook_channel, incoming_webhook_channel_id, @@ -438,23 +473,30 @@ def find_installation( bot_id=row[7], bot_user_id=row[8], bot_scopes=row[9], - user_id=row[10], - user_token=row[11], - user_scopes=row[12], - incoming_webhook_url=row[13], - incoming_webhook_channel=row[14], - incoming_webhook_channel_id=row[15], - incoming_webhook_configuration_url=row[16], - is_enterprise_install=row[17], - token_type=row[18], - installed_at=row[19], + bot_refresh_token=row[10], + bot_token_expires_at=row[11], + user_id=row[12], + user_token=row[13], + user_scopes=row[14], + user_refresh_token=row[15], + user_token_expires_at=row[16], + incoming_webhook_url=row[17], + incoming_webhook_channel=row[18], + incoming_webhook_channel_id=row[19], + incoming_webhook_configuration_url=row[20], + is_enterprise_install=row[21], + token_type=row[22], + installed_at=row[23], ) return installation return None except Exception as e: # skipcq: PYL-W0703 message = f"Failed to find an installation data for enterprise: {enterprise_id}, team: {team_id}: {e}" - self.logger.warning(message) + if self.logger.level <= logging.DEBUG: + self.logger.exception(message) + else: + self.logger.warning(message) return None def delete_bot( @@ -479,7 +521,10 @@ def delete_bot( conn.commit() except Exception as e: # skipcq: PYL-W0703 message = f"Failed to delete bot installation data for enterprise: {enterprise_id}, team: {team_id}: {e}" - self.logger.warning(message) + if self.logger.level <= logging.DEBUG: + self.logger.exception(message) + else: + self.logger.warning(message) def delete_installation( self, @@ -525,4 +570,7 @@ def delete_installation( conn.commit() except Exception as e: # skipcq: PYL-W0703 message = f"Failed to delete installation data for enterprise: {enterprise_id}, team: {team_id}: {e}" - self.logger.warning(message) + if self.logger.level <= logging.DEBUG: + self.logger.exception(message) + else: + self.logger.warning(message) diff --git a/slack_sdk/oauth/token_rotation/__init__.py b/slack_sdk/oauth/token_rotation/__init__.py new file mode 100644 index 000000000..c8f4c483d --- /dev/null +++ b/slack_sdk/oauth/token_rotation/__init__.py @@ -0,0 +1 @@ +from .rotator import TokenRotator # noqa diff --git a/slack_sdk/oauth/token_rotation/async_rotator.py b/slack_sdk/oauth/token_rotation/async_rotator.py new file mode 100644 index 000000000..5d5ebe4e6 --- /dev/null +++ b/slack_sdk/oauth/token_rotation/async_rotator.py @@ -0,0 +1,150 @@ +from time import time +from typing import Optional + +from slack_sdk.errors import SlackApiError, SlackTokenRotationError +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.oauth.installation_store import Installation, Bot + + +class AsyncTokenRotator: + client: AsyncWebClient + client_id: str + client_secret: str + + def __init__( + self, + *, + client_id: str, + client_secret: str, + client: Optional[AsyncWebClient] = None, + ): + self.client = client if client is not None else AsyncWebClient(token=None) + self.client_id = client_id + self.client_secret = client_secret + + async def perform_token_rotation( + self, + *, + installation: Installation, + minutes_before_expiration: int = 120, # 2 hours by default + ) -> Optional[Installation]: + """Performs token rotation if the underlying tokens (bot / user) are expired / expiring. + + Args: + installation: the current installation data + minutes_before_expiration: the minutes before the token expiration + + Returns: + None if no rotation is necessary for now. + """ + + # bot + rotated_bot: Optional[Bot] = await self.perform_bot_token_rotation( + bot=installation.to_bot(), + minutes_before_expiration=minutes_before_expiration, + ) + + # user + rotated_installation: Optional[ + Installation + ] = await self.perform_user_token_rotation( + installation=installation, + minutes_before_expiration=minutes_before_expiration, + ) + + if rotated_bot is not None: + if rotated_installation is None: + rotated_installation = Installation(**installation.to_dict()) + rotated_installation.bot_token = rotated_bot.bot_token + rotated_installation.bot_refresh_token = rotated_bot.bot_refresh_token + rotated_installation.bot_token_expires_at = rotated_bot.bot_token_expires_at + + return rotated_installation + + async def perform_bot_token_rotation( + self, + *, + bot: Bot, + minutes_before_expiration: int = 120, # 2 hours by default + ) -> Optional[Bot]: + """Performs bot token rotation if the underlying bot token is expired / expiring. + + Args: + bot: the current bot installation data + minutes_before_expiration: the minutes before the token expiration + + Returns: + None if no rotation is necessary for now. + """ + if bot.bot_token_expires_at is None: + return None + if bot.bot_token_expires_at > time() + minutes_before_expiration * 60: + return None + + try: + refresh_response = await self.client.oauth_v2_access( + client_id=self.client_id, + client_secret=self.client_secret, + grant_type="refresh_token", + refresh_token=bot.bot_refresh_token, + ) + # TODO: error handling + + if refresh_response.get("token_type") != "bot": + return None + + refreshed_bot = Bot(**bot.to_dict()) + refreshed_bot.bot_token = refresh_response.get("access_token") + refreshed_bot.bot_refresh_token = refresh_response.get("refresh_token") + refreshed_bot.bot_token_expires_at = int(time()) + refresh_response.get( + "expires_in" + ) + return refreshed_bot + + except SlackApiError as e: + raise SlackTokenRotationError(e) + + async def perform_user_token_rotation( + self, + *, + installation: Installation, + minutes_before_expiration: int = 120, # 2 hours by default + ) -> Optional[Bot]: + """Performs user token rotation if the underlying user token is expired / expiring. + + Args: + installation: the current installation data + minutes_before_expiration: the minutes before the token expiration + + Returns: + None if no rotation is necessary for now. + """ + if installation.user_token_expires_at is None: + return None + if installation.user_token_expires_at > time() + minutes_before_expiration * 60: + return None + + try: + refresh_response = await self.client.oauth_v2_access( + client_id=self.client_id, + client_secret=self.client_secret, + grant_type="refresh_token", + refresh_token=installation.user_refresh_token, + ) + # TODO: error handling + + if refresh_response.get("token_type") != "user": + return None + + refreshed_installation = Installation(**installation.to_dict()) + refreshed_installation.user_token = refresh_response.get("access_token") + refreshed_installation.user_refresh_token = refresh_response.get( + "refresh_token" + ) + refreshed_installation.user_token_expires_at = int( + time() + ) + refresh_response.get("expires_in") + return refreshed_installation + + except SlackApiError as e: + raise SlackTokenRotationError(e) diff --git a/slack_sdk/oauth/token_rotation/rotator.py b/slack_sdk/oauth/token_rotation/rotator.py new file mode 100644 index 000000000..686a88fbc --- /dev/null +++ b/slack_sdk/oauth/token_rotation/rotator.py @@ -0,0 +1,141 @@ +from time import time +from typing import Optional + +from slack_sdk.errors import SlackApiError, SlackTokenRotationError +from slack_sdk.web import WebClient +from slack_sdk.oauth.installation_store import Installation, Bot + + +class TokenRotator: + client: WebClient + client_id: str + client_secret: str + + def __init__( + self, *, client_id: str, client_secret: str, client: Optional[WebClient] = None + ): + self.client = client if client is not None else WebClient(token=None) + self.client_id = client_id + self.client_secret = client_secret + + def perform_token_rotation( + self, + *, + installation: Installation, + minutes_before_expiration: int = 120, # 2 hours by default + ) -> Optional[Installation]: + """Performs token rotation if the underlying tokens (bot / user) are expired / expiring. + + Args: + installation: the current installation data + minutes_before_expiration: the minutes before the token expiration + + Returns: + None if no rotation is necessary for now. + """ + + # bot + rotated_bot: Optional[Bot] = self.perform_bot_token_rotation( + bot=installation.to_bot(), + minutes_before_expiration=minutes_before_expiration, + ) + + # user + rotated_installation: Optional[Installation] = self.perform_user_token_rotation( + installation=installation, + minutes_before_expiration=minutes_before_expiration, + ) + + if rotated_bot is not None: + if rotated_installation is None: + rotated_installation = Installation(**installation.to_dict()) + rotated_installation.bot_token = rotated_bot.bot_token + rotated_installation.bot_refresh_token = rotated_bot.bot_refresh_token + rotated_installation.bot_token_expires_at = rotated_bot.bot_token_expires_at + + return rotated_installation + + def perform_bot_token_rotation( + self, + *, + bot: Bot, + minutes_before_expiration: int = 120, # 2 hours by default + ) -> Optional[Bot]: + """Performs bot token rotation if the underlying bot token is expired / expiring. + + Args: + bot: the current bot installation data + minutes_before_expiration: the minutes before the token expiration + + Returns: + None if no rotation is necessary for now. + """ + if bot.bot_token_expires_at is None: + return None + if bot.bot_token_expires_at > time() + minutes_before_expiration * 60: + return None + + try: + refresh_response = self.client.oauth_v2_access( + client_id=self.client_id, + client_secret=self.client_secret, + grant_type="refresh_token", + refresh_token=bot.bot_refresh_token, + ) + if refresh_response.get("token_type") != "bot": + return None + + refreshed_bot = Bot(**bot.to_dict()) + refreshed_bot.bot_token = refresh_response.get("access_token") + refreshed_bot.bot_refresh_token = refresh_response.get("refresh_token") + refreshed_bot.bot_token_expires_at = int(time()) + refresh_response.get( + "expires_in" + ) + return refreshed_bot + + except SlackApiError as e: + raise SlackTokenRotationError(e) + + def perform_user_token_rotation( + self, + *, + installation: Installation, + minutes_before_expiration: int = 120, # 2 hours by default + ) -> Optional[Bot]: + """Performs user token rotation if the underlying user token is expired / expiring. + + Args: + installation: the current installation data + minutes_before_expiration: the minutes before the token expiration + + Returns: + None if no rotation is necessary for now. + """ + if installation.user_token_expires_at is None: + return None + if installation.user_token_expires_at > time() + minutes_before_expiration * 60: + return None + + try: + refresh_response = self.client.oauth_v2_access( + client_id=self.client_id, + client_secret=self.client_secret, + grant_type="refresh_token", + refresh_token=installation.user_refresh_token, + ) + + if refresh_response.get("token_type") != "user": + return None + + refreshed_installation = Installation(**installation.to_dict()) + refreshed_installation.user_token = refresh_response.get("access_token") + refreshed_installation.user_refresh_token = refresh_response.get( + "refresh_token" + ) + refreshed_installation.user_token_expires_at = int( + time() + ) + refresh_response.get("expires_in") + return refreshed_installation + + except SlackApiError as e: + raise SlackTokenRotationError(e) diff --git a/slack_sdk/web/async_client.py b/slack_sdk/web/async_client.py index c6d94e9a2..60dc1d1a9 100644 --- a/slack_sdk/web/async_client.py +++ b/slack_sdk/web/async_client.py @@ -2103,8 +2103,14 @@ async def oauth_v2_access( *, client_id: str, client_secret: str, - code: str, + # This field is required when processing the OAuth redirect URL requests + # while it's absent for token rotation + code: Optional[str] = None, redirect_uri: Optional[str] = None, + # This field is required for token rotation + grant_type: Optional[str] = None, + # This field is required for token rotation + refresh_token: Optional[str] = None, **kwargs ) -> AsyncSlackResponse: """Exchanges a temporary OAuth verifier code for an access token. @@ -2115,10 +2121,17 @@ async def oauth_v2_access( code (str): The code param returned via the OAuth callback. e.g. 'ccdaa72ad' redirect_uri (optional str): Must match the originally submitted URI (if one was sent). e.g. 'https://example.com' + grant_type: The grant type. The possible value is only 'refresh_token' as of July 2021. + refresh_token: The refresh token for token rotation. """ if redirect_uri is not None: kwargs.update({"redirect_uri": redirect_uri}) - kwargs.update({"code": code}) + if code is not None: + kwargs.update({"code": code}) + if grant_type is not None: + kwargs.update({"grant_type": grant_type}) + if refresh_token is not None: + kwargs.update({"refresh_token": refresh_token}) return await self.api_call( "oauth.v2.access", data=kwargs, @@ -2152,6 +2165,21 @@ async def oauth_access( auth={"client_id": client_id, "client_secret": client_secret}, ) + async def oauth_v2_exchange( + self, *, token: str, client_id: str, client_secret: str, **kwargs + ) -> AsyncSlackResponse: + """Exchanges a legacy access token for a new expiring access token and refresh token + + Args: + token: The legacy xoxb or xoxp token being migrated to use token rotation. + client_id: Issued when you created your application. + client_secret:Issued when you created your application. + """ + kwargs.update( + {"client_id": client_id, "client_secret": client_secret, "token": token} + ) + return await self.api_call("oauth.v2.exchange", params=kwargs) + async def pins_add(self, *, channel: str, **kwargs) -> AsyncSlackResponse: """Pins an item to a channel. diff --git a/slack_sdk/web/client.py b/slack_sdk/web/client.py index 6900df34f..e4f9bf14f 100644 --- a/slack_sdk/web/client.py +++ b/slack_sdk/web/client.py @@ -1976,8 +1976,14 @@ def oauth_v2_access( *, client_id: str, client_secret: str, - code: str, + # This field is required when processing the OAuth redirect URL requests + # while it's absent for token rotation + code: Optional[str] = None, redirect_uri: Optional[str] = None, + # This field is required for token rotation + grant_type: Optional[str] = None, + # This field is required for token rotation + refresh_token: Optional[str] = None, **kwargs ) -> SlackResponse: """Exchanges a temporary OAuth verifier code for an access token. @@ -1988,10 +1994,17 @@ def oauth_v2_access( code (str): The code param returned via the OAuth callback. e.g. 'ccdaa72ad' redirect_uri (optional str): Must match the originally submitted URI (if one was sent). e.g. 'https://example.com' + grant_type: The grant type. The possible value is only 'refresh_token' as of July 2021. + refresh_token: The refresh token for token rotation. """ if redirect_uri is not None: kwargs.update({"redirect_uri": redirect_uri}) - kwargs.update({"code": code}) + if code is not None: + kwargs.update({"code": code}) + if grant_type is not None: + kwargs.update({"grant_type": grant_type}) + if refresh_token is not None: + kwargs.update({"refresh_token": refresh_token}) return self.api_call( "oauth.v2.access", data=kwargs, @@ -2025,6 +2038,21 @@ def oauth_access( auth={"client_id": client_id, "client_secret": client_secret}, ) + def oauth_v2_exchange( + self, *, token: str, client_id: str, client_secret: str, **kwargs + ) -> SlackResponse: + """Exchanges a legacy access token for a new expiring access token and refresh token + + Args: + token: The legacy xoxb or xoxp token being migrated to use token rotation. + client_id: Issued when you created your application. + client_secret:Issued when you created your application. + """ + kwargs.update( + {"client_id": client_id, "client_secret": client_secret, "token": token} + ) + return self.api_call("oauth.v2.exchange", params=kwargs) + def pins_add(self, *, channel: str, **kwargs) -> SlackResponse: """Pins an item to a channel. diff --git a/slack_sdk/web/legacy_client.py b/slack_sdk/web/legacy_client.py index 91bab73cd..5aa8e256f 100644 --- a/slack_sdk/web/legacy_client.py +++ b/slack_sdk/web/legacy_client.py @@ -2097,8 +2097,14 @@ def oauth_v2_access( *, client_id: str, client_secret: str, - code: str, + # This field is required when processing the OAuth redirect URL requests + # while it's absent for token rotation + code: Optional[str] = None, redirect_uri: Optional[str] = None, + # This field is required for token rotation + grant_type: Optional[str] = None, + # This field is required for token rotation + refresh_token: Optional[str] = None, **kwargs ) -> Union[Future, SlackResponse]: """Exchanges a temporary OAuth verifier code for an access token. @@ -2109,10 +2115,17 @@ def oauth_v2_access( code (str): The code param returned via the OAuth callback. e.g. 'ccdaa72ad' redirect_uri (optional str): Must match the originally submitted URI (if one was sent). e.g. 'https://example.com' + grant_type: The grant type. The possible value is only 'refresh_token' as of July 2021. + refresh_token: The refresh token for token rotation. """ if redirect_uri is not None: kwargs.update({"redirect_uri": redirect_uri}) - kwargs.update({"code": code}) + if code is not None: + kwargs.update({"code": code}) + if grant_type is not None: + kwargs.update({"grant_type": grant_type}) + if refresh_token is not None: + kwargs.update({"refresh_token": refresh_token}) return self.api_call( "oauth.v2.access", data=kwargs, @@ -2146,6 +2159,21 @@ def oauth_access( auth={"client_id": client_id, "client_secret": client_secret}, ) + def oauth_v2_exchange( + self, *, token: str, client_id: str, client_secret: str, **kwargs + ) -> Union[Future, SlackResponse]: + """Exchanges a legacy access token for a new expiring access token and refresh token + + Args: + token: The legacy xoxb or xoxp token being migrated to use token rotation. + client_id: Issued when you created your application. + client_secret:Issued when you created your application. + """ + kwargs.update( + {"client_id": client_id, "client_secret": client_secret, "token": token} + ) + return self.api_call("oauth.v2.exchange", params=kwargs) + def pins_add(self, *, channel: str, **kwargs) -> Union[Future, SlackResponse]: """Pins an item to a channel. diff --git a/tests/slack_sdk/oauth/installation_store/test_internals.py b/tests/slack_sdk/oauth/installation_store/test_internals.py new file mode 100644 index 000000000..5d0f3d435 --- /dev/null +++ b/tests/slack_sdk/oauth/installation_store/test_internals.py @@ -0,0 +1,16 @@ +import unittest + +from slack_sdk.oauth.installation_store import Installation, FileInstallationStore +from slack_sdk.oauth.installation_store.internals import _from_iso_format_to_datetime + + +class TestFile(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_iso_format(self): + dt = _from_iso_format_to_datetime("2021-07-14 08:00:17") + self.assertEqual(dt.timestamp(), 1626249617.0) diff --git a/tests/slack_sdk/oauth/installation_store/test_simple_cache.py b/tests/slack_sdk/oauth/installation_store/test_simple_cache.py index ad0d088c6..5cd2412d9 100644 --- a/tests/slack_sdk/oauth/installation_store/test_simple_cache.py +++ b/tests/slack_sdk/oauth/installation_store/test_simple_cache.py @@ -54,3 +54,100 @@ def test_save_and_find(self): self.assertIsNone(bot) bot = store.find_bot(enterprise_id="E111", team_id="T111") self.assertIsNotNone(bot) + + def test_save_and_find_token_rotation(self): + sqlite3_store = SQLite3InstallationStore( + database="logs/cacheable.db", client_id="111.222" + ) + sqlite3_store.init() + store = CacheableInstallationStore(sqlite3_store) + + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + ) + store.save(installation) + + bot = store.find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(bot) + self.assertEqual(bot.bot_refresh_token, "xoxe-1-initial") + + # Update the existing data + refreshed_installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-refreshed", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-refreshed", + bot_token_expires_in=43200, + ) + store.save(refreshed_installation) + + # find bots + bot = store.find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(bot) + self.assertEqual(bot.bot_refresh_token, "xoxe-1-refreshed") + bot = store.find_bot(enterprise_id="E111", team_id="T222") + self.assertIsNone(bot) + bot = store.find_bot(enterprise_id=None, team_id="T111") + self.assertIsNone(bot) + + # delete bots + store.delete_bot(enterprise_id="E111", team_id="T222") + bot = store.find_bot(enterprise_id="E111", team_id="T222") + self.assertIsNone(bot) + + # find installations + i = store.find_installation(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(i) + i = store.find_installation(enterprise_id="E111", team_id="T222") + self.assertIsNone(i) + i = store.find_installation(enterprise_id=None, team_id="T111") + self.assertIsNone(i) + + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U111" + ) + self.assertIsNotNone(i) + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U222" + ) + self.assertIsNone(i) + i = store.find_installation( + enterprise_id="E111", team_id="T222", user_id="U111" + ) + self.assertIsNone(i) + + # delete installations + store.delete_installation(enterprise_id="E111", team_id="T111", user_id="U111") + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U111" + ) + self.assertIsNone(i) + i = store.find_installation(enterprise_id="E111", team_id="T111") + self.assertIsNone(i) + + # delete all + store.save(installation) + store.delete_all(enterprise_id="E111", team_id="T111") + + i = store.find_installation(enterprise_id="E111", team_id="T111") + self.assertIsNone(i) + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U111" + ) + self.assertIsNone(i) + bot = store.find_bot(enterprise_id="E111", team_id="T222") + self.assertIsNone(bot) diff --git a/tests/slack_sdk/oauth/installation_store/test_sqlalchemy.py b/tests/slack_sdk/oauth/installation_store/test_sqlalchemy.py index 53d8e273f..b42471ac9 100644 --- a/tests/slack_sdk/oauth/installation_store/test_sqlalchemy.py +++ b/tests/slack_sdk/oauth/installation_store/test_sqlalchemy.py @@ -169,3 +169,96 @@ def test_org_installation(self): self.assertIsNone(i) bot = store.find_bot(enterprise_id=None, team_id="T222") self.assertIsNone(bot) + + def test_save_and_find_token_rotation(self): + store = self.store + + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + ) + store.save(installation) + + bot = store.find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(bot) + self.assertEqual(bot.bot_refresh_token, "xoxe-1-initial") + + # Update the existing data + refreshed_installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-refreshed", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-refreshed", + bot_token_expires_in=43200, + ) + store.save(refreshed_installation) + + # find bots + bot = store.find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(bot) + self.assertEqual(bot.bot_refresh_token, "xoxe-1-refreshed") + bot = store.find_bot(enterprise_id="E111", team_id="T222") + self.assertIsNone(bot) + bot = store.find_bot(enterprise_id=None, team_id="T111") + self.assertIsNone(bot) + + # delete bots + store.delete_bot(enterprise_id="E111", team_id="T222") + bot = store.find_bot(enterprise_id="E111", team_id="T222") + self.assertIsNone(bot) + + # find installations + i = store.find_installation(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(i) + i = store.find_installation(enterprise_id="E111", team_id="T222") + self.assertIsNone(i) + i = store.find_installation(enterprise_id=None, team_id="T111") + self.assertIsNone(i) + + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U111" + ) + self.assertIsNotNone(i) + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U222" + ) + self.assertIsNone(i) + i = store.find_installation( + enterprise_id="E111", team_id="T222", user_id="U111" + ) + self.assertIsNone(i) + + # delete installations + store.delete_installation(enterprise_id="E111", team_id="T111", user_id="U111") + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U111" + ) + self.assertIsNone(i) + i = store.find_installation(enterprise_id="E111", team_id="T111") + self.assertIsNone(i) + + # delete all + store.save(installation) + store.delete_all(enterprise_id="E111", team_id="T111") + + i = store.find_installation(enterprise_id="E111", team_id="T111") + self.assertIsNone(i) + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U111" + ) + self.assertIsNone(i) + bot = store.find_bot(enterprise_id="E111", team_id="T222") + self.assertIsNone(bot) diff --git a/tests/slack_sdk/oauth/installation_store/test_sqlite3.py b/tests/slack_sdk/oauth/installation_store/test_sqlite3.py index 6a831a38b..623ee4419 100644 --- a/tests/slack_sdk/oauth/installation_store/test_sqlite3.py +++ b/tests/slack_sdk/oauth/installation_store/test_sqlite3.py @@ -168,3 +168,96 @@ def test_org_installation(self): self.assertIsNone(i) bot = store.find_bot(enterprise_id=None, team_id="T222") self.assertIsNone(bot) + + def test_save_and_find_token_rotation(self): + store = SQLite3InstallationStore(database="logs/test.db", client_id="111.222") + + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + ) + store.save(installation) + + bot = store.find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(bot) + self.assertEqual(bot.bot_refresh_token, "xoxe-1-initial") + + # Update the existing data + refreshed_installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-refreshed", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-refreshed", + bot_token_expires_in=43200, + ) + store.save(refreshed_installation) + + # find bots + bot = store.find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(bot) + self.assertEqual(bot.bot_refresh_token, "xoxe-1-refreshed") + bot = store.find_bot(enterprise_id="E111", team_id="T222") + self.assertIsNone(bot) + bot = store.find_bot(enterprise_id=None, team_id="T111") + self.assertIsNone(bot) + + # delete bots + store.delete_bot(enterprise_id="E111", team_id="T222") + bot = store.find_bot(enterprise_id="E111", team_id="T222") + self.assertIsNone(bot) + + # find installations + i = store.find_installation(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(i) + i = store.find_installation(enterprise_id="E111", team_id="T222") + self.assertIsNone(i) + i = store.find_installation(enterprise_id=None, team_id="T111") + self.assertIsNone(i) + + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U111" + ) + self.assertIsNotNone(i) + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U222" + ) + self.assertIsNone(i) + i = store.find_installation( + enterprise_id="E111", team_id="T222", user_id="U111" + ) + self.assertIsNone(i) + + # delete installations + store.delete_installation(enterprise_id="E111", team_id="T111", user_id="U111") + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U111" + ) + self.assertIsNone(i) + i = store.find_installation(enterprise_id="E111", team_id="T111") + self.assertIsNone(i) + + # delete all + store.save(installation) + store.delete_all(enterprise_id="E111", team_id="T111") + + i = store.find_installation(enterprise_id="E111", team_id="T111") + self.assertIsNone(i) + i = store.find_installation( + enterprise_id="E111", team_id="T111", user_id="U111" + ) + self.assertIsNone(i) + bot = store.find_bot(enterprise_id="E111", team_id="T222") + self.assertIsNone(bot) diff --git a/tests/slack_sdk/oauth/token_rotation/__init__.py b/tests/slack_sdk/oauth/token_rotation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/slack_sdk/oauth/token_rotation/test_token_rotator.py b/tests/slack_sdk/oauth/token_rotation/test_token_rotator.py new file mode 100644 index 000000000..969f43b4a --- /dev/null +++ b/tests/slack_sdk/oauth/token_rotation/test_token_rotator.py @@ -0,0 +1,91 @@ +import unittest + +from slack_sdk.errors import SlackTokenRotationError +from slack_sdk.oauth.installation_store import Installation +from slack_sdk.oauth.token_rotation import TokenRotator +from slack_sdk.web import WebClient +from tests.slack_sdk.web.mock_web_api_server import ( + setup_mock_web_api_server, + cleanup_mock_web_api_server, +) + + +class TestTokenRotator(unittest.TestCase): + def setUp(self): + setup_mock_web_api_server(self) + self.token_rotator = TokenRotator( + client=WebClient(base_url="http://localhost:8888", token=None), + client_id="111.222", + client_secret="token_rotation_secret", + ) + + def tearDown(self): + cleanup_mock_web_api_server(self) + + def test_refresh(self): + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + ) + refreshed = self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=60 * 24 * 365 + ) + self.assertIsNotNone(refreshed) + + should_not_be_refreshed = self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=1 + ) + self.assertIsNone(should_not_be_refreshed) + + def test_token_rotation_disabled(self): + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + ) + should_not_be_refreshed = self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=60 * 24 * 365 + ) + self.assertIsNone(should_not_be_refreshed) + + should_not_be_refreshed = self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=1 + ) + self.assertIsNone(should_not_be_refreshed) + + def test_refresh_error(self): + token_rotator = TokenRotator( + client=WebClient(base_url="http://localhost:8888", token=None), + client_id="111.222", + client_secret="invalid_value", + ) + + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + ) + with self.assertRaises(SlackTokenRotationError): + token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=60 * 24 * 365 + ) diff --git a/tests/slack_sdk/web/mock_web_api_server.py b/tests/slack_sdk/web/mock_web_api_server.py index 3ff52dcc1..e8e866be7 100644 --- a/tests/slack_sdk/web/mock_web_api_server.py +++ b/tests/slack_sdk/web/mock_web_api_server.py @@ -56,6 +56,28 @@ def set_common_headers(self): "error": "test_data_not_found", } + token_refresh = { + "ok": True, + "app_id": "A111", + "authed_user": { + "id": "W111", + "scope": "search:read", + "access_token": "xoxe.xoxp-1-xxx", + "token_type": "user", + "refresh_token": "xoxe-1-xxx", + "expires_in": 43200, + }, + "scope": "app_mentions:read,chat:write,commands", + "token_type": "bot", + "access_token": "xoxe.xoxb-1-yyy", + "bot_user_id": "UB111", + "refresh_token": "xoxe-1-yyy", + "expires_in": 43201, + "team": {"id": "T111", "name": "Testing Workspace"}, + "enterprise": {"id": "E111", "name": "Sandbox Org"}, + "is_enterprise_install": False, + } + def _handle(self): try: if self.path == "/received_requests.json": @@ -70,6 +92,12 @@ def _handle(self): if self.headers["authorization"] == "Basic MTExLjIyMjpzZWNyZXQ=": self.wfile.write("""{"ok":true}""".encode("utf-8")) return + elif ( + self.headers["authorization"] + == "Basic MTExLjIyMjp0b2tlbl9yb3RhdGlvbl9zZWNyZXQ=" + ): + self.wfile.write(json.dumps(self.token_refresh).encode("utf-8")) + return else: self.wfile.write( """{"ok":false, "error":"invalid"}""".encode("utf-8") diff --git a/tests/slack_sdk_async/oauth/installation_store/test_simple_cache.py b/tests/slack_sdk_async/oauth/installation_store/test_simple_cache.py index 12cc84602..1dab897c0 100644 --- a/tests/slack_sdk_async/oauth/installation_store/test_simple_cache.py +++ b/tests/slack_sdk_async/oauth/installation_store/test_simple_cache.py @@ -39,3 +39,54 @@ async def test_save_and_find(self): self.assertIsNone(bot) bot = await store.async_find_bot(enterprise_id="E111", team_id="T111") self.assertIsNotNone(bot) + + @async_test + async def test_save_and_find_token_rotation(self): + sqlite3_store = SQLite3InstallationStore( + database="logs/cacheable.db", client_id="111.222" + ) + sqlite3_store.init() + store = AsyncCacheableInstallationStore(sqlite3_store) + + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxb-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + ) + await store.async_save(installation) + + bot = await store.async_find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(bot) + self.assertEqual(bot.bot_refresh_token, "xoxe-1-initial") + + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxb-refreshed", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-refreshed", + bot_token_expires_in=43200, + ) + await store.async_save(installation) + + bot = await store.async_find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(bot) + self.assertEqual(bot.bot_refresh_token, "xoxe-1-refreshed") + + os.remove("logs/cacheable.db") + + bot = await sqlite3_store.async_find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNone(bot) + bot = await store.async_find_bot(enterprise_id="E111", team_id="T111") + self.assertIsNotNone(bot) diff --git a/tests/slack_sdk_async/oauth/token_rotation/__init__.py b/tests/slack_sdk_async/oauth/token_rotation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/slack_sdk_async/oauth/token_rotation/test_token_rotator.py b/tests/slack_sdk_async/oauth/token_rotation/test_token_rotator.py new file mode 100644 index 000000000..c4d68f7d9 --- /dev/null +++ b/tests/slack_sdk_async/oauth/token_rotation/test_token_rotator.py @@ -0,0 +1,95 @@ +import unittest + +from slack_sdk.errors import SlackTokenRotationError +from slack_sdk.oauth.installation_store import Installation +from slack_sdk.oauth.token_rotation.async_rotator import AsyncTokenRotator +from slack_sdk.web.async_client import AsyncWebClient +from tests.helpers import async_test +from tests.slack_sdk.web.mock_web_api_server import ( + setup_mock_web_api_server, + cleanup_mock_web_api_server, +) + + +class TestTokenRotator(unittest.TestCase): + def setUp(self): + setup_mock_web_api_server(self) + self.token_rotator = AsyncTokenRotator( + client=AsyncWebClient(base_url="http://localhost:8888", token=None), + client_id="111.222", + client_secret="token_rotation_secret", + ) + + def tearDown(self): + cleanup_mock_web_api_server(self) + + @async_test + async def test_refresh(self): + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + ) + refreshed = await self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=60 * 24 * 365 + ) + self.assertIsNotNone(refreshed) + + should_not_be_refreshed = await self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=1 + ) + self.assertIsNone(should_not_be_refreshed) + + @async_test + async def test_token_rotation_disabled(self): + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + ) + should_not_be_refreshed = await self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=60 * 24 * 365 + ) + self.assertIsNone(should_not_be_refreshed) + + should_not_be_refreshed = await self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=1 + ) + self.assertIsNone(should_not_be_refreshed) + + @async_test + async def test_refresh_error(self): + token_rotator = AsyncTokenRotator( + client=AsyncWebClient(base_url="http://localhost:8888", token=None), + client_id="111.222", + client_secret="invalid_value", + ) + + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + ) + with self.assertRaises(SlackTokenRotationError): + await token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=60 * 24 * 365 + ) diff --git a/tests/slack_sdk_async/web/test_web_client_coverage.py b/tests/slack_sdk_async/web/test_web_client_coverage.py index 6801af339..4f699fd69 100644 --- a/tests/slack_sdk_async/web/test_web_client_coverage.py +++ b/tests/slack_sdk_async/web/test_web_client_coverage.py @@ -604,6 +604,12 @@ async def run_method(self, method_name, method, async_method): await async_method( client_id="123.123", client_secret="secret", code="123456" ) + elif method_name == "oauth_v2_exchange": + method = getattr(self.no_token_client, method_name, None) + method(client_id="123.123", client_secret="secret", token="xoxb-") + await async_method( + client_id="123.123", client_secret="secret", token="xoxb-" + ) elif method_name == "pins_add": self.api_methods_to_call.remove(method(channel="C123")["method"]) await async_method(channel="C123")