Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
RNKuhns committed Dec 7, 2022
2 parents ebe1535 + 6f8f82e commit f166c1f
Show file tree
Hide file tree
Showing 23 changed files with 1,351 additions and 395 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ jobs:
- name: Generate Pytest coverage report
shell: bash -l {0}
run: |
pytest --cov=./ --cov-report=xml
python -m pytest
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
16 changes: 14 additions & 2 deletions docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ additional details.
Base Classes
============

.. currentmodule:: skbase
.. currentmodule:: skbase.base

.. autosummary::
:toctree: api_reference/auto_generated/
Expand All @@ -33,7 +33,7 @@ Base Classes
Object Retrieval
================

.. currentmodule:: skbase
.. currentmodule:: skbase.lookup

.. autosummary::
:toctree: api_reference/auto_generated/
Expand All @@ -42,6 +42,18 @@ Object Retrieval
all_objects
get_package_metadata

.. _obj_validation:

Object Validation and Comparison
================================

.. currentmodule:: skbase.validate

.. autosummary::
:toctree: api_reference/auto_generated/
:template: function.rst


.. _obj_testing:

Testing
Expand Down
15 changes: 15 additions & 0 deletions docs/source/user_documentation/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ that ``skbase`` provides, see the :ref:`api_ref`.
user_guide/overview
user_guide/base_classes
user_guide/lookup
user_guide/validate
user_guide/testing


Expand Down Expand Up @@ -77,6 +78,20 @@ that ``skbase`` provides, see the :ref:`api_ref`.

---

Validation and Comparison
^^^^^^^^^^^^^^^^^^^^^^^^^

Tools for validating and comparing ``BaseObject``-s.

+++

.. link-button:: user_guide/validate
:type: ref
:text: Validation and Comparison
:classes: btn-block btn-secondary stretched-link

---

Testing
^^^^^^^

Expand Down
28 changes: 25 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "skbase"
version = "0.2.0"
description = "Base classes for sklearn-like parametric objects"
authors = [
{name = "Franz Kiraly", email = "f.kiraly@ucl.ac.uk"},
{name = "Franz Király", email = "f.kiraly@ucl.ac.uk"},
{name = "Markus Löning"},
{name = "Ryan Kuhns", email = "rk.skbase@gmail.com"},
]
Expand Down Expand Up @@ -85,6 +85,8 @@ test = [
"coverage",
"pytest-cov",
"safety",
"numpy",
"scipy",
]

[project.urls]
Expand All @@ -97,19 +99,22 @@ download = "https://pypi.org/project/skbase/#files"
file = "LICENSE"

[build-system]
requires = ["setuptools", "wheel", "toml", "build"]
requires = ["setuptools>61", "wheel", "toml", "build"]
build-backend = "setuptools.build_meta"

[tool.pytest.ini_options]
# ignore certain folders
addopts = [
"--doctest-modules",
"--ignore=docs",
"--cov=.",
"--cov-report=xml",
"--cov-report=html",
]

[tool.isort]
profile = "black"
src_paths = ["isort", "test"]
src_paths = ["skbase/*"]
multi_line_output = 3
known_first_party = ["skbase"]

Expand All @@ -134,3 +139,20 @@ ignore_path = ["docs/_build", "docs/source/api_reference/auto_generated"]

[tool.bandit]
exclude_dirs = ["*/tests/*", "*/testing/*"]

[tool.setuptools]
zip-safe = true

[tool.setuptools.package-data]
sktime = [
"*.csv",
"*.csv.gz",
"*.arff",
"*.arff.gz",
"*.txt",
"*.ts",
"*.tsv",
]

[tool.setuptools.packages.find]
exclude = ["tests", "tests.*"]
59 changes: 0 additions & 59 deletions setup.py

This file was deleted.

31 changes: 25 additions & 6 deletions skbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,37 @@
The included functionality makes it easy to re-use scikit-learn and
sktime design principles in your project.
"""
import warnings
from typing import List

__version__ = "0.2.0"
from skbase.base import BaseEstimator, BaseMetaEstimator, BaseObject
from skbase.base._meta import _HeterogenousMetaEstimator
from skbase.lookup import all_objects, get_package_metadata

__author__ = ["mloning", "RNKuhns", "fkiraly"]
__all__ = [
__version__: str = "0.2.0"

__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
__all__: List[str] = [
"BaseObject",
"BaseEstimator",
"BaseMetaEstimator",
"_HeterogenousMetaEstimator",
"all_objects",
"get_package_metadata",
]

from skbase._base import BaseEstimator, BaseObject
from skbase._lookup import all_objects, get_package_metadata
from skbase._meta import _HeterogenousMetaEstimator
warnings.warn(
" ".join(
[
"Importing from the `skbase` module is deprecated as of version 0.3.0.",
"Ability to import from `skbase` will be removed in version 0.5.0.",
"Import BaseObject, BaseEstimator, and HeterogenousMetaEstimator",
"from skbase.base. Import lookup functionality ",
"(all_objects, get_package_metadata) from skbase.lookup.",
"_HeterogenousMetaEstimator has been depracated as of version 0.3.0.",
"Functionality is available as part of BaseMetaEstimator.",
"_HeterogenousMetaEstimator will be removed in version 0.5.0.",
]
),
DeprecationWarning,
)
12 changes: 10 additions & 2 deletions skbase/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
"""Custom exceptions used in ``skbase``."""
from typing import List

__author__: List[str] = ["mloning", "rnkuhns"]
__all__: List[str] = ["NotFittedError"]
__author__: List[str] = ["fkiraly", "mloning", "rnkuhns"]
__all__: List[str] = ["FixtureGenerationError", "NotFittedError"]


class FixtureGenerationError(Exception):
"""Raised when a fixture fails to generate."""

def __init__(self, fixture_name="", err=None):
self.fixture_name = fixture_name
super().__init__(f"fixture {fixture_name} failed to generate. {err}")


class NotFittedError(ValueError, AttributeError):
Expand Down
19 changes: 19 additions & 0 deletions skbase/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3 -u
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
""":mod:`skbase.base` contains base classes for creating parametric objects.
The included functionality makes it easy to re-use scikit-learn and
sktime design principles in your project.
"""
from typing import List

from skbase.base._base import BaseEstimator, BaseObject
from skbase.base._meta import BaseMetaEstimator

__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
__all__: List[str] = [
"BaseObject",
"BaseEstimator",
"BaseMetaEstimator",
]
23 changes: 9 additions & 14 deletions skbase/_base.py → skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@ class name: BaseEstimator
fitted state flag - is_fitted (property)
fitted state check - check_is_fitted (raises error if not is_fitted)
"""

__author__ = ["mloning", "RNKuhns", "fkiraly"]
__all__ = ["BaseEstimator", "BaseObject"]

import inspect
import warnings
from collections import defaultdict
from copy import deepcopy
from typing import List

from sklearn import clone
from sklearn.base import BaseEstimator as _BaseEstimator

from skbase._exceptions import NotFittedError

__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
__all__: List[str] = ["BaseEstimator", "BaseObject"]


class BaseObject(_BaseEstimator):
"""Base class for parametric objects with sktime style tag interface.
Expand Down Expand Up @@ -427,7 +427,7 @@ def get_test_params(cls, parameter_set="default"):
"""
# if non-default parameters are required, but none have been found, raise error
if hasattr(cls, "_required_parameters"):
required_parameters = getattr(cls, "required_parameters", [])
required_parameters = getattr(cls, "_required_parameters", [])
if len(required_parameters) > 0:
raise ValueError(
f"Estimator: {cls} requires "
Expand Down Expand Up @@ -466,13 +466,8 @@ def create_test_instance(cls, parameter_set="default"):
else:
params = cls.get_test_params()

if isinstance(params, list):
if isinstance(params[0], dict):
params = params[0]
else:
raise TypeError(
"get_test_params should either return a dict or list of dict."
)
if isinstance(params, list) and isinstance(params[0], dict):
params = params[0]
elif isinstance(params, dict):
pass
else:
Expand Down Expand Up @@ -853,6 +848,6 @@ def check_is_fitted(self):
"""
if not self.is_fitted:
raise NotFittedError(
f"This instance of {self.__class__.__name__} has not "
f"been fitted yet; please call `fit` first."
f"This instance of {self.__class__.__name__} has not been fitted yet. "
f"Please call `fit` first."
)

0 comments on commit f166c1f

Please sign in to comment.