Skip to content

Commit

Permalink
Merge pull request #158 from seisbench/phasenet_fix
Browse files Browse the repository at this point in the history
Implemented original PhaseNet model and converted weights
  • Loading branch information
yetinam committed Feb 20, 2023
2 parents 4352761 + 774cd53 commit a857fd4
Show file tree
Hide file tree
Showing 10 changed files with 1,014 additions and 1,299 deletions.
1,923 changes: 680 additions & 1,243 deletions contrib/model_conversion/phasenet_conversion.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"h5py>=3.1",
"obspy>=1.2",
"tqdm>=4.52",
"torch>=1.7.0",
"torch>=1.10.0",
"scipy>=1.5",
"nest_asyncio>=1.5.3"
]
Expand Down
Empty file added requirements.txt
Empty file.
18 changes: 17 additions & 1 deletion seisbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,34 @@
import logging as _logging
import os as _os
from pathlib import Path as _Path
from urllib.parse import urljoin as _urljoin

import pkg_resources

__all__ = ["cache_root", "__version__", "config"]
__all__ = [
"cache_root",
"cache_data_root",
"cache_model_root",
"remote_root",
"remote_data_root",
"remote_model_root",
"__version__",
"config",
]

# global variable: cache_root
cache_root = _Path(
_os.getenv("SEISBENCH_CACHE_ROOT", _Path(_Path.home(), ".seisbench"))
)

cache_data_root = cache_root / "datasets"
cache_model_root = cache_root / "models" / "v3"

remote_root = "https://dcache-demo.desy.de:2443/Helmholtz/HelmholtzAI/SeisBench/"

remote_data_root = _urljoin(remote_root, "datasets/")
remote_model_root = _urljoin(remote_root, "models/v3/")

if not cache_root.is_dir():
cache_root.mkdir(parents=True, exist_ok=True)

Expand Down
7 changes: 3 additions & 4 deletions seisbench/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from collections.abc import Iterable
from pathlib import Path
from urllib.parse import urljoin

import h5py
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -2058,7 +2059,7 @@ def _path_internal(cls):
"""
Path to the dataset location in the SeisBench cache. This class method is required for technical reasons.
"""
return Path(seisbench.cache_root, "datasets", cls._name_internal().lower())
return Path(seisbench.cache_data_root, cls._name_internal().lower())

@property
def path(self):
Expand Down Expand Up @@ -2087,9 +2088,7 @@ def _remote_path(cls):
Path within the remote repository. Does only generate the pass without checking actual availability.
Can be overwritten for datasets stored in the correct format but at a different location.
"""
return os.path.join(
seisbench.remote_root, "datasets", cls._name_internal().lower()
)
return urljoin(seisbench.remote_data_root, cls._name_internal().lower())

@classmethod
def available_chunks(cls, force=False, wait_for_file=False):
Expand Down
3 changes: 2 additions & 1 deletion seisbench/data/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import shutil
from abc import ABC
from urllib.parse import urljoin

import pandas as pd

Expand Down Expand Up @@ -73,7 +74,7 @@ def callback_download_original(path):

@staticmethod
def _add_split(metadata_path):
split_url = os.path.join(seisbench.remote_root, "auxiliary/instance_split.csv")
split_url = urljoin(seisbench.remote_root, "auxiliary/instance_split.csv")
split_path = seisbench.cache_root / "auxiliary" / "instance_split.csv"

split_path.parent.mkdir(parents=True, exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion seisbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .dpp import DeepPhasePick, DPPDetector, DPPPicker
from .eqtransformer import EQTransformer
from .gpd import GPD
from .phasenet import PhaseNet
from .phasenet import PhaseNet, PhaseNetLight
41 changes: 36 additions & 5 deletions seisbench/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import defaultdict
from pathlib import Path
from queue import PriorityQueue
from urllib.parse import urljoin

import nest_asyncio
import numpy as np
Expand All @@ -25,6 +26,33 @@
from seisbench.util import log_lifecycle


def _cache_migration_v0_v3():
"""
Migrates model cache from v0 to v3 if necessary
"""
if seisbench.cache_model_root.is_dir():
return # Migration already done

if not (seisbench.cache_root / "models").is_dir():
return # No legacy cache

seisbench.logger.info("Migrating model cache to version 3")

# Move cache
seisbench.cache_model_root.mkdir(parents=True)
for path in (seisbench.cache_root / "models").iterdir():
if path.name == "v3":
continue

path.rename(seisbench.cache_model_root / path.name)

if (seisbench.cache_model_root / "phasenet").is_dir():
# Rename phasenet to phasenetlight
(seisbench.cache_model_root / "phasenet").rename(
seisbench.cache_model_root / "phasenetlight"
)


@log_lifecycle(logging.DEBUG)
def _watchdog(queue_watchdog, tasks):
"""
Expand Down Expand Up @@ -109,11 +137,11 @@ def weights_version(self):

@classmethod
def _model_path(cls):
return Path(seisbench.cache_root, "models", cls._name_internal().lower())
return Path(seisbench.cache_model_root, cls._name_internal().lower())

@classmethod
def _remote_path(cls):
return "/".join((seisbench.remote_root, "models", cls._name_internal().lower()))
return urljoin(seisbench.remote_model_root, cls._name_internal().lower())

@classmethod
def _pretrained_path(cls, name, version_str=""):
Expand Down Expand Up @@ -165,6 +193,8 @@ def from_pretrained(
:rtype: SeisBenchModel
"""
cls._cleanup_local_repository()
_cache_migration_v0_v3()

if version_str == "latest":
versions = cls.list_versions(name, remote=update)
# Always query remote versions if cache is empty
Expand Down Expand Up @@ -293,6 +323,7 @@ def list_pretrained(cls, details=False, remote=True):
:rtype: list or dict
"""
cls._cleanup_local_repository()
_cache_migration_v0_v3()

# Idea: If details, copy all "latest" configs to a temp directory

Expand Down Expand Up @@ -395,6 +426,7 @@ def list_versions(cls, name, remote=True):
:rtype: list[str]
"""
cls._cleanup_local_repository()
_cache_migration_v0_v3()

if cls._model_path().is_dir():
files = [x.name for x in cls._model_path().iterdir()]
Expand Down Expand Up @@ -606,7 +638,6 @@ def _parse_metadata(self):
seisbench_requirement = self._weights_metadata.get(
"seisbench_requirement", None
)
# Ignore version requirements when in dev branch
if seisbench_requirement is not None:
if version.parse(seisbench_requirement) > version.parse(
seisbench.__version__
Expand Down Expand Up @@ -2514,8 +2545,8 @@ def _name_internal(cls):

@classmethod
def _remote_path(cls):
return os.path.join(
seisbench.remote_root, "pipelines", cls._name_internal().lower()
return urljoin(
seisbench.remote_root, "pipelines/" + cls._name_internal().lower()
)

@classmethod
Expand Down

0 comments on commit a857fd4

Please sign in to comment.