forked from pantsbuild/pants
-
Notifications
You must be signed in to change notification settings - Fork 0
/
determine_source_files.py
192 lines (169 loc) · 7.2 KB
/
determine_source_files.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from dataclasses import dataclass
from typing import Iterable, Tuple, Type, Union
from pants.base.specs import AddressSpec, OriginSpec
from pants.core.util_rules import strip_source_roots
from pants.core.util_rules.strip_source_roots import (
SourceRootStrippedSources,
StripSourcesFieldRequest,
)
from pants.engine.fs import DirectoriesToMerge, PathGlobs, Snapshot, SnapshotSubset
from pants.engine.rules import RootRule, rule
from pants.engine.selectors import Get, MultiGet
from pants.engine.target import HydratedSources, HydrateSourcesRequest
from pants.engine.target import Sources as SourcesField
from pants.util.meta import frozen_after_init
@dataclass(frozen=True)
class SourceFiles:
"""A merged snapshot of the `sources` fields of multiple targets, possibly containing a subset
of the `sources` when using `SpecifiedSourceFilesRequest` (instead of
`AllSourceFilesRequest`)."""
snapshot: Snapshot
@property
def files(self) -> Tuple[str, ...]:
return self.snapshot.files
@frozen_after_init
@dataclass(unsafe_hash=True)
class AllSourceFilesRequest:
sources_fields: Tuple[SourcesField, ...]
for_sources_types: Tuple[Type[SourcesField], ...]
enable_codegen: bool
strip_source_roots: bool
def __init__(
self,
sources_fields: Iterable[SourcesField],
*,
for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
enable_codegen: bool = False,
strip_source_roots: bool = False
) -> None:
self.sources_fields = tuple(sources_fields)
self.for_sources_types = tuple(for_sources_types)
self.enable_codegen = enable_codegen
self.strip_source_roots = strip_source_roots
@frozen_after_init
@dataclass(unsafe_hash=True)
class SpecifiedSourceFilesRequest:
sources_fields_with_origins: Tuple[Tuple[SourcesField, OriginSpec], ...]
for_sources_types: Tuple[Type[SourcesField], ...]
enable_codegen: bool
strip_source_roots: bool
def __init__(
self,
sources_fields_with_origins: Iterable[Tuple[SourcesField, OriginSpec]],
*,
for_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
enable_codegen: bool = False,
strip_source_roots: bool = False
) -> None:
self.sources_fields_with_origins = tuple(sources_fields_with_origins)
self.for_sources_types = tuple(for_sources_types)
self.enable_codegen = enable_codegen
self.strip_source_roots = strip_source_roots
def calculate_specified_sources(
sources_snapshot: Snapshot, origin: OriginSpec
) -> Union[Snapshot, SnapshotSubset]:
# AddressSpecs simply use the entire `sources` field.
if isinstance(origin, AddressSpec):
return sources_snapshot
# NB: we ensure that `precise_files_specified` is a subset of the original `sources` field.
# It's possible when given a glob filesystem spec that the spec will have
# resolved files not belonging to this target - those must be filtered out.
precise_files_specified = set(sources_snapshot.files).intersection(origin.resolved_files)
return SnapshotSubset(
directory_digest=sources_snapshot.directory_digest,
globs=PathGlobs(sorted(precise_files_specified)),
)
@rule
async def determine_all_source_files(request: AllSourceFilesRequest) -> SourceFiles:
"""Merge all `Sources` fields into one Snapshot."""
if request.strip_source_roots:
stripped_snapshots = await MultiGet(
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(
sources_field,
for_sources_types=request.for_sources_types,
enable_codegen=request.enable_codegen,
)
)
for sources_field in request.sources_fields
)
digests_to_merge = tuple(
stripped_snapshot.snapshot.directory_digest for stripped_snapshot in stripped_snapshots
)
else:
all_hydrated_sources = await MultiGet(
Get[HydratedSources](
HydrateSourcesRequest(
sources_field,
for_sources_types=request.for_sources_types,
enable_codegen=request.enable_codegen,
)
)
for sources_field in request.sources_fields
)
digests_to_merge = tuple(
hydrated_sources.snapshot.directory_digest for hydrated_sources in all_hydrated_sources
)
result = await Get[Snapshot](DirectoriesToMerge(digests_to_merge))
return SourceFiles(result)
@rule
async def determine_specified_source_files(request: SpecifiedSourceFilesRequest) -> SourceFiles:
"""Determine the specified `sources` for targets, possibly finding a subset of the original
`sources` fields if the user supplied file arguments."""
all_hydrated_sources = await MultiGet(
Get[HydratedSources](
HydrateSourcesRequest(
sources_field_with_origin[0],
for_sources_types=request.for_sources_types,
enable_codegen=request.enable_codegen,
)
)
for sources_field_with_origin in request.sources_fields_with_origins
)
full_snapshots = {}
snapshot_subset_requests = {}
for hydrated_sources, sources_field_with_origin in zip(
all_hydrated_sources, request.sources_fields_with_origins
):
sources_field, origin = sources_field_with_origin
if not hydrated_sources.snapshot.files:
continue
result = calculate_specified_sources(hydrated_sources.snapshot, origin)
if isinstance(result, Snapshot):
full_snapshots[sources_field] = result
else:
snapshot_subset_requests[sources_field] = result
snapshot_subsets: Tuple[Snapshot, ...] = ()
if snapshot_subset_requests:
snapshot_subsets = await MultiGet(
Get[Snapshot](SnapshotSubset, request) for request in snapshot_subset_requests.values()
)
all_snapshots: Iterable[Snapshot] = (*full_snapshots.values(), *snapshot_subsets)
if request.strip_source_roots:
all_sources_fields = (*full_snapshots.keys(), *snapshot_subset_requests.keys())
stripped_snapshots = await MultiGet(
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(
sources_field,
specified_files_snapshot=snapshot,
for_sources_types=request.for_sources_types,
enable_codegen=request.enable_codegen,
)
)
for sources_field, snapshot in zip(all_sources_fields, all_snapshots)
)
all_snapshots = (stripped_snapshot.snapshot for stripped_snapshot in stripped_snapshots)
result = await Get[Snapshot](
DirectoriesToMerge(tuple(snapshot.directory_digest for snapshot in all_snapshots))
)
return SourceFiles(result)
def rules():
return [
determine_all_source_files,
determine_specified_source_files,
RootRule(AllSourceFilesRequest),
RootRule(SpecifiedSourceFilesRequest),
*strip_source_roots.rules(),
]