Skip to content

Commit

Permalink
[BEAM-7746] Add typing for try_split
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik committed Feb 7, 2020
1 parent b4f0288 commit a2451d4
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
1 change: 1 addition & 0 deletions sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,7 @@ def deferred_status(self):
self._deferred_watermark -= (
timestamp.Timestamp.now() - self._deferred_timestamp)
return self._deferred_residual, self._deferred_watermark
return None


class RestrictionTrackerView(object):
Expand Down
16 changes: 11 additions & 5 deletions sdks/python/apache_beam/runners/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
from apache_beam.transforms import sideinputs
from apache_beam.transforms.core import TimerSpec

SplitResultType = Tuple[Tuple[WindowedValue, Optional[Timestamp]],
Optional[Timestamp]]


class NameContext(object):
"""Holds the name information for a step."""
Expand Down Expand Up @@ -415,7 +418,7 @@ def invoke_process(self,
additional_args=None,
additional_kwargs=None
):
# type: (...) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (...) -> Optional[SplitResultType]

"""Invokes the DoFn.process() function.
Expand Down Expand Up @@ -619,7 +622,7 @@ def invoke_process(self,
additional_args=None,
additional_kwargs=None
):
# type: (...) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (...) -> Optional[SplitResultType]
if not additional_args:
additional_args = []
if not additional_kwargs:
Expand Down Expand Up @@ -682,7 +685,7 @@ def _invoke_process_per_window(self,
additional_args,
additional_kwargs,
):
# type: (...) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (...) -> Optional[SplitResultType]
if self.has_windowed_inputs:
window, = windowed_value.windows
side_inputs = [si[window] for si in self.side_inputs]
Expand Down Expand Up @@ -804,6 +807,7 @@ def try_split(self, fraction):
((element, residual), residual_size)),
current_watermark),
None))
return None

def current_element_progress(self):
# type: () -> Optional[iobase.RestrictionProgress]
Expand Down Expand Up @@ -900,22 +904,24 @@ def receive(self, windowed_value):
self.process(windowed_value)

def process(self, windowed_value):
# type: (WindowedValue) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (WindowedValue) -> Optional[SplitResultType]
try:
return self.do_fn_invoker.invoke_process(windowed_value)
except BaseException as exn:
self._reraise_augmented(exn)
return None

def process_with_sized_restriction(self, windowed_value):
# type: (WindowedValue) -> Optional[Tuple[WindowedValue, Timestamp]]
# type: (WindowedValue) -> Optional[SplitResultType]
(element, restriction), _ = windowed_value.value
return self.do_fn_invoker.invoke_process(
windowed_value.with_value(element),
restriction_tracker=self.do_fn_invoker.invoke_create_tracker(
restriction))

def try_split(self, fraction):
# type: (...) -> Optional[Tuple[SplitResultType, SplitResultType]]
assert isinstance(self.do_fn_invoker, PerWindowInvoker)
return self.do_fn_invoker.try_split(fraction)

def current_element_progress(self):
Expand Down
16 changes: 10 additions & 6 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,17 @@ def process_encoded(self, encoded_windowed_values):
self.output(decoded_value)

def try_split(self, fraction_of_remainder, total_buffer_size):
# type: (...) -> Optional[Tuple[int, Optional[Tuple[operations.DoOperation, common.SplitResultType]], Optional[Tuple[operations.DoOperation, common.SplitResultType]], int]]
with self.splitting_lock:
if not self.started:
return
return None
if total_buffer_size < self.index + 1:
total_buffer_size = self.index + 1
elif self.stop and total_buffer_size > self.stop:
total_buffer_size = self.stop
if self.index == -1:
# We are "finished" with the (non-existent) previous element.
current_element_progress = 1
current_element_progress = 1.0
else:
current_element_progress_object = (
self.receivers[0].current_element_progress())
Expand Down Expand Up @@ -900,7 +901,7 @@ def try_split(self, bundle_split_request):

def delayed_bundle_application(self,
op, # type: operations.DoOperation
deferred_remainder # type: Tuple[windowed_value.WindowedValue, Timestamp]
deferred_remainder # type: common.SplitResultType
):
# type: (...) -> beam_fn_api_pb2.DelayedBundleApplication
assert op.input_info is not None
Expand All @@ -918,7 +919,10 @@ def delayed_bundle_application(self,
application=self.construct_bundle_application(
op, output_watermark, element_and_restriction))

def bundle_application(self, op, primary):
def bundle_application(self,
op, # type: operations.DoOperation
primary # type: common.SplitResultType
):
((element_and_restriction, output_watermark), _) = primary
return self.construct_bundle_application(
op, output_watermark, element_and_restriction)
Expand Down Expand Up @@ -1043,7 +1047,7 @@ def shutdown(self):
class ExecutionContext(object):
def __init__(self):
self.delayed_applications = [
] # type: List[Tuple[operations.DoOperation, Tuple[windowed_value.WindowedValue, Timestamp]]]
] # type: List[Tuple[operations.DoOperation, common.SplitResultType]]


class BeamTransformFactory(object):
Expand Down Expand Up @@ -1412,7 +1416,7 @@ def create_par_do(


def _create_pardo_operation(
factory,
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
consumers,
Expand Down
16 changes: 14 additions & 2 deletions sdks/python/apache_beam/runners/worker/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
from builtins import object
from builtins import zip
from typing import TYPE_CHECKING
from typing import Any
from typing import DefaultDict
from typing import Dict
from typing import FrozenSet
from typing import Hashable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

from apache_beam import pvalue
Expand Down Expand Up @@ -132,6 +134,7 @@ def receive(self, windowed_value):
self.update_counters_finish()

def try_split(self, fraction_of_remainder):
# type: (...) -> Optional[Any]
# TODO(SDF): Consider supporting splitting each consumer individually.
# This would never come up in the existing SDF expansion, but might
# be useful to support fused SDF nodes.
Expand Down Expand Up @@ -169,8 +172,13 @@ def __repr__(self):


class SingletonConsumerSet(ConsumerSet):
def __init__(
self, counter_factory, step_name, output_index, consumers, coder):
def __init__(self,
counter_factory,
step_name,
output_index,
consumers, # type: List[Operation]
coder
):
assert len(consumers) == 1
super(SingletonConsumerSet, self).__init__(
counter_factory, step_name, output_index, consumers, coder)
Expand All @@ -183,6 +191,7 @@ def receive(self, windowed_value):
self.update_counters_finish()

def try_split(self, fraction_of_remainder):
# type: (...) -> Optional[Any]
return self.consumer.try_split(fraction_of_remainder)

def current_element_progress(self):
Expand Down Expand Up @@ -288,6 +297,7 @@ def needs_finalization(self):
return False

def try_split(self, fraction_of_remainder):
# type: (...) -> Optional[Any]
return None

def current_element_progress(self):
Expand Down Expand Up @@ -774,10 +784,12 @@ def process(self, o):
self.element_start_output_bytes = None

def try_split(self, fraction_of_remainder):
# type: (...) -> Optional[Tuple[Tuple[DoOperation, common.SplitResultType], Tuple[DoOperation, common.SplitResultType]]]
split = self.dofn_runner.try_split(fraction_of_remainder)
if split:
primary, residual = split
return (self, primary), (self, residual)
return None

def current_element_progress(self):
with self.lock:
Expand Down

0 comments on commit a2451d4

Please sign in to comment.