Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik committed Aug 25, 2019
1 parent 262a6f4 commit 80932de
Show file tree
Hide file tree
Showing 15 changed files with 149 additions and 38 deletions.
8 changes: 6 additions & 2 deletions sdks/python/apache_beam/coders/coder_impl.py
Expand Up @@ -750,8 +750,12 @@ class SequenceCoderImpl(StreamCoderImpl):
# Default buffer size of 64kB of handling iterables of unknown length.
_DEFAULT_BUFFER_SIZE = 64 * 1024

def __init__(self, elem_coder,
read_state=None, write_state=None, write_state_threshold=0):
def __init__(self,
elem_coder, # type: Coder
read_state=None, # type: Optional[Callable[[bytes, Coder], Iterable]]
write_state=None, # type: Optional[Callable[[Iterable, Coder], bytes]]
write_state_threshold=0 # type: int
):
self._elem_coder = elem_coder
self._read_state = read_state
self._write_state = write_state
Expand Down
25 changes: 22 additions & 3 deletions sdks/python/apache_beam/coders/coders.py
Expand Up @@ -82,8 +82,12 @@

CoderT = TypeVar('CoderT', bound='Coder')
ProtoCoderT = TypeVar('ProtoCoderT', bound='ProtoCoder')
ParameterType = Union['message.Message', bytes, None]
ConstructorFn = Callable[[ParameterType, List['Coder'], 'PipelineContext'], Any]
ParameterType = Union[Type['message.Message'], Type[bytes], None]
ConstructorFn = Callable[
[Union['message.Message', bytes],
List['Coder'],
'PipelineContext'],
Any]


def serialize_coder(coder):
Expand Down Expand Up @@ -268,11 +272,26 @@ def __hash__(self):
_known_urns = {} # type: Dict[str, Tuple[ParameterType, ConstructorFn]]

@classmethod
@typing.overload
def register_urn(cls,
urn, # type: str
parameter_type, # type: ParameterType
fn=None # type: Optional[ConstructorFn]
):
# type: (...) -> ConstructorFn
pass

@classmethod
@typing.overload
def register_urn(cls,
urn, # type: str
parameter_type, # type: ParameterType
fn # type: ConstructorFn
):
# type: (...) -> None
pass

@classmethod
def register_urn(cls, urn, parameter_type, fn=None):
"""Registers a urn with a constructor.
For example, if 'beam:fn:foo' had parameter type FooPayload, one could
Expand Down
8 changes: 7 additions & 1 deletion sdks/python/apache_beam/io/iobase.py
Expand Up @@ -99,6 +99,10 @@ class SourceBase(HasDisplayData, urns.RunnerApiFn, Generic[T]):
"""
urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_SOURCE)

def is_bounded(self):
# type: () -> bool
raise NotImplementedError


class BoundedSource(SourceBase[T]):
"""A source that reads a finite amount of input records.
Expand Down Expand Up @@ -861,7 +865,7 @@ class Read(ptransform.PTransform[pvalue.PBeginType, OutT]):
"""A transform that reads a PCollection."""

def __init__(self, source):
# type: (BoundedSource) -> None
# type: (SourceBase) -> None
"""Initializes a Read transform.
Args:
Expand Down Expand Up @@ -926,6 +930,7 @@ def display_data(self):
'source_dd': self.source}

def to_runner_api_parameter(self, context):
# type: (PipelineContext) -> Tuple[str, ptransform.ParameterType]
return (common_urns.deprecated_primitives.READ.urn,
beam_runner_api_pb2.ReadPayload(
source=self.source.to_runner_api(context),
Expand All @@ -935,6 +940,7 @@ def to_runner_api_parameter(self, context):

@staticmethod
def from_runner_api_parameter(parameter, context):
# type: (beam_runner_api_pb2.ReadPayload, PipelineContext) -> Read
return Read(SourceBase.from_runner_api(parameter.source, context))


Expand Down
5 changes: 4 additions & 1 deletion sdks/python/apache_beam/options/pipeline_options.py
Expand Up @@ -27,6 +27,7 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Optional

from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.options.value_provider import StaticValueProvider
Expand Down Expand Up @@ -157,7 +158,9 @@ def _add_argparse_args(cls, parser):
By default the options classes will use command line arguments to initialize
the options.
"""
def __init__(self, flags=None, **kwargs):
def __init__(self,
flags=None, # type: Optional[List[str]]
**kwargs):
"""Initialize an options class.
The initializer will traverse all subclasses, add all their argparse
Expand Down
25 changes: 20 additions & 5 deletions sdks/python/apache_beam/pipeline.py
Expand Up @@ -57,6 +57,7 @@
from builtins import zip
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

from future.utils import with_metaclass
Expand All @@ -83,6 +84,7 @@

if typing.TYPE_CHECKING:
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import PipelineResult
from apache_beam.runners.pipeline_context import PipelineContext

__all__ = ['Pipeline', 'PTransformOverride']
Expand All @@ -103,7 +105,11 @@ class Pipeline(object):
(e.g. ``input | "label" >> my_tranform``).
"""

def __init__(self, runner=None, options=None, argv=None):
def __init__(self,
runner=None, # type: Optional[PipelineRunner]
options=None, # type: Optional[PipelineOptions]
argv=None # type: Optional[List[str]]
):
"""Initialize a pipeline object.
Args:
Expand Down Expand Up @@ -405,6 +411,7 @@ def replace_all(self, replacements):
self._check_replacement(override)

def run(self, test_runner_api=True):
# type: (bool) -> PipelineResult
"""Runs the pipeline. Returns whatever our runner returns after running."""

# When possible, invoke a round trip through the runner API.
Expand Down Expand Up @@ -435,6 +442,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.run().wait_until_finish()

def visit(self, visitor):
# type: (PipelineVisitor) -> None
"""Visits depth-first every node of a pipeline's DAG.
Runner-internal implementation detail; no backwards-compatibility guarantees
Expand Down Expand Up @@ -677,8 +685,12 @@ def visit_transform(self, transform_node):
return proto

@staticmethod
def from_runner_api(proto, runner, options, return_context=False,
allow_proto_holders=False):
def from_runner_api(proto, # type: beam_runner_api_pb2.Pipeline
runner, # type: PipelineRunner
options, # type: PipelineOptions
return_context=False,
allow_proto_holders=False
):
# type: (...) -> Pipeline
"""For internal use only; no backwards-compatibility guarantees."""
p = Pipeline(runner=runner, options=options)
Expand Down Expand Up @@ -753,7 +765,7 @@ def __init__(self,
parent,
transform, # type: ptransform.PTransform
full_label,
inputs
inputs # type: Iterable[pvalue.PCollection]
):
self.parent = parent
self.transform = transform
Expand All @@ -765,7 +777,7 @@ def __init__(self,
self.full_label = full_label
self.inputs = inputs or ()
self.side_inputs = () if transform is None else tuple(transform.side_inputs)
self.outputs = {} # type: Dict[Union[str, int, None], pvalue.PValue]
self.outputs = {} # type: Dict[Union[str, int, None], Union[pvalue.PValue, pvalue.DoOutputsTuple]]
self.parts = [] # type: List[AppliedPTransform]

def __repr__(self):
Expand Down Expand Up @@ -870,6 +882,7 @@ def visit(self, visitor, pipeline, visited):
visitor.visit_value(v, self)

def named_inputs(self):
# type: () -> Dict[str, pvalue.PCollection]
# TODO(BEAM-1833): Push names up into the sdk construction.
main_inputs = {str(ix): input
for ix, input in enumerate(self.inputs)
Expand All @@ -879,6 +892,7 @@ def named_inputs(self):
return dict(main_inputs, **side_inputs)

def named_outputs(self):
# type: () -> Dict[str, pvalue.PCollection]
return {str(tag): output for tag, output in self.outputs.items()
if isinstance(output, pvalue.PCollection)}

Expand Down Expand Up @@ -912,6 +926,7 @@ def transform_to_runner_api(transform, context):

@staticmethod
def from_runner_api(proto, context):
# type: (beam_runner_api_pb2.PTransform, PipelineContext) -> AppliedPTransform
def is_side_input(tag):
# As per named_inputs() above.
return tag.startswith('side')
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/runners/direct/bundle_factory.py
Expand Up @@ -45,6 +45,7 @@ def create_bundle(self, output_pcollection):
return _Bundle(output_pcollection, self._stacked)

def create_empty_committed_bundle(self, output_pcollection):
# type: (pvalue.PCollection) -> _Bundle
bundle = self.create_bundle(output_pcollection)
bundle.commit(None)
return bundle
Expand Down
4 changes: 3 additions & 1 deletion sdks/python/apache_beam/runners/direct/evaluation_context.py
Expand Up @@ -40,7 +40,7 @@
if typing.TYPE_CHECKING:
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pvalue import AsSideInput, PCollection
from apache_beam.runners.direct.bundle_factory import BundleFactory
from apache_beam.runners.direct.bundle_factory import BundleFactory, _Bundle
from apache_beam.utils.timestamp import Timestamp

class _ExecutionContext(object):
Expand Down Expand Up @@ -375,10 +375,12 @@ def get_execution_context(self, applied_ptransform):
self._transform_keyed_states[applied_ptransform])

def create_bundle(self, output_pcollection):
# type: (pvalue.PCollection) -> _Bundle
"""Create an uncommitted bundle for the specified PCollection."""
return self._bundle_factory.create_bundle(output_pcollection)

def create_empty_committed_bundle(self, output_pcollection):
# type: (pvalue.PCollection) -> _Bundle
"""Create empty bundle useful for triggering evaluation."""
return self._bundle_factory.create_empty_committed_bundle(
output_pcollection)
Expand Down
10 changes: 7 additions & 3 deletions sdks/python/apache_beam/runners/direct/transform_evaluator.py
Expand Up @@ -567,9 +567,13 @@ def __missing__(self, key):
class _ParDoEvaluator(_TransformEvaluator):
"""TransformEvaluator for ParDo transform."""

def __init__(self, evaluation_context, applied_ptransform,
input_committed_bundle, side_inputs,
perform_dofn_pickle_test=True):
def __init__(self,
evaluation_context, # type: EvaluationContext
applied_ptransform, # type: AppliedPTransform
input_committed_bundle,
side_inputs,
perform_dofn_pickle_test=True
):
super(_ParDoEvaluator, self).__init__(
evaluation_context, applied_ptransform, input_committed_bundle,
side_inputs)
Expand Down
9 changes: 5 additions & 4 deletions sdks/python/apache_beam/runners/portability/fn_api_runner.py
Expand Up @@ -85,10 +85,11 @@
from google.protobuf import message
from apache_beam.runners.portability import fn_api_runner

ConstructorFn = Callable[[Union['message.Message', bytes],
'FnApiRunner.StateServicer',
Optional['fn_api_runner.ExtendedProvisionInfo']],
Any]
ConstructorFn = Callable[
[Union['message.Message', bytes],
'FnApiRunner.StateServicer',
Optional['fn_api_runner.ExtendedProvisionInfo']],
Any]

# This module is experimental. No backwards-compatibility guarantees.

Expand Down

0 comments on commit 80932de

Please sign in to comment.