Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for #22951 #22953

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ public class GroupIntoBatches<K, InputT>
*/
@AutoValue
public abstract static class BatchingParams<InputT> implements Serializable {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like your adding support for GroupIntoBatches to limit on count and byte size at the same time.

Can you add tests that cover this new scenario to:

  • GroupIntoBatchesTest
  • GroupIntoBatchesTranslationTest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public static <InputT> BatchingParams<InputT> createDefault() {
lukecwik marked this conversation as resolved.
Show resolved Hide resolved
return new AutoValue_GroupIntoBatches_BatchingParams(
Long.MAX_VALUE, Long.MAX_VALUE, null, Duration.ZERO);
}

public static <InputT> BatchingParams<InputT> create(
long batchSize,
long batchSizeBytes,
Expand Down Expand Up @@ -170,8 +175,7 @@ private GroupIntoBatches(BatchingParams<InputT> params) {
/** Aim to create batches each with the specified element count. */
public static <K, InputT> GroupIntoBatches<K, InputT> ofSize(long batchSize) {
Preconditions.checkState(batchSize < Long.MAX_VALUE);
return new GroupIntoBatches<>(
BatchingParams.create(batchSize, Long.MAX_VALUE, null, Duration.ZERO));
return new GroupIntoBatches<K, InputT>(BatchingParams.createDefault()).withSize(batchSize);
}

/**
Expand All @@ -185,9 +189,8 @@ public static <K, InputT> GroupIntoBatches<K, InputT> ofSize(long batchSize) {
* {@link #ofByteSize(long, SerializableFunction)} to specify code to calculate the byte size.
*/
public static <K, InputT> GroupIntoBatches<K, InputT> ofByteSize(long batchSizeBytes) {
Preconditions.checkState(batchSizeBytes < Long.MAX_VALUE);
return new GroupIntoBatches<>(
BatchingParams.create(Long.MAX_VALUE, batchSizeBytes, null, Duration.ZERO));
return new GroupIntoBatches<K, InputT>(BatchingParams.createDefault())
.withByteSize(batchSizeBytes);
}

/**
Expand All @@ -196,16 +199,49 @@ public static <K, InputT> GroupIntoBatches<K, InputT> ofByteSize(long batchSizeB
*/
public static <K, InputT> GroupIntoBatches<K, InputT> ofByteSize(
long batchSizeBytes, SerializableFunction<InputT, Long> getElementByteSize) {
Preconditions.checkState(batchSizeBytes < Long.MAX_VALUE);
return new GroupIntoBatches<>(
BatchingParams.create(Long.MAX_VALUE, batchSizeBytes, getElementByteSize, Duration.ZERO));
return new GroupIntoBatches<K, InputT>(BatchingParams.createDefault())
.withByteSize(batchSizeBytes, getElementByteSize);
}

/** Returns user supplied parameters for batching. */
public BatchingParams<InputT> getBatchingParams() {
return params;
}

/** @see #ofSize(long) */
public GroupIntoBatches<K, InputT> withSize(long batchSize) {
Preconditions.checkState(batchSize < Long.MAX_VALUE);
return new GroupIntoBatches<>(
BatchingParams.create(
batchSize,
params.getBatchSizeBytes(),
params.getElementByteSize(),
params.getMaxBufferingDuration()));
}

/** @see #ofByteSize(long) */
public GroupIntoBatches<K, InputT> withByteSize(long batchSizeBytes) {
Preconditions.checkState(batchSizeBytes < Long.MAX_VALUE);
return new GroupIntoBatches<>(
BatchingParams.create(
params.getBatchSize(),
batchSizeBytes,
params.getElementByteSize(),
params.getMaxBufferingDuration()));
}

/** @see #ofByteSize(long, SerializableFunction) */
public GroupIntoBatches<K, InputT> withByteSize(
long batchSizeBytes, SerializableFunction<InputT, Long> getElementByteSize) {
Preconditions.checkState(batchSizeBytes < Long.MAX_VALUE);
return new GroupIntoBatches<>(
BatchingParams.create(
params.getBatchSize(),
batchSizeBytes,
getElementByteSize,
params.getMaxBufferingDuration()));
}

/**
* Sets a time limit (in processing time) on how long an incomplete batch of elements is allowed
* to be buffered. Once a batch is flushed to output, the timer is reset. The provided limit must
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
Expand Down Expand Up @@ -51,8 +53,13 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Rule;
Expand Down Expand Up @@ -116,16 +123,6 @@ public void testInGlobalWindowBatchSizeCount() {
PAssert.that("Incorrect batch size in one or more elements", collection)
.satisfies(
new SerializableFunction<Iterable<KV<String, Iterable<String>>>, Void>() {

private boolean checkBatchSizes(Iterable<KV<String, Iterable<String>>> listToCheck) {
for (KV<String, Iterable<String>> element : listToCheck) {
if (Iterables.size(element.getValue()) != BATCH_SIZE) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did notice that it's != and not > here, but the test is still valid with > (we have 10 elements, and 5 batch size, so it can't be anything but 5, and we check the batch count at the end with EVEN_NUM_ELEMENTS / BATCH_SIZE)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that this previous test was too strict and your update makes sense. GroupIntoBatches ensures that the batches aren't bigger than BATCH_SIZE elements.

Unfortunately I think the GroupIntoBatches specification is too loose since it uses words like Aim to create batches. It would be great if we could make it a strict guarantee, for example batches will never be bigger then element count, or that they will never be bigger then byte size (except for the case where a single element is bigger then byte size and it will show up in its own group). I wouldn't try to solve this here but it would make sense to have a bug/and or follow-up PR to make this explicit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think I did solved that. I mean apart from the inaccuracy of the weigher.

return false;
}
}
return true;
}

@Override
public Void apply(Iterable<KV<String, Iterable<String>>> input) {
assertTrue(checkBatchSizes(input));
Expand Down Expand Up @@ -155,28 +152,9 @@ public void testInGlobalWindowBatchSizeByteSize() {
PAssert.that("Incorrect batch size in one or more elements", collection)
.satisfies(
new SerializableFunction<Iterable<KV<String, Iterable<String>>>, Void>() {

private boolean checkBatchSizes(Iterable<KV<String, Iterable<String>>> listToCheck) {
for (KV<String, Iterable<String>> element : listToCheck) {
long byteSize = 0;
for (String str : element.getValue()) {
if (byteSize >= BATCH_SIZE_BYTES) {
// We already reached the batch size, so extra elements are not expected.
return false;
}
try {
byteSize += StringUtf8Coder.of().getEncodedElementByteSize(str);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
return true;
}

@Override
public Void apply(Iterable<KV<String, Iterable<String>>> input) {
assertTrue(checkBatchSizes(input));
assertTrue(checkBatchByteSizes(input));
return null;
}
});
Expand All @@ -194,46 +172,27 @@ public Void apply(Iterable<KV<String, Iterable<String>>> input) {
UsesOnWindowExpiration.class
})
public void testInGlobalWindowBatchSizeByteSizeFn() {
SerializableFunction<String, Long> getElementByteSizeFn =
s -> {
try {
return 2 * StringUtf8Coder.of().getEncodedElementByteSize(s);
} catch (Exception e) {
throw new RuntimeException(e);
}
};

PCollection<KV<String, Iterable<String>>> collection =
pipeline
.apply("Input data", Create.of(data))
.apply(
GroupIntoBatches.ofByteSize(
BATCH_SIZE_BYTES,
s -> {
try {
return 2 * StringUtf8Coder.of().getEncodedElementByteSize(s);
} catch (Exception e) {
throw new RuntimeException(e);
}
}))
.apply(GroupIntoBatches.ofByteSize(BATCH_SIZE_BYTES, getElementByteSizeFn))
// set output coder
.setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of())));
PAssert.that("Incorrect batch size in one or more elements", collection)
.satisfies(
new SerializableFunction<Iterable<KV<String, Iterable<String>>>, Void>() {

private boolean checkBatchSizes(Iterable<KV<String, Iterable<String>>> listToCheck) {
for (KV<String, Iterable<String>> element : listToCheck) {
long byteSize = 0;
for (String str : element.getValue()) {
if (byteSize >= BATCH_SIZE_BYTES) {
// We already reached the batch size, so extra elements are not expected.
return false;
}
try {
byteSize += 2 * StringUtf8Coder.of().getEncodedElementByteSize(str);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
return true;
}

@Override
public Void apply(Iterable<KV<String, Iterable<String>>> input) {
assertTrue(checkBatchSizes(input));
assertTrue(checkBatchByteSizes(input, getElementByteSizeFn));
return null;
}
});
Expand Down Expand Up @@ -267,20 +226,9 @@ public void testWithShardedKeyInGlobalWindow() {
PAssert.that("Incorrect batch size in one or more elements", collection)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move the comment just above into checkBatchSizes:

    // Since with default sharding, the number of subshards of a key is nondeterministic, create
    // a large number of input elements and a small batch size and check there is no batch larger
    // than the specified size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Erhm, isn't this comment only valid for .withShardedKey()?

.satisfies(
new SerializableFunction<Iterable<KV<ShardedKey<String>, Iterable<String>>>, Void>() {

private boolean checkBatchSizes(
Iterable<KV<ShardedKey<String>, Iterable<String>>> listToCheck) {
for (KV<ShardedKey<String>, Iterable<String>> element : listToCheck) {
if (Iterables.size(element.getValue()) > batchSize) {
return false;
}
}
return true;
}

@Override
public Void apply(Iterable<KV<ShardedKey<String>, Iterable<String>>> input) {
assertTrue(checkBatchSizes(input));
assertTrue(checkBatchSizes(input, batchSize));
return null;
}
});
Expand Down Expand Up @@ -353,17 +301,6 @@ public void testWithUnevenBatches() {
PAssert.that("Incorrect batch size in one or more elements", collection)
.satisfies(
new SerializableFunction<Iterable<KV<String, Iterable<String>>>, Void>() {

private boolean checkBatchSizes(Iterable<KV<String, Iterable<String>>> listToCheck) {
for (KV<String, Iterable<String>> element : listToCheck) {
// number of elements should be less than or equal to BATCH_SIZE
if (Iterables.size(element.getValue()) > BATCH_SIZE) {
return false;
}
}
return true;
}

@Override
public Void apply(Iterable<KV<String, Iterable<String>>> input) {
assertTrue(checkBatchSizes(input));
Expand Down Expand Up @@ -709,4 +646,118 @@ public void processElement(ProcessContext c, BoundedWindow window) {

pipeline.run().waitUntilFinish();
}

@Test
@Category({
ValidatesRunner.class,
NeedsRunner.class,
UsesTimersInParDo.class,
UsesStatefulParDo.class,
UsesOnWindowExpiration.class
})
public void testMultipleLimitsAtOnceInGlobalWindowBatchSizeCountAndBatchSizeByteSize() {
// with using only one of the limits the result would be only 2 batches,
// if we have 3 both limit works
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// if we have 3 both limit works
// if we have 3 both limits are exercised

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List<KV<String, String>> dataToUse =
Lists.newArrayList(
"a-1",
"a-2",
"a-3" + Strings.repeat("-", 100),
// byte size limit reached (BATCH_SIZE_BYTES = 25)
"b-4",
"b-5",
"b-6",
"b-7",
"b-8",
// count limit reached (BATCH_SIZE = 5)
"c-9")
.stream()
.map(s -> KV.of("key", s))
.collect(Collectors.toList());
PCollection<KV<String, Iterable<String>>> collection =
pipeline
.apply("Input data", Create.of(dataToUse))
.apply(
GroupIntoBatches.<String, String>ofSize(BATCH_SIZE).withByteSize(BATCH_SIZE_BYTES))
// set output coder
.setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of())));
PAssert.that("Incorrect batch size in one or more elements", collection)
.satisfies(
new SerializableFunction<Iterable<KV<String, Iterable<String>>>, Void>() {

private void assertExpectedBatchPrefix(
Iterable<KV<String, Iterable<String>>> listToCheck) {
for (KV<String, Iterable<String>> element : listToCheck) {
Set<String> batchPrefixes =
Streams.stream(element.getValue())
.map(s -> s.split("-")[0])
.collect(Collectors.toSet());
assertEquals("Found invalid batching: " + batchPrefixes, 1, batchPrefixes.size());
}
}

@Override
public Void apply(Iterable<KV<String, Iterable<String>>> input) {
assertTrue(checkBatchSizes(input));
assertTrue(checkBatchByteSizes(input));
assertExpectedBatchPrefix(input);
return null;
}
});

PAssert.thatSingleton("Incorrect batching", collection.apply("Count", Count.globally()))
.satisfies(
numberOfBatches -> {
MatcherAssert.assertThat(numberOfBatches, Matchers.equalTo(3L));
return null;
});
pipeline.run();
}

private static <K> boolean checkBatchSizes(Iterable<KV<K, Iterable<String>>> listToCheck) {
return checkBatchSizes(listToCheck, BATCH_SIZE);
}

private static <K> boolean checkBatchSizes(
Iterable<KV<K, Iterable<String>>> listToCheck, int batchSize) {
for (KV<?, Iterable<String>> element : listToCheck) {
// number of elements should be less than or equal to the batch size
if (Iterables.size(element.getValue()) > batchSize) {
return false;
}
}
return true;
}

private static <K> boolean checkBatchByteSizes(Iterable<KV<K, Iterable<String>>> listToCheck) {
return checkBatchByteSizes(
listToCheck,
s -> {
try {
return StringUtf8Coder.of().getEncodedElementByteSize(s);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}

private static <K> boolean checkBatchByteSizes(
Iterable<KV<K, Iterable<String>>> listToCheck,
SerializableFunction<String, Long> getElementByteSizeFn) {
for (KV<?, Iterable<String>> element : listToCheck) {
long byteSize = 0;
for (String str : element.getValue()) {
if (byteSize >= BATCH_SIZE_BYTES) {
// We already reached the batch size, so extra elements are not expected.
return false;
}
try {
byteSize += getElementByteSizeFn.apply(str);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
return true;
}
}