diff --git a/src/python/pants/backend/python/rules/importable_python_sources.py b/src/python/pants/backend/python/rules/importable_python_sources.py index bed7665bf6c..e65f86b2511 100644 --- a/src/python/pants/backend/python/rules/importable_python_sources.py +++ b/src/python/pants/backend/python/rules/importable_python_sources.py @@ -37,6 +37,7 @@ async def prepare_python_sources(targets: Targets) -> ImportablePythonSources: AllSourceFilesRequest( (tgt.get(Sources) for tgt in targets), for_sources_types=(PythonSources, ResourcesSources, FilesSources), + enable_codegen=True, strip_source_roots=True, ) ) diff --git a/src/python/pants/backend/python/rules/run_setup_py.py b/src/python/pants/backend/python/rules/run_setup_py.py index f9d16446993..74db8a9cf9b 100644 --- a/src/python/pants/backend/python/rules/run_setup_py.py +++ b/src/python/pants/backend/python/rules/run_setup_py.py @@ -68,6 +68,7 @@ TargetsWithOrigins, TransitiveTargets, ) +from pants.engine.unions import UnionMembership from pants.option.custom_types import shell_str from pants.python.python_setup import PythonSetup from pants.source.source_root import SourceRootConfig @@ -279,6 +280,7 @@ async def run_setup_pys( python_setup: PythonSetup, distdir: DistDir, workspace: Workspace, + union_membership: UnionMembership, ) -> SetupPy: """Run setup.py commands on all exported targets addressed.""" args = tuple(options.values.args) @@ -310,7 +312,7 @@ async def run_setup_pys( owners = await MultiGet( Get[ExportedTarget](OwnedDependency(tgt)) for tgt in transitive_targets.closure - if is_ownable_target(tgt) + if is_ownable_target(tgt, union_membership) ) exported_targets = list(FrozenOrderedSet(owners)) @@ -474,7 +476,9 @@ async def get_sources( stripped_srcs_list = await MultiGet( Get[SourceRootStrippedSources]( StripSourcesFieldRequest( - target.get(Sources), for_sources_types=(PythonSources, ResourcesSources) + target.get(Sources), + for_sources_types=(PythonSources, ResourcesSources), + enable_codegen=True, ) ) for target in targets @@ -523,7 +527,9 @@ async def get_ancestor_init_py( source_roots = source_root_config.get_source_roots() sources = await Get[SourceFiles]( AllSourceFilesRequest( - (tgt.get(Sources) for tgt in targets), for_sources_types=(PythonSources,) + (tgt.get(Sources) for tgt in targets), + for_sources_types=(PythonSources,), + enable_codegen=True, ) ) # Find the ancestors of all dirs containing .py files, including those dirs themselves. @@ -564,12 +570,16 @@ def _is_exported(target: Target) -> bool: @named_rule(desc="Compute distribution's 3rd party requirements") -async def get_requirements(dep_owner: DependencyOwner) -> ExportedTargetRequirements: +async def get_requirements( + dep_owner: DependencyOwner, union_membership: UnionMembership +) -> ExportedTargetRequirements: transitive_targets = await Get[TransitiveTargets]( Addresses([dep_owner.exported_target.target.address]) ) - ownable_tgts = [tgt for tgt in transitive_targets.closure if is_ownable_target(tgt)] + ownable_tgts = [ + tgt for tgt in transitive_targets.closure if is_ownable_target(tgt, union_membership) + ] owners = await MultiGet(Get[ExportedTarget](OwnedDependency(tgt)) for tgt in ownable_tgts) owned_by_us: Set[Target] = set() owned_by_others: Set[Target] = set() @@ -611,7 +621,9 @@ async def get_requirements(dep_owner: DependencyOwner) -> ExportedTargetRequirem @named_rule(desc="Find all code to be published in the distribution") -async def get_owned_dependencies(dependency_owner: DependencyOwner) -> OwnedDependencies: +async def get_owned_dependencies( + dependency_owner: DependencyOwner, union_membership: UnionMembership +) -> OwnedDependencies: """Find the dependencies of dependency_owner that are owned by it. Includes dependency_owner itself. @@ -619,7 +631,9 @@ async def get_owned_dependencies(dependency_owner: DependencyOwner) -> OwnedDepe transitive_targets = await Get[TransitiveTargets]( Addresses([dependency_owner.exported_target.target.address]) ) - ownable_targets = [tgt for tgt in transitive_targets.closure if is_ownable_target(tgt)] + ownable_targets = [ + tgt for tgt in transitive_targets.closure if is_ownable_target(tgt, union_membership) + ] owners = await MultiGet(Get[ExportedTarget](OwnedDependency(tgt)) for tgt in ownable_targets) owned_dependencies = [ tgt @@ -692,8 +706,12 @@ async def setup_setuptools(setuptools: Setuptools) -> SetuptoolsSetup: return SetuptoolsSetup(requirements_pex=requirements_pex,) -def is_ownable_target(tgt: Target) -> bool: - return tgt.has_field(PythonSources) or tgt.has_field(ResourcesSources) +def is_ownable_target(tgt: Target, union_membership: UnionMembership) -> bool: + return ( + tgt.has_field(PythonSources) + or tgt.has_field(ResourcesSources) + or tgt.get(Sources).can_generate(PythonSources, union_membership) + ) def rules(): diff --git a/src/python/pants/core/util_rules/determine_source_files.py b/src/python/pants/core/util_rules/determine_source_files.py index 403d1bcef7a..eae3f7ebcbf 100644 --- a/src/python/pants/core/util_rules/determine_source_files.py +++ b/src/python/pants/core/util_rules/determine_source_files.py @@ -36,6 +36,7 @@ def files(self) -> Tuple[str, ...]: class AllSourceFilesRequest: sources_fields: Tuple[SourcesField, ...] for_sources_types: Tuple[Type[SourcesField], ...] + enable_codegen: bool strip_source_roots: bool def __init__( @@ -43,10 +44,12 @@ def __init__( sources_fields: Iterable[SourcesField], *, for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,), + enable_codegen: bool = False, strip_source_roots: bool = False ) -> None: self.sources_fields = tuple(sources_fields) self.for_sources_types = tuple(for_sources_types) + self.enable_codegen = enable_codegen self.strip_source_roots = strip_source_roots @@ -55,6 +58,7 @@ def __init__( class SpecifiedSourceFilesRequest: sources_fields_with_origins: Tuple[Tuple[SourcesField, OriginSpec], ...] for_sources_types: Tuple[Type[SourcesField], ...] + enable_codegen: bool strip_source_roots: bool def __init__( @@ -62,10 +66,12 @@ def __init__( sources_fields_with_origins: Iterable[Tuple[SourcesField, OriginSpec]], *, for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,), + enable_codegen: bool = False, strip_source_roots: bool = False ) -> None: self.sources_fields_with_origins = tuple(sources_fields_with_origins) self.for_sources_types = tuple(for_sources_types) + self.enable_codegen = enable_codegen self.strip_source_roots = strip_source_roots @@ -91,7 +97,11 @@ async def determine_all_source_files(request: AllSourceFilesRequest) -> SourceFi if request.strip_source_roots: stripped_snapshots = await MultiGet( Get[SourceRootStrippedSources]( - StripSourcesFieldRequest(sources_field, for_sources_types=request.for_sources_types) + StripSourcesFieldRequest( + sources_field, + for_sources_types=request.for_sources_types, + enable_codegen=request.enable_codegen, + ) ) for sources_field in request.sources_fields ) @@ -101,7 +111,11 @@ async def determine_all_source_files(request: AllSourceFilesRequest) -> SourceFi else: all_hydrated_sources = await MultiGet( Get[HydratedSources]( - HydrateSourcesRequest(sources_field, for_sources_types=request.for_sources_types) + HydrateSourcesRequest( + sources_field, + for_sources_types=request.for_sources_types, + enable_codegen=request.enable_codegen, + ) ) for sources_field in request.sources_fields ) @@ -119,7 +133,9 @@ async def determine_specified_source_files(request: SpecifiedSourceFilesRequest) all_hydrated_sources = await MultiGet( Get[HydratedSources]( HydrateSourcesRequest( - sources_field_with_origin[0], for_sources_types=request.for_sources_types + sources_field_with_origin[0], + for_sources_types=request.for_sources_types, + enable_codegen=request.enable_codegen, ) ) for sources_field_with_origin in request.sources_fields_with_origins @@ -154,6 +170,7 @@ async def determine_specified_source_files(request: SpecifiedSourceFilesRequest) sources_field, specified_files_snapshot=snapshot, for_sources_types=request.for_sources_types, + enable_codegen=request.enable_codegen, ) ) for sources_field, snapshot in zip(all_sources_fields, all_snapshots) diff --git a/src/python/pants/core/util_rules/strip_source_roots.py b/src/python/pants/core/util_rules/strip_source_roots.py index a490efd6c20..24758cf631c 100644 --- a/src/python/pants/core/util_rules/strip_source_roots.py +++ b/src/python/pants/core/util_rules/strip_source_roots.py @@ -58,18 +58,21 @@ class StripSourcesFieldRequest: """ sources_field: SourcesField - for_sources_types: Tuple[Type[SourcesField], ...] = (SourcesField,) - specified_files_snapshot: Optional[Snapshot] = None + for_sources_types: Tuple[Type[SourcesField], ...] + enable_codegen: bool + specified_files_snapshot: Optional[Snapshot] def __init__( self, sources_field: SourcesField, *, for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,), + enable_codegen: bool = False, specified_files_snapshot: Optional[Snapshot] = None, ) -> None: self.sources_field = sources_field self.for_sources_types = tuple(for_sources_types) + self.enable_codegen = enable_codegen self.specified_files_snapshot = specified_files_snapshot @@ -145,7 +148,9 @@ async def strip_source_roots_from_sources_field( else: hydrated_sources = await Get[HydratedSources]( HydrateSourcesRequest( - request.sources_field, for_sources_types=request.for_sources_types + request.sources_field, + for_sources_types=request.for_sources_types, + enable_codegen=request.enable_codegen, ) ) sources_snapshot = hydrated_sources.snapshot diff --git a/src/python/pants/engine/target.py b/src/python/pants/engine/target.py index 6fec652edb8..2fa645ccad6 100644 --- a/src/python/pants/engine/target.py +++ b/src/python/pants/engine/target.py @@ -38,7 +38,7 @@ from pants.engine.legacy.structs import BundleAdaptor from pants.engine.rules import RootRule, rule from pants.engine.selectors import Get -from pants.engine.unions import UnionMembership +from pants.engine.unions import UnionMembership, union from pants.source.wrapped_globs import EagerFilesetWithSpec, FilesetRelPathWrapper, Filespec from pants.util.collections import ensure_list, ensure_str_list from pants.util.frozendict import FrozenDict @@ -306,6 +306,7 @@ def __init__( def field_types(self) -> Tuple[Type[Field], ...]: return (*self.core_fields, *self.plugin_fields) + @union @final class PluginField: """A sentinel class to allow plugin authors to add additional fields to this target type. @@ -915,12 +916,56 @@ def __init__( bulleted_list_sep = "\n * " super().__init__( f"Multiple of the registered implementations for {goal_description} work for " - f"{target.address} (target type {repr(target.alias)}).\n\n" - "It is ambiguous which implementation to use. Possible implementations:" + f"{target.address} (target type {repr(target.alias)}). It is ambiguous which " + "implementation to use.\n\nPossible implementations:" f"{bulleted_list_sep}{bulleted_list_sep.join(possible_config_types)}" ) +class AmbiguousCodegenImplementationsException(Exception): + """Exception for when there are multiple codegen implementations and it is ambiguous which to + use.""" + + def __init__( + self, + generators: Iterable[Type["GenerateSourcesRequest"]], + *, + for_sources_types: Iterable[Type["Sources"]], + ) -> None: + bulleted_list_sep = "\n * " + all_same_generator_paths = ( + len(set((generator.input, generator.output) for generator in generators)) == 1 + ) + example_generator = list(generators)[0] + input = example_generator.input.__name__ + if all_same_generator_paths: + output = example_generator.output.__name__ + possible_generators = sorted(generator.__name__ for generator in generators) + super().__init__( + f"Multiple of the registered code generators can generate {output} from {input}. " + "It is ambiguous which implementation to use.\n\nPossible implementations:" + f"{bulleted_list_sep}{bulleted_list_sep.join(possible_generators)}" + ) + else: + possible_output_types = sorted( + generator.output.__name__ + for generator in generators + if issubclass(generator.output, tuple(for_sources_types)) + ) + possible_generators_with_output = [ + f"{generator.__name__} -> {generator.output.__name__}" + for generator in sorted(generators, key=lambda generator: generator.output.__name__) + ] + super().__init__( + f"Multiple of the registered code generators can generate one of " + f"{possible_output_types} from {input}. It is ambiguous which implementation to " + f"use. This can happen when the call site requests too many different output types " + f"from the same original protocol sources.\n\nPossible implementations with their " + f"output type: {bulleted_list_sep}" + f"{bulleted_list_sep.join(possible_generators_with_output)}" + ) + + # ----------------------------------------------------------------------------------------------- # Field templates # ----------------------------------------------------------------------------------------------- @@ -1162,7 +1207,7 @@ def compute_value( # ----------------------------------------------------------------------------------------------- -# Sources +# Sources and codegen # ----------------------------------------------------------------------------------------------- @@ -1269,65 +1314,191 @@ def filespec(self) -> Filespec: args=includes, exclude=[excludes], root=self.address.spec_path ) + @final + @classmethod + def can_generate(cls, output_type: Type["Sources"], union_membership: UnionMembership) -> bool: + """Can this Sources field be used to generate the output_type? + + Generally, this method does not need to be used. Most call sites can simply use the below, + and the engine will generate the sources if possible or will return an instance of + HydratedSources with an empty snapshot if not possible: + + await Get[HydratedSources]( + HydrateSourcesRequest( + sources_field, + for_sources_types=[FortranSources], + enable_codegen=True, + ) + ) + + This method is useful when you need to filter targets before hydrating them, such as how + you may filter targets via `tgt.has_field(MyField)`. + """ + generate_request_types: Iterable[ + Type[GenerateSourcesRequest] + ] = union_membership.union_rules.get(GenerateSourcesRequest, ()) + return any( + issubclass(cls, generate_request_type.input) + and issubclass(generate_request_type.output, output_type) + for generate_request_type in generate_request_types + ) + @frozen_after_init @dataclass(unsafe_hash=True) class HydrateSourcesRequest: field: Sources for_sources_types: Tuple[Type[Sources], ...] + enable_codegen: bool def __init__( - self, field: Sources, *, for_sources_types: Iterable[Type[Sources]] = (Sources,) + self, + field: Sources, + *, + for_sources_types: Iterable[Type[Sources]] = (Sources,), + enable_codegen: bool = False, ) -> None: """Convert raw sources globs into an instance of HydratedSources. If you only want to handle certain Sources fields, such as only PythonSources, set `for_sources_types`. Any invalid sources will return a `HydratedSources` instance with an - empty snapshot and `output_type = None`. + empty snapshot and `sources_type = None`. + + If `enable_codegen` is set to `True`, any codegen sources will try to be converted to one + of the `for_sources_types`. """ self.field = field self.for_sources_types = tuple(for_sources_types) + self.enable_codegen = enable_codegen + self.__post_init__() + + def __post_init__(self) -> None: + if self.enable_codegen and self.for_sources_types == (Sources,): + raise ValueError( + "When setting `enable_codegen=True` on `HydrateSourcesRequest`, you must also " + "explicitly set `for_source_types`. Why? `for_source_types` is used to " + "determine which language(s) to try to generate. For example, " + "`for_source_types=(PythonSources,)` will hydrate `PythonSources` like normal, " + "and, if it encounters codegen sources that can be converted into Python, it will " + "generate Python files." + ) @dataclass(frozen=True) class HydratedSources: """The result of hydrating a SourcesField. - The `output_type` will indicate which of the `HydrateSourcesRequest.valid_sources_type` the + The `sources_type` will indicate which of the `HydrateSourcesRequest.for_sources_type` the result corresponds to, e.g. if the result comes from `FilesSources` vs. `PythonSources`. If this - value is None, then the input `Sources` field was not one of the expected types. This property - allows for switching on the result, e.g. handling hydrated files() sources differently than - hydrated Python sources. + value is None, then the input `Sources` field was not one of the expected types; or, when + codegen was enabled in the request, there was no valid code generator to generate the requested + language from the original input. This property allows for switching on the result, e.g. + handling hydrated files() sources differently than hydrated Python sources. """ snapshot: Snapshot filespec: Filespec - output_type: Optional[Type[Sources]] + sources_type: Optional[Type[Sources]] def eager_fileset_with_spec(self, *, address: Address) -> EagerFilesetWithSpec: return EagerFilesetWithSpec(address.spec_path, self.filespec, self.snapshot) +@union +@dataclass(frozen=True) +class GenerateSourcesRequest: + """A request to go from protocol sources -> a particular language. + + This should be subclassed for each distinct codegen implementation. The subclasses must define + the class properties `input` and `output`. The subclass must also be registered via + `UnionRule(GenerateSourcesRequest, GenerateFortranFromAvroRequest)`, for example. + + The rule to actually implement the codegen should take the subclass as input, and it must + return `GeneratedSources`. + + For example: + + class GenerateFortranFromAvroRequest: + input = AvroSources + output = FortranSources + + @rule + def generate_fortran_from_avro(request: GenerateFortranFromAvroRequest) -> GeneratedSources: + ... + + def rules(): + return [ + generate_fortran_from_avro, + UnionRule(GenerateSourcesRequest, GenerateFortranFromAvroRequest), + ] + """ + + protocol_sources: Snapshot + protocol_target: Target + + input: ClassVar[Type[Sources]] + output: ClassVar[Type[Sources]] + + +@dataclass(frozen=True) +class GeneratedSources: + snapshot: Snapshot + + @rule async def hydrate_sources( - request: HydrateSourcesRequest, glob_match_error_behavior: GlobMatchErrorBehavior + request: HydrateSourcesRequest, + glob_match_error_behavior: GlobMatchErrorBehavior, + union_membership: UnionMembership, ) -> HydratedSources: sources_field = request.field - output_type = next( + # First, find if there are any code generators for the input `sources_field`. This will be used + # to determine if the sources_field is valid or not. + # We could alternatively use `sources_field.can_generate()`, but we want to error if there are + # 2+ generators due to ambiguity. + generate_request_types: Iterable[ + Type[GenerateSourcesRequest] + ] = union_membership.union_rules.get(GenerateSourcesRequest, ()) + relevant_generate_request_types = [ + generate_request_type + for generate_request_type in generate_request_types + if isinstance(sources_field, generate_request_type.input) + and issubclass(generate_request_type.output, request.for_sources_types) + ] + if request.enable_codegen and len(relevant_generate_request_types) > 1: + raise AmbiguousCodegenImplementationsException( + relevant_generate_request_types, for_sources_types=request.for_sources_types + ) + generate_request_type = next(iter(relevant_generate_request_types), None) + + # Now, determine if any of the `for_sources_types` may be used, either because the + # sources_field is a direct subclass or can be generated into one of the valid types. + def compatible_with_sources_field(valid_type: Type[Sources]) -> bool: + is_instance = isinstance(sources_field, valid_type) + can_be_generated = ( + request.enable_codegen + and generate_request_type is not None + and issubclass(generate_request_type.output, valid_type) + ) + return is_instance or can_be_generated + + sources_type = next( ( valid_type for valid_type in request.for_sources_types - if isinstance(sources_field, valid_type) + if compatible_with_sources_field(valid_type) ), None, ) - if output_type is None: - return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, output_type=None) + if sources_type is None: + return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, sources_type=None) + # Now, hydrate the `globs`. Even if we are going to use codegen, we will need the original + # protocol sources to be hydrated. globs = sources_field.sanitized_raw_value if globs is None: - return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, output_type=output_type) + return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, sources_type=sources_type) conjunction = ( GlobExpansionConjunction.all_match @@ -1349,7 +1520,17 @@ async def hydrate_sources( ) ) sources_field.validate_snapshot(snapshot) - return HydratedSources(snapshot, sources_field.filespec, output_type=output_type) + + # Finally, return if codegen is not in use; otherwise, run the relevant code generator. + if not request.enable_codegen or generate_request_type is None: + return HydratedSources(snapshot, sources_field.filespec, sources_type=sources_type) + wrapped_protocol_target = await Get[WrappedTarget](Address, sources_field.address) + generated_sources = await Get[GeneratedSources]( + GenerateSourcesRequest, generate_request_type(snapshot, wrapped_protocol_target.target) + ) + return HydratedSources( + generated_sources.snapshot, sources_field.filespec, sources_type=sources_type + ) # ----------------------------------------------------------------------------------------------- diff --git a/src/python/pants/engine/target_test.py b/src/python/pants/engine/target_test.py index a7af86753cb..b95494fe2d2 100644 --- a/src/python/pants/engine/target_test.py +++ b/src/python/pants/engine/target_test.py @@ -11,11 +11,18 @@ from pants.base.specs import FilesystemLiteralSpec from pants.engine.addresses import Address -from pants.engine.fs import EMPTY_DIRECTORY_DIGEST, PathGlobs, Snapshot +from pants.engine.fs import ( + EMPTY_DIRECTORY_DIGEST, + FileContent, + InputFilesContent, + PathGlobs, + Snapshot, +) from pants.engine.internals.scheduler import ExecutionError from pants.engine.rules import RootRule, rule from pants.engine.selectors import Get, Params from pants.engine.target import ( + AmbiguousCodegenImplementationsException, AmbiguousImplementationsException, AsyncField, BoolField, @@ -23,6 +30,8 @@ ConfigurationWithOrigin, DictStringToStringField, DictStringToStringSequenceField, + GeneratedSources, + GenerateSourcesRequest, HydratedSources, HydrateSourcesRequest, InvalidFieldChoiceException, @@ -43,6 +52,7 @@ TargetsWithOrigins, TargetWithOrigin, TooManyTargetsException, + WrappedTarget, ) from pants.engine.target import rules as target_rules from pants.engine.unions import UnionMembership, UnionRule, union @@ -804,7 +814,7 @@ class SourcesSubclass(Sources): HydrateSourcesRequest(valid_sources, for_sources_types=[SourcesSubclass]), ) assert hydrated_valid_sources.snapshot.files == ("f1.f95",) - assert hydrated_valid_sources.output_type == SourcesSubclass + assert hydrated_valid_sources.sources_type == SourcesSubclass invalid_sources = Sources(["*"], address=addr) hydrated_invalid_sources = self.request_single_product( @@ -812,7 +822,7 @@ class SourcesSubclass(Sources): HydrateSourcesRequest(invalid_sources, for_sources_types=[SourcesSubclass]), ) assert hydrated_invalid_sources.snapshot.files == () - assert hydrated_invalid_sources.output_type is None + assert hydrated_invalid_sources.sources_type is None def test_unmatched_globs(self) -> None: self.create_files("", files=["f1.f95"]) @@ -889,3 +899,156 @@ def hydrate(sources_cls: Type[Sources], sources: Iterable[str]) -> HydratedSourc "f2.txt", "f3.txt", ) + + +# ----------------------------------------------------------------------------------------------- +# Test Codegen +# ----------------------------------------------------------------------------------------------- + + +class AvroSources(Sources): + pass + + +class AvroLibrary(Target): + alias = "avro_library" + core_fields = (AvroSources,) + + +class GenerateFortranFromAvroRequest(GenerateSourcesRequest): + input = AvroSources + output = FortranSources + + +@rule +async def generate_fortran_from_avro(request: GenerateFortranFromAvroRequest) -> GeneratedSources: + protocol_files = request.protocol_sources.files + + def generate_fortran(fp: str) -> FileContent: + parent = str(PurePath(fp).parent).replace("src/avro", "src/fortran") + file_name = f"{PurePath(fp).stem}.f95" + return FileContent(str(PurePath(parent, file_name)), b"Generated") + + result = await Get[Snapshot](InputFilesContent([generate_fortran(fp) for fp in protocol_files])) + return GeneratedSources(result) + + +class TestCodegen(TestBase): + @classmethod + def rules(cls): + return ( + *super().rules(), + *target_rules(), + generate_fortran_from_avro, + RootRule(GenerateFortranFromAvroRequest), + RootRule(HydrateSourcesRequest), + UnionRule(GenerateSourcesRequest, GenerateFortranFromAvroRequest), + ) + + @classmethod + def target_types(cls): + return [AvroLibrary] + + def setUp(self) -> None: + self.address = Address.parse("src/avro:lib") + self.create_files("src/avro", files=["f.avro"]) + self.add_to_build_file("src/avro", "avro_library(name='lib', sources=['*.avro'])") + self.union_membership = self.request_single_product(UnionMembership, Params()) + + def test_generate_sources(self) -> None: + protocol_sources = AvroSources(["*.avro"], address=self.address) + assert protocol_sources.can_generate(FortranSources, self.union_membership) is True + + # First, get the original protocol sources. + hydrated_protocol_sources = self.request_single_product( + HydratedSources, HydrateSourcesRequest(protocol_sources) + ) + assert hydrated_protocol_sources.snapshot.files == ("src/avro/f.avro",) + + # Test directly feeding the protocol sources into the codegen rule. + wrapped_tgt = self.request_single_product(WrappedTarget, self.address) + generated_sources = self.request_single_product( + GeneratedSources, + GenerateFortranFromAvroRequest(hydrated_protocol_sources.snapshot, wrapped_tgt.target), + ) + assert generated_sources.snapshot.files == ("src/fortran/f.f95",) + + # Test that HydrateSourcesRequest can also be used. + generated_via_hydrate_sources = self.request_single_product( + HydratedSources, + HydrateSourcesRequest( + protocol_sources, for_sources_types=[FortranSources], enable_codegen=True + ), + ) + assert generated_via_hydrate_sources.snapshot.files == ("src/fortran/f.f95",) + assert generated_via_hydrate_sources.sources_type == FortranSources + + def test_works_with_subclass_fields(self) -> None: + class CustomAvroSources(AvroSources): + pass + + protocol_sources = CustomAvroSources(["*.avro"], address=self.address) + assert protocol_sources.can_generate(FortranSources, self.union_membership) is True + generated = self.request_single_product( + HydratedSources, + HydrateSourcesRequest( + protocol_sources, for_sources_types=[FortranSources], enable_codegen=True + ), + ) + assert generated.snapshot.files == ("src/fortran/f.f95",) + + def test_cannot_generate_language(self) -> None: + class SmalltalkSources(Sources): + pass + + protocol_sources = AvroSources(["*.avro"], address=self.address) + assert protocol_sources.can_generate(SmalltalkSources, self.union_membership) is False + generated = self.request_single_product( + HydratedSources, + HydrateSourcesRequest( + protocol_sources, for_sources_types=[SmalltalkSources], enable_codegen=True + ), + ) + assert generated.snapshot.files == () + assert generated.sources_type is None + + def test_ambiguous_implementations_exception(self) -> None: + # This error message is quite complex. We test that it correctly generates the message. + class FortranGenerator1(GenerateSourcesRequest): + input = AvroSources + output = FortranSources + + class FortranGenerator2(GenerateSourcesRequest): + input = AvroSources + output = FortranSources + + class SmalltalkSources(Sources): + pass + + class SmalltalkGenerator(GenerateSourcesRequest): + input = AvroSources + output = SmalltalkSources + + class IrrelevantSources(Sources): + pass + + # Test when all generators have the same input and output. + exc = AmbiguousCodegenImplementationsException( + [FortranGenerator1, FortranGenerator2], for_sources_types=[FortranSources] + ) + assert "can generate FortranSources from AvroSources" in str(exc) + assert "* FortranGenerator1" in str(exc) + assert "* FortranGenerator2" in str(exc) + + # Test when the generators have different input and output, which usually happens because + # the call site used too expansive of a `for_sources_types` argument. + exc = AmbiguousCodegenImplementationsException( + [FortranGenerator1, SmalltalkGenerator], + for_sources_types=[FortranSources, SmalltalkSources, IrrelevantSources], + ) + assert "can generate one of ['FortranSources', 'SmalltalkSources'] from AvroSources" in str( + exc + ) + assert "IrrelevantSources" not in str(exc) + assert "* FortranGenerator1 -> FortranSources" in str(exc) + assert "* SmalltalkGenerator -> SmalltalkSources" in str(exc)