Skip to content

Commit

Permalink
add a compile-graphs mode to use less memory
Browse files Browse the repository at this point in the history
  • Loading branch information
jbmchuck committed Dec 5, 2023
1 parent 6c13560 commit 2fead1f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 30 deletions.
60 changes: 44 additions & 16 deletions altimeter/core/artifact_io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gzip
import os
from pathlib import Path
import tempfile
from typing import Optional, Type

import boto3
Expand Down Expand Up @@ -35,7 +36,11 @@ def write_json(self, name: str, data: BaseModel) -> str:

@abc.abstractmethod
def write_graph_set(
self, name: str, graph_set: ValidatedGraphSet, compression: Optional[str] = None
self,
name: str,
graph_set: ValidatedGraphSet,
compression: Optional[str] = None,
high_mem: bool = True,
) -> str:
"""Write a graph artifact
Expand Down Expand Up @@ -94,7 +99,11 @@ def write_json(self, name: str, data: BaseModel) -> str:
return artifact_path

def write_graph_set(
self, name: str, graph_set: ValidatedGraphSet, compression: Optional[str] = None
self,
name: str,
graph_set: ValidatedGraphSet,
compression: Optional[str] = None,
high_mem: bool = True,
) -> str:
"""Write a graph artifact
Expand Down Expand Up @@ -165,7 +174,11 @@ def write_json(self, name: str, data: BaseModel) -> str:
return f"s3://{self.bucket}/{output_key}"

def write_graph_set(
self, name: str, graph_set: ValidatedGraphSet, compression: Optional[str] = None
self,
name: str,
graph_set: ValidatedGraphSet,
compression: Optional[str] = None,
high_mem: bool = True,
) -> str:
"""Write a graph artifact
Expand All @@ -187,19 +200,34 @@ def write_graph_set(
graph = graph_set.to_rdf()
with logger.bind(bucket=self.bucket, key_prefix=self.key_prefix, key=key):
logger.info(event=LogEvent.WriteToS3Start)
with io.BytesIO() as rdf_bytes_buf:
if compression is None:
graph.serialize(rdf_bytes_buf, format="xml")
elif compression == GZIP:
with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz:
graph.serialize(gz, format="xml")
else:
raise ValueError(f"Unknown compression arg {compression}")
rdf_bytes_buf.flush()
rdf_bytes_buf.seek(0)
session = boto3.Session()
s3_client = session.client("s3")
s3_client.upload_fileobj(rdf_bytes_buf, self.bucket, output_key)
if high_mem:
with io.BytesIO() as rdf_bytes_buf:
if compression is None:
graph.serialize(rdf_bytes_buf, format="xml")
elif compression == GZIP:
with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz:
graph.serialize(gz, format="xml")
else:
raise ValueError(f"Unknown compression arg {compression}")
rdf_bytes_buf.flush()
rdf_bytes_buf.seek(0)
session = boto3.Session()
s3_client = session.client("s3")
s3_client.upload_fileobj(rdf_bytes_buf, self.bucket, output_key)
else:
with tempfile.TemporaryDirectory() as graph_dir:
graph_path = Path(graph_dir, "graph.rdf")
with graph_path.open("wb") as graph_fp:
if compression is None:
graph.serialize(graph_fp, format="xml")
elif compression == GZIP:
with gzip.GzipFile(fileobj=graph_fp, mode="wb") as gz:
graph.serialize(gz, format="xml")
else:
raise ValueError(f"Unknown compression arg {compression}")
session = boto3.Session()
s3_client = session.client("s3")
s3_client.upload_file(str(graph_path), self.bucket, output_key)
s3_client.put_object_tagging(
Bucket=self.bucket,
Key=output_key,
Expand Down
33 changes: 19 additions & 14 deletions bin/sfn_compile_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class CompileGraphsInput(BaseImmutableModel):
config: AWSConfig
scan_id: str
account_scan_manifests: Tuple[AccountScanManifest, ...]
high_mem: bool = True


class CompileGraphsOutput(BaseImmutableModel):
Expand Down Expand Up @@ -60,21 +61,25 @@ def lambda_handler(event: Dict[str, Any], _: Any) -> Dict[str, Any]:
if not graph_sets:
raise Exception("BUG: No graph_sets generated.")
validated_graph_set = ValidatedGraphSet.from_graph_set(GraphSet.from_graph_sets(graph_sets))
master_artifact_path = artifact_writer.write_json(name="master", data=validated_graph_set)
start_time = validated_graph_set.start_time
end_time = validated_graph_set.end_time
scan_manifest = ScanManifest(
scanned_accounts=scanned_accounts,
master_artifact=master_artifact_path,
artifacts=artifacts,
errors=errors,
unscanned_accounts=list(unscanned_accounts),
start_time=start_time,
end_time=end_time,
)
artifact_writer.write_json("manifest", data=scan_manifest)
if compile_graphs_input.high_mem:
master_artifact_path = artifact_writer.write_json(name="master", data=validated_graph_set)
start_time = validated_graph_set.start_time
end_time = validated_graph_set.end_time
scan_manifest = ScanManifest(
scanned_accounts=scanned_accounts,
master_artifact=master_artifact_path,
artifacts=artifacts,
errors=errors,
unscanned_accounts=list(unscanned_accounts),
start_time=start_time,
end_time=end_time,
)
artifact_writer.write_json("manifest", data=scan_manifest)
rdf_path = artifact_writer.write_graph_set(
name="master", graph_set=validated_graph_set, compression=GZIP
name="master",
graph_set=validated_graph_set,
compression=GZIP,
high_mem=compile_graphs_input.high_mem,
)

return CompileGraphsOutput(
Expand Down

0 comments on commit 2fead1f

Please sign in to comment.