forked from mlcommons/GaNDLF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
patch_extraction.py
95 lines (81 loc) · 3.47 KB
/
patch_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os, warnings
from typing import Optional, Union
from functools import partial
from pathlib import Path
from PIL import Image
from GANDLF.data.patch_miner.opm.patch_manager import PatchManager
from GANDLF.data.patch_miner.opm.utils import (
alpha_rgb_2d_channel_check,
patch_size_check,
parse_config,
generate_initial_mask,
get_patch_size_in_microns,
patch_artifact_check,
# pen_marking_check,
)
from GANDLF.utils import parseTrainingCSV
def parse_gandlf_csv(fpath):
df, _ = parseTrainingCSV(fpath, train=False)
df = df.drop_duplicates()
# nans can be easily removed using df.dropna(axis=1, how='all')
# we want to keep them because we want the user to check the CSV instead
# there might be cases where labels are accidentally removed for some subjects, but not all
assert (
df.isnull().values.any() == False
), "Data CSV contains null/nan values, please check."
for _, row in df.iterrows():
if "Label" in row:
yield row["SubjectID"], row["Channel_0"], row["Label"]
else:
yield row["SubjectID"], row["Channel_0"], None
def patch_extraction(
input_path: str, output_path: str, config: Optional[Union[str, dict]] = None
) -> None:
"""
Extract patches from whole slide images.
Args:
input_path (str): The path to the input CSV file.
output_path (str): The path to the output directory.
config (Optional[Union[str, dict]], optional): The path to the configuration file. Defaults to None.
"""
Image.MAX_IMAGE_PIXELS = None
warnings.simplefilter("ignore")
# initialize default config
cfg = {}
if config is not None:
cfg = config
if isinstance(config, str):
cfg = parse_config(config)
cfg["scale"] = cfg.get("scale", 16)
cfg["patch_size"] = cfg.get("patch_size", (256, 256))
original_patch_size = cfg["patch_size"]
if not os.path.exists(output_path):
Path(output_path).mkdir(parents=True, exist_ok=True)
output_path = os.path.abspath(output_path)
out_csv_path = os.path.join(output_path, "opm_train.csv")
for sid, slide, label in parse_gandlf_csv(input_path):
# Create new instance of slide manager
manager = PatchManager(slide, os.path.join(output_path, str(sid)))
if label is not None:
manager.set_label_map(label)
manager.set_subjectID(str(sid))
manager.set_image_header("Channel_0")
manager.set_mask_header("Label")
cfg["patch_size"] = get_patch_size_in_microns(slide, original_patch_size)
# Generate an initial validity mask
mask, scale = generate_initial_mask(slide, cfg["scale"])
print("Setting valid mask...")
manager.set_valid_mask(mask, scale)
# Reject patch if any pixels are transparent
manager.add_patch_criteria(alpha_rgb_2d_channel_check)
# manager.add_patch_criteria(pen_marking_check) ### will be added to main code after rigourous experimentation
manager.add_patch_criteria(patch_artifact_check)
# Reject patch if image dimensions are not equal to PATCH_SIZE
patch_dims_check = partial(
patch_size_check,
patch_height=cfg["patch_size"][0],
patch_width=cfg["patch_size"][1],
)
manager.add_patch_criteria(patch_dims_check)
# Save patches releases saves all patches stored in manager, dumps to specified output file
manager.mine_patches(output_csv=out_csv_path, config=cfg)