-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
test_csv.py
132 lines (111 loc) 路 3.72 KB
/
test_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import textwrap
import pyarrow as pa
import pytest
from datasets import ClassLabel, Features, Image
from datasets.packaged_modules.csv.csv import Csv
from ..utils import require_pil
@pytest.fixture
def csv_file(tmp_path):
filename = tmp_path / "file.csv"
data = textwrap.dedent(
"""\
header1,header2
1,2
10,20
"""
)
with open(filename, "w") as f:
f.write(data)
return str(filename)
@pytest.fixture
def malformed_csv_file(tmp_path):
filename = tmp_path / "malformed_file.csv"
data = textwrap.dedent(
"""\
header1,header2
1,2
10,20,
"""
)
with open(filename, "w") as f:
f.write(data)
return str(filename)
@pytest.fixture
def csv_file_with_image(tmp_path, image_file):
filename = tmp_path / "csv_with_image.csv"
data = textwrap.dedent(
f"""\
image
{image_file}
"""
)
with open(filename, "w") as f:
f.write(data)
return str(filename)
@pytest.fixture
def csv_file_with_label(tmp_path):
filename = tmp_path / "csv_with_label.csv"
data = textwrap.dedent(
"""\
label
good
bad
good
"""
)
with open(filename, "w") as f:
f.write(data)
return str(filename)
@pytest.fixture
def csv_file_with_int_list(tmp_path):
filename = tmp_path / "csv_with_int_list.csv"
data = textwrap.dedent(
"""\
int_list
1 2 3
4 5 6
7 8 9
"""
)
with open(filename, "w") as f:
f.write(data)
return str(filename)
def test_csv_generate_tables_raises_error_with_malformed_csv(csv_file, malformed_csv_file, caplog):
csv = Csv()
generator = csv._generate_tables([[csv_file, malformed_csv_file]])
with pytest.raises(ValueError, match="Error tokenizing data"):
for _ in generator:
pass
assert any(
record.levelname == "ERROR"
and "Failed to read file" in record.message
and os.path.basename(malformed_csv_file) in record.message
for record in caplog.records
)
@require_pil
def test_csv_cast_image(csv_file_with_image):
with open(csv_file_with_image, encoding="utf-8") as f:
image_file = f.read().splitlines()[1]
csv = Csv(encoding="utf-8", features=Features({"image": Image()}))
generator = csv._generate_tables([[csv_file_with_image]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.schema.field("image").type == Image()()
generated_content = pa_table.to_pydict()["image"]
assert generated_content == [{"path": image_file, "bytes": None}]
def test_csv_cast_label(csv_file_with_label):
with open(csv_file_with_label, encoding="utf-8") as f:
labels = f.read().splitlines()[1:]
csv = Csv(encoding="utf-8", features=Features({"label": ClassLabel(names=["good", "bad"])}))
generator = csv._generate_tables([[csv_file_with_label]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.schema.field("label").type == ClassLabel(names=["good", "bad"])()
generated_content = pa_table.to_pydict()["label"]
assert generated_content == [ClassLabel(names=["good", "bad"]).str2int(label) for label in labels]
def test_csv_convert_int_list(csv_file_with_int_list):
csv = Csv(encoding="utf-8", sep=",", converters={"int_list": lambda x: [int(i) for i in x.split()]})
generator = csv._generate_tables([[csv_file_with_int_list]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa.types.is_list(pa_table.schema.field("int_list").type)
generated_content = pa_table.to_pydict()["int_list"]
assert generated_content == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]