Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-7746] Add python type hints (part 2) #10367

Merged
merged 14 commits into from
Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,8 @@ def encode_to_stream(self, value, out, nested):
buffer = create_OutputStream()
if (self._write_state is not None
and out.size() - start_size > self._write_state_threshold):
tail = (value_iter[index + 1:] if isinstance(value, (list, tuple))
tail = (value_iter[index + 1:]
if isinstance(value_iter, (list, tuple))
chadrik marked this conversation as resolved.
Show resolved Hide resolved
else value_iter)
state_token = self._write_state(tail, self._elem_coder)
out.write_var_int64(-1)
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/coders/row_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def to_type_hint(self):
def as_cloud_object(self, coders_context=None):
raise NotImplementedError("as_cloud_object not supported for RowCoder")

__hash__ = None
__hash__ = None # type: ignore[assignment]

def __eq__(self, other):
return type(self) == type(other) and self.schema == other.schema
Expand Down
5 changes: 0 additions & 5 deletions sdks/python/apache_beam/examples/complete/game/game_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,6 @@ def get_schema(self):
return ', '.join(
'%s:%s' % (col, self.schema[col]) for col in self.schema)

def get_schema(self):
"""Build the output table schema."""
return ', '.join(
'%s:%s' % (col, self.schema[col]) for col in self.schema)

def expand(self, pcoll):
return (
pcoll
Expand Down
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/io/fileio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_match_all_two_directories(self):

assert_that(files_pc, equal_to(files))

def test_match_files_one_directory_failure(self):
def test_match_files_one_directory_failure1(self):
directories = [
'%s%s' % (self._new_tempdir(), os.sep),
'%s%s' % (self._new_tempdir(), os.sep)]
Expand All @@ -114,7 +114,7 @@ def test_match_files_one_directory_failure(self):

assert_that(files_pc, equal_to(files))

def test_match_files_one_directory_failure(self):
def test_match_files_one_directory_failure2(self):
directories = [
'%s%s' % (self._new_tempdir(), os.sep),
'%s%s' % (self._new_tempdir(), os.sep)]
Expand Down
11 changes: 9 additions & 2 deletions sdks/python/apache_beam/io/filesystems_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import logging
import os
import shutil
import sys
import tempfile
import unittest

Expand All @@ -49,6 +50,12 @@ def _join(first_path, *paths):

class FileSystemsTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Method has been renamed in Python 3
if sys.version_info[0] < 3:
cls.assertCountEqual = cls.assertItemsEqual

def setUp(self):
self.tmpdir = tempfile.mkdtemp()

Expand Down Expand Up @@ -132,7 +139,7 @@ def test_match_file_exception(self):
FileSystems.match([None])
self.assertEqual(list(error.exception.exception_details), [None])

def test_match_directory(self):
def test_match_directory_with_files(self):
path1 = os.path.join(self.tmpdir, 'f1')
path2 = os.path.join(self.tmpdir, 'f2')
open(path1, 'a').close()
Expand All @@ -142,7 +149,7 @@ def test_match_directory(self):
path = os.path.join(self.tmpdir, '*')
result = FileSystems.match([path])[0]
files = [f.path for f in result.metadata_list]
self.assertEqual(files, [path1, path2])
self.assertCountEqual(files, [path1, path2])

def test_match_directory(self):
result = FileSystems.match([self.tmpdir])[0]
Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,7 @@ class ThreadsafeRestrictionTracker(object):
"""

def __init__(self, restriction_tracker):
# type: (RestrictionTracker) -> None
if not isinstance(restriction_tracker, RestrictionTracker):
raise ValueError(
'Initialize ThreadsafeRestrictionTracker requires'
Expand Down Expand Up @@ -1379,31 +1380,36 @@ def __repr__(self):

@property
def completed_work(self):
# type: () -> float
if self._completed:
return self._completed
elif self._remaining and self._fraction:
return self._remaining * self._fraction / (1 - self._fraction)

@property
def remaining_work(self):
# type: () -> float
if self._remaining:
return self._remaining
elif self._completed:
return self._completed * (1 - self._fraction) / self._fraction

@property
def total_work(self):
# type: () -> float
return self.completed_work + self.remaining_work

@property
def fraction_completed(self):
# type: () -> float
if self._fraction is not None:
return self._fraction
else:
return float(self._completed) / self.total_work

@property
def fraction_remaining(self):
# type: () -> float
if self._fraction is not None:
return 1 - self._fraction
else:
Expand Down
7 changes: 0 additions & 7 deletions sdks/python/apache_beam/io/restriction_trackers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,6 @@ def test_check_done_after_try_claim_past_end_of_range(self):
self.assertFalse(tracker.try_claim(220))
tracker.check_done()

def test_check_done_after_try_claim_past_end_of_range(self):
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
self.assertTrue(tracker.try_claim(150))
self.assertTrue(tracker.try_claim(175))
self.assertFalse(tracker.try_claim(200))
tracker.check_done()

def test_check_done_after_try_claim_right_before_end_of_range(self):
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
self.assertTrue(tracker.try_claim(150))
Expand Down
19 changes: 10 additions & 9 deletions sdks/python/apache_beam/io/tfrecordio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_process_deflate(self):
validate=True))
assert_that(result, equal_to([b'foo', b'bar']))

def test_process_gzip(self):
def test_process_gzip_with_coder(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file_gzip(path, FOO_BAR_RECORD_BASE64)
Expand All @@ -303,27 +303,28 @@ def test_process_gzip(self):
validate=True))
assert_that(result, equal_to([b'foo', b'bar']))

def test_process_auto(self):
def test_process_gzip_without_coder(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result.gz')
path = temp_dir.create_temp_file('result')
_write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| ReadFromTFRecord(
path,
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO,
validate=True))
compression_type=CompressionTypes.GZIP))
assert_that(result, equal_to([b'foo', b'bar']))

def test_process_gzip(self):
def test_process_auto(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
path = temp_dir.create_temp_file('result.gz')
_write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| ReadFromTFRecord(
path, compression_type=CompressionTypes.GZIP))
path,
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO,
validate=True))
assert_that(result, equal_to([b'foo', b'bar']))

def test_process_gzip_auto(self):
Expand Down
3 changes: 0 additions & 3 deletions sdks/python/apache_beam/io/vcfio.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,6 @@ def __le__(self, other):

return self < other or self == other

def __ne__(self, other):
chadrik marked this conversation as resolved.
Show resolved Hide resolved
return not self == other

def __gt__(self, other):
if not isinstance(other, Variant):
return NotImplemented
Expand Down
25 changes: 24 additions & 1 deletion sdks/python/apache_beam/metrics/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import threading
import time
from builtins import object
from typing import Optional

from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import metrics_pb2
Expand Down Expand Up @@ -86,6 +87,7 @@ def reset(self):
self.value = CounterAggregator.identity_element()

def combine(self, other):
# type: (CounterCell) -> CounterCell
result = CounterCell()
result.inc(self.value + other.value)
return result
Expand All @@ -106,6 +108,7 @@ def update(self, value):
self.value += value

def get_cumulative(self):
# type: () -> int
with self._lock:
return self.value

Expand Down Expand Up @@ -144,6 +147,7 @@ def reset(self):
self.data = DistributionAggregator.identity_element()

def combine(self, other):
# type: (DistributionCell) -> DistributionCell
result = DistributionCell()
result.data = self.data.combine(other.data)
return result
Expand All @@ -169,6 +173,7 @@ def _update(self, value):
self.data.max = ivalue

def get_cumulative(self):
# type: () -> DistributionData
with self._lock:
return self.data.get_cumulative()

Expand Down Expand Up @@ -204,6 +209,7 @@ def reset(self):
self.data = GaugeAggregator.identity_element()

def combine(self, other):
# type: (GaugeCell) -> GaugeCell
result = GaugeCell()
result.data = self.data.combine(other.data)
return result
Expand All @@ -220,6 +226,7 @@ def update(self, value):
self.data.timestamp = time.time()

def get_cumulative(self):
# type: () -> GaugeData
with self._lock:
return self.data.get_cumulative()

Expand All @@ -239,6 +246,7 @@ def to_runner_api_monitoring_info(self, name, transform_id):
class DistributionResult(object):
"""The result of a Distribution metric."""
def __init__(self, data):
# type: (DistributionData) -> None
self.data = data

def __eq__(self, other):
Expand Down Expand Up @@ -290,6 +298,7 @@ def mean(self):

class GaugeResult(object):
def __init__(self, data):
# type: (GaugeData) -> None
self.data = data

def __eq__(self, other):
Expand Down Expand Up @@ -349,9 +358,11 @@ def __repr__(self):
self.timestamp)

def get_cumulative(self):
# type: () -> GaugeData
return GaugeData(self.value, timestamp=self.timestamp)

def combine(self, other):
# type: (Optional[GaugeData]) -> GaugeData
if other is None:
return self

Expand All @@ -362,6 +373,7 @@ def combine(self, other):

@staticmethod
def singleton(value, timestamp=None):
# type: (...) -> GaugeData
return GaugeData(value, timestamp=timestamp)

def to_runner_api(self):
Expand Down Expand Up @@ -427,9 +439,11 @@ def __repr__(self):
self.max)

def get_cumulative(self):
# type: () -> DistributionData
return DistributionData(self.sum, self.count, self.min, self.max)

def combine(self, other):
# type: (Optional[DistributionData]) -> DistributionData
if other is None:
return self

Expand Down Expand Up @@ -474,7 +488,7 @@ def identity_element(self):
"""
raise NotImplementedError

def combine(self, updates):
def combine(self, x, y):
raise NotImplementedError

def result(self, x):
Expand All @@ -490,12 +504,15 @@ class CounterAggregator(MetricAggregator):
"""
@staticmethod
def identity_element():
# type: () -> int
return 0

def combine(self, x, y):
# type: (...) -> int
return int(x) + int(y)

def result(self, x):
# type: (...) -> int
return int(x)


Expand All @@ -508,12 +525,15 @@ class DistributionAggregator(MetricAggregator):
"""
@staticmethod
def identity_element():
# type: () -> DistributionData
return DistributionData(0, 0, 2**63 - 1, -2**63)

def combine(self, x, y):
# type: (DistributionData, DistributionData) -> DistributionData
return x.combine(y)

def result(self, x):
# type: (DistributionData) -> DistributionResult
return DistributionResult(x.get_cumulative())


Expand All @@ -526,11 +546,14 @@ class GaugeAggregator(MetricAggregator):
"""
@staticmethod
def identity_element():
# type: () -> GaugeData
return GaugeData(None, timestamp=0)

def combine(self, x, y):
# type: (GaugeData, GaugeData) -> GaugeData
result = x.combine(y)
return result

def result(self, x):
# type: (GaugeData) -> GaugeResult
return GaugeResult(x.get_cumulative())