import time
from datetime import datetime
from typing import Any, Dict, List, Optional
from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.moto_api._internal import mock_random
from moto.s3.models import s3_backends
from moto.s3.utils import bucket_and_name_from_url
from moto.utilities.paginator import paginate
from moto.utilities.utils import get_partition
class TaggableResourceMixin:
# This mixing was copied from Redshift when initially implementing
# Athena. TBD if it's worth the overhead.
def __init__(
self,
account_id: str,
region_name: str,
resource_name: str,
tags: List[Dict[str, str]],
):
self.region = region_name
self.resource_name = resource_name
self.tags = tags or []
self.arn = f"arn:{get_partition(region_name)}:athena:{region_name}:{account_id}:{resource_name}"
def create_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]:
new_keys = [tag_set["Key"] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
self.tags.extend(tags)
return self.tags
def delete_tags(self, tag_keys: List[str]) -> List[Dict[str, str]]:
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
return self.tags
class WorkGroup(TaggableResourceMixin, BaseModel):
resource_type = "workgroup"
state = "ENABLED"
def __init__(
self,
athena_backend: "AthenaBackend",
name: str,
configuration: Dict[str, Any],
description: str,
tags: List[Dict[str, str]],
):
self.region_name = athena_backend.region_name
super().__init__(
athena_backend.account_id,
self.region_name,
f"workgroup/{name}",
tags,
)
self.athena_backend = athena_backend
self.name = name
self.description = description
self.configuration = configuration
if "EnableMinimumEncryptionConfiguration" not in self.configuration:
self.configuration["EnableMinimumEncryptionConfiguration"] = False
if "EnforceWorkGroupConfiguration" not in self.configuration:
self.configuration["EnforceWorkGroupConfiguration"] = True
if "EngineVersion" not in self.configuration:
self.configuration["EngineVersion"] = {
"EffectiveEngineVersion": "Athena engine version 3",
"SelectedEngineVersion": "AUTO",
}
if "PublishCloudWatchMetricsEnabled" not in self.configuration:
self.configuration["PublishCloudWatchMetricsEnabled"] = False
if "RequesterPaysEnabled" not in self.configuration:
self.configuration["RequesterPaysEnabled"] = False
class DataCatalog(TaggableResourceMixin, BaseModel):
def __init__(
self,
athena_backend: "AthenaBackend",
name: str,
catalog_type: str,
description: str,
parameters: str,
tags: List[Dict[str, str]],
):
self.region_name = athena_backend.region_name
super().__init__(
athena_backend.account_id,
self.region_name,
f"datacatalog/{name}",
tags,
)
self.athena_backend = athena_backend
self.name = name
self.type = catalog_type
self.description = description
self.parameters = parameters
class Execution(BaseModel):
def __init__(
self,
query: str,
context: str,
config: Dict[str, Any],
workgroup: Optional[WorkGroup],
execution_parameters: Optional[List[str]],
):
self.id = str(mock_random.uuid4())
self.query = query
self.context = context
self.config = config
self.workgroup = workgroup
self.execution_parameters = execution_parameters
self.start_time = time.time()
self.status = "SUCCEEDED"
if self.config is not None and "OutputLocation" in self.config:
if not self.config["OutputLocation"].endswith("/"):
self.config["OutputLocation"] += "/"
self.config["OutputLocation"] += f"{self.id}.csv"
class QueryResults(BaseModel):
def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]):
self.rows = rows
self.column_info = column_info
def to_dict(self) -> Dict[str, Any]:
return {
"ResultSet": {
"Rows": self.rows,
"ResultSetMetadata": {"ColumnInfo": self.column_info},
},
}
class NamedQuery(BaseModel):
def __init__(
self,
name: str,
description: str,
database: str,
query_string: str,
workgroup: WorkGroup,
):
self.id = str(mock_random.uuid4())
self.name = name
self.description = description
self.database = database
self.query_string = query_string
self.workgroup = workgroup
class PreparedStatement(BaseModel):
def __init__(
self,
statement_name: str,
workgroup: WorkGroup,
query_statement: str,
description: str,
):
self.statement_name = statement_name
self.workgroup = workgroup
self.query_statement = query_statement
self.description = description
self.last_modified_time = datetime.now()
class AthenaBackend(BaseBackend):
PAGINATION_MODEL = {
"list_named_queries": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 50,
"unique_attribute": "id",
}
}
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.work_groups: Dict[str, WorkGroup] = {}
self.executions: Dict[str, Execution] = {}
self.named_queries: Dict[str, NamedQuery] = {}
self.data_catalogs: Dict[str, DataCatalog] = {}
self.query_results: Dict[str, QueryResults] = {}
self.query_results_queue: List[QueryResults] = []
self.prepared_statements: Dict[str, PreparedStatement] = {}
# Initialise with the primary workgroup
self.create_work_group(
name="primary",
description="",
configuration={
"ResultConfiguration": {},
"EnforceWorkGroupConfiguration": False,
},
tags=[],
)
def create_work_group(
self,
name: str,
configuration: Dict[str, Any],
description: str,
tags: List[Dict[str, str]],
) -> Optional[WorkGroup]:
if name in self.work_groups:
return None
work_group = WorkGroup(self, name, configuration, description, tags)
self.work_groups[name] = work_group
return work_group
def list_work_groups(self) -> List[Dict[str, Any]]:
return [
{
"Name": wg.name,
"State": wg.state,
"Description": wg.description,
"CreationTime": time.time(),
}
for wg in self.work_groups.values()
]
def get_work_group(self, name: str) -> Optional[Dict[str, Any]]:
if name not in self.work_groups:
return None
wg = self.work_groups[name]
return {
"Name": wg.name,
"State": wg.state,
"Configuration": wg.configuration,
"Description": wg.description,
"CreationTime": time.time(),
}
def delete_work_group(self, name: str) -> None:
self.work_groups.pop(name, None)
def start_query_execution(
self,
query: str,
context: str,
config: Dict[str, Any],
workgroup: str,
execution_parameters: Optional[List[str]],
) -> str:
execution = Execution(
query=query,
context=context,
config=config,
workgroup=self.work_groups.get(workgroup),
execution_parameters=execution_parameters,
)
self.executions[execution.id] = execution
self._store_predefined_query_results(execution.id)
return execution.id
def _store_predefined_query_results(self, exec_id: str) -> None:
if exec_id not in self.query_results and self.query_results_queue:
self.query_results[exec_id] = self.query_results_queue.pop(0)
self._store_query_result_in_s3(exec_id)
def get_query_execution(self, exec_id: str) -> Execution:
return self.executions[exec_id]
def list_query_executions(self, workgroup: Optional[str]) -> Dict[str, Execution]:
if workgroup is not None:
return {
exec_id: execution
for exec_id, execution in self.executions.items()
if execution.workgroup and execution.workgroup.name == workgroup
}
return self.executions
def get_query_results(self, exec_id: str) -> QueryResults:
"""
Queries are not executed by Moto, so this call will always return 0 rows by default.
You can use a dedicated API to override this, by configuring a queue of expected results.
A request to `get_query_results` will take the first result from that queue, and assign it to the provided QueryExecutionId. Subsequent requests using the same QueryExecutionId will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return an empty result if the queue is empty.
Configuring this queue by making an HTTP request to `/moto-api/static/athena/query-results`. An example invocation looks like this:
.. sourcecode:: python
expected_results = {
"account_id": "123456789012", # This is the default - can be omitted
"region": "us-east-1", # This is the default - can be omitted
"results": [
{
"rows": [{"Data": [{"VarCharValue": "1"}]}],
"column_info": [{
"CatalogName": "string",
"SchemaName": "string",
"TableName": "string",
"Name": "string",
"Label": "string",
"Type": "string",
"Precision": 123,
"Scale": 123,
"Nullable": "NOT_NULL",
"CaseSensitive": True,
}],
},
# other results as required
],
}
resp = requests.post(
"http://motoapi.amazonaws.com/moto-api/static/athena/query-results",
json=expected_results,
)
assert resp.status_code == 201
client = boto3.client("athena", region_name="us-east-1")
details = client.get_query_execution(QueryExecutionId="any_id")["QueryExecution"]
.. note:: The exact QueryExecutionId is not relevant here, but will likely be whatever value is returned by start_query_execution
Query results will also be stored in the S3 output location (in CSV format).
"""
self._store_predefined_query_results(exec_id)
results = (
self.query_results[exec_id]
if exec_id in self.query_results
else QueryResults(rows=[], column_info=[])
)
return results
def _store_query_result_in_s3(self, exec_id: str) -> None:
try:
output_location = self.executions[exec_id].config["OutputLocation"]
bucket, key = bucket_and_name_from_url(output_location)
query_result = ""
for row in self.query_results[exec_id].rows:
query_result += ",".join(
[
f'"{r["VarCharValue"]}"' if "VarCharValue" in r else ""
for r in row["Data"]
]
)
query_result += "\n"
s3_backends[self.account_id][self.partition].put_object(
bucket_name=bucket, # type: ignore
key_name=key, # type: ignore
value=query_result.encode("utf-8"),
)
except: # noqa
# Execution may not exist
# OutputLocation may not exist
pass
def stop_query_execution(self, exec_id: str) -> None:
execution = self.executions[exec_id]
execution.status = "CANCELLED"
def create_named_query(
self,
name: str,
description: str,
database: str,
query_string: str,
workgroup: str,
) -> str:
nq = NamedQuery(
name=name,
description=description,
database=database,
query_string=query_string,
workgroup=self.work_groups[workgroup],
)
self.named_queries[nq.id] = nq
return nq.id
def get_named_query(self, query_id: str) -> Optional[NamedQuery]:
return self.named_queries[query_id] if query_id in self.named_queries else None
def list_data_catalogs(self) -> List[Dict[str, str]]:
return [
{"CatalogName": dc.name, "Type": dc.type}
for dc in self.data_catalogs.values()
]
def get_data_catalog(self, name: str) -> Optional[Dict[str, str]]:
if name not in self.data_catalogs:
return None
dc = self.data_catalogs[name]
return {
"Name": dc.name,
"Description": dc.description,
"Type": dc.type,
"Parameters": dc.parameters,
}
def create_data_catalog(
self,
name: str,
catalog_type: str,
description: str,
parameters: str,
tags: List[Dict[str, str]],
) -> Optional[DataCatalog]:
if name in self.data_catalogs:
return None
data_catalog = DataCatalog(
self, name, catalog_type, description, parameters, tags
)
self.data_catalogs[name] = data_catalog
return data_catalog
@paginate(pagination_model=PAGINATION_MODEL)
def list_named_queries(self, work_group: str) -> List[str]:
named_query_ids = [
q.id for q in self.named_queries.values() if q.workgroup.name == work_group
]
return named_query_ids
def create_prepared_statement(
self,
statement_name: str,
workgroup: WorkGroup,
query_statement: str,
description: str,
) -> None:
ps = PreparedStatement(
statement_name=statement_name,
workgroup=workgroup,
query_statement=query_statement,
description=description,
)
self.prepared_statements[ps.statement_name] = ps
return None
def get_prepared_statement(
self, statement_name: str, work_group: WorkGroup
) -> Optional[PreparedStatement]:
if statement_name in self.prepared_statements:
ps = self.prepared_statements[statement_name]
if ps.workgroup == work_group:
return ps
return None
athena_backends = BackendDict(AthenaBackend, "athena")