Skip to content

Commit

Permalink
Fixing State.readLater() usage + Unit test fix to reflect changed exp…
Browse files Browse the repository at this point in the history
…ectations
  • Loading branch information
nbali committed Nov 17, 2022
1 parent 5840e55 commit 6d016f1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,21 @@ public void processElement(
@Timestamp Instant elementTs,
BoundedWindow window,
OutputReceiver<KV<K, Iterable<InputT>>> receiver) {

final boolean shouldCareAboutWeight = weigher != null && batchSizeBytes != Long.MAX_VALUE;
final boolean shouldCareAboutMaxBufferingDuration =
maxBufferingDuration.isLongerThan(Duration.ZERO);

if (shouldCareAboutWeight) {
storedBatchSizeBytes.readLater();
}
storedBatchSize.readLater();
if (shouldCareAboutMaxBufferingDuration) {
minBufferedTs.readLater();
}

LOG.debug("*** BATCH *** Add element for window {} ", window);
if (shouldCareAboutWeight()) {
if (shouldCareAboutWeight) {
final long elementWeight = weigher.apply(element.getValue());
if (elementWeight + storedBatchSizeBytes.read() > batchSizeBytes) {
// Firing by count and size limits behave differently.
Expand Down Expand Up @@ -494,17 +507,13 @@ public void processElement(
bufferingTimer.clear();
}
storedBatchSizeBytes.add(elementWeight);
storedBatchSizeBytes.readLater();
}
batch.add(element.getValue());
// Blind add is supported with combiningState
storedBatchSize.add(1L);

long num;
if (maxBufferingDuration.isLongerThan(Duration.ZERO)) {
minBufferedTs.readLater();
num = storedBatchSize.read();

final long num = storedBatchSize.read();
if (shouldCareAboutMaxBufferingDuration) {
long oldOutputTs =
MoreObjects.firstNonNull(
minBufferedTs.read(), BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis());
Expand All @@ -523,15 +532,14 @@ public void processElement(
.set(Instant.ofEpochMilli(targetTs));
}
}
num = storedBatchSize.read();

if (num % prefetchFrequency == 0) {
// Prefetch data and modify batch state (readLater() modifies this)
batch.readLater();
}

if (num >= batchSize
|| (shouldCareAboutWeight() && storedBatchSizeBytes.read() >= batchSizeBytes)) {
|| (shouldCareAboutWeight && storedBatchSizeBytes.read() >= batchSizeBytes)) {
LOG.debug("*** END OF BATCH *** for window {}", window.toString());
flushBatch(
receiver,
Expand All @@ -545,10 +553,6 @@ public void processElement(
}
}

private boolean shouldCareAboutWeight() {
return weigher != null && batchSizeBytes != Long.MAX_VALUE;
}

@OnTimer(END_OF_BUFFERING_ID)
public void onBufferingTimer(
OutputReceiver<KV<K, Iterable<InputT>>> receiver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,15 @@ public Void apply(Iterable<KV<String, Iterable<String>>> input) {
}
});
PAssert.thatSingleton("Incorrect collection size", collection.apply("Count", Count.globally()))
.isEqualTo(3L);
.isEqualTo(4L);
pipeline.run();
}

@Test
@Category({
ValidatesRunner.class,
NeedsRunner.class,
UsesTestStream.class,
UsesTimersInParDo.class,
UsesStatefulParDo.class,
UsesOnWindowExpiration.class
Expand All @@ -180,9 +181,25 @@ public void testInGlobalWindowBatchSizeByteSizeFn() {
}
};

// to ensure ordered processing
TestStream.Builder<KV<String, String>> streamBuilder =
TestStream.create(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
.advanceWatermarkTo(Instant.EPOCH);

long offset = 0L;
for (KV<String, String> kv : data) {
streamBuilder =
streamBuilder.addElements(
TimestampedValue.of(kv, Instant.EPOCH.plus(Duration.standardSeconds(offset))));
offset++;
}

// fire them all at once
TestStream<KV<String, String>> stream = streamBuilder.advanceWatermarkToInfinity();

PCollection<KV<String, Iterable<String>>> collection =
pipeline
.apply("Input data", Create.of(data))
.apply("Input data", stream)
.apply(GroupIntoBatches.ofByteSize(BATCH_SIZE_BYTES, getElementByteSizeFn))
// set output coder
.setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of())));
Expand All @@ -192,11 +209,10 @@ public void testInGlobalWindowBatchSizeByteSizeFn() {
@Override
public Void apply(Iterable<KV<String, Iterable<String>>> input) {
assertTrue(checkBatchByteSizes(input, getElementByteSizeFn));
assertEquals("Invalid batch count", 9L, Iterables.size(input));
return null;
}
});
PAssert.thatSingleton("Incorrect collection size", collection.apply("Count", Count.globally()))
.isEqualTo(5L);
pipeline.run();
}

Expand Down Expand Up @@ -662,15 +678,18 @@ public void testMultipleLimitsAtOnceInGlobalWindowBatchSizeCountAndBatchSizeByte
Lists.newArrayList(
"a-1",
"a-2",
"a-3" + Strings.repeat("-", 100),
// byte size limit reached (BATCH_SIZE_BYTES = 25)
"b-4",
"b-5",
"b-6",
"b-7",
"b-8",
// batch byte size limit would be reached with the next one so "firing" current
// batch content (BATCH_SIZE_BYTES = 25)
"b-3" + Strings.repeat("-", 100),
// batch byte size is over the limit, but we have a single element that we can't
// split to smaller batches (BATCH_SIZE_BYTES = 25)
"c-4",
"c-5",
"c-6",
"c-7",
"c-8",
// count limit reached (BATCH_SIZE = 5)
"c-9")
"d-9")
.stream()
.map(s -> KV.of("key", s))
.collect(Collectors.toList());
Expand Down Expand Up @@ -719,7 +738,7 @@ public Void apply(Iterable<KV<String, Iterable<String>>> input) {
assertTrue(checkBatchByteSizes(input));
assertExpectedBatchPrefix(input);
assertEquals(
Lists.newArrayList(3, 5, 1),
Lists.newArrayList(2, 1, 5, 1),
Streams.stream(input)
.map(KV::getValue)
.map(Iterables::size)
Expand Down

0 comments on commit 6d016f1

Please sign in to comment.