Skip to content

Commit

Permalink
more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik committed Jul 17, 2019
1 parent d3c509a commit 59cffe3
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 6 deletions.
36 changes: 36 additions & 0 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import base64
import sys
import typing
from builtins import object

import google.protobuf.wrappers_pb2
Expand All @@ -36,6 +37,10 @@
from apache_beam.typehints import typehints
from apache_beam.utils import proto_utils

if typing.TYPE_CHECKING:
from apache_beam.coders.typecoders import CoderRegistry
from apache_beam.runners.pipeline_context import PipelineContext

# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
from .stream import get_varint_size
Expand All @@ -62,6 +67,8 @@
'StrUtf8Coder', 'TimestampCoder', 'TupleCoder',
'TupleSequenceCoder', 'VarIntCoder', 'WindowedValueCoder']

CoderT = typing.TypeVar('CoderT', bound='Coder')


def serialize_coder(coder):
from apache_beam.internal import pickler
Expand Down Expand Up @@ -169,25 +176,29 @@ def to_type_hint(self):

@classmethod
def from_type_hint(cls, unused_typehint, unused_registry):
# type: (typing.Type[CoderT], typing.Any, CoderRegistry) -> CoderT
# If not overridden, just construct the coder without arguments.
return cls()

def is_kv_coder(self):
return False

def key_coder(self):
# type: () -> Coder
if self.is_kv_coder():
raise NotImplementedError('key_coder: %s' % self)
else:
raise ValueError('Not a KV coder: %s.' % self)

def value_coder(self):
# type: () -> Coder
if self.is_kv_coder():
raise NotImplementedError('value_coder: %s' % self)
else:
raise ValueError('Not a KV coder: %s.' % self)

def _get_component_coders(self):
# type: () -> typing.Sequence[Coder]
"""For internal use only; no backwards-compatibility guarantees.
Returns the internal component coders of this coder."""
Expand Down Expand Up @@ -261,6 +272,7 @@ def register(fn):
return register

def to_runner_api(self, context):
# type: (PipelineContext) -> beam_runner_api_pb2.Coder
urn, typed_param, components = self.to_runner_api_parameter(context)
return beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
Expand All @@ -272,6 +284,7 @@ def to_runner_api(self, context):

@classmethod
def from_runner_api(cls, coder_proto, context):
# type: (typing.Type[CoderT], beam_runner_api_pb2.Coder, PipelineContext) -> CoderT
"""Converts from an FunctionSpec to a Fn object.
Prefer registering a urn with its parameter type and constructor.
Expand All @@ -290,13 +303,15 @@ def from_runner_api(cls, coder_proto, context):
raise

def to_runner_api_parameter(self, context):
# type: (PipelineContext) -> typing.Tuple[str, typing.Any, typing.Sequence[Coder]]
return (
python_urns.PICKLED_CODER,
google.protobuf.wrappers_pb2.BytesValue(value=serialize_coder(self)),
())

@staticmethod
def register_structured_urn(urn, cls):
# type: (str, typing.Type[Coder]) -> None
"""Register a coder that's completely defined by its urn and its
component(s), if any, which are passed to construct the instance.
"""
Expand Down Expand Up @@ -480,9 +495,11 @@ class _TimerCoder(FastCoder):
For internal use."""
def __init__(self, payload_coder):
# type: (Coder) -> None
self._payload_coder = payload_coder

def _get_component_coders(self):
# type: () -> typing.List[Coder]
return [self._payload_coder]

def _create_impl(self):
Expand Down Expand Up @@ -640,6 +657,7 @@ class FastPrimitivesCoder(FastCoder):
For unknown types, falls back to another coder (e.g. PickleCoder).
"""
def __init__(self, fallback_coder=PickleCoder()):
# type: (Coder) -> None
self._fallback_coder = fallback_coder

def _create_impl(self):
Expand Down Expand Up @@ -760,6 +778,7 @@ def __hash__(self):

@staticmethod
def from_type_hint(typehint, unused_registry):
# type: (typing.Any, CoderRegistry) -> ProtoCoder
if issubclass(typehint, google.protobuf.message.Message):
return ProtoCoder(typehint)
else:
Expand Down Expand Up @@ -790,6 +809,7 @@ class TupleCoder(FastCoder):
"""Coder of tuple objects."""

def __init__(self, components):
# type: (typing.Iterable[Coder]) -> None
self._coders = tuple(components)

def _create_impl(self):
Expand All @@ -810,6 +830,7 @@ def to_type_hint(self):

@staticmethod
def from_type_hint(typehint, registry):
# type: (typing.Any, CoderRegistry) -> TupleCoder
return TupleCoder([registry.get_coder(t) for t in typehint.tuple_types])

def as_cloud_object(self, coders_context=None):
Expand All @@ -828,20 +849,24 @@ def as_cloud_object(self, coders_context=None):
return super(TupleCoder, self).as_cloud_object(coders_context)

def _get_component_coders(self):
# type: () -> typing.Tuple[Coder, ...]
return self.coders()

def coders(self):
# type: () -> typing.Tuple[Coder, ...]
return self._coders

def is_kv_coder(self):
return len(self._coders) == 2

def key_coder(self):
# type: () -> Coder
if len(self._coders) != 2:
raise ValueError('TupleCoder does not have exactly 2 components.')
return self._coders[0]

def value_coder(self):
# type: () -> Coder
if len(self._coders) != 2:
raise ValueError('TupleCoder does not have exactly 2 components.')
return self._coders[1]
Expand Down Expand Up @@ -871,6 +896,7 @@ class TupleSequenceCoder(FastCoder):
"""Coder of homogeneous tuple objects."""

def __init__(self, elem_coder):
# type: (Coder) -> None
self._elem_coder = elem_coder

def value_coder(self):
Expand All @@ -891,9 +917,11 @@ def as_deterministic_coder(self, step_label, error_message=None):

@staticmethod
def from_type_hint(typehint, registry):
# type: (typing.Any, CoderRegistry) -> TupleSequenceCoder
return TupleSequenceCoder(registry.get_coder(typehint.inner_type))

def _get_component_coders(self):
# type: () -> typing.Tuple[Coder, ...]
return (self._elem_coder,)

def __repr__(self):
Expand All @@ -911,6 +939,7 @@ class IterableCoder(FastCoder):
"""Coder of iterables of homogeneous objects."""

def __init__(self, elem_coder):
# type: (Coder) -> None
self._elem_coder = elem_coder

def _create_impl(self):
Expand Down Expand Up @@ -945,9 +974,11 @@ def to_type_hint(self):

@staticmethod
def from_type_hint(typehint, registry):
# type: (typing.Any, CoderRegistry) -> IterableCoder
return IterableCoder(registry.get_coder(typehint.inner_type))

def _get_component_coders(self):
# type: () -> typing.Tuple[Coder, ...]
return (self._elem_coder,)

def __repr__(self):
Expand Down Expand Up @@ -1040,6 +1071,7 @@ def as_cloud_object(self, coders_context=None):
}

def _get_component_coders(self):
# type: () -> typing.List[Coder]
return [self.wrapped_value_coder, self.window_coder]

def is_kv_coder(self):
Expand Down Expand Up @@ -1075,6 +1107,7 @@ class LengthPrefixCoder(FastCoder):
Coder which prefixes the length of the encoded object in the stream."""

def __init__(self, value_coder):
# type: (Coder) -> None
self._value_coder = value_coder

def _create_impl(self):
Expand All @@ -1100,6 +1133,7 @@ def as_cloud_object(self, coders_context=None):
}

def _get_component_coders(self):
# type: () -> typing.Tuple[Coder, ...]
return (self._value_coder,)

def __repr__(self):
Expand Down Expand Up @@ -1140,6 +1174,7 @@ def is_deterministic(self):
return False

def _get_component_coders(self):
# type: () -> typing.Tuple[Coder, ...]
return (self._element_coder,)

def __repr__(self):
Expand All @@ -1154,6 +1189,7 @@ def __hash__(self):
return hash((type(self), self._element_coder, self._write_state_threshold))

def to_runner_api_parameter(self, context):
# type: (PipelineContext) -> typing.Tuple[str, typing.Any, typing.Sequence[Coder]]
return (
common_urns.coders.STATE_BACKED_ITERABLE.urn,
str(self._write_state_threshold).encode('ascii'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def run_pipeline(self, pipeline, options):
if not self._job_endpoint:
self.init_dockerized_job_server()
job_endpoint = self._job_endpoint
job_service = None
job_service = None # type: typing.Union[None, beam_job_api_pb2_grpc.JobServiceStub, beam_job_api_pb2_grpc.JobServiceServicer]
elif job_endpoint == 'embed':
job_service = local_job_service.LocalJobServicer()
else:
Expand Down
10 changes: 5 additions & 5 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def _inspect_process(self):
return getfullargspec(self._process_argspec_fn())


class CombineFn(WithTypeHints[T, T], HasDisplayData, urns.RunnerApiFn):
class CombineFn(WithTypeHints[T_, T_], HasDisplayData, urns.RunnerApiFn):
"""A function object used by a Combine transform with custom processing.
A CombineFn specifies how multiple values in all or part of a PCollection can
Expand Down Expand Up @@ -836,7 +836,7 @@ def __bool__(self):
return False


class CallableWrapperCombineFn(CombineFn[T, T]):
class CallableWrapperCombineFn(CombineFn[T_]):
"""For internal use only; no backwards-compatibility guarantees.
A CombineFn (function) object wrapping a callable object.
Expand All @@ -847,7 +847,7 @@ class CallableWrapperCombineFn(CombineFn[T, T]):
_DEFAULT_BUFFER_SIZE = 10

def __init__(self, fn, buffer_size=_DEFAULT_BUFFER_SIZE):
# type: (typing.Callable[[typing.Iterable[T]], T], int) -> None
# type: (typing.Callable[[typing.Iterable[T_]], T_], int) -> None
"""Initializes a CallableFn object wrapping a callable.
Args:
Expand Down Expand Up @@ -1337,9 +1337,9 @@ def Map(fn, # type: typing.Callable[[InT], OutT]
return pardo


def Filter(fn, # type: typing.Callable[[T], bool]
def Filter(fn, # type: typing.Callable[[T_], bool]
*args, **kwargs): # pylint: disable=invalid-name
# type: (...) -> ParDo[T, T]
# type: (...) -> ParDo[T_, T_]
""":func:`Filter` is a :func:`FlatMap` with its callable filtering out
elements.
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ class Reshuffle(PTransform[T_, T_]):

def expand(self, pcoll):
# type: (pvalue.PValue[T_]) -> pvalue.PCollection[T_]
# FIXME: mypy plugin causing mypy to crash here:
return (pcoll
| 'AddRandomKeys' >> Map(lambda t: (random.getrandbits(32), t))
| ReshufflePerKey()
Expand Down

0 comments on commit 59cffe3

Please sign in to comment.