Skip to content

Commit

Permalink
Add SpaceNet3 (microsoft#480)
Browse files Browse the repository at this point in the history
* Add SpaceNet3

* Fixes

* Replace itertools.product with zip

* Update docstring

* Remove unused options
  • Loading branch information
ashnair1 committed Mar 29, 2022
1 parent b41b7d0 commit 5d1eeeb
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 157 deletions.
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Expand Up @@ -233,6 +233,7 @@ SpaceNet
.. autoclass:: SpaceNet
.. autoclass:: SpaceNet1
.. autoclass:: SpaceNet2
.. autoclass:: SpaceNet3
.. autoclass:: SpaceNet4
.. autoclass:: SpaceNet5
.. autoclass:: SpaceNet7
Expand Down
Binary file added tests/data/spacenet/sn3_AOI_3_Paris.tar.gz
Binary file not shown.
Binary file added tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz
Binary file not shown.
83 changes: 75 additions & 8 deletions tests/datasets/test_spacenet.py
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.

import glob
import itertools
import os
import shutil
from pathlib import Path
Expand All @@ -14,7 +13,14 @@
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch

from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
from torchgeo.datasets import (
SpaceNet1,
SpaceNet2,
SpaceNet3,
SpaceNet4,
SpaceNet5,
SpaceNet7,
)

TEST_DATA_DIR = "tests/data/spacenet"

Expand Down Expand Up @@ -142,6 +148,71 @@ def test_plot(self, dataset: SpaceNet2) -> None:
plt.close()


class TestSpaceNet3:
@pytest.fixture(params=zip(["PAN", "MS"], [False, True]))
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet3:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection)
test_md5 = {
"sn3_AOI_3_Paris": "197440e0ade970169a801a173a492c27",
"sn3_AOI_5_Khartoum": "b21ff7dd33a15ec32bd380c083263cdf",
}

monkeypatch.setattr(SpaceNet3, "collection_md5_dict", test_md5)
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[no-untyped-call]
return SpaceNet3(
root,
image=request.param[0],
speed_mask=request.param[1],
collections=["sn3_AOI_3_Paris", "sn3_AOI_5_Khartoum"],
transforms=transforms,
download=True,
api_key="",
)

def test_getitem(self, dataset: SpaceNet3) -> None:
# Iterate over all elements to maximize coverage
samples = [dataset[i] for i in range(len(dataset))]
x = samples[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
if dataset.image == "MS":
assert x["image"].shape[0] == 8
else:
assert x["image"].shape[0] == 1

def test_len(self, dataset: SpaceNet3) -> None:
assert len(dataset) == 4

def test_already_downloaded(self, dataset: SpaceNet3) -> None:
SpaceNet3(root=dataset.root, download=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SpaceNet3(str(tmp_path))

def test_collection_checksum(self, dataset: SpaceNet3) -> None:
dataset.collection_md5_dict["sn3_AOI_5_Khartoum"] = "randommd5hash123"
with pytest.raises(
RuntimeError, match="Collection sn3_AOI_5_Khartoum corrupted"
):
SpaceNet3(root=dataset.root, download=True, checksum=True)

def test_plot(self, dataset: SpaceNet3) -> None:
x = dataset[0].copy()
x["prediction"] = x["mask"]
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
dataset.plot({"image": x["image"]})
plt.close()


class TestSpaceNet4:
@pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"])
def dataset(
Expand Down Expand Up @@ -206,9 +277,7 @@ def test_plot(self, dataset: SpaceNet4) -> None:


class TestSpaceNet5:
@pytest.fixture(
params=itertools.product(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True])
)
@pytest.fixture(params=zip(["PAN", "MS"], [False, True]))
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet5:
Expand Down Expand Up @@ -239,9 +308,7 @@ def test_getitem(self, dataset: SpaceNet5) -> None:
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
if dataset.image == "PS-RGB":
assert x["image"].shape[0] == 3
elif dataset.image in ["MS", "PS-MS"]:
if dataset.image == "MS":
assert x["image"].shape[0] == 8
else:
assert x["image"].shape[0] == 1
Expand Down
11 changes: 10 additions & 1 deletion torchgeo/datasets/__init__.py
Expand Up @@ -75,7 +75,15 @@
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
from .spacenet import (
SpaceNet,
SpaceNet1,
SpaceNet2,
SpaceNet3,
SpaceNet4,
SpaceNet5,
SpaceNet7,
)
from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
Expand Down Expand Up @@ -155,6 +163,7 @@
"SpaceNet",
"SpaceNet1",
"SpaceNet2",
"SpaceNet3",
"SpaceNet4",
"SpaceNet5",
"SpaceNet7",
Expand Down

0 comments on commit 5d1eeeb

Please sign in to comment.