/
ComputeAffectedTests.kt
371 lines (331 loc) · 14.9 KB
/
ComputeAffectedTests.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
package org.oppia.android.scripts.ci
import org.oppia.android.scripts.common.BazelClient
import org.oppia.android.scripts.common.CommandExecutor
import org.oppia.android.scripts.common.CommandExecutorImpl
import org.oppia.android.scripts.common.GitClient
import org.oppia.android.scripts.common.ProtoStringEncoder.Companion.toCompressedBase64
import org.oppia.android.scripts.common.ScriptBackgroundCoroutineDispatcher
import org.oppia.android.scripts.proto.AffectedTestsBucket
import java.io.File
import java.util.Locale
import kotlin.system.exitProcess
private const val COMPUTE_ALL_TESTS_PREFIX = "compute_all_tests="
private const val MAX_TEST_COUNT_PER_LARGE_SHARD = 50
private const val MAX_TEST_COUNT_PER_MEDIUM_SHARD = 25
private const val MAX_TEST_COUNT_PER_SMALL_SHARD = 15
/**
* The main entrypoint for computing the list of affected test targets based on changes in the local
* Oppia Android Git repository.
*
* Usage:
* bazel run //scripts:compute_affected_tests -- \\
* <path_to_directory_root> <path_to_output_file> <base_develop_branch_reference> \\
* <compute_all_tests=true/false>
*
* Arguments:
* - path_to_directory_root: directory path to the root of the Oppia Android repository.
* - path_to_output_file: path to the file in which the affected test targets will be printed.
* - base_develop_branch_reference: the reference to the local develop branch that should be use.
* Generally, this is 'origin/develop'.
* - compute_all_tests: whether to compute a list of all tests to run.
*
* Example:
* bazel run //scripts:compute_affected_tests -- $(pwd) /tmp/affected_test_buckets.proto64 \\
* origin/develop compute_all_tests=false
*/
fun main(args: Array<String>) {
if (args.size < 4) {
println(
"Usage: bazel run //scripts:compute_affected_tests --" +
" <path_to_directory_root> <path_to_output_file> <base_develop_branch_reference>" +
" <compute_all_tests=true/false>"
)
exitProcess(1)
}
val pathToRoot = args[0]
val pathToOutputFile = args[1]
val baseDevelopBranchReference = args[2]
val computeAllTestsSetting = args[3].let {
check(it.startsWith(COMPUTE_ALL_TESTS_PREFIX)) {
"Expected last argument to start with '$COMPUTE_ALL_TESTS_PREFIX'"
}
val computeAllTestsValue = it.removePrefix(COMPUTE_ALL_TESTS_PREFIX)
return@let computeAllTestsValue.toBooleanStrictOrNull()
?: error(
"Expected last argument to have 'true' or 'false' passed to it, not:" +
" '$computeAllTestsValue'"
)
}
ScriptBackgroundCoroutineDispatcher().use { scriptBgDispatcher ->
ComputeAffectedTests(scriptBgDispatcher)
.compute(pathToRoot, pathToOutputFile, baseDevelopBranchReference, computeAllTestsSetting)
}
}
// Needed since the codebase isn't yet using Kotlin 1.5, so this function isn't available.
private fun String.toBooleanStrictOrNull(): Boolean? {
return when (toLowerCase(Locale.US)) {
"false" -> false
"true" -> true
else -> null
}
}
/** Utility used to compute affected test targets. */
class ComputeAffectedTests(
private val scriptBgDispatcher: ScriptBackgroundCoroutineDispatcher,
val maxTestCountPerLargeShard: Int = MAX_TEST_COUNT_PER_LARGE_SHARD,
val maxTestCountPerMediumShard: Int = MAX_TEST_COUNT_PER_MEDIUM_SHARD,
val maxTestCountPerSmallShard: Int = MAX_TEST_COUNT_PER_SMALL_SHARD,
val commandExecutor: CommandExecutor = CommandExecutorImpl(scriptBgDispatcher)
) {
private companion object {
private const val GENERIC_TEST_BUCKET_NAME = "generic"
}
/**
* Computes a list of tests to run.
*
* @param pathToRoot the absolute path to the working root directory
* @param pathToOutputFile the absolute path to the file in which the encoded Base64 test bucket
* protos should be printed
* @param baseDevelopBranchReference see [GitClient]
* @param computeAllTestsSetting whether all tests should be outputted versus only the ones which
* are affected by local changes in the repository
*/
fun compute(
pathToRoot: String,
pathToOutputFile: String,
baseDevelopBranchReference: String,
computeAllTestsSetting: Boolean
) {
val rootDirectory = File(pathToRoot).absoluteFile
check(rootDirectory.isDirectory) { "Expected '$pathToRoot' to be a directory" }
check(rootDirectory.list().contains("WORKSPACE")) {
"Expected script to be run from the workspace's root directory"
}
println("Running from directory root: $rootDirectory")
val gitClient = GitClient(rootDirectory, baseDevelopBranchReference, commandExecutor)
val bazelClient = BazelClient(rootDirectory, commandExecutor)
println("Current branch: ${gitClient.currentBranch}")
println("Most recent common commit: ${gitClient.branchMergeBase}")
val currentBranch = gitClient.currentBranch.toLowerCase(Locale.US)
val affectedTestTargets = if (computeAllTestsSetting || currentBranch == "develop") {
computeAllTestTargets(bazelClient)
} else computeAffectedTargetsForNonDevelopBranch(gitClient, bazelClient, rootDirectory)
val filteredTestTargets = filterTargets(affectedTestTargets)
println()
println("Affected test targets:")
println(filteredTestTargets.joinToString(separator = "\n") { "- $it" })
// Bucket the targets & then shuffle them so that shards are run in different orders each time
// (to avoid situations where the longest/most expensive tests are run last).
val affectedTestBuckets = bucketTargets(filteredTestTargets)
val encodedTestBucketEntries =
affectedTestBuckets.associateBy { it.toCompressedBase64() }.entries.shuffled()
File(pathToOutputFile).printWriter().use { writer ->
encodedTestBucketEntries.forEachIndexed { index, (encoded, bucket) ->
writer.println("${bucket.cacheBucketName}-shard$index;$encoded")
}
}
}
private fun computeAllTestTargets(bazelClient: BazelClient): List<String> {
println("Computing all test targets")
return bazelClient.retrieveAllTestTargets()
}
private fun computeAffectedTargetsForNonDevelopBranch(
gitClient: GitClient,
bazelClient: BazelClient,
rootDirectory: File
): List<String> {
// Compute the list of changed files, but exclude files which no longer exist (since bazel query
// can't handle these well).
val changedFiles = gitClient.changedFiles.filter { filepath ->
File(rootDirectory, filepath).exists()
}.toSet()
println("Changed files (per Git): $changedFiles")
// Compute the changed targets 100 files at a time to avoid unnecessarily long-running Bazel
// commands.
val changedFileTargets =
changedFiles.chunked(size = 100).fold(initial = setOf<String>()) { allTargets, filesChunk ->
allTargets + bazelClient.retrieveBazelTargets(filesChunk).toSet()
}
println("Changed Bazel file targets: $changedFileTargets")
// Similarly, compute the affect test targets list 100 file targets at a time.
val affectedTestTargets =
changedFileTargets.chunked(size = 100)
.fold(initial = setOf<String>()) { allTargets, targetChunk ->
allTargets + bazelClient.retrieveRelatedTestTargets(targetChunk).toSet()
}
println("Affected Bazel test targets: $affectedTestTargets")
// Compute the list of Bazel files that were changed.
val changedBazelFiles = changedFiles.filter { file ->
file.endsWith(".bzl", ignoreCase = true) ||
file.endsWith(".bazel", ignoreCase = true) ||
file == "WORKSPACE"
}
println("Changed Bazel-specific support files: $changedBazelFiles")
// Compute the list of affected tests based on BUILD/Bazel/WORKSPACE files. These are generally
// framed as: if a BUILD file changes, run all tests transitively connected to it.
val transitiveTestTargets = bazelClient.retrieveTransitiveTestTargets(changedBazelFiles)
println("Affected test targets due to transitive build deps: $transitiveTestTargets")
return (affectedTestTargets + transitiveTestTargets).toSet().toList()
}
private fun filterTargets(testTargets: List<String>): List<String> {
// Filtering out the targets to be ignored.
return testTargets.filter { targetPath ->
!targetPath
.startsWith(
"//instrumentation/src/javatests/org/oppia/android/instrumentation/player",
ignoreCase = true
)
}
}
private fun bucketTargets(testTargets: List<String>): List<AffectedTestsBucket> {
// Group first by the bucket, then by the grouping strategy. Here's what's happening here:
// 1. Create: Map<TestBucket, List<String>>
// 2. Convert to: Iterable<Pair<TestBucket, List<String>>>
// 3. Convert to: Map<GroupingStrategy, List<Pair<TestBucket, List<String>>>>
// 4. Convert to: Map<GroupingStrategy, Map<TestBucket, List<String>>>
val groupedBuckets: Map<GroupingStrategy, Map<TestBucket, List<String>>> =
testTargets.groupBy { TestBucket.retrieveCorrespondingTestBucket(it) }
.entries.groupBy(
keySelector = { checkNotNull(it.key).groupingStrategy },
valueTransform = { checkNotNull(it.key) to it.value }
).mapValues { (_, bucketLists) -> bucketLists.toMap() }
// Next, properly segment buckets by splitting out individual ones and collecting like one:
// 5. Convert to: Map<String, Map<TestBucket, List<String>>>
val partitionedBuckets: Map<String, Map<TestBucket, List<String>>> =
groupedBuckets.entries.flatMap { (strategy, buckets) ->
return@flatMap when (strategy) {
GroupingStrategy.BUCKET_SEPARATELY -> {
// Each entry in the combined map should be a separate entry in the segmented map:
// 1. Start with: Map<TestBucket, List<String>>
// 2. Convert to: Map<TestBucket, Map<TestBucket, List<String>>>
// 3. Convert to: Map<String, Map<TestBucket, List<String>>>
// 4. Convert to: Iterable<Pair<String, Map<TestBucket, List<String>>>>
buckets.mapValues { (testBucket, targets) -> mapOf(testBucket to targets) }
.mapKeys { (testBucket, _) -> testBucket.cacheBucketName }
.entries.map { (cacheName, bucket) -> cacheName to bucket }
}
GroupingStrategy.BUCKET_GENERICALLY -> listOf(GENERIC_TEST_BUCKET_NAME to buckets)
}
}.toMap()
// Next, collapse the test bucket lists & partition them based on the common sharding strategy
// for each group:
// 6. Convert to: Map<String, List<List<String>>>
val shardedBuckets: Map<String, List<List<String>>> =
partitionedBuckets.mapValues { (_, bucketMap) ->
val shardingStrategies = bucketMap.keys.map { it.shardingStrategy }.toSet()
check(shardingStrategies.size == 1) {
"Error: expected all buckets in the same partition to share a sharding strategy:" +
" ${bucketMap.keys} (strategies: $shardingStrategies)"
}
val maxTestCountPerShard = when (shardingStrategies.first()) {
ShardingStrategy.LARGE_PARTITIONS -> maxTestCountPerLargeShard
ShardingStrategy.MEDIUM_PARTITIONS -> maxTestCountPerMediumShard
ShardingStrategy.SMALL_PARTITIONS -> maxTestCountPerSmallShard
}
val allPartitionTargets = bucketMap.values.flatten()
// Use randomization to encourage cache breadth & potentially improve workflow performance.
allPartitionTargets.shuffled().chunked(maxTestCountPerShard)
}
// Finally, compile into a list of protos:
// 7. Convert to List<AffectedTestsBucket>
return shardedBuckets.entries.flatMap { (bucketName, shardedTargets) ->
shardedTargets.map { targets ->
AffectedTestsBucket.newBuilder().apply {
cacheBucketName = bucketName
addAllAffectedTestTargets(targets)
}.build()
}
}
}
private enum class TestBucket(
val cacheBucketName: String,
val groupingStrategy: GroupingStrategy,
val shardingStrategy: ShardingStrategy
) {
/** Corresponds to app layer tests. */
APP(
cacheBucketName = "app",
groupingStrategy = GroupingStrategy.BUCKET_SEPARATELY,
shardingStrategy = ShardingStrategy.SMALL_PARTITIONS
),
/** Corresponds to data layer tests. */
DATA(
cacheBucketName = "data",
groupingStrategy = GroupingStrategy.BUCKET_GENERICALLY,
shardingStrategy = ShardingStrategy.LARGE_PARTITIONS
),
/** Corresponds to domain layer tests. */
DOMAIN(
cacheBucketName = "domain",
groupingStrategy = GroupingStrategy.BUCKET_SEPARATELY,
shardingStrategy = ShardingStrategy.LARGE_PARTITIONS
),
/** Corresponds to instrumentation tests. */
INSTRUMENTATION(
cacheBucketName = "instrumentation",
groupingStrategy = GroupingStrategy.BUCKET_GENERICALLY,
shardingStrategy = ShardingStrategy.LARGE_PARTITIONS
),
/** Corresponds to scripts tests. */
SCRIPTS(
cacheBucketName = "scripts",
groupingStrategy = GroupingStrategy.BUCKET_SEPARATELY,
shardingStrategy = ShardingStrategy.MEDIUM_PARTITIONS
),
/** Corresponds to testing utility tests. */
TESTING(
cacheBucketName = "testing",
groupingStrategy = GroupingStrategy.BUCKET_GENERICALLY,
shardingStrategy = ShardingStrategy.LARGE_PARTITIONS
),
/** Corresponds to production utility tests. */
UTILITY(
cacheBucketName = "utility",
groupingStrategy = GroupingStrategy.BUCKET_GENERICALLY,
shardingStrategy = ShardingStrategy.LARGE_PARTITIONS
);
companion object {
private val EXTRACT_BUCKET_REGEX = "^//([^(/|:)]+?)[/:].+?\$".toRegex()
/** Returns the [TestBucket] that corresponds to the specific [testTarget]. */
fun retrieveCorrespondingTestBucket(testTarget: String): TestBucket? {
return EXTRACT_BUCKET_REGEX.matchEntire(testTarget)
?.groupValues
?.maybeSecond()
?.let { bucket ->
values().find { it.cacheBucketName == bucket }
?: error(
"Invalid bucket name: $bucket (expected one of:" +
" ${values().map { it.cacheBucketName }})"
)
} ?: error("Invalid target: $testTarget (could not extract bucket name)")
}
private fun <E> List<E>.maybeSecond(): E? = if (size >= 2) this[1] else null
}
}
private enum class GroupingStrategy {
/** Indicates that a particular test bucket should be sharded by itself. */
BUCKET_SEPARATELY,
/**
* Indicates that a particular test bucket should be combined with all other generically grouped
* buckets.
*/
BUCKET_GENERICALLY
}
private enum class ShardingStrategy {
/**
* Indicates that the tests for a test bucket run very quickly and don't need as much
* parallelization.
*/
LARGE_PARTITIONS,
/**
* Indicates that the tests for a test bucket are somewhere between [LARGE_PARTITIONS] and
* [SMALL_PARTITIONS].
*/
MEDIUM_PARTITIONS,
/**
* Indicates that the tests for a test bucket run slowly and require more parallelization for
* faster CI runs.
*/
SMALL_PARTITIONS
}
}