Skip to content

Commit

Permalink
Enable arrow tests, fix bug with Arrow stream message reading (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanCutler authored and yongtang committed Oct 1, 2019
1 parent f1aba9b commit 94b2fed
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
3 changes: 2 additions & 1 deletion tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ arrow::Status ArrowStreamClient::Tell(int64_t* position) const {
arrow::Status ArrowStreamClient::Read(int64_t nbytes,
int64_t* bytes_read,
void* out) {
// TODO: look into why 0 bytes are requested
// TODO: 0 bytes requested when message body length == 0
if (nbytes == 0) {
*bytes_read = 0;
return arrow::Status::OK();
}

Expand Down
16 changes: 9 additions & 7 deletions tensorflow_io/arrow/python/ops/arrow_dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@

import tensorflow as tf
from tensorflow import dtypes
from tensorflow.compat.v2 import data
from tensorflow.python.data.ops.dataset_ops import flat_structure
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure as structure_lib
from tensorflow_io.core.python.ops import core_ops

Expand Down Expand Up @@ -88,7 +87,7 @@ def arrow_schema_to_tensor_types(schema):
return tensor_types, tensor_shapes


class ArrowBaseDataset(data.Dataset):
class ArrowBaseDataset(dataset_ops.DatasetV2):
"""Base class for Arrow Datasets to provide columns used in record batches
and corresponding output tensor types, shapes and classes.
"""
Expand Down Expand Up @@ -121,21 +120,24 @@ def __init__(self,
dtypes.string,
name="batch_mode")
if batch_size is not None or batch_mode == 'auto':
spec_batch_size = batch_size if batch_mode == 'drop_remainder' else None
# pylint: disable=protected-access
self._structure = self._structure._batch(
batch_size if batch_mode == 'drop_remainder' else None)
self._structure = nest.map_structure(
lambda component_spec: component_spec._batch(spec_batch_size),
self._structure)
print(self._flat_structure)
variant_tensor = make_variant_fn(
columns=self._columns,
batch_size=self._batch_size,
batch_mode=self._batch_mode,
**flat_structure(self))
**self._flat_structure)
super(ArrowBaseDataset, self).__init__(variant_tensor)

def _inputs(self):
return []

@property
def _element_structure(self):
def element_spec(self):
return self._structure

@property
Expand Down
2 changes: 0 additions & 2 deletions tests/test_arrow_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from tensorflow import errors # pylint: disable=wrong-import-position
from tensorflow import test # pylint: disable=wrong-import-position

pytest.skip(
"arrow test is disabled temporarily", allow_module_level=True)
import tensorflow_io.arrow as arrow_io # pylint: disable=wrong-import-position

if sys.version_info == (3, 4):
Expand Down

0 comments on commit 94b2fed

Please sign in to comment.