Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow role-based authentication for Athena #447

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,14 @@ data_sources:
# optional settings
output_location: s3://some-bucket/
workgroup: primary
access_key_id: ...
secret_access_key: ...
region: ...
credentials:
access_key_id: ...
secret_access_key: ...
# optional credential-settings, for role-based authentication:
role_arn: ...
region: ...

```

Here’s an example IAM policy:
Expand Down
87 changes: 75 additions & 12 deletions lib/blazer/adapters/athena_adapter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run_statement(statement, comment, bind_params = [])
request_token = Digest::MD5.hexdigest([statement, bind_params.to_json, data_source.id, settings["workgroup"]].compact.join("/"))
statement_name = "blazer_#{request_token}"
begin
client.create_prepared_statement({
create_prepared_statement({
statement_name: statement_name,
work_group: settings["workgroup"],
query_statement: statement
Expand Down Expand Up @@ -45,15 +45,15 @@ def run_statement(statement, comment, bind_params = [])
query_options[:work_group] = settings["workgroup"]
end

resp = client.start_query_execution(**query_options)
resp = start_query_execution(**query_options)
query_execution_id = resp.query_execution_id

timeout = data_source.timeout || 300
stop_at = Time.now + timeout
resp = nil

begin
resp = client.get_query_results(
resp = get_query_results(
query_execution_id: query_execution_id
)
rescue Aws::Athena::Errors::InvalidRequestException => e
Expand Down Expand Up @@ -119,11 +119,11 @@ def run_statement(statement, comment, bind_params = [])
end

def tables
glue.get_tables(database_name: database).table_list.map(&:name).sort
get_tables(database_name: database).table_list.map(&:name).sort
end

def schema
glue.get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
end

def preview_statement
Expand Down Expand Up @@ -154,11 +154,51 @@ def engine_version
end

def fetch_error(query_execution_id)
client.get_query_execution(
get_query_execution(
query_execution_id: query_execution_id
).query_execution.status.state_change_reason
end

def autorefresh_credentials
yield
rescue Aws::Athena::Errors::ExpiredTokenException
# Clear our cached Athena & Glue clients to force fetching new credentials, and immediately retry
@client = nil
@glue = nil
@client_credentials = nil
yield
end

def get_tables(**options)
autorefresh_credentials do
glue.get_tables(**options)
end
end

def create_prepared_statement(**options)
autorefresh_credentials do
client.create_prepared_statement(**options)
end
end

def start_query_execution(**options)
autorefresh_credentials do
client.start_query_execution(**options)
end
end

def get_query_results(**options)
autorefresh_credentials do
client.get_query_results(**options)
end
end

def get_query_execution(**options)
autorefresh_credentials do
client.get_query_execution(**options)
end
end

def client
@client ||= Aws::Athena::Client.new(**client_options)
end
Expand All @@ -168,13 +208,36 @@ def glue
end

def client_options
@client_options ||= begin
options = {}
if settings["access_key_id"] || settings["secret_access_key"]
options[:credentials] = Aws::Credentials.new(settings["access_key_id"], settings["secret_access_key"])
options = {}
if credentials = client_credentials
options[:credentials] = credentials
end
options[:region] = settings["region"] if settings["region"]
options
end

def client_credentials
@client_credentials ||= begin
# Loading the access key & secret from the top-level settings is supported for backwards compatibility,
# but prefer loading them from the 'credentials' sub-hash.
creds = (settings["credentials"] || {}).with_defaults(settings.slice("access_key_id", "secret_access_key", "region"))
access_key_id = creds["access_key_id"]
secret_access_key = creds["secret_access_key"]
region = creds["region"]
role_arn = creds["role_arn"]
role_session_name = creds["role_session_name"] || "blazer"
if role_arn
Aws::STS::Client.new(
access_key_id: access_key_id,
secret_access_key: secret_access_key,
region: region,
).assume_role(
role_arn: role_arn,
role_session_name: role_session_name,
)
elsif access_key_id && secret_access_key
Aws::Credentials.new(access_key_id, secret_access_key)
end
options[:region] = settings["region"] if settings["region"]
options
end
end
end
Expand Down