diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index d92735ab9a5..8d0496dc418 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -224,7 +224,7 @@ def _get_output_signature( collate_fn_args: dict, cols_to_retain: Optional[List[str]] = None, batch_size: Optional[int] = None, - num_test_batches: int = 10, + num_test_batches: int = 200, ): """Private method used by `to_tf_dataset()` to find the shapes and dtypes of samples from this dataset after being passed through the collate_fn. Tensorflow needs an exact signature for tf.numpy_function, so @@ -253,11 +253,9 @@ def _get_output_signature( if len(dataset) == 0: raise ValueError("Unable to get the output signature because the dataset is empty.") - if batch_size is None: - test_batch_size = min(len(dataset), 8) - else: + if batch_size is not None: batch_size = min(len(dataset), batch_size) - test_batch_size = batch_size + test_batch_size = min(len(dataset), 2) test_batches = [] for _ in range(num_test_batches):