Skip to content

Commit

Permalink
Refresh credentials on ExpiredTokenException
Browse files Browse the repository at this point in the history
  • Loading branch information
jdelStrother committed Dec 18, 2023
1 parent 0c97d0e commit 6dba9d8
Showing 1 changed file with 51 additions and 11 deletions.
62 changes: 51 additions & 11 deletions lib/blazer/adapters/athena_adapter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run_statement(statement, comment, bind_params = [])
request_token = Digest::MD5.hexdigest([statement, bind_params.to_json, data_source.id, settings["workgroup"]].compact.join("/"))
statement_name = "blazer_#{request_token}"
begin
client.create_prepared_statement({
create_prepared_statement({
statement_name: statement_name,
work_group: settings["workgroup"],
query_statement: statement
Expand Down Expand Up @@ -45,15 +45,15 @@ def run_statement(statement, comment, bind_params = [])
query_options[:work_group] = settings["workgroup"]
end

resp = client.start_query_execution(**query_options)
resp = start_query_execution(**query_options)
query_execution_id = resp.query_execution_id

timeout = data_source.timeout || 300
stop_at = Time.now + timeout
resp = nil

begin
resp = client.get_query_results(
resp = get_query_results(
query_execution_id: query_execution_id
)
rescue Aws::Athena::Errors::InvalidRequestException => e
Expand Down Expand Up @@ -119,11 +119,11 @@ def run_statement(statement, comment, bind_params = [])
end

def tables
glue.get_tables(database_name: database).table_list.map(&:name).sort
get_tables(database_name: database).table_list.map(&:name).sort
end

def schema
glue.get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
end

def preview_statement
Expand Down Expand Up @@ -154,11 +154,51 @@ def engine_version
end

def fetch_error(query_execution_id)
client.get_query_execution(
get_query_execution(
query_execution_id: query_execution_id
).query_execution.status.state_change_reason
end

def autorefresh_credentials
yield
rescue Aws::Athena::Errors::ExpiredTokenException
# Clear our cached Athena & Glue clients to force fetching new credentials, and immediately retry
@client = nil
@glue = nil
@client_credentials = nil
yield
end

def get_tables(**options)
autorefresh_credentials do
glue.get_tables(**options)
end
end

def create_prepared_statement(**options)
autorefresh_credentials do
client.create_prepared_statement(**options)
end
end

def start_query_execution(**options)
autorefresh_credentials do
client.start_query_execution(**options)
end
end

def get_query_results(**options)
autorefresh_credentials do
client.get_query_results(**options)
end
end

def get_query_execution(**options)
autorefresh_credentials do
client.get_query_execution(**options)
end
end

def client
@client ||= Aws::Athena::Client.new(**client_options)
end
Expand All @@ -168,12 +208,12 @@ def glue
end

def client_options
@client_options ||= begin
options = {}
options[:credentials] = client_credentials if client_credentials
options[:region] = settings["region"] if settings["region"]
options
options = {}
if credentials = client_credentials
options[:credentials] = credentials
end
options[:region] = settings["region"] if settings["region"]
options
end

def client_credentials
Expand Down

0 comments on commit 6dba9d8

Please sign in to comment.