Skip to content

Commit

Permalink
Refresh credentials on ExpiredTokenException
Browse files Browse the repository at this point in the history
  • Loading branch information
jdelStrother committed Dec 18, 2023
1 parent 0c97d0e commit b09a3c9
Showing 1 changed file with 69 additions and 32 deletions.
101 changes: 69 additions & 32 deletions lib/blazer/adapters/athena_adapter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run_statement(statement, comment, bind_params = [])
request_token = Digest::MD5.hexdigest([statement, bind_params.to_json, data_source.id, settings["workgroup"]].compact.join("/"))
statement_name = "blazer_#{request_token}"
begin
client.create_prepared_statement({
create_prepared_statement({
statement_name: statement_name,
work_group: settings["workgroup"],
query_statement: statement
Expand Down Expand Up @@ -45,15 +45,15 @@ def run_statement(statement, comment, bind_params = [])
query_options[:work_group] = settings["workgroup"]
end

resp = client.start_query_execution(**query_options)
resp = start_query_execution(**query_options)
query_execution_id = resp.query_execution_id

timeout = data_source.timeout || 300
stop_at = Time.now + timeout
resp = nil

begin
resp = client.get_query_results(
resp = get_query_results(
query_execution_id: query_execution_id
)
rescue Aws::Athena::Errors::InvalidRequestException => e
Expand Down Expand Up @@ -119,11 +119,11 @@ def run_statement(statement, comment, bind_params = [])
end

def tables
glue.get_tables(database_name: database).table_list.map(&:name).sort
get_tables(database_name: database).table_list.map(&:name).sort
end

def schema
glue.get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
end

def preview_statement
Expand Down Expand Up @@ -154,11 +154,50 @@ def engine_version
end

def fetch_error(query_execution_id)
client.get_query_execution(
get_query_execution(
query_execution_id: query_execution_id
).query_execution.status.state_change_reason
end

def autorefresh_credentials
yield
rescue Aws::Athena::Errors::ExpiredTokenException
# Clear our cached Athena & Glue clients to force fetching new credentials, and immediately retry
@client = nil
@glue = nil
yield
end

def get_tables(**options)
autorefresh_credentials do
glue.get_tables(**options)
end
end

def create_prepared_statement(**options)
autorefresh_credentials do
client.create_prepared_statement(**options)
end
end

def start_query_execution(**options)
autorefresh_credentials do
client.start_query_execution(**options)
end
end

def get_query_results(**options)
autorefresh_credentials do
client.get_query_results(**options)
end
end

def get_query_execution(**options)
autorefresh_credentials do
client.get_query_execution(**options)
end
end

def client
@client ||= Aws::Athena::Client.new(**client_options)
end
Expand All @@ -168,36 +207,34 @@ def glue
end

def client_options
@client_options ||= begin
options = {}
options[:credentials] = client_credentials if client_credentials
options[:region] = settings["region"] if settings["region"]
options
options = {}
if credentials = client_credentials
options[:credentials] = credentials
end
options[:region] = settings["region"] if settings["region"]
options
end

def client_credentials
@client_credentials ||= begin
# Loading the access key & secret from the top-level settings is supported for backwards compatibility,
# but prefer loading them from the 'credentials' sub-hash.
creds = (settings["credentials"] || {}).with_defaults(settings.slice("access_key_id", "secret_access_key", "region"))
access_key_id = creds["access_key_id"]
secret_access_key = creds["secret_access_key"]
region = creds["region"]
role_arn = creds["role_arn"]
role_session_name = creds["role_session_name"] || "blazer"
if role_arn
Aws::STS::Client.new(
access_key_id: access_key_id,
secret_access_key: secret_access_key,
region: region,
).assume_role(
role_arn: role_arn,
role_session_name: role_session_name,
)
elsif access_key_id && secret_access_key
Aws::Credentials.new(access_key_id, secret_access_key)
end
# Loading the access key & secret from the top-level settings is supported for backwards compatibility,
# but prefer loading them from the 'credentials' sub-hash.
creds = (settings["credentials"] || {}).with_defaults(settings.slice("access_key_id", "secret_access_key", "region"))
access_key_id = creds["access_key_id"]
secret_access_key = creds["secret_access_key"]
region = creds["region"]
role_arn = creds["role_arn"]
role_session_name = creds["role_session_name"] || "blazer"
if role_arn
Aws::STS::Client.new(
access_key_id: access_key_id,
secret_access_key: secret_access_key,
region: region,
).assume_role(
role_arn: role_arn,
role_session_name: role_session_name,
)
elsif access_key_id && secret_access_key
Aws::Credentials.new(access_key_id, secret_access_key)
end
end
end
Expand Down

0 comments on commit b09a3c9

Please sign in to comment.