Skip to content

Commit

Permalink
Add generic mechanism to codegen sources in V2 (pantsbuild#9634)
Browse files Browse the repository at this point in the history
## Goals of design

See https://docs.google.com/document/d/1tJ1SL3URSXUWlrN-GJ1fA1M4jm8zqcaodBWghBWrRWM/edit?ts=5ea310fd for more info. 

tl;dr:

1) Protocols now only have one generic target, like `avro_library`. This means that call sites must declare which language should be generated from that protocol.
    * Must be declarative.
2) You can still get the original protocol sources, e.g. for `./pants filedeps`.
3) Must work with subclassing of fields.
4) Must be extensible.
     * Example: Pants only implements Thrift -> Python. A plugin author should be able to add Thrift -> Java.

## Implementation

Normally, to hydrate sources, we call `await Get[HydratedSources](HydrateSourcesRequest(my_sources_field))`. We always use the exact same rule to do this because all `sources` fields are hydrated identically.

Here, each codegen rule is unique. So, we need to use unions. This means that we also need a uniform product for each codegen rule for the union to work properly. This leads to:

```python
await Get[GeneratedSources](GenerateSourcesRequest, GeneratePythonFromAvroRequest(..))
await Get[GeneratedSources](GenerateSourcesRequest, GenerateJavaFromThriftRequest(..))
```

Each `GenerateSourcesRequest` subclass gets registered as a union rule. This achieves goal #4 of extensibility.

--

To still work with subclassing of fields (goal #3), each `GenerateSourcesRequest` declares the input type and output type, which then allows us to use `isinstance()` to accommodate subclasses:

```python
class GenerateFortranFromAvroRequest(GenerateSourcesRequest):
    input = AvroSources
    output = FortranSources
```

--

To achieve goals #1 and #2 of allowing call sites to declaratively either get the original protocol sources or generated sources, we hook up codegen to the `hydrate_sources` rule and `HydrateSourcesRequest` type:

```python
protocol_sources = await Get[HydratedSources](HydrateSourcesRequest(avro_sources, for_sources_types=[FortranSources], codegen_enabled=True))
```

[ci skip-rust-tests]
[ci skip-jvm-tests]
  • Loading branch information
Eric-Arellano committed Apr 28, 2020
1 parent 46b2c9c commit 9945df1
Show file tree
Hide file tree
Showing 6 changed files with 421 additions and 36 deletions.
Expand Up @@ -37,6 +37,7 @@ async def prepare_python_sources(targets: Targets) -> ImportablePythonSources:
AllSourceFilesRequest(
(tgt.get(Sources) for tgt in targets),
for_sources_types=(PythonSources, ResourcesSources, FilesSources),
enable_codegen=True,
strip_source_roots=True,
)
)
Expand Down
36 changes: 27 additions & 9 deletions src/python/pants/backend/python/rules/run_setup_py.py
Expand Up @@ -68,6 +68,7 @@
TargetsWithOrigins,
TransitiveTargets,
)
from pants.engine.unions import UnionMembership
from pants.option.custom_types import shell_str
from pants.python.python_setup import PythonSetup
from pants.source.source_root import SourceRootConfig
Expand Down Expand Up @@ -279,6 +280,7 @@ async def run_setup_pys(
python_setup: PythonSetup,
distdir: DistDir,
workspace: Workspace,
union_membership: UnionMembership,
) -> SetupPy:
"""Run setup.py commands on all exported targets addressed."""
args = tuple(options.values.args)
Expand Down Expand Up @@ -310,7 +312,7 @@ async def run_setup_pys(
owners = await MultiGet(
Get[ExportedTarget](OwnedDependency(tgt))
for tgt in transitive_targets.closure
if is_ownable_target(tgt)
if is_ownable_target(tgt, union_membership)
)
exported_targets = list(FrozenOrderedSet(owners))

Expand Down Expand Up @@ -474,7 +476,9 @@ async def get_sources(
stripped_srcs_list = await MultiGet(
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(
target.get(Sources), for_sources_types=(PythonSources, ResourcesSources)
target.get(Sources),
for_sources_types=(PythonSources, ResourcesSources),
enable_codegen=True,
)
)
for target in targets
Expand Down Expand Up @@ -523,7 +527,9 @@ async def get_ancestor_init_py(
source_roots = source_root_config.get_source_roots()
sources = await Get[SourceFiles](
AllSourceFilesRequest(
(tgt.get(Sources) for tgt in targets), for_sources_types=(PythonSources,)
(tgt.get(Sources) for tgt in targets),
for_sources_types=(PythonSources,),
enable_codegen=True,
)
)
# Find the ancestors of all dirs containing .py files, including those dirs themselves.
Expand Down Expand Up @@ -564,12 +570,16 @@ def _is_exported(target: Target) -> bool:


@named_rule(desc="Compute distribution's 3rd party requirements")
async def get_requirements(dep_owner: DependencyOwner) -> ExportedTargetRequirements:
async def get_requirements(
dep_owner: DependencyOwner, union_membership: UnionMembership
) -> ExportedTargetRequirements:
transitive_targets = await Get[TransitiveTargets](
Addresses([dep_owner.exported_target.target.address])
)

ownable_tgts = [tgt for tgt in transitive_targets.closure if is_ownable_target(tgt)]
ownable_tgts = [
tgt for tgt in transitive_targets.closure if is_ownable_target(tgt, union_membership)
]
owners = await MultiGet(Get[ExportedTarget](OwnedDependency(tgt)) for tgt in ownable_tgts)
owned_by_us: Set[Target] = set()
owned_by_others: Set[Target] = set()
Expand Down Expand Up @@ -611,15 +621,19 @@ async def get_requirements(dep_owner: DependencyOwner) -> ExportedTargetRequirem


@named_rule(desc="Find all code to be published in the distribution")
async def get_owned_dependencies(dependency_owner: DependencyOwner) -> OwnedDependencies:
async def get_owned_dependencies(
dependency_owner: DependencyOwner, union_membership: UnionMembership
) -> OwnedDependencies:
"""Find the dependencies of dependency_owner that are owned by it.
Includes dependency_owner itself.
"""
transitive_targets = await Get[TransitiveTargets](
Addresses([dependency_owner.exported_target.target.address])
)
ownable_targets = [tgt for tgt in transitive_targets.closure if is_ownable_target(tgt)]
ownable_targets = [
tgt for tgt in transitive_targets.closure if is_ownable_target(tgt, union_membership)
]
owners = await MultiGet(Get[ExportedTarget](OwnedDependency(tgt)) for tgt in ownable_targets)
owned_dependencies = [
tgt
Expand Down Expand Up @@ -692,8 +706,12 @@ async def setup_setuptools(setuptools: Setuptools) -> SetuptoolsSetup:
return SetuptoolsSetup(requirements_pex=requirements_pex,)


def is_ownable_target(tgt: Target) -> bool:
return tgt.has_field(PythonSources) or tgt.has_field(ResourcesSources)
def is_ownable_target(tgt: Target, union_membership: UnionMembership) -> bool:
return (
tgt.has_field(PythonSources)
or tgt.has_field(ResourcesSources)
or tgt.get(Sources).can_generate(PythonSources, union_membership)
)


def rules():
Expand Down
23 changes: 20 additions & 3 deletions src/python/pants/core/util_rules/determine_source_files.py
Expand Up @@ -36,17 +36,20 @@ def files(self) -> Tuple[str, ...]:
class AllSourceFilesRequest:
sources_fields: Tuple[SourcesField, ...]
for_sources_types: Tuple[Type[SourcesField], ...]
enable_codegen: bool
strip_source_roots: bool

def __init__(
self,
sources_fields: Iterable[SourcesField],
*,
for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
enable_codegen: bool = False,
strip_source_roots: bool = False
) -> None:
self.sources_fields = tuple(sources_fields)
self.for_sources_types = tuple(for_sources_types)
self.enable_codegen = enable_codegen
self.strip_source_roots = strip_source_roots


Expand All @@ -55,17 +58,20 @@ def __init__(
class SpecifiedSourceFilesRequest:
sources_fields_with_origins: Tuple[Tuple[SourcesField, OriginSpec], ...]
for_sources_types: Tuple[Type[SourcesField], ...]
enable_codegen: bool
strip_source_roots: bool

def __init__(
self,
sources_fields_with_origins: Iterable[Tuple[SourcesField, OriginSpec]],
*,
for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
enable_codegen: bool = False,
strip_source_roots: bool = False
) -> None:
self.sources_fields_with_origins = tuple(sources_fields_with_origins)
self.for_sources_types = tuple(for_sources_types)
self.enable_codegen = enable_codegen
self.strip_source_roots = strip_source_roots


Expand All @@ -91,7 +97,11 @@ async def determine_all_source_files(request: AllSourceFilesRequest) -> SourceFi
if request.strip_source_roots:
stripped_snapshots = await MultiGet(
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(sources_field, for_sources_types=request.for_sources_types)
StripSourcesFieldRequest(
sources_field,
for_sources_types=request.for_sources_types,
enable_codegen=request.enable_codegen,
)
)
for sources_field in request.sources_fields
)
Expand All @@ -101,7 +111,11 @@ async def determine_all_source_files(request: AllSourceFilesRequest) -> SourceFi
else:
all_hydrated_sources = await MultiGet(
Get[HydratedSources](
HydrateSourcesRequest(sources_field, for_sources_types=request.for_sources_types)
HydrateSourcesRequest(
sources_field,
for_sources_types=request.for_sources_types,
enable_codegen=request.enable_codegen,
)
)
for sources_field in request.sources_fields
)
Expand All @@ -119,7 +133,9 @@ async def determine_specified_source_files(request: SpecifiedSourceFilesRequest)
all_hydrated_sources = await MultiGet(
Get[HydratedSources](
HydrateSourcesRequest(
sources_field_with_origin[0], for_sources_types=request.for_sources_types
sources_field_with_origin[0],
for_sources_types=request.for_sources_types,
enable_codegen=request.enable_codegen,
)
)
for sources_field_with_origin in request.sources_fields_with_origins
Expand Down Expand Up @@ -154,6 +170,7 @@ async def determine_specified_source_files(request: SpecifiedSourceFilesRequest)
sources_field,
specified_files_snapshot=snapshot,
for_sources_types=request.for_sources_types,
enable_codegen=request.enable_codegen,
)
)
for sources_field, snapshot in zip(all_sources_fields, all_snapshots)
Expand Down
11 changes: 8 additions & 3 deletions src/python/pants/core/util_rules/strip_source_roots.py
Expand Up @@ -58,18 +58,21 @@ class StripSourcesFieldRequest:
"""

sources_field: SourcesField
for_sources_types: Tuple[Type[SourcesField], ...] = (SourcesField,)
specified_files_snapshot: Optional[Snapshot] = None
for_sources_types: Tuple[Type[SourcesField], ...]
enable_codegen: bool
specified_files_snapshot: Optional[Snapshot]

def __init__(
self,
sources_field: SourcesField,
*,
for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
enable_codegen: bool = False,
specified_files_snapshot: Optional[Snapshot] = None,
) -> None:
self.sources_field = sources_field
self.for_sources_types = tuple(for_sources_types)
self.enable_codegen = enable_codegen
self.specified_files_snapshot = specified_files_snapshot


Expand Down Expand Up @@ -145,7 +148,9 @@ async def strip_source_roots_from_sources_field(
else:
hydrated_sources = await Get[HydratedSources](
HydrateSourcesRequest(
request.sources_field, for_sources_types=request.for_sources_types
request.sources_field,
for_sources_types=request.for_sources_types,
enable_codegen=request.enable_codegen,
)
)
sources_snapshot = hydrated_sources.snapshot
Expand Down

0 comments on commit 9945df1

Please sign in to comment.