Skip to content

Commit

Permalink
[BEAM-7746] Add typing for try_split
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik committed Jan 15, 2020
1 parent e44a1b6 commit cafbbf5
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
1 change: 1 addition & 0 deletions sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,7 @@ def deferred_status(self):
self._deferred_watermark -= (timestamp.Timestamp.now() -
self._deferred_timestamp)
return self._deferred_residual, self._deferred_watermark
return None


class RestrictionTrackerView(object):
Expand Down
16 changes: 11 additions & 5 deletions sdks/python/apache_beam/runners/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
from apache_beam.transforms import sideinputs
from apache_beam.transforms.core import TimerSpec

SplitResultType = Tuple[
Tuple[WindowedValue, Optional[Timestamp]], Optional[Timestamp]]


class NameContext(object):
"""Holds the name information for a step."""
Expand Down Expand Up @@ -414,7 +417,7 @@ def invoke_process(self,
additional_args=None,
additional_kwargs=None
):
# type: (...) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (...) -> Optional[SplitResultType]
"""Invokes the DoFn.process() function.
Args:
Expand Down Expand Up @@ -618,7 +621,7 @@ def invoke_process(self,
additional_args=None,
additional_kwargs=None
):
# type: (...) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (...) -> Optional[SplitResultType]
if not additional_args:
additional_args = []
if not additional_kwargs:
Expand Down Expand Up @@ -684,7 +687,7 @@ def _invoke_process_per_window(self,
additional_kwargs,
output_processor # type: OutputProcessor
):
# type: (...) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (...) -> Optional[SplitResultType]
if self.has_windowed_inputs:
window, = windowed_value.windows
side_inputs = [si[window] for si in self.side_inputs]
Expand Down Expand Up @@ -800,6 +803,7 @@ def try_split(self, fraction):
(element, primary), primary_size)), None), None),
((self.current_windowed_value.with_value((
(element, residual), residual_size)), current_watermark), None))
return None

def current_element_progress(self):
# type: () -> Optional[iobase.RestrictionProgress]
Expand Down Expand Up @@ -889,22 +893,24 @@ def receive(self, windowed_value):
self.process(windowed_value)

def process(self, windowed_value):
# type: (WindowedValue) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (WindowedValue) -> Optional[SplitResultType]
try:
return self.do_fn_invoker.invoke_process(windowed_value)
except BaseException as exn:
self._reraise_augmented(exn)
return None

def process_with_sized_restriction(self, windowed_value):
# type: (WindowedValue) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (WindowedValue) -> Optional[SplitResultType]
(element, restriction), _ = windowed_value.value
return self.do_fn_invoker.invoke_process(
windowed_value.with_value(element),
restriction_tracker=self.do_fn_invoker.invoke_create_tracker(
restriction))

def try_split(self, fraction):
# type: (...) -> Optional[Tuple[SplitResultType, SplitResultType]]
assert isinstance(self.do_fn_invoker, PerWindowInvoker)
return self.do_fn_invoker.try_split(fraction)

def current_element_progress(self):
Expand Down
16 changes: 10 additions & 6 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,17 @@ def process_encoded(self, encoded_windowed_values):
self.output(decoded_value)

def try_split(self, fraction_of_remainder, total_buffer_size):
# type: (...) -> Optional[Tuple[int, Optional[Tuple[operations.DoOperation, common.SplitResultType]], Optional[Tuple[operations.DoOperation, common.SplitResultType]], int]]
with self.splitting_lock:
if not self.started:
return
return None
if total_buffer_size < self.index + 1:
total_buffer_size = self.index + 1
elif self.stop and total_buffer_size > self.stop:
total_buffer_size = self.stop
if self.index == -1:
# We are "finished" with the (non-existent) previous element.
current_element_progress = 1
current_element_progress = 1.0
else:
current_element_progress_object = (
self.receivers[0].current_element_progress())
Expand Down Expand Up @@ -866,7 +867,7 @@ def try_split(self, bundle_split_request):

def delayed_bundle_application(self,
op, # type: operations.DoOperation
deferred_remainder # type: Tuple[windowed_value.WindowedValue, Timestamp]
deferred_remainder # type: common.SplitResultType
):
# type: (...) -> beam_fn_api_pb2.DelayedBundleApplication
assert op.input_info is not None
Expand All @@ -884,7 +885,10 @@ def delayed_bundle_application(self,
application=self.construct_bundle_application(
op, output_watermark, element_and_restriction))

def bundle_application(self, op, primary):
def bundle_application(self,
op, # type: operations.DoOperation
primary # type: common.SplitResultType
):
((element_and_restriction, output_watermark),
_) = primary
return self.construct_bundle_application(
Expand Down Expand Up @@ -1003,7 +1007,7 @@ def shutdown(self):

class ExecutionContext(object):
def __init__(self):
self.delayed_applications = [] # type: List[Tuple[operations.DoOperation, Tuple[windowed_value.WindowedValue, Timestamp]]]
self.delayed_applications = [] # type: List[Tuple[operations.DoOperation, common.SplitResultType]]


class BeamTransformFactory(object):
Expand Down Expand Up @@ -1370,7 +1374,7 @@ def create_par_do(


def _create_pardo_operation(
factory,
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
consumers,
Expand Down
15 changes: 13 additions & 2 deletions sdks/python/apache_beam/runners/worker/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from builtins import object
from builtins import zip
from typing import TYPE_CHECKING
from typing import Any
from typing import DefaultDict
from typing import Dict
from typing import FrozenSet
Expand Down Expand Up @@ -132,6 +133,7 @@ def receive(self, windowed_value):
self.update_counters_finish()

def try_split(self, fraction_of_remainder):
# type: (...) -> Optional[Any]
# TODO(SDF): Consider supporting splitting each consumer individually.
# This would never come up in the existing SDF expansion, but might
# be useful to support fused SDF nodes.
Expand Down Expand Up @@ -165,8 +167,13 @@ def __repr__(self):


class SingletonConsumerSet(ConsumerSet):
def __init__(
self, counter_factory, step_name, output_index, consumers, coder):
def __init__(self,
counter_factory,
step_name,
output_index,
consumers, # type: List[Operation]
coder
):
assert len(consumers) == 1
super(SingletonConsumerSet, self).__init__(
counter_factory, step_name, output_index, consumers, coder)
Expand All @@ -179,6 +186,7 @@ def receive(self, windowed_value):
self.update_counters_finish()

def try_split(self, fraction_of_remainder):
# type: (...) -> Optional[Any]
return self.consumer.try_split(fraction_of_remainder)

def current_element_progress(self):
Expand Down Expand Up @@ -278,6 +286,7 @@ def needs_finalization(self):
return False

def try_split(self, fraction_of_remainder):
# type: (...) -> Optional[Any]
return None

def current_element_progress(self):
Expand Down Expand Up @@ -768,10 +777,12 @@ def process(self, o):
self.element_start_output_bytes = None

def try_split(self, fraction_of_remainder):
# type: (...) -> Optional[Tuple[Tuple[DoOperation, common.SplitResultType], Tuple[DoOperation, common.SplitResultType]]]
split = self.dofn_runner.try_split(fraction_of_remainder)
if split:
primary, residual = split
return (self, primary), (self, residual)
return None

def current_element_progress(self):
with self.lock:
Expand Down

0 comments on commit cafbbf5

Please sign in to comment.