Skip to content

Commit

Permalink
Allow HydratedSourcesRequest to indicate which Sources types are ex…
Browse files Browse the repository at this point in the history
…pected (pantsbuild#9641)

Soon, we will add codegen. With this, we need a way to signal which language should be generated, if any. 

@stuhood proposed in pantsbuild#9634 (comment) that we can extend this setting to indicate more generally which Sources fields are valid, e.g. that we expect to work with `PythonSources` and `FilesSources` (or subclasses), but nothing else. All invalid fields would return an empty snapshot and indicate via a new `HydratedSources.output_type` field that it was an invalid sources field.

This means that call sites can still pre-filter sources fields like they typically do via `tgt.has_field()` (and configurations), but they can also use this new sugar. If they want to use codegen in the upcoming PR, they must use this new mechanism.

Further, when getting back the `HydratedSources`, call sites can switch on the type. Previously, they could do this by zipping the original `Sources` with the resulting `HydratedSources`, but this won't work once we have codegen, as the original `Sources` will be, for example, `ThriftSources`.

```python
if hydrated_sources.output_type == PythonSources:
   ...
elif hydrated_sources.output_type == FilesSources:
   ...
```
 
[ci skip-rust-tests]
[ci skip-jvm-tests]
  • Loading branch information
Eric-Arellano committed Apr 28, 2020
1 parent 10f59e0 commit 46b2c9c
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 99 deletions.
Expand Up @@ -12,7 +12,7 @@
from pants.engine.fs import Snapshot
from pants.engine.rules import RootRule, rule
from pants.engine.selectors import Get
from pants.engine.target import Sources, Target, Targets
from pants.engine.target import Sources, Targets


@dataclass(frozen=True)
Expand All @@ -33,14 +33,11 @@ class ImportablePythonSources:

@rule
async def prepare_python_sources(targets: Targets) -> ImportablePythonSources:
def is_relevant(tgt: Target) -> bool:
return any(
tgt.has_field(field) for field in (PythonSources, ResourcesSources, FilesSources)
)

stripped_sources = await Get[SourceFiles](
AllSourceFilesRequest(
(tgt.get(Sources) for tgt in targets if is_relevant(tgt)), strip_source_roots=True
(tgt.get(Sources) for tgt in targets),
for_sources_types=(PythonSources, ResourcesSources, FilesSources),
strip_source_roots=True,
)
)
init_injected = await Get[InitInjectedSnapshot](InjectInitRequest(stripped_sources.snapshot))
Expand Down
10 changes: 8 additions & 2 deletions src/python/pants/backend/python/rules/run_setup_py.py
Expand Up @@ -472,7 +472,11 @@ async def get_sources(
) -> SetupPySources:
targets = request.targets
stripped_srcs_list = await MultiGet(
Get[SourceRootStrippedSources](StripSourcesFieldRequest(target.get(Sources)))
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(
target.get(Sources), for_sources_types=(PythonSources, ResourcesSources)
)
)
for target in targets
)

Expand Down Expand Up @@ -518,7 +522,9 @@ async def get_ancestor_init_py(
"""
source_roots = source_root_config.get_source_roots()
sources = await Get[SourceFiles](
AllSourceFilesRequest(tgt[PythonSources] for tgt in targets if tgt.has_field(PythonSources))
AllSourceFilesRequest(
(tgt.get(Sources) for tgt in targets), for_sources_types=(PythonSources,)
)
)
# Find the ancestors of all dirs containing .py files, including those dirs themselves.
source_dir_ancestors: Set[Tuple[str, str]] = set() # Items are (src_root, path incl. src_root).
Expand Down
37 changes: 29 additions & 8 deletions src/python/pants/core/util_rules/determine_source_files.py
Expand Up @@ -2,7 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from dataclasses import dataclass
from typing import Iterable, Tuple, Union
from typing import Iterable, Tuple, Type, Union

from pants.base.specs import AddressSpec, OriginSpec
from pants.core.util_rules import strip_source_roots
Expand Down Expand Up @@ -35,28 +35,37 @@ def files(self) -> Tuple[str, ...]:
@dataclass(unsafe_hash=True)
class AllSourceFilesRequest:
sources_fields: Tuple[SourcesField, ...]
strip_source_roots: bool = False
for_sources_types: Tuple[Type[SourcesField], ...]
strip_source_roots: bool

def __init__(
self, sources_fields: Iterable[SourcesField], *, strip_source_roots: bool = False
self,
sources_fields: Iterable[SourcesField],
*,
for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
strip_source_roots: bool = False
) -> None:
self.sources_fields = tuple(sources_fields)
self.for_sources_types = tuple(for_sources_types)
self.strip_source_roots = strip_source_roots


@frozen_after_init
@dataclass(unsafe_hash=True)
class SpecifiedSourceFilesRequest:
sources_fields_with_origins: Tuple[Tuple[SourcesField, OriginSpec], ...]
strip_source_roots: bool = False
for_sources_types: Tuple[Type[SourcesField], ...]
strip_source_roots: bool

def __init__(
self,
sources_fields_with_origins: Iterable[Tuple[SourcesField, OriginSpec]],
*,
for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
strip_source_roots: bool = False
) -> None:
self.sources_fields_with_origins = tuple(sources_fields_with_origins)
self.for_sources_types = tuple(for_sources_types)
self.strip_source_roots = strip_source_roots


Expand All @@ -81,15 +90,19 @@ async def determine_all_source_files(request: AllSourceFilesRequest) -> SourceFi
"""Merge all `Sources` fields into one Snapshot."""
if request.strip_source_roots:
stripped_snapshots = await MultiGet(
Get[SourceRootStrippedSources](StripSourcesFieldRequest(sources_field))
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(sources_field, for_sources_types=request.for_sources_types)
)
for sources_field in request.sources_fields
)
digests_to_merge = tuple(
stripped_snapshot.snapshot.directory_digest for stripped_snapshot in stripped_snapshots
)
else:
all_hydrated_sources = await MultiGet(
Get[HydratedSources](HydrateSourcesRequest(sources_field))
Get[HydratedSources](
HydrateSourcesRequest(sources_field, for_sources_types=request.for_sources_types)
)
for sources_field in request.sources_fields
)
digests_to_merge = tuple(
Expand All @@ -104,7 +117,11 @@ async def determine_specified_source_files(request: SpecifiedSourceFilesRequest)
"""Determine the specified `sources` for targets, possibly finding a subset of the original
`sources` fields if the user supplied file arguments."""
all_hydrated_sources = await MultiGet(
Get[HydratedSources](HydrateSourcesRequest(sources_field_with_origin[0]))
Get[HydratedSources](
HydrateSourcesRequest(
sources_field_with_origin[0], for_sources_types=request.for_sources_types
)
)
for sources_field_with_origin in request.sources_fields_with_origins
)

Expand Down Expand Up @@ -133,7 +150,11 @@ async def determine_specified_source_files(request: SpecifiedSourceFilesRequest)
all_sources_fields = (*full_snapshots.keys(), *snapshot_subset_requests.keys())
stripped_snapshots = await MultiGet(
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(sources_field, specified_files_snapshot=snapshot)
StripSourcesFieldRequest(
sources_field,
specified_files_snapshot=snapshot,
for_sources_types=request.for_sources_types,
)
)
for sources_field, snapshot in zip(all_sources_fields, all_snapshots)
)
Expand Down
24 changes: 21 additions & 3 deletions src/python/pants/core/util_rules/strip_source_roots.py
Expand Up @@ -4,7 +4,7 @@
import itertools
from dataclasses import dataclass
from pathlib import PurePath
from typing import Optional, cast
from typing import Iterable, Optional, Tuple, Type, cast

from pants.core.target_types import FilesSources
from pants.engine.addresses import Address
Expand All @@ -23,6 +23,7 @@
from pants.engine.target import Sources as SourcesField
from pants.engine.target import rules as target_rules
from pants.source.source_root import NoSourceRootError, SourceRootConfig
from pants.util.meta import frozen_after_init


@dataclass(frozen=True)
Expand All @@ -46,7 +47,8 @@ class StripSnapshotRequest:
representative_path: Optional[str] = None


@dataclass(frozen=True)
@frozen_after_init
@dataclass(unsafe_hash=True)
class StripSourcesFieldRequest:
"""A request to strip source roots for every file in a `Sources` field.
Expand All @@ -56,8 +58,20 @@ class StripSourcesFieldRequest:
"""

sources_field: SourcesField
for_sources_types: Tuple[Type[SourcesField], ...] = (SourcesField,)
specified_files_snapshot: Optional[Snapshot] = None

def __init__(
self,
sources_field: SourcesField,
*,
for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
specified_files_snapshot: Optional[Snapshot] = None,
) -> None:
self.sources_field = sources_field
self.for_sources_types = tuple(for_sources_types)
self.specified_files_snapshot = specified_files_snapshot


@rule
async def strip_source_roots_from_snapshot(
Expand Down Expand Up @@ -129,7 +143,11 @@ async def strip_source_roots_from_sources_field(
if request.specified_files_snapshot is not None:
sources_snapshot = request.specified_files_snapshot
else:
hydrated_sources = await Get[HydratedSources](HydrateSourcesRequest(request.sources_field))
hydrated_sources = await Get[HydratedSources](
HydrateSourcesRequest(
request.sources_field, for_sources_types=request.for_sources_types
)
)
sources_snapshot = hydrated_sources.snapshot

if not sources_snapshot.files:
Expand Down

0 comments on commit 46b2c9c

Please sign in to comment.