From 272e75fde2558f7bb0037fda06b75c3a10c05479 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 15 Dec 2021 00:58:45 +0100 Subject: [PATCH 01/55] callback API --- sklearn/__init__.py | 1 + sklearn/base.py | 136 +++++++++ sklearn/callback/__init__.py | 25 ++ sklearn/callback/_base.py | 126 ++++++++ sklearn/callback/_computation_tree.py | 268 ++++++++++++++++++ sklearn/callback/_convergence_monitor.py | 118 ++++++++ sklearn/callback/_early_stopping.py | 48 ++++ sklearn/callback/_progressbar.py | 257 +++++++++++++++++ sklearn/callback/_snapshot.py | 82 ++++++ sklearn/callback/_text_verbose.py | 44 +++ .../callback/tests/test_computation_tree.py | 98 +++++++ sklearn/decomposition/_nmf.py | 95 ++++++- sklearn/linear_model/_logistic.py | 62 +++- sklearn/linear_model/_sag.py | 4 + sklearn/linear_model/_sag_fast.pyx.tp | 21 +- sklearn/pipeline.py | 30 +- sklearn/utils/optimize.py | 21 +- 17 files changed, 1416 insertions(+), 20 deletions(-) create mode 100644 sklearn/callback/__init__.py create mode 100644 sklearn/callback/_base.py create mode 100644 sklearn/callback/_computation_tree.py create mode 100644 sklearn/callback/_convergence_monitor.py create mode 100644 sklearn/callback/_early_stopping.py create mode 100644 sklearn/callback/_progressbar.py create mode 100644 sklearn/callback/_snapshot.py create mode 100644 sklearn/callback/_text_verbose.py create mode 100644 sklearn/callback/tests/test_computation_tree.py diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 77ee28271bfaf..0e667babf1cee 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -84,6 +84,7 @@ __all__ = [ "calibration", + "callback", "cluster", "covariance", "cross_decomposition", diff --git a/sklearn/base.py b/sklearn/base.py index 06e9a63630923..4f6b63cb2add1 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -9,6 +9,7 @@ import platform import inspect import re +import pickle import numpy as np @@ -28,6 +29,9 @@ from .utils.validation import check_is_fitted from .utils._estimator_html_repr import estimator_html_repr from .utils.validation import _get_feature_names +from .callback import BaseCallback +from .callback import AutoPropagatedMixin +from .callback import ComputationTree def clone(estimator, *, safe=True): @@ -84,6 +88,10 @@ def clone(estimator, *, safe=True): new_object = klass(**new_object_params) params_set = new_object.get_params(deep=False) + # copy callbacks + if hasattr(estimator, "_callbacks"): + new_object._callbacks = clone(estimator._callbacks, safe=False) + # quick sanity check of the parameters of the clone for name in new_object_params: param1 = new_object_params[name] @@ -597,6 +605,134 @@ def _validate_data( return out + def _set_callbacks(self, callbacks): + """Set callbacks for the estimator. + + Parameters + ---------- + callbacks : callback or list of callbacks + the callbacks to set. + """ + if not isinstance(callbacks, list): + callbacks = [callbacks] + + if not all(isinstance(callback, BaseCallback) for callback in callbacks): + raise TypeError(f"callbacks must be subclasses of BaseCallback.") + + self._callbacks = callbacks + + # XXX should be a method of MetaEstimatorMixin but this mixin can't handle all + # meta-estimators. + def _propagate_callbacks(self, sub_estimator, parent_node): + """Propagate the auto-propagated callbacks to a sub-estimator + + Parameters + ---------- + sub_estimator : estimator instance + The sub-estimator to propagate the callbacks to. + + parent_node : ComputationNode instance + The computation node in this estimator to set as parent_node to the + computation tree of the sub-estimator. It must be the node where the fit + method of the sub-estimator is called. + """ + if not hasattr(self, "_callbacks"): + return + + if hasattr(sub_estimator, "_callbacks") and any( + isinstance(callback, AutoPropagatedMixin) + for callback in sub_estimator._callbacks + ): + bad_callbacks = [ + callback.__class__.__name__ + for callback in sub_estimator._callbacks + if isinstance(callback, AutoPropagatedMixin) + ] + raise TypeError( + f"The sub-estimators ({sub_estimator.__class__.__name__}) of a" + f" meta-estimator ({self.__class__.__name__}) can't have" + f" auto-propagated callbacks ({bad_callbacks})." + " Set them directly on the meta-estimator." + ) + + propagated_callbacks = [ + callback + for callback in self._callbacks + if isinstance(callback, AutoPropagatedMixin) + ] + + if not propagated_callbacks: + return + + sub_estimator._parent_node = parent_node + + if not hasattr(sub_estimator, "_callbacks"): + sub_estimator._callbacks = propagated_callbacks + else: + sub_estimator._callbacks.extend(propagated_callbacks) + + def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): + """Evaluate the on_fit_begin method of the callbacks + + The computation tree is also built at this point. + + This method should be called after all data and parameters validation. + + Parameters + ---------- + X : ndarray or sparse matrix, default=None + The training data. + + y : ndarray, default=None + The target. + + levels : list of dict + A description of the nested levels of computation of the estimator to build + the computation tree. It's a list of dict with "descr" and "max_iter" keys. + + Returns + ------- + root : ComputationNode instance + The root of the computation tree. + """ + self._computation_tree = ComputationTree( + estimator_name=self.__class__.__name__, + levels=levels, + parent_node=getattr(self, "_parent_node", None), + ) + + if hasattr(self, "_callbacks"): + file_path = self._computation_tree.tree_dir / "computation_tree.pkl" + with open(file_path, "wb") as f: + pickle.dump(self._computation_tree, f) + + for callback in self._callbacks: + is_propagated = hasattr(self, "_parent_node") and isinstance( + callback, AutoPropagatedMixin + ) + if not is_propagated: + # Only call the on_fit_begin method of callbacks that are not + # propagated from a meta-estimator. + callback.on_fit_begin(estimator=self, X=X, y=y) + + return self._computation_tree.root + + def _eval_callbacks_on_fit_end(self): + """Evaluate the on_fit_end method of the callbacks""" + if not hasattr(self, "_callbacks"): + return + + self._computation_tree._tree_status[0] = True + + for callback in self._callbacks: + is_propagated = isinstance(callback, AutoPropagatedMixin) and hasattr( + self, "_parent_node" + ) + if not is_propagated: + # Only call the on_fit_end method of callbacks that are not + # propagated from a meta-estimator. + callback.on_fit_end() + @property def _repr_html_(self): """HTML representation of estimator. diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py new file mode 100644 index 0000000000000..1f0f3f7215a18 --- /dev/null +++ b/sklearn/callback/__init__.py @@ -0,0 +1,25 @@ +# License: BSD 3 clause + +from ._base import AutoPropagatedMixin +from ._base import BaseCallback +from ._computation_tree import ComputationNode +from ._computation_tree import ComputationTree +from ._computation_tree import load_computation_tree +from ._convergence_monitor import ConvergenceMonitor +from ._early_stopping import EarlyStopping +from ._progressbar import ProgressBar +from ._snapshot import Snapshot +from ._text_verbose import TextVerbose + +__all__ = [ + "AutoPropagatedMixin", + "Basecallback", + "ComputationNode", + "ComputationTree", + "load_computation_tree", + "ConvergenceMonitor", + "EarlyStopping", + "ProgressBar", + "Snapshot", + "TextVerbose", +] diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py new file mode 100644 index 0000000000000..604a450336610 --- /dev/null +++ b/sklearn/callback/_base.py @@ -0,0 +1,126 @@ +# License: BSD 3 clause + +from abc import ABC, abstractmethod + + +# Not a method of BaseEstimator because it might be called from an extern function +def _eval_callbacks_on_fit_iter_end(**kwargs): + """Evaluate the on_fit_iter_end method of the callbacks + + This function should be called at the end of each computation node. + + Parameters + ---------- + kwargs : dict + arguments passed to the callback. + + Returns + ------- + stop : bool + Whether or not to stop the fit at this node. + """ + estimator = kwargs.get("estimator") + node = kwargs.get("node") + + if not hasattr(estimator, "_callbacks") or node is None: + return False + + estimator._computation_tree._tree_status[node.tree_status_idx] = True + + # stopping_criterion and reconstruction_attributes can be costly to compute. They + # are passed as lambdas for lazy evaluation. We only actually compute them if a + # callback requests it. + if any( + getattr(callback, "request_stopping_criterion", False) + for callback in estimator._callbacks + ): + kwarg = kwargs.pop("stopping_criterion", lambda: None)() + kwargs["stopping_criterion"] = kwarg + + if any( + getattr(callback, "request_reconstruction_attributes", False) + for callback in estimator._callbacks + ): + kwarg = kwargs.pop("reconstruction_attributes", lambda: None)() + kwargs["reconstruction_attributes"] = kwarg + + return any(callback.on_fit_iter_end(**kwargs) for callback in estimator._callbacks) + + +class BaseCallback(ABC): + """Abstract class for the callbacks""" + + @abstractmethod + def on_fit_begin(self, estimator, *, X=None, y=None): + """Method called at the beginning of the fit method of the estimator + + Parameters + ---------- + estimator: estimator instance + The estimator the callback is set on. + X: ndarray or sparse matrix, default=None + The training data. + y: ndarray, default=None + The target. + """ + pass + + @abstractmethod + def on_fit_end(self): + """Method called at the end of the fit method of the estimator""" + pass + + @abstractmethod + def on_fit_iter_end(self, estimator, node, **kwargs): + """Method called at the end of each computation node of the estimator + + Parameters + ---------- + estimator : estimator instance + The caller estimator. It might differ from the estimator passed to the + `on_fit_begin` method for auto-propagated callbacks. + + node : ComputationNode instance + The caller computation node. + + kwargs : dict + arguments passed to the callback. Possible keys are + + - stopping_criterion: float + Usually iterations stop when `stopping_criterion <= tol`. + This is only provided at the innermost level of iterations. + + - tol: float + Tolerance for the stopping criterion. + This is only provided at the innermost level of iterations. + + - reconstruction_attributes: dict + Necessary attributes to construct an estimator (by copying this + estimator and setting these as attributes) which will behave as if + the fit stopped at this node. + This is only provided at the outermost level of iterations. + + - fit_state: dict + Model specific quantities updated during fit. This is not meant to be + used by generic callbacks but by a callback designed for a specific + estimator instead. + + Returns + ------- + stop : bool or None + Whether or not to stop the current level of iterations at this node. + """ + pass + + +class AutoPropagatedMixin: + """Mixin for auto-propagated callbacks + + An auto-propagated callback (from a meta-estimator to its sub-estimators) must be + set on the meta-estimator. Its `on_fit_begin` and `on_fit_end` methods will only be + called at the beginning and end of the fit method of the meta-estimator, while its + `on_fit_iter_end` method will be called at each computation node of the + meta-estimator and its sub-estimators. + """ + + pass diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py new file mode 100644 index 0000000000000..edd3c8f1f657f --- /dev/null +++ b/sklearn/callback/_computation_tree.py @@ -0,0 +1,268 @@ +# License: BSD 3 clause + +from tempfile import mkdtemp +from pathlib import Path +import pickle +import os + +import numpy as np + + +class ComputationNode: + """A node in a ComputationTree + + Parameters + ---------- + computation_tree : ComputationTree instance + The computation tree it belongs to. + + parent : ComputationNode instance, default=None + The parent node. None means this is the root. + + max_iter : int, default=None + The number of its children. None means it's a leaf. + + description : str, default=None + A description of this computation node. None means it's a leaf. + + tree_status_idx : int, default=0 + The index of the status of this node in the `tree_status` array of its + computation tree. + + idx : int, default=0 + The index of this node in the children list of its parent. + + Attributes + ---------- + children : list + The list of its children nodes. For a leaf, it's an empty list + + depth : int + The depth of this node in its computation tree. The root has a depth of 0. + """ + + def __init__( + self, + computation_tree, + parent=None, + max_iter=None, + description=None, + tree_status_idx=0, + idx=0, + ): + self.computation_tree = computation_tree + self.parent = parent + self.max_iter = max_iter + self.description = description + self.tree_status_idx = tree_status_idx + self.idx = idx + self.children = [] + self.depth = 0 if self.parent is None else self.parent.depth + 1 + + def get_ancestors(self, include_ancestor_trees=True): + """Get the list of all nodes in the path from the node to the root + + Parameters + ---------- + include_ancestor_trees : bool, default=True + If True, propagate to the tree of the `parent_node` of this tree if it + exists and so on. + + Returns + ------- + ancestors : list + The list of ancestors of this node (included). + """ + node = self + ancestors = [node] + + while node.parent is not None: + node = node.parent + ancestors.append(node) + + if include_ancestor_trees: + node_parent_tree = node.computation_tree.parent_node + if node_parent_tree is not None: + ancestors.extend(node_parent_tree.get_ancestors()) + + return ancestors + + +class ComputationTree: + """Data structure to store the computation tree of an estimator + + Parameters + ---------- + estimator_name : str + The name of the estimator. + + levels : list of dict + A description of the nested levels of computation of the estimator to build the + tree. It's a list of dict with "descr" and "max_iter" keys. + + parent_node : ComputationNode, default=None + The node where the estimator is used in the computation tree of a + meta-estimator. This node is not set to be the parent of the root of this tree. + + Attributes + ---------- + depth : int + The depth of the tree. It corresponds to the depth of its deepest leaf. + + root : ComputationNode instance + The root of the computation tree. + + tree_dir : pathlib.Path instance + The path of the directory where the computation tree is dumped during the fit of + its estimator. If it has a parent tree, this is a sub-directory of the + `tree_dir` of its parent. + """ + + def __init__(self, estimator_name, levels, *, parent_node=None): + self.estimator_name = estimator_name + self.parent_node = parent_node + + self.depth = len(levels) - 1 + self.root, self.n_nodes = self._build_tree(levels) + + parent_tree_dir = ( + None + if self.parent_node is None + else self.parent_node.computation_tree.tree_dir + ) + if parent_tree_dir is None: + self.tree_dir = Path(mkdtemp()) + else: + # This tree has a parent tree. Place it in a subdir of its parent dir + # and give it a name that allows from the parent tree to find the sub dir + # of the sub tree of a given leaf. + self.tree_dir = parent_tree_dir / str(parent_node.tree_status_idx) + self.tree_dir.mkdir() + self._filename = self.tree_dir / "tree_status.memmap" + + self._set_tree_status(mode="w+") + self._tree_status[:] = False + + def _build_tree(self, levels): + """Build the computation tree from the description of the levels""" + root = ComputationNode( + computation_tree=self, + max_iter=levels[0]["max_iter"], + description=levels[0]["descr"], + ) + + n_nodes = self._recursive_build_tree(root, levels) + + return root, n_nodes + + def _recursive_build_tree(self, parent, levels, n_nodes=1): + """Recursively build the tree from the root the leaves""" + if parent.depth == self.depth: + return n_nodes + + for i in range(parent.max_iter): + children_max_iter = levels[parent.depth + 1]["max_iter"] + description = levels[parent.depth + 1]["descr"] + + node = ComputationNode( + computation_tree=self, + parent=parent, + max_iter=children_max_iter, + description=description, + tree_status_idx=n_nodes, + idx=i, + ) + parent.children.append(node) + + n_nodes = self._recursive_build_tree(node, levels, n_nodes + 1) + + return n_nodes + + def _set_tree_status(self, mode): + """Create a memory-map to the tree_status array stored on the disk""" + # This has to be done each time we unpickle the tree + self._tree_status = np.memmap( + self._filename, dtype=bool, mode=mode, shape=(self.n_nodes,) + ) + + def get_progress(self, node): + """Return the number of finished child nodes of this node""" + if self._tree_status[node.tree_status_idx]: + return node.max_iter + + # Since the children of a node are not ordered (to account for parallel + # execution), we can't rely on the highest index for which the status is True. + return sum( + [self._tree_status[child.tree_status_idx] for child in node.children] + ) + + def iterate(self, include_leaves=False): + """Return an iterable over the nodes of the computation tree + + Nodes are discovered in a depth first search manner. + + Parameters + ---------- + include_leaves : bool + Whether or not to include the leaves of the tree in the iterable + + Returns + ------- + nodes_list : list + A list of the nodes of the computation tree. + """ + return self._recursive_iterate(include_leaves=include_leaves) + + def _recursive_iterate(self, node=None, include_leaves=False, node_list=None): + """Recursively constructs the iterable""" + # TODO make it a generator + if node is None: + node = self.root + node_list = [] + + if node.children or include_leaves: + node_list.append(node) + + for child in node.children: + self._recursive_iterate(child, include_leaves, node_list) + + return node_list + + def __repr__(self): + res = ( + f"[{self.estimator_name}] {self.root.description} : progress " + f"{self.get_progress(self.root)} / {self.root.max_iter}\n" + ) + for node in self.iterate(include_leaves=False): + if node is not self.root: + res += ( + f"{' ' * node.depth}{node.description} {node.idx}: progress " + f"{self.get_progress(node)} / {node.max_iter}\n" + ) + return res + + +def load_computation_tree(directory): + """load the computation tree of a directory + + Parameters + ---------- + directory : pathlib.Path instance + The directory where the computation tree is dumped + + Returns + ------- + computation_tree : ComputationTree instance + The loaded computation tree + """ + file_path = directory / "computation_tree.pkl" + if not file_path.exists() or not os.path.getsize(file_path) > 0: + # Do not try to load the tree when it's created but not yet written + return + + with open(file_path, "rb") as f: + computation_tree = pickle.load(f) + + computation_tree._set_tree_status(mode="r") + + return computation_tree diff --git a/sklearn/callback/_convergence_monitor.py b/sklearn/callback/_convergence_monitor.py new file mode 100644 index 0000000000000..9f53d657cc75a --- /dev/null +++ b/sklearn/callback/_convergence_monitor.py @@ -0,0 +1,118 @@ +# License: BSD 3 clause + +from copy import copy +from pathlib import Path +from tempfile import mkdtemp +import time + +import matplotlib.pyplot as plt +import pandas as pd + +from . import BaseCallback + + +class ConvergenceMonitor(BaseCallback): + """Monitor model convergence. + + Parameters + ---------- + monitor : + + X_val : ndarray, default=None + Validation data + + y_val : ndarray, default=None + Validation target + + Attributes + ---------- + data : pandas.DataFrame + The monitored quantities at each iteration. + """ + + request_reconstruction_attributes = True + + def __init__(self, *, monitor="objective_function", X_val=None, y_val=None): + self.X_val = X_val + self.y_val = y_val + self._data_file = Path(mkdtemp()) / "convergence_monitor.csv" + + def on_fit_begin(self, estimator, *, X=None, y=None): + self.estimator = estimator + self.X_train = X + self.y_train = y + self._start_time = {} + + def on_fit_iter_end(self, *, node, **kwargs): + if node.depth != node.computation_tree.depth: + return + + reconstruction_attributes = kwargs.get("reconstruction_attributes", None) + if reconstruction_attributes is None: + return + + new_estimator = copy(self.estimator) + for key, val in reconstruction_attributes.items(): + setattr(new_estimator, key, val) + + if node.idx == 0: + self._start_time[node.parent] = time.perf_counter() + curr_time = 0 + else: + curr_time = time.perf_counter() - self._start_time[node.parent] + + obj_train, *_ = new_estimator.objective_function(self.X_train, self.y_train, normalize=True) + if self.X_val is not None: + obj_val, *_ = new_estimator.objective_function(self.X_val, self.y_val, normalize=True) + else: + obj_val = None + + ancestors = node.get_ancestors()[:0:-1] + ancestors_desc = [ + f"{n.computation_tree.estimator_name}-{n.description}" for n in ancestors + ] + ancestors_idx = [f"{n.idx}" for n in ancestors] + + if not self._data_file.exists(): + with open(self._data_file, "w") as f: + f.write( + f"{','.join(ancestors_desc)},iteration,time,obj_train,obj_val\n" + ) + + with open(self._data_file, "a") as f: + f.write( + f"{','.join(ancestors_idx)},{node.idx},{curr_time},{obj_train},{obj_val}\n" + ) + + def on_fit_end(self): + pass + + def get_data(self): + if not hasattr(self, "data"): + self.data = pd.read_csv(self._data_file) + return self.data + + def plot(self, x="iteration"): + data = self.get_data() + + # all columns but iteration, time, obj_train, obj_val + group_by_columns = list(data.columns[:-4]) + groups = data.groupby(group_by_columns) + + for key in groups.groups.keys(): + group = groups.get_group(key) + fig, ax = plt.subplots() + + ax.plot(group[x], group["obj_train"], label="obj_train") + if self.X_val is not None: + ax.plot(group[x], group["obj_val"], label="obj_val") + + if x == "iteration": + x_label = "Number of iterations" + elif x == "time": + x_label = "Time (s)" + ax.set_xlabel(x_label) + ax.set_ylabel("objective function") + + ax.legend() + plt.show() diff --git a/sklearn/callback/_early_stopping.py b/sklearn/callback/_early_stopping.py new file mode 100644 index 0000000000000..44a0108e04b26 --- /dev/null +++ b/sklearn/callback/_early_stopping.py @@ -0,0 +1,48 @@ +# License: BSD 3 clause + +from . import BaseCallback + + +class EarlyStopping(BaseCallback): + def __init__( + self, + X_val=None, + y_val=None, + monitor="objective_function", + max_no_improvement=10, + tol=1e-2, + ): + self.X_val = X_val + self.y_val = y_val + self.monitor = monitor + self.max_no_improvement = max_no_improvement + self.tol = tol + + def on_fit_begin(self, estimator, X=None, y=None): + self.estimator = estimator + self._no_improvement = {} + self._last_monitored = {} + + def on_fit_iter_end(self, *, node, **kwargs): + if node.depth != self.estimator._computation_tree.depth: + return + + if self.monitor == "objective_function": + objective_function = kwargs.get("objective_function", None) + monitored, *_ = objective_function(self.X_val) + elif self.monitor == "TODO": + pass + + if node.parent not in self._last_monitored or monitored < self._last_monitored[ + node.parent + ] * (1 - self.tol): + self._no_improvement[node.parent] = 0 + self._last_monitored[node.parent] = monitored + else: + self._no_improvement[node.parent] += 1 + + if self._no_improvement[node.parent] >= self.max_no_improvement: + return True + + def on_fit_end(self): + pass diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py new file mode 100644 index 0000000000000..ae11e67d59f57 --- /dev/null +++ b/sklearn/callback/_progressbar.py @@ -0,0 +1,257 @@ +# License: BSD 3 clause + +from copy import copy +import pickle +from threading import Thread, Event + +import numpy as np +from tqdm import tqdm +from rich.progress import Progress +from rich.progress import BarColumn, TimeRemainingColumn, TextColumn +from rich.style import Style + +from . import BaseCallback +from . import AutoPropagatedMixin +from . import load_computation_tree + + +class ProgressBar(BaseCallback, AutoPropagatedMixin): + """Callback that displays progress bars for each iterative steps of the estimator + + Parameters + ---------- + backend: {"rich"}, default="rich" + The backend for the progress bars display. + + max_depth_show : int, default=None + The maximum nested level of progress bars to display. + + max_depth_keep : int, default=None + The maximum nested level of progress bars to keep displayed when they are + finished. + """ + + def __init__(self, backend="rich", max_depth_show=None, max_depth_keep=None): + self.backend = backend + if max_depth_show is not None and max_depth_show < 0: + raise ValueError(f"max_depth_show should be >= 0.") + if max_depth_keep is not None and max_depth_keep < 0: + raise ValueError(f"max_depth_keep should be >= 0.") + self.max_depth_show = max_depth_show + self.max_depth_keep = max_depth_keep + + def on_fit_begin(self, estimator, X=None, y=None): + self._stop_event = Event() + + if self.backend == "rich": + self.progress_monitor = _RichProgressMonitor( + estimator=estimator, + event=self._stop_event, + max_depth_show=self.max_depth_show, + max_depth_keep=self.max_depth_keep, + ) + else: + raise ValueError(f"backend should be 'rich', got {self.backend} instead.") + + self.progress_monitor.start() + + def on_fit_iter_end(self, *, estimator, node, **kwargs): + pass + + def on_fit_end(self): + self._stop_event.set() + self.progress_monitor.join() + + def __getstate__(self): + state = self.__dict__.copy() + if "_stop_event" in state: + del state["_stop_event"] + if "progress_monitor" in state: + del state["progress_monitor"] + return state + + +# Custom Progress class to allow showing the tasks in a given order (given by setting +# the _ordered_tasks attribute). In particular it allows to dynamically create and +# insert tasks between existing tasks. +class _Progress(Progress): + def get_renderables(self): + table = self.make_tasks_table(getattr(self, "_ordered_tasks", [])) + yield table + + +class _RichProgressMonitor(Thread): + """Thread monitoring the progress of an estimator with rich based display + + The display is a list of nested rich tasks using rich.Progress. There is one for + each node in the computation tree of the estimator and in the computation trees of + estimators used in the estimator. + + Parameters + ---------- + estimator : estimator instance + The estimator to monitor + + event : threading.Event instance + This thread will run until event is set. + + max_depth_show : int, default=None + The maximum nested level of progress bars to display. + + max_depth_keep : int, default=None + The maximum nested level of progress bars to keep displayed when they are + finished. + """ + + def __init__(self, estimator, event, max_depth_show=None, max_depth_keep=None): + Thread.__init__(self) + self.estimator = estimator + self.event = event + self.max_depth_show = max_depth_show + self.max_depth_keep = max_depth_keep + + # _computation_trees is a dict `directory: tuple` where + # - tuple[0] is the computation tree of the directory + # - tuple[1] is a dict `node.tree_status_idx: task_id` + self._computation_trees = {} + + def run(self): + with _Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn( + complete_style=Style(color="dark_orange"), + finished_style=Style(color="cyan"), + ), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + auto_refresh=False, + ) as progress_ctx: + self._progress_ctx = progress_ctx + + while not self.event.wait(0.05): + self._recursive_update_tasks() + self._progress_ctx.refresh() + + self._recursive_update_tasks() + self._progress_ctx.refresh() + + def _recursive_update_tasks(self, this_dir=None, depth=0): + """Recursively loop through directories and init or update tasks + + Parameters + ---------- + this_dir : pathlib.Path instance + The directory to + + depth : int + The current depth + """ + if self.max_depth_show is not None and depth > self.max_depth_show: + # Fast exit if this dir is deeper than what we want to show anyway + return + + if this_dir is None: + this_dir = self.estimator._computation_tree.tree_dir + # _ordered_tasks holds the list of the tasks in the order we want them to + # be displayed. + self._progress_ctx._ordered_tasks = [] + + if this_dir not in self._computation_trees: + # First time we discover this directory -> store the computation tree + # If the computation tree is not readable yet, skip and try again next time + computation_tree = load_computation_tree(this_dir) + if computation_tree is None: + return + + self._computation_trees[this_dir] = (computation_tree, {}) + + computation_tree, task_ids = self._computation_trees[this_dir] + + for node in computation_tree.iterate(include_leaves=True): + if node.children: + # node is not a leaf, create or update its task + if node.tree_status_idx not in task_ids: + visible = True + if ( + self.max_depth_show is not None + and depth + node.depth > self.max_depth_show + ): + # If this node is deeper than what we want to show, we create + # the task anyway but make it not visible + visible = False + + task_ids[node.tree_status_idx] = self._progress_ctx.add_task( + self._format_task_description(node, computation_tree, depth), + total=node.max_iter, + visible=visible, + ) + + task_id = task_ids[node.tree_status_idx] + task = self._progress_ctx.tasks[task_id] + self._progress_ctx._ordered_tasks.append(task) + + parent_task = self._get_parent_task(node, computation_tree, task_ids) + if parent_task is not None and parent_task.finished: + # If the task of the parent node is finished, make this task + # finished. It can happen if some computations are stopped + # before reaching max_iter. + visible = True + if ( + self.max_depth_keep is not None + and depth + node.depth > self.max_depth_keep + ): + # If this node is deeper than what we want to keep in the output + # make it not visible + visible = False + self._progress_ctx.update( + task_id, completed=node.max_iter, visible=visible, refresh=False + ) + else: + node_progress = computation_tree.get_progress(node) + if node_progress != task.completed: + self._progress_ctx.update( + task_id, completed=node_progress, refresh=False + ) + else: + # node is a leaf, look for tasks of its sub computation tree before + # going to the next node + child_dir = this_dir / str(node.tree_status_idx) + if child_dir.exists(): + self._recursive_update_tasks( + child_dir, depth + computation_tree.depth + ) + + def _format_task_description(self, node, computation_tree, depth): + """Return a formatted description for the task of the node""" + colors = ["red", "green", "blue", "yellow"] + + indent = f"{' ' * (depth + node.depth)}" + style = f"[{colors[(depth + node.depth)%len(colors)]}]" + + description = f"{computation_tree.estimator_name} - {node.description}" + if node.parent is None and computation_tree.parent_node is not None: + description = ( + f"{computation_tree.parent_node.description} {computation_tree.parent_node.idx} |" + f" {description}" + ) + if node.parent is not None: + description = f"{description} {node.idx}" + + return f"{style}{indent}{description}" + + def _get_parent_task(self, node, computation_tree, task_ids): + """Get the task of the parent node""" + if node.parent is not None: + # node is not the root, return the task of its parent + task_id = task_ids[node.parent.tree_status_idx] + return self._progress_ctx.tasks[task_id] + if computation_tree.parent_node is not None: + # node is the root, return the task of the parent of the parent_node of + # its computation tree + parent_dir = computation_tree.parent_node.computation_tree.tree_dir + _, parent_tree_task_ids = self._computation_trees[parent_dir] + task_id = parent_tree_task_ids[ + computation_tree.parent_node.parent.tree_status_idx + ] + return self._progress_ctx._tasks[task_id] + return diff --git a/sklearn/callback/_snapshot.py b/sklearn/callback/_snapshot.py new file mode 100644 index 0000000000000..231eafc8cbb9e --- /dev/null +++ b/sklearn/callback/_snapshot.py @@ -0,0 +1,82 @@ +# License: BSD 3 clause + +from copy import copy +from datetime import datetime +from pathlib import Path +import pickle + +import numpy as np + +from . import BaseCallback + + +class Snapshot(BaseCallback): + """Take regular snapshots of an estimator + + Parameters + ---------- + keep_last_n : int or None, default=1 + Only the last `keep_last_n` snapshots are kept on the disk. None means all + snapshots are kept. + + base_dir : str or pathlib.Path instance, default=None + The directory where the snapshots should be stored. If None, they are stored in + the current directory. + + Attributes + ---------- + directory : pathlib.Path instance + The directory where the snapshots are saved. It's a sub-directory of `base_dir`. + """ + + request_reconstruction_attributes = True + + def __init__(self, keep_last_n=1, base_dir=None): + self.keep_last_n = keep_last_n + if keep_last_n is not None and keep_last_n <= 0: + raise ValueError( + "keep_last_n must be a positive integer, got" + f" {self.keep_last_n} instead." + ) + + self.base_dir = Path("." if base_dir is None else base_dir) + + def on_fit_begin(self, estimator, X=None, y=None): + self.estimator = estimator + + # Use a hash in the name of this directory to avoid name collision if several + # clones of this estimator are fitted in parallel in a meta-estimator for + # instance. + dir_name = ( + "snapshots_" + f"{self.estimator.__class__.__name__}_" + f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}_" + f"{hash(self.estimator._computation_tree)}" + ) + + self.directory = self.base_dir / dir_name + self.directory.mkdir() + + def on_fit_iter_end(self, *, node, **kwargs): + reconstruction_attributes = kwargs.get("reconstruction_attributes", None) + if reconstruction_attributes is None: + return + + new_estimator = copy(self.estimator) + for key, val in reconstruction_attributes.items(): + setattr(new_estimator, key, val) + + file_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}.pkl" + file_path = self.directory / file_name + + with open(file_path, "wb") as f: + pickle.dump(new_estimator, f) + + if self.keep_last_n is not None: + for snapshot in sorted(self.directory.iterdir())[: -self.keep_last_n]: + snapshot.unlink(missing_ok=True) + + def on_fit_end(self): + if self.keep_last_n is not None: + for snapshot in sorted(self.directory.iterdir())[: -self.keep_last_n]: + snapshot.unlink() diff --git a/sklearn/callback/_text_verbose.py b/sklearn/callback/_text_verbose.py new file mode 100644 index 0000000000000..b857ff592c87c --- /dev/null +++ b/sklearn/callback/_text_verbose.py @@ -0,0 +1,44 @@ +# License: BSD 3 clause + +import time + +from . import BaseCallback +from . import AutoPropagatedMixin + + +class TextVerbose(BaseCallback, AutoPropagatedMixin): + request_stopping_criterion = True + + def __init__(self, min_time_between_calls=0): + self.min_time_between_calls = min_time_between_calls + + def on_fit_begin(self, estimator, X=None, y=None): + self.estimator = estimator + self._start_time = time.perf_counter() + + def on_fit_iter_end(self, *, node, **kwargs): + if node.depth != node.computation_tree.depth: + return + + stopping_criterion = kwargs.get("stopping_criterion", None) + tol = kwargs.get("tol", None) + + current_time = time.perf_counter() - self._start_time + + s = f"{node.description} {node.idx}" + parent = node.parent + while parent is not None and parent.parent is not None: + s = f"{parent.description} {parent.idx} - {s}" + parent = parent.parent + + msg = ( + f"[{parent.computation_tree.estimator_name}] {s} | time {current_time:.5f}s" + ) + + if stopping_criterion is not None and tol is not None: + msg += f" | stopping_criterion={stopping_criterion:.3E} | tol={tol:.3E}" + + print(msg) + + def on_fit_end(self): + pass diff --git a/sklearn/callback/tests/test_computation_tree.py b/sklearn/callback/tests/test_computation_tree.py new file mode 100644 index 0000000000000..b726177a342ec --- /dev/null +++ b/sklearn/callback/tests/test_computation_tree.py @@ -0,0 +1,98 @@ +# License: BSD 3 clause + +import numpy as np +import pytest + +from sklearn.callback import ComputationTree +from sklearn.callback import ComputationNode +from sklearn.callback import load_computation_tree + + +levels = [ + {"descr": "level0", "max_iter": 3}, + {"descr": "level1", "max_iter": 5}, + {"descr": "level2", "max_iter": 7}, + {"descr": "level3", "max_iter": None}, +] + + +def test_computation_tree(): + # Check the construction of the computation tree + computation_tree = ComputationTree(estimator_name="estimator", levels=levels) + assert computation_tree.estimator_name == "estimator" + + root = computation_tree.root + assert root.parent is None + assert root.idx == 0 + + assert len(root.children) == root.max_iter == 3 + assert [node.idx for node in root.children] == list(range(3)) + + for node1 in root.children: + assert len(node1.children) == 5 + assert [n.idx for n in node1.children] == list(range(5)) + + for node2 in node1.children: + assert len(node2.children) == 7 + assert [n.idx for n in node2.children] == list(range(7)) + + for node3 in node2.children: + assert not node3.children + + +def test_n_nodes(): + # Check that the number of node in a comutation tree corresponds to what we expect + # from the level descriptions + computation_tree = ComputationTree(estimator_name="", levels=levels) + + max_iter_per_level = [level["max_iter"] for level in levels[:-1]] + expected_n_nodes = 1 + np.sum(np.cumprod(max_iter_per_level)) + + assert computation_tree.n_nodes == expected_n_nodes + assert len(computation_tree.iterate(include_leaves=True)) == expected_n_nodes + assert computation_tree._tree_status.shape == (expected_n_nodes,) + + +def test_tree_status_idx(): + # Check that each node has a unique index in the _tree_status array and that their + # order corresponds to the order given by a depth first search. + computation_tree = ComputationTree(estimator_name="", levels=levels) + + indexes = [ + node.tree_status_idx for node in computation_tree.iterate(include_leaves=True) + ] + assert indexes == list(range(computation_tree.n_nodes)) + + +def test_get_ancestors(): + # Check that the ancestor search excludes the root and can propagate to parent trees + parent_levels = [ + {"descr": "parent_level0", "max_iter": 2}, + {"descr": "parent_level1", "max_iter": 4}, + {"descr": "parent_level2", "max_iter": None}, + ] + + parent_computation_tree = ComputationTree( + estimator_name="parent_estimator", levels=parent_levels + ) + parent_node = parent_computation_tree.root.children[0].children[2] + + computation_tree = ComputationTree( + estimator_name="estimator", levels=levels, parent_node=parent_node + ) + node = computation_tree.root.children[1].children[3].children[5] + + ancestors = node.get_ancestors(include_ancestor_trees=False) + assert ancestors == [node, node.parent, node.parent.parent] + assert [n.idx for n in ancestors] == [5, 3, 1] + assert computation_tree.root not in ancestors + + ancestors = node.get_ancestors(include_ancestor_trees=True) + assert ancestors == [ + node, + node.parent, + node.parent.parent, + parent_node, + parent_node.parent, + ] + assert [n.idx for n in ancestors] == [5, 3, 1, 2, 0] diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index cc1451be54567..f53f33c6b804a 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -6,6 +6,7 @@ # Tom Dupre la Tour # License: BSD 3 clause +from functools import partial import numbers import numpy as np import scipy.sparse as sp @@ -23,6 +24,7 @@ check_is_fitted, check_non_negative, ) +from ..callback._base import _eval_callbacks_on_fit_iter_end EPSILON = np.finfo(np.float32).eps @@ -424,6 +426,8 @@ def _fit_coordinate_descent( verbose=0, shuffle=False, random_state=None, + estimator=None, + parent_node=None, ): """Compute Non-negative Matrix Factorization (NMF) with Coordinate Descent @@ -500,7 +504,9 @@ def _fit_coordinate_descent( rng = check_random_state(random_state) - for n_iter in range(1, max_iter + 1): + nodes = parent_node.children if parent_node is not None else [None] * max_iter + + for n_iter, node in enumerate(nodes, 1): violation = 0.0 # Update W @@ -519,6 +525,21 @@ def _fit_coordinate_descent( if violation_init == 0: break + if _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=node, + stopping_criterion=lambda: violation / violation_init, + tol=tol, + fit_state={"H": Ht.T, "W": W}, + reconstruction_attributes=lambda: { + "n_components_": Ht.T.shape[0], + "components_": H, + "n_iter_": n_iter, + "reconstruction_err_": _beta_divergence(X, W, Ht.T, 2, True), + }, + ): + break + if verbose: print("violation:", violation / violation_init) @@ -731,6 +752,8 @@ def _fit_multiplicative_update( l2_reg_H=0, update_H=True, verbose=0, + estimator=None, + parent_node=None, ): """Compute Non-negative Matrix Factorization with Multiplicative Update. @@ -815,8 +838,10 @@ def _fit_multiplicative_update( error_at_init = _beta_divergence(X, W, H, beta_loss, square_root=True) previous_error = error_at_init + nodes = parent_node.children if parent_node is not None else [None] * max_iter + H_sum, HHt, XHt = None, None, None - for n_iter in range(1, max_iter + 1): + for n_iter, node in enumerate(nodes, 1): # update W # H_sum, HHt and XHt are saved and reused if not update_H delta_W, H_sum, HHt, XHt = _multiplicative_update_w( @@ -842,6 +867,27 @@ def _fit_multiplicative_update( if beta_loss <= 1: H[H < np.finfo(np.float64).eps] = 0.0 + if _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=node, + stopping_criterion=lambda: ( + ( + previous_error + - _beta_divergence(X, W, H, beta_loss, square_root=True) + ) + / error_at_init + ), + tol=tol, + fit_state={"H": H, "W": W}, + reconstruction_attributes=lambda: { + "n_components_": H.shape[0], + "components_": H, + "n_iter_": n_iter, + "reconstruction_err_": _beta_divergence(X, W, H, 2, True), + }, + ): + break + # test convergence criterion every 10 iterations if tol > 0 and n_iter % 10 == 0: error = _beta_divergence(X, W, H, beta_loss, square_root=True) @@ -1538,20 +1584,27 @@ def fit_transform(self, X, y=None, W=None, H=None): X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32] ) - with config_context(assume_finite=True): - W, H, n_iter = self._fit_transform(X, W=W, H=H) - - self.reconstruction_err_ = _beta_divergence( - X, W, H, self._beta_loss, square_root=True + root = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.max_iter}, + {"descr": "iter", "max_iter": None}, + ], + X=X, ) + W, H, n_iter = self._fit_transform(X, W=W, H=H, parent_node=root) + self.n_components_ = H.shape[0] self.components_ = H self.n_iter_ = n_iter + self._eval_callbacks_on_fit_end() + return W - def _fit_transform(self, X, y=None, W=None, H=None, update_H=True): + def _fit_transform( + self, X, y=None, W=None, H=None, update_H=True, parent_node=None + ): """Learn a NMF model for the data X and returns the transformed data. Parameters @@ -1618,6 +1671,8 @@ def _fit_transform(self, X, y=None, W=None, H=None, update_H=True): verbose=self.verbose, shuffle=self.shuffle, random_state=self.random_state, + estimator=self, + parent_node=parent_node, ) elif self.solver == "mu": W, H, n_iter = _fit_multiplicative_update( @@ -1633,6 +1688,8 @@ def _fit_transform(self, X, y=None, W=None, H=None, update_H=True): l2_reg_H, update_H=update_H, verbose=self.verbose, + estimator=self, + parent_node=parent_node, ) else: raise ValueError("Invalid solver parameter '%s'." % self.solver) @@ -1713,6 +1770,28 @@ def inverse_transform(self, W): check_is_fitted(self) return np.dot(W, self.components_) + def objective_function(self, X, y=None, *, W=None, H=None, normalize=False): + if W is None: + W = self.transform(X) + if H is None: + H = self.components_ + + data_fit = _beta_divergence(X, W, H, self._beta_loss) + + l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = self._scale_regularization(X) + penalization = ( + l1_reg_W * W.sum() + + l1_reg_H * H.sum() + + l2_reg_W * (W ** 2).sum() + + l2_reg_H * (H ** 2).sum() + ) + + if normalize: + data_fit /= X.shape[0] + penalization /= X.shape[0] + + return data_fit + penalization, data_fit, penalization + @property def _n_features_out(self): """Number of transformed output features.""" diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 08e71edbc69ab..82063f36d0434 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -33,6 +33,7 @@ from ..utils.fixes import delayed from ..model_selection import check_cv from ..metrics import get_scorer +from ..callback._base import _eval_callbacks_on_fit_iter_end _LOGISTIC_SOLVER_CONVERGENCE_MSG = ( @@ -505,6 +506,8 @@ def _logistic_regression_path( max_squared_sum=None, sample_weight=None, l1_ratio=None, + estimator=None, + parent_node=None, ): """Compute a Logistic Regression model for a list of regularization parameters. @@ -796,13 +799,20 @@ def grad(x, *args): hess = _logistic_grad_hess warm_start_sag = {"coef": np.expand_dims(w0, axis=1)} + # Distinguish between LogReg and LogRegCV + if parent_node is not None: + nodes = [parent_node] if len(Cs) == 1 else parent_node.children + else: + nodes = [None] * len(Cs) + coefs = list() n_iter = np.zeros(len(Cs), dtype=np.int32) - for i, C in enumerate(Cs): + for i, (C, node) in enumerate(zip(Cs, nodes)): if solver == "lbfgs": iprint = [-1, 50, 1, 100, 101][ np.searchsorted(np.array([0, 1, 2, 3]), verbose) ] + children = iter(node.children) if node is not None else None opt_res = optimize.minimize( func, w0, @@ -810,6 +820,10 @@ def grad(x, *args): jac=True, args=(X, target, 1.0 / C, sample_weight), options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}, + callback=lambda xk: _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=next(children) if children is not None else None, + ), ) n_iter_i = _check_optimize_result( solver, @@ -821,7 +835,15 @@ def grad(x, *args): elif solver == "newton-cg": args = (X, target, 1.0 / C, sample_weight) w0, n_iter_i = _newton_cg( - hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol + hess, + func, + grad, + w0, + args=args, + maxiter=max_iter, + tol=tol, + estimator=estimator, + parent_node=node, ) elif solver == "liblinear": coef_, intercept_, n_iter_i, = _fit_liblinear( @@ -876,6 +898,8 @@ def grad(x, *args): max_squared_sum, warm_start_sag, is_saga=(solver == "saga"), + estimator=estimator, + parent_node=node, ) else: @@ -893,8 +917,20 @@ def grad(x, *args): else: coefs.append(w0.copy()) + if len(Cs) > 1: + _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=node, + ) + n_iter[i] = n_iter_i + if multi_class == "ovr": + _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=parent_node, + ) + return np.array(coefs), np.array(Cs), n_iter @@ -1578,6 +1614,22 @@ def fit(self, X, y, sample_weight=None): if warm_start_coef is None: warm_start_coef = [None] * n_classes + if len(classes_) == 1: + levels = [ + {"descr": "fit", "max_iter": self.max_iter}, + {"descr": "iter", "max_iter": None}, + ] + else: + levels = [ + {"descr": "fit", "max_iter": len(classes_)}, + {"descr": "class", "max_iter": self.max_iter}, + {"descr": "iter", "max_iter": None}, + ] + root = self._eval_callbacks_on_fit_begin(levels=levels, X=X, y=y) + + # distinguish between multinomial and ovr + nodes = [root] if len(classes_) == 1 else root.children + path_func = delayed(_logistic_regression_path) # The SAG solver releases the GIL so it's more efficient to use @@ -1610,8 +1662,10 @@ def fit(self, X, y, sample_weight=None): penalty=penalty, max_squared_sum=max_squared_sum, sample_weight=sample_weight, + estimator=self, + parent_node=node, ) - for class_, warm_start_coef_ in zip(classes_, warm_start_coef) + for class_, warm_start_coef_, node in zip(classes_, warm_start_coef, nodes) ) fold_coefs_, _, n_iter_ = zip(*fold_coefs_) @@ -1632,6 +1686,8 @@ def fit(self, X, y, sample_weight=None): else: self.intercept_ = np.zeros(n_classes) + self._eval_callbacks_on_fit_end() + return self def predict_proba(self, X): diff --git a/sklearn/linear_model/_sag.py b/sklearn/linear_model/_sag.py index 48dcd7aef8ad3..7307ca76c4408 100644 --- a/sklearn/linear_model/_sag.py +++ b/sklearn/linear_model/_sag.py @@ -101,6 +101,8 @@ def sag_solver( max_squared_sum=None, warm_start_mem=None, is_saga=False, + estimator=None, + parent_node=None, ): """SAG solver for Ridge and LogisticRegression. @@ -346,6 +348,8 @@ def sag_solver( intercept_decay, is_saga, verbose, + estimator=estimator, + parent_node=parent_node, ) if n_iter_ == max_iter: diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 756a048eea999..8144c98df3012 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -47,6 +47,7 @@ from ._sgd_fast cimport LossFunction from ._sgd_fast cimport Log, SquaredLoss from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64 +from ..callback._base import _eval_callbacks_on_fit_iter_end from libc.stdio cimport printf @@ -231,7 +232,9 @@ def sag{{name_suffix}}(SequentialDataset{{name_suffix}} dataset, np.ndarray[{{c_type}}, ndim=1, mode='c'] intercept_sum_gradient_init, double intercept_decay, bint saga, - bint verbose): + bint verbose, + estimator, + parent_node): """Stochastic Average Gradient (SAG) and SAGA solvers. Used in Ridge and LogisticRegression. @@ -515,6 +518,22 @@ def sag{{name_suffix}}(SequentialDataset{{name_suffix}} dataset, fabs(weights[idx] - previous_weights[idx])) previous_weights[idx] = weights[idx] + + with gil: + if _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=parent_node.children[n_iter] if parent_node is not None else None, + stopping_criterion = ( + lambda: max_change / max_weight + if max_weight != 0 + else 0 + if max_weight == max_change == 0 + else np.inf + ), + tol=tol, + ): + break + if ((max_weight != 0 and max_change / max_weight <= tol) or max_weight == 0 and max_change == 0): if verbose: diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 6134b6318c838..47553d07ac169 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -26,12 +26,13 @@ ) from .utils.deprecation import deprecated from .utils._tags import _safe_tags +from .utils.metaestimators import _BaseComposition from .utils.validation import check_memory from .utils.validation import check_is_fitted from .utils.fixes import delayed from .exceptions import NotFittedError +from .callback._base import _eval_callbacks_on_fit_iter_end -from .utils.metaestimators import _BaseComposition __all__ = ["Pipeline", "FeatureUnion", "make_pipeline", "make_union"] @@ -318,15 +319,24 @@ def _fit(self, X, y=None, **fit_params_steps): # Setup the memory memory = check_memory(self.memory) + root = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": len(self.steps)}, + {"descr": "step", "max_iter": None}, + ], + X=X, + y=y, + ) + fit_transform_one_cached = memory.cache(_fit_transform_one) - for (step_idx, name, transformer) in self._iter( + for (step_idx, name, transformer), node in zip(self._iter( with_final=False, filter_passthrough=False - ): + ), root.children[:-1]): if transformer is None or transformer == "passthrough": + _eval_callbacks_on_fit_iter_end(estimator=self, node=node) with _print_elapsed_time("Pipeline", self._log_message(step_idx)): continue - if hasattr(memory, "location"): # joblib >= 0.12 if memory.location is None: @@ -346,6 +356,7 @@ def _fit(self, X, y=None, **fit_params_steps): else: cloned_transformer = clone(transformer) # Fit or load from cache the current transformer + self._propagate_callbacks(cloned_transformer, parent_node=node) X, fitted_transformer = fit_transform_one_cached( cloned_transformer, X, @@ -359,6 +370,9 @@ def _fit(self, X, y=None, **fit_params_steps): # transformer. This is necessary when loading the transformer # from the cache. self.steps[step_idx] = (name, fitted_transformer) + + _eval_callbacks_on_fit_iter_end(estimator=self, node=node) + return X def fit(self, X, y=None, **fit_params): @@ -388,12 +402,20 @@ def fit(self, X, y=None, **fit_params): Pipeline with fitted steps. """ fit_params_steps = self._check_fit_params(**fit_params) + Xt = self._fit(X, y, **fit_params_steps) with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if self._final_estimator != "passthrough": + node = self._computation_tree.root.children[-1] + self._propagate_callbacks(self._final_estimator, parent_node=node) + fit_params_last_step = fit_params_steps[self.steps[-1][0]] self._final_estimator.fit(Xt, y, **fit_params_last_step) + _eval_callbacks_on_fit_iter_end(estimator=self, node=node) + + self._eval_callbacks_on_fit_end() + return self def fit_transform(self, X, y=None, **fit_params): diff --git a/sklearn/utils/optimize.py b/sklearn/utils/optimize.py index bd2ac8bdfd27d..2e3b6eb1c125b 100644 --- a/sklearn/utils/optimize.py +++ b/sklearn/utils/optimize.py @@ -18,6 +18,7 @@ from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1 from ..exceptions import ConvergenceWarning +from ..callback._base import _eval_callbacks_on_fit_iter_end class _LineSearchError(RuntimeError): @@ -120,6 +121,8 @@ def _newton_cg( maxinner=200, line_search=True, warn=True, + estimator=None, + parent_node=None, ): """ Minimization of scalar function of one or more variables using the @@ -168,20 +171,31 @@ def _newton_cg( """ x0 = np.asarray(x0).flatten() xk = x0 - k = 0 if line_search: old_fval = func(x0, *args) old_old_fval = None + nodes = parent_node.children if parent_node is not None else [None] * maxiter + # Outer loop: our Newton iteration - while k < maxiter: + for k, node in enumerate(nodes, 1): # Compute a search direction pk by applying the CG method to # del2 f(xk) p = - fgrad f(xk) starting from 0. fgrad, fhess_p = grad_hess(xk, *args) absgrad = np.abs(fgrad) - if np.max(absgrad) <= tol: + max_absgrad = np.max(absgrad) + + if _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=node, + stopping_criterion=lambda: max_absgrad, + tol=tol, + ): + break + + if max_absgrad <= tol: break maggrad = np.sum(absgrad) @@ -204,7 +218,6 @@ def _newton_cg( break xk = xk + alphak * xsupi # upcast if necessary - k += 1 if warn and k >= maxiter: warnings.warn( From 584bdf72f1dfa969eceb0b7ab3ffbba5cfcf6aea Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 17 Dec 2021 18:18:35 +0100 Subject: [PATCH 02/55] cln nmf and test reconstruction attributes --- sklearn/decomposition/_nmf.py | 18 +++++++-------- sklearn/decomposition/tests/test_nmf.py | 29 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index f53f33c6b804a..4fa46dd2cb12c 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -6,7 +6,6 @@ # Tom Dupre la Tour # License: BSD 3 clause -from functools import partial import numbers import numpy as np import scipy.sparse as sp @@ -504,9 +503,7 @@ def _fit_coordinate_descent( rng = check_random_state(random_state) - nodes = parent_node.children if parent_node is not None else [None] * max_iter - - for n_iter, node in enumerate(nodes, 1): + for n_iter in range(1, max_iter + 1): violation = 0.0 # Update W @@ -527,7 +524,7 @@ def _fit_coordinate_descent( if _eval_callbacks_on_fit_iter_end( estimator=estimator, - node=node, + node=parent_node.children[n_iter - 1] if parent_node is not None else None, stopping_criterion=lambda: violation / violation_init, tol=tol, fit_state={"H": Ht.T, "W": W}, @@ -838,10 +835,8 @@ def _fit_multiplicative_update( error_at_init = _beta_divergence(X, W, H, beta_loss, square_root=True) previous_error = error_at_init - nodes = parent_node.children if parent_node is not None else [None] * max_iter - H_sum, HHt, XHt = None, None, None - for n_iter, node in enumerate(nodes, 1): + for n_iter in range(1, max_iter + 1): # update W # H_sum, HHt and XHt are saved and reused if not update_H delta_W, H_sum, HHt, XHt = _multiplicative_update_w( @@ -869,7 +864,7 @@ def _fit_multiplicative_update( if _eval_callbacks_on_fit_iter_end( estimator=estimator, - node=node, + node=parent_node.children[n_iter - 1] if parent_node is not None else None, stopping_criterion=lambda: ( ( previous_error @@ -883,7 +878,7 @@ def _fit_multiplicative_update( "n_components_": H.shape[0], "components_": H, "n_iter_": n_iter, - "reconstruction_err_": _beta_divergence(X, W, H, 2, True), + "reconstruction_err_": _beta_divergence(X, W, H, beta_loss, True), }, ): break @@ -1594,6 +1589,9 @@ def fit_transform(self, X, y=None, W=None, H=None): W, H, n_iter = self._fit_transform(X, W=W, H=H, parent_node=root) + self.reconstruction_err_ = _beta_divergence( + X, W, H, self._beta_loss, square_root=True + ) self.n_components_ = H.shape[0] self.components_ = H self.n_iter_ = n_iter diff --git a/sklearn/decomposition/tests/test_nmf.py b/sklearn/decomposition/tests/test_nmf.py index c95b7ceb737db..7a58b64d6464d 100644 --- a/sklearn/decomposition/tests/test_nmf.py +++ b/sklearn/decomposition/tests/test_nmf.py @@ -1,4 +1,6 @@ +import pickle import re +import tempfile import numpy as np import scipy.sparse as sp @@ -18,6 +20,7 @@ from sklearn.utils.extmath import squared_norm from sklearn.base import clone from sklearn.exceptions import ConvergenceWarning +from sklearn.callback import Snapshot @pytest.mark.parametrize("solver", ["cd", "mu"]) @@ -719,3 +722,29 @@ def test_feature_names_out(): names = nmf.get_feature_names_out() assert_array_equal([f"nmf{i}" for i in range(3)], names) + + +@pytest.mark.parametrize("solver, beta_loss", [("mu", 0), ("mu", 2), ("cd", 2)]) +def test_nmf_callback_reconstruction_attributes(solver, beta_loss): + # Check that the reconstruction attributes passed to the callback allow to make + # a new estimator as if the fit ended when the callback is called. + X = np.random.RandomState(0).random_sample((100, 100)) + + nmf = NMF(n_components=3, solver=solver, beta_loss=beta_loss, random_state=0) + nmf.fit(X) + + with tempfile.TemporaryDirectory() as tmp_dir: + callback = Snapshot(base_dir=tmp_dir) + nmf._set_callbacks(callback) + nmf.fit(X) + + # load model from last iteration + snapshot = sorted(callback.directory.iterdir())[-1] + with open(snapshot, "rb") as f: + loaded_nmf = pickle.load(f) + + # The model loaded from the last iteration is the same as the original model + assert nmf.n_iter_ == loaded_nmf.n_iter_ + assert_allclose(nmf.components_, loaded_nmf.components_) + assert_allclose(nmf.reconstruction_err_, loaded_nmf.reconstruction_err_) + assert_allclose(nmf.transform(X), loaded_nmf.transform(X)) From bb32ff3bbcd798f1cd2e204c2437dc38359a36a0 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 20 Dec 2021 19:12:27 +0100 Subject: [PATCH 03/55] cln snapshot + test snapshot + uuid for computation tree --- sklearn/callback/_computation_tree.py | 10 +- sklearn/callback/_snapshot.py | 46 ++++----- sklearn/callback/tests/test_callbacks.py | 120 +++++++++++++++++++++++ sklearn/decomposition/tests/test_nmf.py | 9 +- 4 files changed, 151 insertions(+), 34 deletions(-) create mode 100644 sklearn/callback/tests/test_callbacks.py diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py index edd3c8f1f657f..161891ca32004 100644 --- a/sklearn/callback/_computation_tree.py +++ b/sklearn/callback/_computation_tree.py @@ -1,9 +1,10 @@ # License: BSD 3 clause -from tempfile import mkdtemp +import os from pathlib import Path import pickle -import os +from tempfile import mkdtemp +from uuid import uuid4 import numpy as np @@ -116,6 +117,9 @@ class ComputationTree: The path of the directory where the computation tree is dumped during the fit of its estimator. If it has a parent tree, this is a sub-directory of the `tree_dir` of its parent. + + uid : uuid.UUID + Unique indentifier for a ComputationTree instance. """ def __init__(self, estimator_name, levels, *, parent_node=None): @@ -125,6 +129,8 @@ def __init__(self, estimator_name, levels, *, parent_node=None): self.depth = len(levels) - 1 self.root, self.n_nodes = self._build_tree(levels) + self.uid = uuid4() + parent_tree_dir = ( None if self.parent_node is None diff --git a/sklearn/callback/_snapshot.py b/sklearn/callback/_snapshot.py index 231eafc8cbb9e..99a1bcc0ce68a 100644 --- a/sklearn/callback/_snapshot.py +++ b/sklearn/callback/_snapshot.py @@ -22,11 +22,6 @@ class Snapshot(BaseCallback): base_dir : str or pathlib.Path instance, default=None The directory where the snapshots should be stored. If None, they are stored in the current directory. - - Attributes - ---------- - directory : pathlib.Path instance - The directory where the snapshots are saved. It's a sub-directory of `base_dir`. """ request_reconstruction_attributes = True @@ -42,41 +37,36 @@ def __init__(self, keep_last_n=1, base_dir=None): self.base_dir = Path("." if base_dir is None else base_dir) def on_fit_begin(self, estimator, X=None, y=None): - self.estimator = estimator - - # Use a hash in the name of this directory to avoid name collision if several - # clones of this estimator are fitted in parallel in a meta-estimator for - # instance. - dir_name = ( - "snapshots_" - f"{self.estimator.__class__.__name__}_" - f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}_" - f"{hash(self.estimator._computation_tree)}" - ) - - self.directory = self.base_dir / dir_name - self.directory.mkdir() + subdir = self._get_subdir(estimator._computation_tree) + subdir.mkdir() - def on_fit_iter_end(self, *, node, **kwargs): + def on_fit_iter_end(self, *, estimator, node, **kwargs): reconstruction_attributes = kwargs.get("reconstruction_attributes", None) if reconstruction_attributes is None: return - new_estimator = copy(self.estimator) + new_estimator = copy(estimator) for key, val in reconstruction_attributes.items(): setattr(new_estimator, key, val) - file_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}.pkl" - file_path = self.directory / file_name + subdir = self._get_subdir(node.computation_tree) + snapshot_filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}.pkl" - with open(file_path, "wb") as f: + with open(subdir / snapshot_filename, "wb") as f: pickle.dump(new_estimator, f) if self.keep_last_n is not None: - for snapshot in sorted(self.directory.iterdir())[: -self.keep_last_n]: + for snapshot in sorted(subdir.iterdir())[: -self.keep_last_n]: snapshot.unlink(missing_ok=True) def on_fit_end(self): - if self.keep_last_n is not None: - for snapshot in sorted(self.directory.iterdir())[: -self.keep_last_n]: - snapshot.unlink() + pass + + def _get_subdir(self, computation_tree): + """Return the sub directory containing the snapshots of the estimator""" + subdir = ( + self.base_dir + / f"snapshots_{computation_tree.estimator_name}_{str(computation_tree.uid)}" + ) + + return subdir diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py new file mode 100644 index 0000000000000..17dafc616f457 --- /dev/null +++ b/sklearn/callback/tests/test_callbacks.py @@ -0,0 +1,120 @@ +# License: BSD 3 clause + +import pickle +import pytest +import tempfile +from time import sleep + +from joblib import Parallel, delayed + +from sklearn.base import BaseEstimator, clone +from sklearn.callback import Snapshot +from sklearn.callback._base import _eval_callbacks_on_fit_iter_end +from sklearn.datasets import make_classification + + +class Estimator(BaseEstimator): + def __init__(self, max_iter=20): + self.max_iter = max_iter + + def fit(self, X, y): + root = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.max_iter}, + {"descr": "iter", "max_iter": None}, + ], + X=X, + y=y, + ) + + for i in range(self.max_iter): + if _eval_callbacks_on_fit_iter_end( + estimator=self, + node=root.children[i], + reconstruction_attributes=lambda: {"n_iter_": i + 1}, + ): + break + + self.n_iter_ = i + 1 + + self._eval_callbacks_on_fit_end() + + return self + + +class MetaEstimator(BaseEstimator): + def __init__( + self, estimator, n_outer=4, n_inner=3, n_jobs=None, prefer="processes" + ): + self.estimator = estimator + self.n_outer = n_outer + self.n_inner = n_inner + self.n_jobs = n_jobs + self.prefer = prefer + + def fit(self, X, y): + root = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.n_outer}, + {"descr": "outer", "max_iter": self.n_inner}, + {"descr": "inner", "max_iter": None}, + ], + X=X, + y=y, + ) + + res = Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( + delayed(self._func)(self.estimator, X, y, node, i) + for i, node in enumerate(root.children) + ) + + self._eval_callbacks_on_fit_end() + + return self + + def _func(self, estimator, X, y, parent_node, i): + for j, node in enumerate(parent_node.children): + est = clone(estimator) + self._propagate_callbacks(est, parent_node=node) + est.fit(X, y) + + _eval_callbacks_on_fit_iter_end(estimator=self, node=node) + + _eval_callbacks_on_fit_iter_end(estimator=self, node=parent_node) + + return + + +@pytest.mark.parametrize("n_jobs", (1, 2)) +@pytest.mark.parametrize("prefer", ("threads", "processes")) +def test_snapshot_meta_estimator(n_jobs, prefer): + # Test for the Snapshot callback + X, y = make_classification() + estimator = Estimator(max_iter=20) + + with tempfile.TemporaryDirectory() as tmp_dir: + keep_last_n = 5 + callback = Snapshot(keep_last_n=keep_last_n, base_dir=tmp_dir) + estimator._set_callbacks(callback) + metaestimator = MetaEstimator( + estimator=estimator, n_outer=4, n_inner=3, n_jobs=n_jobs, prefer=prefer + ) + + metaestimator.fit(X, y) + + # There's a subdir of base_dir for each clone of estimator fitted in + # metaestimator. There are n_outer * n_inner such clones + snapshot_dirs = list(callback.base_dir.iterdir()) + assert len(snapshot_dirs) == metaestimator.n_outer * metaestimator.n_inner + + for snapshot_dir in snapshot_dirs: + snapshots = sorted(snapshot_dir.iterdir()) + assert len(snapshots) == keep_last_n + + for i, snapshot in enumerate(snapshots): + with open(snapshot, "rb") as f: + loaded_estimator = pickle.load(f) + + # We kept last 5 snapshots out of 20 iterations. + # This one is the 16 + i-th. + assert loaded_estimator.n_iter_ == 16 + i diff --git a/sklearn/decomposition/tests/test_nmf.py b/sklearn/decomposition/tests/test_nmf.py index 7a58b64d6464d..a1ef1e90792af 100644 --- a/sklearn/decomposition/tests/test_nmf.py +++ b/sklearn/decomposition/tests/test_nmf.py @@ -728,9 +728,9 @@ def test_feature_names_out(): def test_nmf_callback_reconstruction_attributes(solver, beta_loss): # Check that the reconstruction attributes passed to the callback allow to make # a new estimator as if the fit ended when the callback is called. - X = np.random.RandomState(0).random_sample((100, 100)) + X = np.random.RandomState(0).random_sample((100, 20)) - nmf = NMF(n_components=3, solver=solver, beta_loss=beta_loss, random_state=0) + nmf = NMF(n_components=5, solver=solver, beta_loss=beta_loss, random_state=0) nmf.fit(X) with tempfile.TemporaryDirectory() as tmp_dir: @@ -739,11 +739,12 @@ def test_nmf_callback_reconstruction_attributes(solver, beta_loss): nmf.fit(X) # load model from last iteration - snapshot = sorted(callback.directory.iterdir())[-1] + snapshot_dir = next(callback.base_dir.iterdir()) + snapshot = sorted(snapshot_dir.iterdir())[-1] with open(snapshot, "rb") as f: loaded_nmf = pickle.load(f) - # The model loaded from the last iteration is the same as the original model + # The model saved during the last iteration is the same as the original model assert nmf.n_iter_ == loaded_nmf.n_iter_ assert_allclose(nmf.components_, loaded_nmf.components_) assert_allclose(nmf.reconstruction_err_, loaded_nmf.reconstruction_err_) From 7a1825db4c9d2a3a7170235fd95fdd7747c3ff96 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 31 Dec 2021 17:20:55 +0100 Subject: [PATCH 04/55] cln --- sklearn/base.py | 14 +++++++++++ sklearn/callback/_base.py | 16 ++++++------ sklearn/callback/_snapshot.py | 11 +++------ sklearn/callback/tests/test_callbacks.py | 6 ++++- sklearn/decomposition/_nmf.py | 31 +++++++++++++++--------- sklearn/linear_model/_logistic.py | 11 +++------ sklearn/utils/optimize.py | 8 +++--- 7 files changed, 57 insertions(+), 40 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 4f6b63cb2add1..7823e61f63c1e 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -733,6 +733,20 @@ def _eval_callbacks_on_fit_end(self): # propagated from a meta-estimator. callback.on_fit_end() + def _from_reconstruction_attributes(self, *, reconstruction_attributes): + """ + + Parameters + ---------- + reconstruction_attributes : callable + The necessary fitted attributes to create a working fitted estimator from + this instance. + """ + new_estimator = copy.copy(self) + for key, val in reconstruction_attributes().items(): + setattr(new_estimator, key, val) + return new_estimator + @property def _repr_html_(self): """HTML representation of estimator. diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 604a450336610..a473f172fd575 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -38,11 +38,11 @@ def _eval_callbacks_on_fit_iter_end(**kwargs): kwargs["stopping_criterion"] = kwarg if any( - getattr(callback, "request_reconstruction_attributes", False) + getattr(callback, "request_from_reconstruction_attributes", False) for callback in estimator._callbacks ): - kwarg = kwargs.pop("reconstruction_attributes", lambda: None)() - kwargs["reconstruction_attributes"] = kwarg + kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() + kwargs["from_reconstruction_attributes"] = kwarg return any(callback.on_fit_iter_end(**kwargs) for callback in estimator._callbacks) @@ -94,11 +94,11 @@ def on_fit_iter_end(self, estimator, node, **kwargs): Tolerance for the stopping criterion. This is only provided at the innermost level of iterations. - - reconstruction_attributes: dict - Necessary attributes to construct an estimator (by copying this - estimator and setting these as attributes) which will behave as if - the fit stopped at this node. - This is only provided at the outermost level of iterations. + - from_reconstruction_attributes: estimator instance + A ready to predict, transform, etc ... estimator as if the fit stopped + at this node. Usually it's a copy of the caller estimator with the + necessary attributes set but it can sometimes be an instance of another + class (e.g. LogisticRegressionCV -> LogisticRegression) - fit_state: dict Model specific quantities updated during fit. This is not meant to be diff --git a/sklearn/callback/_snapshot.py b/sklearn/callback/_snapshot.py index 99a1bcc0ce68a..cbf200336c749 100644 --- a/sklearn/callback/_snapshot.py +++ b/sklearn/callback/_snapshot.py @@ -1,6 +1,5 @@ # License: BSD 3 clause -from copy import copy from datetime import datetime from pathlib import Path import pickle @@ -24,7 +23,7 @@ class Snapshot(BaseCallback): the current directory. """ - request_reconstruction_attributes = True + request_from_reconstruction_attributes = True def __init__(self, keep_last_n=1, base_dir=None): self.keep_last_n = keep_last_n @@ -41,14 +40,10 @@ def on_fit_begin(self, estimator, X=None, y=None): subdir.mkdir() def on_fit_iter_end(self, *, estimator, node, **kwargs): - reconstruction_attributes = kwargs.get("reconstruction_attributes", None) - if reconstruction_attributes is None: + new_estimator = kwargs.get("from_reconstruction_attributes", None) + if new_estimator is None: return - new_estimator = copy(estimator) - for key, val in reconstruction_attributes.items(): - setattr(new_estimator, key, val) - subdir = self._get_subdir(node.computation_tree) snapshot_filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}.pkl" diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py index 17dafc616f457..c43241f469f8c 100644 --- a/sklearn/callback/tests/test_callbacks.py +++ b/sklearn/callback/tests/test_callbacks.py @@ -1,5 +1,6 @@ # License: BSD 3 clause +from functools import partial import pickle import pytest import tempfile @@ -31,7 +32,10 @@ def fit(self, X, y): if _eval_callbacks_on_fit_iter_end( estimator=self, node=root.children[i], - reconstruction_attributes=lambda: {"n_iter_": i + 1}, + from_reconstruction_attributes=partial( + self._from_reconstruction_attributes, + reconstruction_attributes=lambda : {"n_iter_": i + 1}, + ) ): break diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index 4fa46dd2cb12c..154dbb3db6532 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -6,6 +6,7 @@ # Tom Dupre la Tour # License: BSD 3 clause +from functools import partial import numbers import numpy as np import scipy.sparse as sp @@ -528,12 +529,15 @@ def _fit_coordinate_descent( stopping_criterion=lambda: violation / violation_init, tol=tol, fit_state={"H": Ht.T, "W": W}, - reconstruction_attributes=lambda: { - "n_components_": Ht.T.shape[0], - "components_": H, - "n_iter_": n_iter, - "reconstruction_err_": _beta_divergence(X, W, Ht.T, 2, True), - }, + from_reconstruction_attributes=partial( + estimator._from_reconstruction_attributes, + reconstruction_attributes=lambda : { + "n_components_": Ht.T.shape[0], + "components_": H, + "n_iter_": n_iter, + "reconstruction_err_": _beta_divergence(X, W, Ht.T, 2, True), + } + ), ): break @@ -874,12 +878,15 @@ def _fit_multiplicative_update( ), tol=tol, fit_state={"H": H, "W": W}, - reconstruction_attributes=lambda: { - "n_components_": H.shape[0], - "components_": H, - "n_iter_": n_iter, - "reconstruction_err_": _beta_divergence(X, W, H, beta_loss, True), - }, + from_reconstruction_attributes=partial( + estimator._from_reconstruction_attributes, + reconstruction_attributes=lambda { + "n_components_": H.shape[0], + "components_": H, + "n_iter_": n_iter, + "reconstruction_err_": _beta_divergence(X, W, H, beta_loss, True), + } + ), ): break diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 82063f36d0434..540f1a656c077 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -799,15 +799,12 @@ def grad(x, *args): hess = _logistic_grad_hess warm_start_sag = {"coef": np.expand_dims(w0, axis=1)} - # Distinguish between LogReg and LogRegCV - if parent_node is not None: - nodes = [parent_node] if len(Cs) == 1 else parent_node.children - else: - nodes = [None] * len(Cs) - coefs = list() n_iter = np.zeros(len(Cs), dtype=np.int32) - for i, (C, node) in enumerate(zip(Cs, nodes)): + for i, C in enumerate(Cs): + # Distinguish between LogReg and LogRegCV + node = None if parent_node is None else parent_node if len(Cs) == 1 else parent_node.children + if solver == "lbfgs": iprint = [-1, 50, 1, 100, 101][ np.searchsorted(np.array([0, 1, 2, 3]), verbose) diff --git a/sklearn/utils/optimize.py b/sklearn/utils/optimize.py index 2e3b6eb1c125b..b634f457bd287 100644 --- a/sklearn/utils/optimize.py +++ b/sklearn/utils/optimize.py @@ -171,15 +171,14 @@ def _newton_cg( """ x0 = np.asarray(x0).flatten() xk = x0 + k = 0 if line_search: old_fval = func(x0, *args) old_old_fval = None - nodes = parent_node.children if parent_node is not None else [None] * maxiter - # Outer loop: our Newton iteration - for k, node in enumerate(nodes, 1): + while k < maxiter: # Compute a search direction pk by applying the CG method to # del2 f(xk) p = - fgrad f(xk) starting from 0. fgrad, fhess_p = grad_hess(xk, *args) @@ -189,7 +188,7 @@ def _newton_cg( if _eval_callbacks_on_fit_iter_end( estimator=estimator, - node=node, + node=None if parent_node is None else parent_node.children[k], stopping_criterion=lambda: max_absgrad, tol=tol, ): @@ -218,6 +217,7 @@ def _newton_cg( break xk = xk + alphak * xsupi # upcast if necessary + k += 1 if warn and k >= maxiter: warnings.warn( From 3e3b25f3d5202a3a56a7fdfcc22e373a538a30bd Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 31 Dec 2021 17:25:57 +0100 Subject: [PATCH 05/55] black --- sklearn/callback/tests/test_callbacks.py | 4 ++-- sklearn/decomposition/_nmf.py | 8 ++++---- sklearn/decomposition/tests/test_nmf.py | 2 +- sklearn/linear_model/_logistic.py | 8 +++++++- sklearn/pipeline.py | 5 +++-- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py index c43241f469f8c..1f5fcf6bdd3c4 100644 --- a/sklearn/callback/tests/test_callbacks.py +++ b/sklearn/callback/tests/test_callbacks.py @@ -34,8 +34,8 @@ def fit(self, X, y): node=root.children[i], from_reconstruction_attributes=partial( self._from_reconstruction_attributes, - reconstruction_attributes=lambda : {"n_iter_": i + 1}, - ) + reconstruction_attributes=lambda: {"n_iter_": i + 1}, + ), ): break diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index 154dbb3db6532..f63146dc11250 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -531,12 +531,12 @@ def _fit_coordinate_descent( fit_state={"H": Ht.T, "W": W}, from_reconstruction_attributes=partial( estimator._from_reconstruction_attributes, - reconstruction_attributes=lambda : { + reconstruction_attributes=lambda: { "n_components_": Ht.T.shape[0], "components_": H, "n_iter_": n_iter, "reconstruction_err_": _beta_divergence(X, W, Ht.T, 2, True), - } + }, ), ): break @@ -880,12 +880,12 @@ def _fit_multiplicative_update( fit_state={"H": H, "W": W}, from_reconstruction_attributes=partial( estimator._from_reconstruction_attributes, - reconstruction_attributes=lambda { + reconstruction_attributes=lambda: { "n_components_": H.shape[0], "components_": H, "n_iter_": n_iter, "reconstruction_err_": _beta_divergence(X, W, H, beta_loss, True), - } + }, ), ): break diff --git a/sklearn/decomposition/tests/test_nmf.py b/sklearn/decomposition/tests/test_nmf.py index a1ef1e90792af..c84ee43175df4 100644 --- a/sklearn/decomposition/tests/test_nmf.py +++ b/sklearn/decomposition/tests/test_nmf.py @@ -746,6 +746,6 @@ def test_nmf_callback_reconstruction_attributes(solver, beta_loss): # The model saved during the last iteration is the same as the original model assert nmf.n_iter_ == loaded_nmf.n_iter_ - assert_allclose(nmf.components_, loaded_nmf.components_) + assert_allclose(nmf.components_, loaded_nmf.components_) assert_allclose(nmf.reconstruction_err_, loaded_nmf.reconstruction_err_) assert_allclose(nmf.transform(X), loaded_nmf.transform(X)) diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 540f1a656c077..1d4bbc815bb3d 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -803,7 +803,13 @@ def grad(x, *args): n_iter = np.zeros(len(Cs), dtype=np.int32) for i, C in enumerate(Cs): # Distinguish between LogReg and LogRegCV - node = None if parent_node is None else parent_node if len(Cs) == 1 else parent_node.children + node = ( + None + if parent_node is None + else parent_node + if len(Cs) == 1 + else parent_node.children + ) if solver == "lbfgs": iprint = [-1, 50, 1, 100, 101][ diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 47553d07ac169..657ba79307ce3 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -330,9 +330,10 @@ def _fit(self, X, y=None, **fit_params_steps): fit_transform_one_cached = memory.cache(_fit_transform_one) - for (step_idx, name, transformer), node in zip(self._iter( + for (step_idx, name, transformer), in self._iter( with_final=False, filter_passthrough=False - ), root.children[:-1]): + ): + node = root.children[step_idx] if transformer is None or transformer == "passthrough": _eval_callbacks_on_fit_iter_end(estimator=self, node=node) with _print_elapsed_time("Pipeline", self._log_message(step_idx)): From 26dbb6954c4155daf2ce7b09b91911379bf4705f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 31 Dec 2021 17:33:47 +0100 Subject: [PATCH 06/55] lint --- sklearn/base.py | 4 ++-- sklearn/callback/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 7823e61f63c1e..c14a5d314a502 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -617,7 +617,7 @@ def _set_callbacks(self, callbacks): callbacks = [callbacks] if not all(isinstance(callback, BaseCallback) for callback in callbacks): - raise TypeError(f"callbacks must be subclasses of BaseCallback.") + raise TypeError("callbacks must be subclasses of BaseCallback.") self._callbacks = callbacks @@ -734,7 +734,7 @@ def _eval_callbacks_on_fit_end(self): callback.on_fit_end() def _from_reconstruction_attributes(self, *, reconstruction_attributes): - """ + """Return a as if fitted copy of this estimator Parameters ---------- diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py index 1f0f3f7215a18..c8d5ea0bf0606 100644 --- a/sklearn/callback/__init__.py +++ b/sklearn/callback/__init__.py @@ -13,7 +13,7 @@ __all__ = [ "AutoPropagatedMixin", - "Basecallback", + "BaseCallback", "ComputationNode", "ComputationTree", "load_computation_tree", From eb7b8246d5fc1e770cc0d5d98b1f6130d6fba461 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 14 Feb 2022 16:02:42 +0100 Subject: [PATCH 07/55] wip --- sklearn/callback/_convergence_monitor.py | 35 ++++++++++++++---------- sklearn/pipeline.py | 2 +- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/sklearn/callback/_convergence_monitor.py b/sklearn/callback/_convergence_monitor.py index 9f53d657cc75a..ac04335e04661 100644 --- a/sklearn/callback/_convergence_monitor.py +++ b/sklearn/callback/_convergence_monitor.py @@ -3,20 +3,21 @@ from copy import copy from pathlib import Path from tempfile import mkdtemp -import time import matplotlib.pyplot as plt import pandas as pd from . import BaseCallback +# import ..metrics as metrics + class ConvergenceMonitor(BaseCallback): """Monitor model convergence. Parameters ---------- - monitor : + monitor : X_val : ndarray, default=None Validation data @@ -33,37 +34,41 @@ class ConvergenceMonitor(BaseCallback): request_reconstruction_attributes = True def __init__(self, *, monitor="objective_function", X_val=None, y_val=None): + if monitor == "objective_function": + self._monitor = "objective_function" + else: + self._monitor = getattr(metrics, monitor, None) + if self._monitor is None: + raise ValueError(f"unknown metric {monitor}") + self.X_val = X_val self.y_val = y_val + self._data_file = Path(mkdtemp()) / "convergence_monitor.csv" def on_fit_begin(self, estimator, *, X=None, y=None): self.estimator = estimator self.X_train = X self.y_train = y - self._start_time = {} - - def on_fit_iter_end(self, *, node, **kwargs): - if node.depth != node.computation_tree.depth: - return + def on_fit_iter_end(self, *, estimator, node, **kwargs): reconstruction_attributes = kwargs.get("reconstruction_attributes", None) if reconstruction_attributes is None: return - new_estimator = copy(self.estimator) + new_estimator = copy(estimator) for key, val in reconstruction_attributes.items(): setattr(new_estimator, key, val) - if node.idx == 0: - self._start_time[node.parent] = time.perf_counter() - curr_time = 0 - else: - curr_time = time.perf_counter() - self._start_time[node.parent] + # if self._monitor = - obj_train, *_ = new_estimator.objective_function(self.X_train, self.y_train, normalize=True) + obj_train, *_ = new_estimator.objective_function( + self.X_train, self.y_train, normalize=True + ) if self.X_val is not None: - obj_val, *_ = new_estimator.objective_function(self.X_val, self.y_val, normalize=True) + obj_val, *_ = new_estimator.objective_function( + self.X_val, self.y_val, normalize=True + ) else: obj_val = None diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 657ba79307ce3..96a4738a9196a 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -330,7 +330,7 @@ def _fit(self, X, y=None, **fit_params_steps): fit_transform_one_cached = memory.cache(_fit_transform_one) - for (step_idx, name, transformer), in self._iter( + for (step_idx, name, transformer) in self._iter( with_final=False, filter_passthrough=False ): node = root.children[step_idx] From f78442ebc9895210f34905c96e6ca7fd4d2b6e3a Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 23 Feb 2022 14:46:17 +0100 Subject: [PATCH 08/55] class --- sklearn/model_selection/_search.py | 73 +++++++++++++++++++++++--- sklearn/model_selection/_validation.py | 8 +++ 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 5ceb71569b932..fc16eefe8070f 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -15,6 +15,7 @@ from collections.abc import Mapping, Sequence, Iterable from functools import partial, reduce from itertools import product +from itertools import cycle import numbers import operator import time @@ -23,6 +24,7 @@ import numpy as np from numpy.ma import MaskedArray from scipy.stats import rankdata +from joblib import Parallel from ..base import BaseEstimator, is_classifier, clone from ..base import MetaEstimatorMixin @@ -33,7 +35,6 @@ from ._validation import _normalize_score_results from ._validation import _warn_or_raise_about_fit_failures from ..exceptions import NotFittedError -from joblib import Parallel from ..utils import check_random_state from ..utils.random import sample_without_replacement from ..utils._tags import _safe_tags @@ -783,7 +784,7 @@ def fit(self, X, y=None, *, groups=None, **fit_params): X, y, groups = indexable(X, y, groups) fit_params = _check_fit_params(X, fit_params) - cv_orig = check_cv(self.cv, y, classifier=is_classifier(estimator)) + cv_orig = self._checked_cv_orig n_splits = cv_orig.get_n_splits(X, y, groups) base_estimator = clone(self.estimator) @@ -806,7 +807,7 @@ def fit(self, X, y=None, *, groups=None, **fit_params): all_out = [] all_more_results = defaultdict(list) - def evaluate_candidates(candidate_params, cv=None, more_results=None): + def evaluate_candidates(candidate_params, cv=None, more_results=None, parent_node=None): cv = cv or cv_orig candidate_params = list(candidate_params) n_candidates = len(candidate_params) @@ -819,6 +820,11 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None): ) ) + if parent_node is not None: + nodes = parent_node.children + else: + nodes = cycle([None]) + out = parallel( delayed(_fit_and_score)( clone(base_estimator), @@ -830,10 +836,11 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None): split_progress=(split_idx, n_splits), candidate_progress=(cand_idx, n_candidates), **fit_and_score_kwargs, + caller=self, + node=node, ) - for (cand_idx, parameters), (split_idx, (train, test)) in product( - enumerate(candidate_params), enumerate(cv.split(X, y, groups)) - ) + for ((cand_idx, parameters), (split_idx, (train, test))), node in zip(product( + enumerate(candidate_params), enumerate(cv.split(X, y, groups))), nodes) ) if len(out) < 1: @@ -1370,10 +1377,60 @@ def __init__( ) self.param_grid = param_grid + def fit(self, X, y=None, *, groups=None, **fit_params): + """Run fit with all sets of parameters. + + Parameters + ---------- + + X : array-like of shape (n_samples, n_features) + Training vector, where `n_samples` is the number of samples and + `n_features` is the number of features. + + y : array-like of shape (n_samples, n_output) or (n_samples,), default=None + Target relative to X for classification or regression; + None for unsupervised learning. + + groups : array-like of shape (n_samples,), default=None + Group labels for the samples used while splitting the dataset into + train/test set. Only used in conjunction with a "Group" :term:`cv` + instance (e.g., :class:`~sklearn.model_selection.GroupKFold`). + + **fit_params : dict of str -> object + Parameters passed to the `fit` method of the estimator. + + If a fit parameter is an array-like whose length is equal to + `num_samples` then it will be split across CV groups along with `X` + and `y`. For example, the :term:`sample_weight` parameter is split + because `len(sample_weights) = len(X)`. + + Returns + ------- + self : object + Instance of fitted estimator. + """ + self._param_grid = ParameterGrid(self.param_grid) + + self._checked_cv_orig = check_cv( + self.cv, y, classifier=is_classifier(self.estimator) + ) + n_splits = self._checked_cv_orig.get_n_splits(X, y, groups) + + self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": len(self._param_grid) * n_splits}, + {"descr": "param - fold", "max_iter": None}, + ], + X=X, + y=y, + ) + super().fit(X, y=y, groups=groups, **fit_params) + + self._eval_callbacks_on_fit_end() + def _run_search(self, evaluate_candidates): """Search all candidates in param_grid""" - evaluate_candidates(ParameterGrid(self.param_grid)) - + evaluate_candidates(self._param_grid, parent_node=self._computation_tree.root) class RandomizedSearchCV(BaseSearchCV): """Randomized search on hyper parameters. diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 927fe7a2cc452..6bf61bf246302 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -33,6 +33,7 @@ from ..exceptions import FitFailedWarning from ._split import check_cv from ..preprocessing import LabelEncoder +from ..callback._base import _eval_callbacks_on_fit_iter_end __all__ = [ @@ -547,6 +548,8 @@ def _fit_and_score( split_progress=None, candidate_progress=None, error_score=np.nan, + caller=None, + node=None, ): """Fit estimator and compute scores for a given dataset split. @@ -673,6 +676,9 @@ def _fit_and_score( cloned_parameters[k] = clone(v, safe=False) estimator = estimator.set_params(**cloned_parameters) + + if caller is not None: + caller._propagate_callbacks(estimator, parent_node=node) start_time = time.time() @@ -736,6 +742,8 @@ def _fit_and_score( end_msg += result_msg print(end_msg) + _eval_callbacks_on_fit_iter_end(estimator=caller, node=node) + result["test_scores"] = test_scores if return_train_score: result["train_scores"] = train_scores From 34bab15d7feb1dc191df80d731f03c5454244011 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 23 Feb 2022 17:57:36 +0100 Subject: [PATCH 09/55] more tests --- sklearn/base.py | 47 +++++---- sklearn/callback/_computation_tree.py | 8 +- .../test_base_estimator_callback_methods.py | 95 +++++++++++++++++++ sklearn/callback/tests/test_callbacks.py | 89 ++--------------- .../callback/tests/test_computation_tree.py | 33 ++++--- 5 files changed, 161 insertions(+), 111 deletions(-) create mode 100644 sklearn/callback/tests/test_base_estimator_callback_methods.py diff --git a/sklearn/base.py b/sklearn/base.py index cf3459267d13a..c542332280a07 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -616,6 +616,11 @@ def _set_callbacks(self, callbacks): ---------- callbacks : callback or list of callbacks the callbacks to set. + + Returns + ------- + self : estimator instance + The estimator instance itself. """ if not isinstance(callbacks, list): callbacks = [callbacks] @@ -625,9 +630,11 @@ def _set_callbacks(self, callbacks): self._callbacks = callbacks + return self + # XXX should be a method of MetaEstimatorMixin but this mixin can't handle all # meta-estimators. - def _propagate_callbacks(self, sub_estimator, parent_node): + def _propagate_callbacks(self, sub_estimator, *, parent_node): """Propagate the auto-propagated callbacks to a sub-estimator Parameters @@ -640,9 +647,6 @@ def _propagate_callbacks(self, sub_estimator, parent_node): computation tree of the sub-estimator. It must be the node where the fit method of the sub-estimator is called. """ - if not hasattr(self, "_callbacks"): - return - if hasattr(sub_estimator, "_callbacks") and any( isinstance(callback, AutoPropagatedMixin) for callback in sub_estimator._callbacks @@ -659,6 +663,9 @@ def _propagate_callbacks(self, sub_estimator, parent_node): " Set them directly on the meta-estimator." ) + if not hasattr(self, "_callbacks"): + return + propagated_callbacks = [ callback for callback in self._callbacks @@ -668,7 +675,7 @@ def _propagate_callbacks(self, sub_estimator, parent_node): if not propagated_callbacks: return - sub_estimator._parent_node = parent_node + sub_estimator._parent_ct_node = parent_node if not hasattr(sub_estimator, "_callbacks"): sub_estimator._callbacks = propagated_callbacks @@ -702,7 +709,7 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): self._computation_tree = ComputationTree( estimator_name=self.__class__.__name__, levels=levels, - parent_node=getattr(self, "_parent_node", None), + parent_node=getattr(self, "_parent_ct_node", None), ) if hasattr(self, "_callbacks"): @@ -710,13 +717,13 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): with open(file_path, "wb") as f: pickle.dump(self._computation_tree, f) + # Only call the on_fit_begin method of callbacks that are not + # propagated from a meta-estimator. for callback in self._callbacks: - is_propagated = hasattr(self, "_parent_node") and isinstance( + is_propagated = hasattr(self, "_parent_ct_node") and isinstance( callback, AutoPropagatedMixin ) if not is_propagated: - # Only call the on_fit_begin method of callbacks that are not - # propagated from a meta-estimator. callback.on_fit_begin(estimator=self, X=X, y=y) return self._computation_tree.root @@ -728,25 +735,33 @@ def _eval_callbacks_on_fit_end(self): self._computation_tree._tree_status[0] = True + # Only call the on_fit_end method of callbacks that are not + # propagated from a meta-estimator. for callback in self._callbacks: is_propagated = isinstance(callback, AutoPropagatedMixin) and hasattr( - self, "_parent_node" + self, "_parent_ct_node" ) if not is_propagated: - # Only call the on_fit_end method of callbacks that are not - # propagated from a meta-estimator. callback.on_fit_end() def _from_reconstruction_attributes(self, *, reconstruction_attributes): - """Return a as if fitted copy of this estimator + """Return an as if fitted copy of this estimator Parameters ---------- reconstruction_attributes : callable - The necessary fitted attributes to create a working fitted estimator from - this instance. + A callable that has no arguments and returns the necessary fitted attributes + to create a working fitted estimator from this instance. + + Using a callable allows lazy evaluation of the potentially costly + reconstruction attributes. + + Returns + ------- + fitted_estimator : estimator instance + The fitted copy of this estimator. """ - new_estimator = copy.copy(self) + new_estimator = copy.copy(self) # XXX deepcopy ? for key, val in reconstruction_attributes().items(): setattr(new_estimator, key, val) return new_estimator diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py index 161891ca32004..a69a8788e26c5 100644 --- a/sklearn/callback/_computation_tree.py +++ b/sklearn/callback/_computation_tree.py @@ -88,6 +88,12 @@ def get_ancestors(self, include_ancestor_trees=True): return ancestors + def __repr__(self): + return ( + f"ComputationNode(description={self.description}, " + f"depth={self.depth}, idx={self.idx})" + ) + class ComputationTree: """Data structure to store the computation tree of an estimator @@ -221,7 +227,7 @@ def iterate(self, include_leaves=False): def _recursive_iterate(self, node=None, include_leaves=False, node_list=None): """Recursively constructs the iterable""" - # TODO make it a generator + # TODO make it an iterator ? if node is None: node = self.root node_list = [] diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py new file mode 100644 index 0000000000000..676f0a5cfdd0e --- /dev/null +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -0,0 +1,95 @@ +# License: BSD 3 clause + +from pathlib import Path +import pytest + +from sklearn.callback.tests._utils import TestingCallback +from sklearn.callback.tests._utils import TestingAutoPropagatedCallback +from sklearn.callback.tests._utils import NotValidCallback +from sklearn.callback.tests._utils import Estimator +from sklearn.callback.tests._utils import MetaEstimator + + +@pytest.mark.parametrize("callbacks", + [ + TestingCallback(), + [TestingCallback()], + [TestingCallback(), TestingAutoPropagatedCallback()], + ] +) +def test_set_callbacks(callbacks): + """Sanity check for the _set_callbacks method""" + estimator = Estimator() + + set_callbacks_return = estimator._set_callbacks(callbacks) + assert hasattr(estimator, "_callbacks") + assert estimator._callbacks in (callbacks, [callbacks]) + assert set_callbacks_return is estimator + + +@pytest.mark.parametrize("callbacks", [None, NotValidCallback()]) +def test_set_callbacks_error(callbacks): + """Check the error message when not passing a valid callback to _set_callbacks""" + estimator = Estimator() + + with pytest.raises(TypeError, match="callbacks must be subclasses of BaseCallback"): + estimator._set_callbacks(callbacks) + + +def test_propagate_callbacks(): + """Sanity check for the _propagate_callbacks method""" + not_propagated_callback = TestingCallback() + propagated_callback = TestingAutoPropagatedCallback() + + estimator = Estimator() + estimator._set_callbacks([not_propagated_callback, propagated_callback]) + + sub_estimator = Estimator() + estimator._propagate_callbacks(sub_estimator, parent_node=None) + + assert hasattr(sub_estimator, "_parent_ct_node") + assert not_propagated_callback not in sub_estimator._callbacks + assert propagated_callback in sub_estimator._callbacks + + +def test_propagate_callback_no_callback(): + """Check that no callback is propagated if there's no callback""" + estimator = Estimator() + sub_estimator = Estimator() + estimator._propagate_callbacks(sub_estimator, parent_node=None) + + assert not hasattr(estimator, "_callbacks") + assert not hasattr(sub_estimator, "_callbacks") + + +def test_auto_propagated_callbacks(): + """Check that it's not possible to set an auto-propagated callback on the + sub-estimator of a meta-estimator. + """ + estimator = Estimator() + estimator._set_callbacks(TestingAutoPropagatedCallback()) + + meta_estimator = MetaEstimator(estimator=estimator) + + match = ( + r"sub-estimators .*of a meta-estimator .*can't have auto-propagated callbacks" + ) + with pytest.raises(TypeError, match=match): + meta_estimator.fit(X=None, y=None) + + +def test_eval_callbacks_on_fit_begin(): + """Check that _eval_callbacks_on_fit_begin creates and dumps the computation tree""" + estimator = Estimator()._set_callbacks(TestingCallback()) + assert not hasattr(estimator, "_computation_tree") + + levels = [ + {"descr": "fit", "max_iter": 10}, + {"descr": "iter", "max_iter": None}, + ] + ct_root = estimator._eval_callbacks_on_fit_begin(levels=levels) + assert hasattr(estimator, "_computation_tree") + assert ct_root is estimator._computation_tree.root + + ct_pickle = Path(estimator._computation_tree.tree_dir) / "computation_tree.pkl" + assert ct_pickle.exists() diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py index 1f5fcf6bdd3c4..a87cdbcbf3199 100644 --- a/sklearn/callback/tests/test_callbacks.py +++ b/sklearn/callback/tests/test_callbacks.py @@ -1,99 +1,24 @@ # License: BSD 3 clause -from functools import partial import pickle import pytest import tempfile -from time import sleep -from joblib import Parallel, delayed +import numpy as np -from sklearn.base import BaseEstimator, clone from sklearn.callback import Snapshot -from sklearn.callback._base import _eval_callbacks_on_fit_iter_end -from sklearn.datasets import make_classification +from sklearn.callback.tests._utils import Estimator +from sklearn.callback.tests._utils import MetaEstimator -class Estimator(BaseEstimator): - def __init__(self, max_iter=20): - self.max_iter = max_iter - - def fit(self, X, y): - root = self._eval_callbacks_on_fit_begin( - levels=[ - {"descr": "fit", "max_iter": self.max_iter}, - {"descr": "iter", "max_iter": None}, - ], - X=X, - y=y, - ) - - for i in range(self.max_iter): - if _eval_callbacks_on_fit_iter_end( - estimator=self, - node=root.children[i], - from_reconstruction_attributes=partial( - self._from_reconstruction_attributes, - reconstruction_attributes=lambda: {"n_iter_": i + 1}, - ), - ): - break - - self.n_iter_ = i + 1 - - self._eval_callbacks_on_fit_end() - - return self - - -class MetaEstimator(BaseEstimator): - def __init__( - self, estimator, n_outer=4, n_inner=3, n_jobs=None, prefer="processes" - ): - self.estimator = estimator - self.n_outer = n_outer - self.n_inner = n_inner - self.n_jobs = n_jobs - self.prefer = prefer - - def fit(self, X, y): - root = self._eval_callbacks_on_fit_begin( - levels=[ - {"descr": "fit", "max_iter": self.n_outer}, - {"descr": "outer", "max_iter": self.n_inner}, - {"descr": "inner", "max_iter": None}, - ], - X=X, - y=y, - ) - - res = Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( - delayed(self._func)(self.estimator, X, y, node, i) - for i, node in enumerate(root.children) - ) - - self._eval_callbacks_on_fit_end() - - return self - - def _func(self, estimator, X, y, parent_node, i): - for j, node in enumerate(parent_node.children): - est = clone(estimator) - self._propagate_callbacks(est, parent_node=node) - est.fit(X, y) - - _eval_callbacks_on_fit_iter_end(estimator=self, node=node) - - _eval_callbacks_on_fit_iter_end(estimator=self, node=parent_node) - - return +X = np.zeros((100, 3)) +y = np.zeros(100, dtype=int) @pytest.mark.parametrize("n_jobs", (1, 2)) @pytest.mark.parametrize("prefer", ("threads", "processes")) def test_snapshot_meta_estimator(n_jobs, prefer): - # Test for the Snapshot callback - X, y = make_classification() + """Test for the Snapshot callback""" estimator = Estimator(max_iter=20) with tempfile.TemporaryDirectory() as tmp_dir: @@ -122,3 +47,5 @@ def test_snapshot_meta_estimator(n_jobs, prefer): # We kept last 5 snapshots out of 20 iterations. # This one is the 16 + i-th. assert loaded_estimator.n_iter_ == 16 + i + + diff --git a/sklearn/callback/tests/test_computation_tree.py b/sklearn/callback/tests/test_computation_tree.py index b726177a342ec..902175b71a250 100644 --- a/sklearn/callback/tests/test_computation_tree.py +++ b/sklearn/callback/tests/test_computation_tree.py @@ -1,11 +1,8 @@ # License: BSD 3 clause import numpy as np -import pytest from sklearn.callback import ComputationTree -from sklearn.callback import ComputationNode -from sklearn.callback import load_computation_tree levels = [ @@ -17,7 +14,7 @@ def test_computation_tree(): - # Check the construction of the computation tree + """Check the construction of the computation tree""" computation_tree = ComputationTree(estimator_name="estimator", levels=levels) assert computation_tree.estimator_name == "estimator" @@ -41,8 +38,9 @@ def test_computation_tree(): def test_n_nodes(): - # Check that the number of node in a comutation tree corresponds to what we expect - # from the level descriptions + """Check that the number of node in a comutation tree corresponds to what we expect + from the level descriptions + """ computation_tree = ComputationTree(estimator_name="", levels=levels) max_iter_per_level = [level["max_iter"] for level in levels[:-1]] @@ -54,8 +52,9 @@ def test_n_nodes(): def test_tree_status_idx(): - # Check that each node has a unique index in the _tree_status array and that their - # order corresponds to the order given by a depth first search. + """Check that each node has a unique index in the _tree_status array and that their + order corresponds to the order given by a depth first search. + """ computation_tree = ComputationTree(estimator_name="", levels=levels) indexes = [ @@ -65,7 +64,7 @@ def test_tree_status_idx(): def test_get_ancestors(): - # Check that the ancestor search excludes the root and can propagate to parent trees + """Check the ancestor search and its propagation to parent trees""" parent_levels = [ {"descr": "parent_level0", "max_iter": 2}, {"descr": "parent_level1", "max_iter": 4}, @@ -76,23 +75,31 @@ def test_get_ancestors(): estimator_name="parent_estimator", levels=parent_levels ) parent_node = parent_computation_tree.root.children[0].children[2] + # indices of each node (in its parent children) in this chain are 0, 0, 2. + # (root is always 0). + expected_parent_indices = [2, 0, 0] computation_tree = ComputationTree( estimator_name="estimator", levels=levels, parent_node=parent_node ) node = computation_tree.root.children[1].children[3].children[5] + expected_node_indices = [5, 3, 1, 0] ancestors = node.get_ancestors(include_ancestor_trees=False) - assert ancestors == [node, node.parent, node.parent.parent] - assert [n.idx for n in ancestors] == [5, 3, 1] - assert computation_tree.root not in ancestors + assert ancestors == [ + node, node.parent, node.parent.parent, node.parent.parent.parent + ] + assert [n.idx for n in ancestors] == expected_node_indices + assert computation_tree.root in ancestors ancestors = node.get_ancestors(include_ancestor_trees=True) assert ancestors == [ node, node.parent, node.parent.parent, + node.parent.parent.parent, parent_node, parent_node.parent, + parent_node.parent.parent, ] - assert [n.idx for n in ancestors] == [5, 3, 1, 2, 0] + assert [n.idx for n in ancestors] == expected_node_indices + expected_parent_indices From 596a58ef39815701346c47769d2bb14ab7814da9 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 23 Feb 2022 17:57:49 +0100 Subject: [PATCH 10/55] cln --- sklearn/pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index c845f684a7945..433b9e4d57c56 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -344,8 +344,10 @@ def _fit(self, X, y=None, **fit_params_steps): cloned_transformer = transformer else: cloned_transformer = clone(transformer) - # Fit or load from cache the current transformer + self._propagate_callbacks(cloned_transformer, parent_node=node) + + # Fit or load from cache the current transformer X, fitted_transformer = fit_transform_one_cached( cloned_transformer, X, From 4f9363cf7ec622bab3cd320c9fcadc128ffcbb47 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 12 Sep 2022 18:29:10 +0200 Subject: [PATCH 11/55] wip --- sklearn/callback/_base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index a473f172fd575..0e11acd4f54ef 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -3,7 +3,8 @@ from abc import ABC, abstractmethod -# Not a method of BaseEstimator because it might be called from an extern function +# Not a method of BaseEstimator because it might not be directly called from fit but +# by a non-method function called by fit def _eval_callbacks_on_fit_iter_end(**kwargs): """Evaluate the on_fit_iter_end method of the callbacks @@ -54,6 +55,8 @@ class BaseCallback(ABC): def on_fit_begin(self, estimator, *, X=None, y=None): """Method called at the beginning of the fit method of the estimator + Only called + Parameters ---------- estimator: estimator instance @@ -105,6 +108,11 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) used by generic callbacks but by a callback designed for a specific estimator instead. + - extra_verbose: dict + Model specific . This is not meant to be + used by generic callbacks but by a callback designed for a specific + estimator instead. + Returns ------- stop : bool or None From 35c5284239faf6f962d8d3f66436889da0020291 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 16 Sep 2022 11:05:15 +0200 Subject: [PATCH 12/55] wip --- sklearn/base.py | 8 ++ sklearn/callback/_base.py | 11 ++ sklearn/callback/tests/_utils.py | 109 ++++++++++++++++++ .../test_base_estimator_callback_methods.py | 29 +++++ sklearn/callback/tests/test_callbacks.py | 17 +++ 5 files changed, 174 insertions(+) create mode 100644 sklearn/callback/tests/_utils.py diff --git a/sklearn/base.py b/sklearn/base.py index 11de1ecfb1fd2..e8938f1c134e8 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -10,6 +10,8 @@ import inspect import re import pickle +from shutil import rmtree +from functools import partial import numpy as np @@ -32,6 +34,7 @@ from .callback import BaseCallback from .callback import AutoPropagatedMixin from .callback import ComputationTree +from .callback._base import CallbackContext def clone(estimator, *, safe=True): @@ -678,6 +681,11 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): ) if hasattr(self, "_callbacks"): + # + #if self._computation_tree.parent_node is None: + CallbackContext(self._callbacks, finalizer=partial(rmtree, ignore_errors=True), finalizer_args=self._computation_tree.tree_dir) + + # file_path = self._computation_tree.tree_dir / "computation_tree.pkl" with open(file_path, "wb") as f: pickle.dump(self._computation_tree, f) diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 0e11acd4f54ef..ea0b28be5f937 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -1,6 +1,7 @@ # License: BSD 3 clause from abc import ABC, abstractmethod +import weakref # Not a method of BaseEstimator because it might not be directly called from fit but @@ -120,6 +121,9 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) """ pass + def _set_context(self, context): + self._callback_context = context + class AutoPropagatedMixin: """Mixin for auto-propagated callbacks @@ -132,3 +136,10 @@ class AutoPropagatedMixin: """ pass + + +class CallbackContext: + def __init__(self, callbacks, finalizer, finalizer_args): + for callback in callbacks: + callback._set_context(self) + weakref.finalize(self, finalizer, finalizer_args) diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py new file mode 100644 index 0000000000000..84e94fce16e7c --- /dev/null +++ b/sklearn/callback/tests/_utils.py @@ -0,0 +1,109 @@ +from functools import partial + +from joblib.parallel import Parallel, delayed + +from sklearn.base import BaseEstimator, clone +from sklearn.callback import BaseCallback +from sklearn.callback import AutoPropagatedMixin +from sklearn.callback._base import _eval_callbacks_on_fit_iter_end + + +class TestingCallback(BaseCallback): + def on_fit_begin(self, estimator, *, X=None, y=None): + pass + + def on_fit_end(self): + pass + + def on_fit_iter_end(self, estimator, node, **kwargs): + pass + + +class TestingAutoPropagatedCallback(TestingCallback, AutoPropagatedMixin): + pass + + +class NotValidCallback: + def on_fit_begin(self, estimator, *, X=None, y=None): + pass + + def on_fit_end(self): + pass + + def on_fit_iter_end(self, estimator, node, **kwargs): + pass + + +class Estimator(BaseEstimator): + def __init__(self, max_iter=20): + self.max_iter = max_iter + + def fit(self, X, y): + root = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.max_iter}, + {"descr": "iter", "max_iter": None}, + ], + X=X, + y=y, + ) + + for i in range(self.max_iter): + if _eval_callbacks_on_fit_iter_end( + estimator=self, + node=root.children[i], + from_reconstruction_attributes=partial( + self._from_reconstruction_attributes, + reconstruction_attributes=lambda: {"n_iter_": i + 1}, + ), + ): + break + + self.n_iter_ = i + 1 + + self._eval_callbacks_on_fit_end() + + return self + + +class MetaEstimator(BaseEstimator): + def __init__( + self, estimator, n_outer=4, n_inner=3, n_jobs=None, prefer="processes" + ): + self.estimator = estimator + self.n_outer = n_outer + self.n_inner = n_inner + self.n_jobs = n_jobs + self.prefer = prefer + + def fit(self, X, y): + root = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.n_outer}, + {"descr": "outer", "max_iter": self.n_inner}, + {"descr": "inner", "max_iter": None}, + ], + X=X, + y=y, + ) + + res = Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( + delayed(self._func)(self.estimator, X, y, node, i) + for i, node in enumerate(root.children) + ) + + self._eval_callbacks_on_fit_end() + + return self + + def _func(self, estimator, X, y, parent_node, i): + for j, node in enumerate(parent_node.children): + est = clone(estimator) + self._propagate_callbacks(est, parent_node=node) + est.fit(X, y) + + _eval_callbacks_on_fit_iter_end(estimator=self, node=node) + + _eval_callbacks_on_fit_iter_end(estimator=self, node=parent_node) + + return \ No newline at end of file diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index 676f0a5cfdd0e..ea750abbcf890 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -9,6 +9,8 @@ from sklearn.callback.tests._utils import Estimator from sklearn.callback.tests._utils import MetaEstimator +from sklearn.callback import ProgressBar + @pytest.mark.parametrize("callbacks", [ @@ -93,3 +95,30 @@ def test_eval_callbacks_on_fit_begin(): ct_pickle = Path(estimator._computation_tree.tree_dir) / "computation_tree.pkl" assert ct_pickle.exists() + + +def test_callback_context_finalize(): + """Check that the folder containing the computation tree of the estimator is + deleted when there are no reference left to its callbacks. + """ + callback = TestingCallback() + + # estimator is not fitted, its computation tree is not built yet + est = Estimator()._set_callbacks(callbacks=callback) + assert not hasattr(est, "_computation_tree") + + # estimator is fitted, a folder has been created to hold its computation tree + est.fit(X=None, y=None) + assert hasattr(est, "_computation_tree") + tree_dir = est._computation_tree.tree_dir + assert tree_dir.is_dir() + + # there is no more reference to the estimator, but there is still a reference to the + # callback which might need to access the computation tree + del est + assert tree_dir.is_dir() + + # there is no more reference to the callback, the computation tree folder must be + # deleted + del callback + assert not tree_dir.is_dir() diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py index a87cdbcbf3199..bd76325e0af28 100644 --- a/sklearn/callback/tests/test_callbacks.py +++ b/sklearn/callback/tests/test_callbacks.py @@ -6,7 +6,11 @@ import numpy as np +from sklearn.callback import ConvergenceMonitor +from sklearn.callback import EarlyStopping +from sklearn.callback import ProgressBar from sklearn.callback import Snapshot +from sklearn.callback import TextVerbose from sklearn.callback.tests._utils import Estimator from sklearn.callback.tests._utils import MetaEstimator @@ -15,6 +19,19 @@ y = np.zeros(100, dtype=int) +@pytest.mark.parametrize("Callback", [ConvergenceMonitor, EarlyStopping, ProgressBar, Snapshot, TextVerbose,]) +def test_callback_doesnt_hold_ref_to_estimator(Callback): + callback = Callback() + est = Estimator()._set_callbacks(callbacks=callback) + est.fit(X, y) + + tree_dir = est._computation_tree.tree_dir + + del est + del callback + assert not tree_dir.is_dir() + + @pytest.mark.parametrize("n_jobs", (1, 2)) @pytest.mark.parametrize("prefer", ("threads", "processes")) def test_snapshot_meta_estimator(n_jobs, prefer): From 115e1840122e542fc940c2dbec056273e6440965 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 16 Sep 2022 17:48:42 +0200 Subject: [PATCH 13/55] wip --- sklearn/base.py | 15 ++++++++------- sklearn/callback/_base.py | 5 ++++- sklearn/callback/_progressbar.py | 4 ++-- sklearn/callback/tests/test_callbacks.py | 16 +++++++++------- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index e8938f1c134e8..14da63c1b9cb2 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -3,6 +3,7 @@ # Author: Gael Varoquaux # License: BSD 3 clause +from codecs import ignore_errors import copy import warnings from collections import defaultdict @@ -645,10 +646,11 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): sub_estimator._parent_ct_node = parent_node - if not hasattr(sub_estimator, "_callbacks"): - sub_estimator._callbacks = propagated_callbacks - else: - sub_estimator._callbacks.extend(propagated_callbacks) + # if not hasattr(sub_estimator, "_callbacks"): + # sub_estimator._callbacks = propagated_callbacks + # else: + # sub_estimator._callbacks.extend(propagated_callbacks) + sub_estimator._set_callbacks(getattr(sub_estimator, "_callbacks", []) + propagated_callbacks) def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): """Evaluate the on_fit_begin method of the callbacks @@ -681,11 +683,10 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): ) if hasattr(self, "_callbacks"): - # - #if self._computation_tree.parent_node is None: + # CallbackContext(self._callbacks, finalizer=partial(rmtree, ignore_errors=True), finalizer_args=self._computation_tree.tree_dir) - # + # file_path = self._computation_tree.tree_dir / "computation_tree.pkl" with open(file_path, "wb") as f: pickle.dump(self._computation_tree, f) diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index ea0b28be5f937..045065801cbbd 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -122,7 +122,10 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) pass def _set_context(self, context): - self._callback_context = context + if not hasattr(self, "_callback_contexts"): + self._callback_contexts = [] + + self._callback_contexts.append(context) class AutoPropagatedMixin: diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index ae11e67d59f57..fd7201de5c918 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -105,7 +105,7 @@ class _RichProgressMonitor(Thread): def __init__(self, estimator, event, max_depth_show=None, max_depth_keep=None): Thread.__init__(self) - self.estimator = estimator + self.computation_tree = estimator._computation_tree self.event = event self.max_depth_show = max_depth_show self.max_depth_keep = max_depth_keep @@ -151,7 +151,7 @@ def _recursive_update_tasks(self, this_dir=None, depth=0): return if this_dir is None: - this_dir = self.estimator._computation_tree.tree_dir + this_dir = self.computation_tree.tree_dir # _ordered_tasks holds the list of the tasks in the order we want them to # be displayed. self._progress_ctx._ordered_tasks = [] diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py index bd76325e0af28..2a457d354077e 100644 --- a/sklearn/callback/tests/test_callbacks.py +++ b/sklearn/callback/tests/test_callbacks.py @@ -2,6 +2,7 @@ import pickle import pytest +import sys import tempfile import numpy as np @@ -22,14 +23,15 @@ @pytest.mark.parametrize("Callback", [ConvergenceMonitor, EarlyStopping, ProgressBar, Snapshot, TextVerbose,]) def test_callback_doesnt_hold_ref_to_estimator(Callback): callback = Callback() - est = Estimator()._set_callbacks(callbacks=callback) + est = Estimator() + callback_refcount = sys.getrefcount(callback) + est_refcount = sys.getrefcount(est) + + est._set_callbacks(callbacks=callback) est.fit(X, y) - - tree_dir = est._computation_tree.tree_dir - - del est - del callback - assert not tree_dir.is_dir() + # estimator has a ref on the callback but the callback has no ref to the estimator + assert sys.getrefcount(est) == est_refcount + assert sys.getrefcount(callback) == callback_refcount + 1 @pytest.mark.parametrize("n_jobs", (1, 2)) From bdb49901b778faf670e08a8025e8e0d727466608 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 21 Sep 2022 12:11:49 +0200 Subject: [PATCH 14/55] wip --- sklearn/base.py | 7 +++++++ sklearn/callback/_base.py | 21 +++++++++++++++++++++ sklearn/callback/_progressbar.py | 3 +++ sklearn/decomposition/_nmf.py | 5 +++++ sklearn/pipeline.py | 2 ++ 5 files changed, 38 insertions(+) diff --git a/sklearn/base.py b/sklearn/base.py index 14da63c1b9cb2..78dfa06178bc7 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -718,6 +718,13 @@ def _eval_callbacks_on_fit_end(self): if not is_propagated: callback.on_fit_end() + def _eval_callbacks_on_fit_exception(self): + if not hasattr(self, "_callbacks"): + return + + for callback in self._callbacks: + callback.on_fit_exception() + def _from_reconstruction_attributes(self, *, reconstruction_attributes): """Return an as if fitted copy of this estimator diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 045065801cbbd..c78fb3c773b61 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -1,6 +1,7 @@ # License: BSD 3 clause from abc import ABC, abstractmethod +from functools import wraps import weakref @@ -121,6 +122,10 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) """ pass + @abstractmethod + def on_fit_exception(self): + pass + def _set_context(self, context): if not hasattr(self, "_callback_contexts"): self._callback_contexts = [] @@ -146,3 +151,19 @@ def __init__(self, callbacks, finalizer, finalizer_args): for callback in callbacks: callback._set_context(self) weakref.finalize(self, finalizer, finalizer_args) + + +def callback_aware(fit_method): + """Decorator ... + """ + @wraps(fit_method) + def inner(self, *args, **kwargs): + try: + return fit_method(self, *args, **kwargs) + except BaseException: + self._eval_callbacks_on_fit_exception() + raise + finally: + self._eval_callbacks_on_fit_end() + + return inner diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index fd7201de5c918..713ab995e169a 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -62,6 +62,9 @@ def on_fit_end(self): self._stop_event.set() self.progress_monitor.join() + def on_fit_exception(self): + pass + def __getstate__(self): state = self.__dict__.copy() if "_stop_event" in state: diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index f54fd2d18e690..759946afc8c5a 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -9,6 +9,7 @@ from abc import ABC from functools import partial from numbers import Integral, Real +from subprocess import call import numpy as np import scipy.sparse as sp import time @@ -34,6 +35,7 @@ validate_params, ) from ..callback._base import _eval_callbacks_on_fit_iter_end +from ..callback._base import callback_aware EPSILON = np.finfo(np.float32).eps @@ -869,6 +871,8 @@ def _fit_multiplicative_update( H_sum, HHt, XHt = None, None, None for n_iter in range(1, max_iter + 1): + if n_iter == 30: + raise ValueError("eh ouais") # update W # H_sum, HHt and XHt are saved and reused if not update_H W, H_sum, HHt, XHt = _multiplicative_update_w( @@ -1726,6 +1730,7 @@ def _check_params(self, X): return self + @callback_aware def fit_transform(self, X, y=None, W=None, H=None): """Learn a NMF model for the data X and returns the transformed data. diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index b93d412020bef..81f4500726ff3 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -32,6 +32,7 @@ from .utils.fixes import delayed from .exceptions import NotFittedError from .callback._base import _eval_callbacks_on_fit_iter_end +from .callback._base import callback_aware __all__ = ["Pipeline", "FeatureUnion", "make_pipeline", "make_union"] @@ -366,6 +367,7 @@ def _fit(self, X, y=None, **fit_params_steps): return X + @callback_aware def fit(self, X, y=None, **fit_params): """Fit the model. From 7a43c306b6f57b847b3f7348905c47d2786757fc Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 23 Sep 2022 18:15:57 +0200 Subject: [PATCH 15/55] wip --- sklearn/base.py | 26 +++------ sklearn/callback/__init__.py | 2 - sklearn/callback/_base.py | 49 +++++++++------- sklearn/callback/_computation_tree.py | 5 ++ sklearn/callback/_progressbar.py | 84 +++++++++++++++++++++------ sklearn/callback/_snapshot.py | 2 - sklearn/callback/_text_verbose.py | 5 +- sklearn/callback/tests/_utils.py | 6 +- sklearn/decomposition/_nmf.py | 2 - 9 files changed, 112 insertions(+), 69 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 78dfa06178bc7..9b4e659d8647a 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -33,7 +33,6 @@ from .utils._estimator_html_repr import estimator_html_repr from .utils._param_validation import validate_parameter_constraints from .callback import BaseCallback -from .callback import AutoPropagatedMixin from .callback import ComputationTree from .callback._base import CallbackContext @@ -617,13 +616,12 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): method of the sub-estimator is called. """ if hasattr(sub_estimator, "_callbacks") and any( - isinstance(callback, AutoPropagatedMixin) - for callback in sub_estimator._callbacks + callback.auto_propagate for callback in sub_estimator._callbacks ): bad_callbacks = [ callback.__class__.__name__ for callback in sub_estimator._callbacks - if isinstance(callback, AutoPropagatedMixin) + if callback.auto_propagate ] raise TypeError( f"The sub-estimators ({sub_estimator.__class__.__name__}) of a" @@ -638,7 +636,7 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): propagated_callbacks = [ callback for callback in self._callbacks - if isinstance(callback, AutoPropagatedMixin) + if callback.auto_propagate ] if not propagated_callbacks: @@ -646,11 +644,9 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): sub_estimator._parent_ct_node = parent_node - # if not hasattr(sub_estimator, "_callbacks"): - # sub_estimator._callbacks = propagated_callbacks - # else: - # sub_estimator._callbacks.extend(propagated_callbacks) - sub_estimator._set_callbacks(getattr(sub_estimator, "_callbacks", []) + propagated_callbacks) + sub_estimator._set_callbacks( + getattr(sub_estimator, "_callbacks", []) + propagated_callbacks + ) def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): """Evaluate the on_fit_begin method of the callbacks @@ -694,10 +690,7 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): # Only call the on_fit_begin method of callbacks that are not # propagated from a meta-estimator. for callback in self._callbacks: - is_propagated = hasattr(self, "_parent_ct_node") and isinstance( - callback, AutoPropagatedMixin - ) - if not is_propagated: + if not callback._is_propagated(estimator=self): callback.on_fit_begin(estimator=self, X=X, y=y) return self._computation_tree.root @@ -712,10 +705,7 @@ def _eval_callbacks_on_fit_end(self): # Only call the on_fit_end method of callbacks that are not # propagated from a meta-estimator. for callback in self._callbacks: - is_propagated = isinstance(callback, AutoPropagatedMixin) and hasattr( - self, "_parent_ct_node" - ) - if not is_propagated: + if not callback._is_propagated(estimator=self): callback.on_fit_end() def _eval_callbacks_on_fit_exception(self): diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py index c8d5ea0bf0606..9767411b6c934 100644 --- a/sklearn/callback/__init__.py +++ b/sklearn/callback/__init__.py @@ -1,6 +1,5 @@ # License: BSD 3 clause -from ._base import AutoPropagatedMixin from ._base import BaseCallback from ._computation_tree import ComputationNode from ._computation_tree import ComputationTree @@ -12,7 +11,6 @@ from ._text_verbose import TextVerbose __all__ = [ - "AutoPropagatedMixin", "BaseCallback", "ComputationNode", "ComputationTree", diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index c78fb3c773b61..65b24dc85e9bb 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -33,17 +33,11 @@ def _eval_callbacks_on_fit_iter_end(**kwargs): # stopping_criterion and reconstruction_attributes can be costly to compute. They # are passed as lambdas for lazy evaluation. We only actually compute them if a # callback requests it. - if any( - getattr(callback, "request_stopping_criterion", False) - for callback in estimator._callbacks - ): + if any(cb.request_stopping_criterion for cb in estimator._callbacks): kwarg = kwargs.pop("stopping_criterion", lambda: None)() kwargs["stopping_criterion"] = kwarg - if any( - getattr(callback, "request_from_reconstruction_attributes", False) - for callback in estimator._callbacks - ): + if any(cb.request_from_reconstruction_attributes for cb in estimator._callbacks): kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() kwargs["from_reconstruction_attributes"] = kwarg @@ -126,6 +120,32 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) def on_fit_exception(self): pass + @property + def auto_propagate(self): + """Whether or not this callback should be propagated to sub-estimators. + + An auto-propagated callback (from a meta-estimator to its sub-estimators) must + be set on the meta-estimator. Its `on_fit_begin` and `on_fit_end` methods will + only be called at the beginning and end of the fit method of the meta-estimator, + while its `on_fit_iter_end` method will be called at each computation node of + the meta-estimator and its sub-estimators. + """ + return False + + def _is_propagated(self, estimator): + """Check if this callback attached to estimator has been propagated from a + meta-estimator. + """ + return self.auto_propagate and hasattr(estimator, "_parent_ct_node") + + @property + def request_stopping_criterion(self): + return False + + @property + def request_from_reconstruction_attributes(self): + return False + def _set_context(self, context): if not hasattr(self, "_callback_contexts"): self._callback_contexts = [] @@ -133,19 +153,6 @@ def _set_context(self, context): self._callback_contexts.append(context) -class AutoPropagatedMixin: - """Mixin for auto-propagated callbacks - - An auto-propagated callback (from a meta-estimator to its sub-estimators) must be - set on the meta-estimator. Its `on_fit_begin` and `on_fit_end` methods will only be - called at the beginning and end of the fit method of the meta-estimator, while its - `on_fit_iter_end` method will be called at each computation node of the - meta-estimator and its sub-estimators. - """ - - pass - - class CallbackContext: def __init__(self, callbacks, finalizer, finalizer_args): for callback in callbacks: diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py index a69a8788e26c5..a6eb739580446 100644 --- a/sklearn/callback/_computation_tree.py +++ b/sklearn/callback/_computation_tree.py @@ -208,6 +208,11 @@ def get_progress(self, node): [self._tree_status[child.tree_status_idx] for child in node.children] ) + def get_child_computation_tree_dir(self, node): + if node.children: + raise ValueError("node is not a leaf") + return self.tree_dir / str(node.tree_status_idx) + def iterate(self, include_leaves=False): """Return an iterable over the nodes of the computation tree diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index 713ab995e169a..1de13c87f2a8f 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -1,21 +1,30 @@ # License: BSD 3 clause -from copy import copy -import pickle +import importlib from threading import Thread, Event -import numpy as np -from tqdm import tqdm -from rich.progress import Progress -from rich.progress import BarColumn, TimeRemainingColumn, TextColumn -from rich.style import Style - from . import BaseCallback -from . import AutoPropagatedMixin from . import load_computation_tree -class ProgressBar(BaseCallback, AutoPropagatedMixin): +def _check_backend_support(backend, caller_name): + """Raise ImportError with detailed error message if backend is not installed. + + Parameters + ---------- + backend : {"rich", "tqdm"} + The requested backend. + + caller_name : str + The name of the caller that requires the backend. + """ + try: + importlib.import_module(backend) # noqa + except ImportError as e: + raise ImportError(f"{caller_name} requires {backend} installed.") from e + + +class ProgressBar(BaseCallback): """Callback that displays progress bars for each iterative steps of the estimator Parameters @@ -31,13 +40,20 @@ class ProgressBar(BaseCallback, AutoPropagatedMixin): finished. """ + auto_propagate = True + def __init__(self, backend="rich", max_depth_show=None, max_depth_keep=None): + if backend not in ("rich", "tqdm"): + raise ValueError(f"backend should be 'rich' or 'tqdm', got {self.backend} instead.") + _check_backend_support(backend, caller_name="Progressbar") self.backend = backend + if max_depth_show is not None and max_depth_show < 0: raise ValueError(f"max_depth_show should be >= 0.") + self.max_depth_show = max_depth_show + if max_depth_keep is not None and max_depth_keep < 0: raise ValueError(f"max_depth_keep should be >= 0.") - self.max_depth_show = max_depth_show self.max_depth_keep = max_depth_keep def on_fit_begin(self, estimator, X=None, y=None): @@ -50,8 +66,11 @@ def on_fit_begin(self, estimator, X=None, y=None): max_depth_show=self.max_depth_show, max_depth_keep=self.max_depth_keep, ) - else: - raise ValueError(f"backend should be 'rich', got {self.backend} instead.") + elif self.backend == "tqdm": + self.progress_monitor = _TqdmProgressMonitor( + estimator=estimator, + event=self._stop_event, + ) self.progress_monitor.start() @@ -77,10 +96,15 @@ def __getstate__(self): # Custom Progress class to allow showing the tasks in a given order (given by setting # the _ordered_tasks attribute). In particular it allows to dynamically create and # insert tasks between existing tasks. -class _Progress(Progress): - def get_renderables(self): - table = self.make_tasks_table(getattr(self, "_ordered_tasks", [])) - yield table + +try: + from rich.progress import Progress + class _Progress(Progress): + def get_renderables(self): + table = self.make_tasks_table(getattr(self, "_ordered_tasks", [])) + yield table +except: + pass class _RichProgressMonitor(Thread): @@ -119,6 +143,9 @@ def __init__(self, estimator, event, max_depth_show=None, max_depth_keep=None): self._computation_trees = {} def run(self): + from rich.progress import BarColumn, TimeRemainingColumn, TextColumn + from rich.style import Style + with _Progress( TextColumn("[progress.description]{task.description}"), BarColumn( @@ -218,7 +245,8 @@ def _recursive_update_tasks(self, this_dir=None, depth=0): else: # node is a leaf, look for tasks of its sub computation tree before # going to the next node - child_dir = this_dir / str(node.tree_status_idx) + child_dir = computation_tree.get_child_computation_tree_dir(node) + # child_dir = this_dir / str(node.tree_status_idx) if child_dir.exists(): self._recursive_update_tasks( child_dir, depth + computation_tree.depth @@ -258,3 +286,23 @@ def _get_parent_task(self, node, computation_tree, task_ids): ] return self._progress_ctx._tasks[task_id] return + + +class _TqdmProgressMonitor(Thread): + def __init__(self, estimator, event): + Thread.__init__(self) + self.computation_tree = estimator._computation_tree + self.event = event + + def run(self): + from tqdm import tqdm + + root = self.computation_tree.root + + with tqdm(total=len(root.children)) as pbar: + while not self.event.wait(0.05): + node_progress = self.computation_tree.get_progress(root) + if node_progress != pbar.total: + pbar.update(node_progress - pbar.n) + + pbar.update(pbar.total - pbar.n) diff --git a/sklearn/callback/_snapshot.py b/sklearn/callback/_snapshot.py index cbf200336c749..238bc29cf8543 100644 --- a/sklearn/callback/_snapshot.py +++ b/sklearn/callback/_snapshot.py @@ -4,8 +4,6 @@ from pathlib import Path import pickle -import numpy as np - from . import BaseCallback diff --git a/sklearn/callback/_text_verbose.py b/sklearn/callback/_text_verbose.py index b857ff592c87c..0064ec97f2052 100644 --- a/sklearn/callback/_text_verbose.py +++ b/sklearn/callback/_text_verbose.py @@ -3,10 +3,11 @@ import time from . import BaseCallback -from . import AutoPropagatedMixin -class TextVerbose(BaseCallback, AutoPropagatedMixin): +class TextVerbose(BaseCallback): + + auto_propagate = True request_stopping_criterion = True def __init__(self, min_time_between_calls=0): diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 84e94fce16e7c..888a5649c19bb 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -4,7 +4,6 @@ from sklearn.base import BaseEstimator, clone from sklearn.callback import BaseCallback -from sklearn.callback import AutoPropagatedMixin from sklearn.callback._base import _eval_callbacks_on_fit_iter_end @@ -19,9 +18,8 @@ def on_fit_iter_end(self, estimator, node, **kwargs): pass -class TestingAutoPropagatedCallback(TestingCallback, AutoPropagatedMixin): - pass - +class TestingAutoPropagatedCallback(TestingCallback): + auto_propagate = True class NotValidCallback: def on_fit_begin(self, estimator, *, X=None, y=None): diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index 7d0ae56f09c31..75185df49de5f 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -847,8 +847,6 @@ def _fit_multiplicative_update( H_sum, HHt, XHt = None, None, None for n_iter in range(1, max_iter + 1): - if n_iter == 30: - raise ValueError("eh ouais") # update W # H_sum, HHt and XHt are saved and reused if not update_H W, H_sum, HHt, XHt = _multiplicative_update_w( From a218068ef0a6627e1e1436b6f7fd6c02186b1dd8 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 13 Oct 2022 09:55:20 +0200 Subject: [PATCH 16/55] wip --- sklearn/base.py | 7 ------- sklearn/callback/_base.py | 7 ------- sklearn/callback/_early_stopping.py | 8 +++++++- sklearn/callback/_progressbar.py | 5 +---- .../tests/test_base_estimator_callback_methods.py | 2 -- sklearn/callback/tests/test_callbacks.py | 2 -- 6 files changed, 8 insertions(+), 23 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 9b4e659d8647a..687c1a9954ab8 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -708,13 +708,6 @@ def _eval_callbacks_on_fit_end(self): if not callback._is_propagated(estimator=self): callback.on_fit_end() - def _eval_callbacks_on_fit_exception(self): - if not hasattr(self, "_callbacks"): - return - - for callback in self._callbacks: - callback.on_fit_exception() - def _from_reconstruction_attributes(self, *, reconstruction_attributes): """Return an as if fitted copy of this estimator diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 65b24dc85e9bb..96cc1619651dd 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -116,10 +116,6 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) """ pass - @abstractmethod - def on_fit_exception(self): - pass - @property def auto_propagate(self): """Whether or not this callback should be propagated to sub-estimators. @@ -167,9 +163,6 @@ def callback_aware(fit_method): def inner(self, *args, **kwargs): try: return fit_method(self, *args, **kwargs) - except BaseException: - self._eval_callbacks_on_fit_exception() - raise finally: self._eval_callbacks_on_fit_end() diff --git a/sklearn/callback/_early_stopping.py b/sklearn/callback/_early_stopping.py index 44a0108e04b26..dc45da6379a52 100644 --- a/sklearn/callback/_early_stopping.py +++ b/sklearn/callback/_early_stopping.py @@ -1,9 +1,13 @@ # License: BSD 3 clause +from urllib import request from . import BaseCallback class EarlyStopping(BaseCallback): + + request_from_reconstruction_attributes = True + def __init__( self, X_val=None, @@ -23,7 +27,9 @@ def on_fit_begin(self, estimator, X=None, y=None): self._no_improvement = {} self._last_monitored = {} - def on_fit_iter_end(self, *, node, **kwargs): + def on_fit_iter_end(self, *, estimator, node, **kwargs): + new_estimator = kwargs.get("from_reconstruction_attributes", None) + if node.depth != self.estimator._computation_tree.depth: return diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index 1de13c87f2a8f..bd371bc1c3a7c 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -29,7 +29,7 @@ class ProgressBar(BaseCallback): Parameters ---------- - backend: {"rich"}, default="rich" + backend: {"rich", "tqdm"}, default="rich" The backend for the progress bars display. max_depth_show : int, default=None @@ -81,9 +81,6 @@ def on_fit_end(self): self._stop_event.set() self.progress_monitor.join() - def on_fit_exception(self): - pass - def __getstate__(self): state = self.__dict__.copy() if "_stop_event" in state: diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index ea750abbcf890..c77d88b68ce3d 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -9,8 +9,6 @@ from sklearn.callback.tests._utils import Estimator from sklearn.callback.tests._utils import MetaEstimator -from sklearn.callback import ProgressBar - @pytest.mark.parametrize("callbacks", [ diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py index 2a457d354077e..fb99003eb3b09 100644 --- a/sklearn/callback/tests/test_callbacks.py +++ b/sklearn/callback/tests/test_callbacks.py @@ -66,5 +66,3 @@ def test_snapshot_meta_estimator(n_jobs, prefer): # We kept last 5 snapshots out of 20 iterations. # This one is the 16 + i-th. assert loaded_estimator.n_iter_ == 16 + i - - From f794694ce9fea32213503d926a314511bade5e8e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 13 Oct 2022 16:27:27 +0200 Subject: [PATCH 17/55] update poor_score --- doc/developers/develop.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 0e4b8258476da..3329649d20513 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -553,8 +553,9 @@ preserves_dtype (default=``[np.float64]``) poor_score (default=False) whether the estimator fails to provide a "reasonable" test-set score, which - currently for regression is an R2 of 0.5 on a subset of the boston housing - dataset, and for classification an accuracy of 0.83 on + currently for regression is an R2 of 0.5 on ``make_regression(n_samples=200, + n_features=10, n_informative=1, bias=5.0, noise=20, random_state=42)``, and + for classification an accuracy of 0.83 on ``make_blobs(n_samples=300, random_state=0)``. These datasets and values are based on current estimators in sklearn and might be replaced by something more systematic. From 37e569b13a7202aa79d7b5aa1b8d30de323139ca Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 21 Jun 2023 10:24:55 +0200 Subject: [PATCH 18/55] wip --- sklearn/base.py | 60 +++++++++----- sklearn/callback/_base.py | 27 +++---- sklearn/callback/_convergence_monitor.py | 11 ++- sklearn/callback/_early_stopping.py | 65 +++++++++++----- sklearn/callback/_progressbar.py | 6 +- sklearn/callback/_text_verbose.py | 1 - sklearn/callback/tests/_utils.py | 13 +++- .../test_base_estimator_callback_methods.py | 7 +- sklearn/callback/tests/test_callbacks.py | 13 +++- .../callback/tests/test_computation_tree.py | 5 +- sklearn/decomposition/_nmf.py | 78 ++++++++++++------- sklearn/linear_model/_logistic.py | 8 +- sklearn/model_selection/_search.py | 17 +++- sklearn/model_selection/_validation.py | 2 +- sklearn/pipeline.py | 4 +- 15 files changed, 205 insertions(+), 112 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index feb0fb4e31a57..9c802b536f89d 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -3,7 +3,6 @@ # Author: Gael Varoquaux # License: BSD 3 clause -from codecs import ignore_errors import copy import functools import warnings @@ -659,7 +658,7 @@ def _set_callbacks(self, callbacks): Returns ------- self : estimator instance - The estimator instance itself. + The estimator instance itself. """ if not isinstance(callbacks, list): callbacks = [callbacks] @@ -705,9 +704,7 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): return propagated_callbacks = [ - callback - for callback in self._callbacks - if callback.auto_propagate + callback for callback in self._callbacks if callback.auto_propagate ] if not propagated_callbacks: @@ -749,28 +746,50 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): parent_node=getattr(self, "_parent_ct_node", None), ) - if hasattr(self, "_callbacks"): - # - CallbackContext(self._callbacks, finalizer=partial(rmtree, ignore_errors=True), finalizer_args=self._computation_tree.tree_dir) + if not hasattr(self, "_callbacks"): + return self._computation_tree.root, None, None, None, None + + X_val, y_val = None, None - # - file_path = self._computation_tree.tree_dir / "computation_tree.pkl" - with open(file_path, "wb") as f: - pickle.dump(self._computation_tree, f) + if any(callback.request_validation_split for callback in self._callbacks): + splitter = next( + callback.validation_split for callback in self._callbacks if hasattr(callback, "validation_split") + ) - # Only call the on_fit_begin method of callbacks that are not - # propagated from a meta-estimator. - for callback in self._callbacks: - if not callback._is_propagated(estimator=self): - callback.on_fit_begin(estimator=self, X=X, y=y) + train, val = next(splitter.split(X)) + if X is not None: + X, X_val = X[train], X[val] + if y is not None: + y, y_val = y[train], y[val] + + # + CallbackContext( + self._callbacks, + finalizer=partial(rmtree, ignore_errors=True), + finalizer_args=self._computation_tree.tree_dir, + ) - return self._computation_tree.root + # + file_path = self._computation_tree.tree_dir / "computation_tree.pkl" + with open(file_path, "wb") as f: + pickle.dump(self._computation_tree, f) + + # Only call the on_fit_begin method of callbacks that are not + # propagated from a meta-estimator. + for callback in self._callbacks: + if not callback._is_propagated(estimator=self): + callback.on_fit_begin(estimator=self, X=X, y=y) + + return self._computation_tree.root, X, y, X_val, y_val def _eval_callbacks_on_fit_end(self): """Evaluate the on_fit_end method of the callbacks""" if not hasattr(self, "_callbacks"): return + if not hasattr(self, "_computation_tree"): + return + self._computation_tree._tree_status[0] = True # Only call the on_fit_end method of callbacks that are not @@ -1309,7 +1328,10 @@ def wrapper(estimator, *args, **kwargs): prefer_skip_nested_validation or global_skip_validation ) ): - return fit_method(estimator, *args, **kwargs) + try: + return fit_method(estimator, *args, **kwargs) + finally: + estimator._eval_callbacks_on_fit_end() return wrapper diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 96cc1619651dd..a07115f7e4e0c 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -30,9 +30,9 @@ def _eval_callbacks_on_fit_iter_end(**kwargs): estimator._computation_tree._tree_status[node.tree_status_idx] = True - # stopping_criterion and reconstruction_attributes can be costly to compute. They - # are passed as lambdas for lazy evaluation. We only actually compute them if a - # callback requests it. + # stopping_criterion and reconstruction_attributes can be costly to compute. + # They are passed as lambdas for lazy evaluation. We only actually + # compute them if a callback requests it. if any(cb.request_stopping_criterion for cb in estimator._callbacks): kwarg = kwargs.pop("stopping_criterion", lambda: None)() kwargs["stopping_criterion"] = kwarg @@ -51,7 +51,7 @@ class BaseCallback(ABC): def on_fit_begin(self, estimator, *, X=None, y=None): """Method called at the beginning of the fit method of the estimator - Only called + Only called Parameters ---------- @@ -141,11 +141,15 @@ def request_stopping_criterion(self): @property def request_from_reconstruction_attributes(self): return False + + @property + def request_validation_split(self): + return False def _set_context(self, context): if not hasattr(self, "_callback_contexts"): self._callback_contexts = [] - + self._callback_contexts.append(context) @@ -154,16 +158,3 @@ def __init__(self, callbacks, finalizer, finalizer_args): for callback in callbacks: callback._set_context(self) weakref.finalize(self, finalizer, finalizer_args) - - -def callback_aware(fit_method): - """Decorator ... - """ - @wraps(fit_method) - def inner(self, *args, **kwargs): - try: - return fit_method(self, *args, **kwargs) - finally: - self._eval_callbacks_on_fit_end() - - return inner diff --git a/sklearn/callback/_convergence_monitor.py b/sklearn/callback/_convergence_monitor.py index ac04335e04661..98fec496d6eb7 100644 --- a/sklearn/callback/_convergence_monitor.py +++ b/sklearn/callback/_convergence_monitor.py @@ -33,7 +33,13 @@ class ConvergenceMonitor(BaseCallback): request_reconstruction_attributes = True - def __init__(self, *, monitor="objective_function", X_val=None, y_val=None): + def __init__( + self, + *, + monitor="objective_function", + on="val", + higher_is_better=False, + ): if monitor == "objective_function": self._monitor = "objective_function" else: @@ -41,9 +47,6 @@ def __init__(self, *, monitor="objective_function", X_val=None, y_val=None): if self._monitor is None: raise ValueError(f"unknown metric {monitor}") - self.X_val = X_val - self.y_val = y_val - self._data_file = Path(mkdtemp()) / "convergence_monitor.csv" def on_fit_begin(self, estimator, *, X=None, y=None): diff --git a/sklearn/callback/_early_stopping.py b/sklearn/callback/_early_stopping.py index dc45da6379a52..6d408dda8c960 100644 --- a/sklearn/callback/_early_stopping.py +++ b/sklearn/callback/_early_stopping.py @@ -1,54 +1,77 @@ # License: BSD 3 clause -from urllib import request from . import BaseCallback class EarlyStopping(BaseCallback): - request_from_reconstruction_attributes = True def __init__( self, - X_val=None, - y_val=None, monitor="objective_function", + on="validation_set", + higher_is_better=False, + validation_split="auto", max_no_improvement=10, - tol=1e-2, + threshold=1e-2, ): - self.X_val = X_val - self.y_val = y_val + from ..model_selection import KFold + self.validation_split = validation_split + if validation_split == "auto": + self.validation_split = KFold(n_splits=5, shuffle=True, random_state=42) self.monitor = monitor + self.on = on + self.higher_is_better = higher_is_better self.max_no_improvement = max_no_improvement - self.tol = tol + self.threshold = threshold def on_fit_begin(self, estimator, X=None, y=None): - self.estimator = estimator self._no_improvement = {} self._last_monitored = {} + self.early_stopped_ = None def on_fit_iter_end(self, *, estimator, node, **kwargs): - new_estimator = kwargs.get("from_reconstruction_attributes", None) - - if node.depth != self.estimator._computation_tree.depth: + if node.depth != estimator._computation_tree.depth: return + reconstructed_estimator = kwargs.pop("from_reconstruction_attributes") + data = kwargs.pop("data") + + X = data["X_val"] if self.on == "validation_set" else data["X"] + y = data["y_val"] if self.on == "validation_set" else data["y"] + if self.monitor == "objective_function": - objective_function = kwargs.get("objective_function", None) - monitored, *_ = objective_function(self.X_val) - elif self.monitor == "TODO": - pass - - if node.parent not in self._last_monitored or monitored < self._last_monitored[ - node.parent - ] * (1 - self.tol): + new_monitored, *_ = reconstructed_estimator.objective_function(X, y, normalize=True) + elif callable(self.monitor): + new_monitored = self.monitor(reconstructed_estimator, X, y) + elif self.monitor is None or isinstance(self.monitor, str): + from ..metrics import check_scoring + scorer = check_scoring(estimator, self.monitor) + new_monitored = scorer(estimator, X, y) + + if self._score_improved(node, new_monitored): self._no_improvement[node.parent] = 0 - self._last_monitored[node.parent] = monitored + self._last_monitored[node.parent] = new_monitored else: self._no_improvement[node.parent] += 1 if self._no_improvement[node.parent] >= self.max_no_improvement: + self.early_stopped_ = node.idx + return True + + def _score_improved(self, node, new_monitored): + if node.parent not in self._last_monitored: return True + + last_monitored = self._last_monitored[node.parent] + if self.higher_is_better: + return new_monitored > last_monitored * (1 + self.threshold) + else: + return new_monitored < last_monitored * (1 - self.threshold) def on_fit_end(self): pass + + @property + def request_validation_split(self): + return self.on == "val" diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index bd371bc1c3a7c..738e8f897ce4a 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -44,7 +44,9 @@ class ProgressBar(BaseCallback): def __init__(self, backend="rich", max_depth_show=None, max_depth_keep=None): if backend not in ("rich", "tqdm"): - raise ValueError(f"backend should be 'rich' or 'tqdm', got {self.backend} instead.") + raise ValueError( + f"backend should be 'rich' or 'tqdm', got {self.backend} instead." + ) _check_backend_support(backend, caller_name="Progressbar") self.backend = backend @@ -96,10 +98,12 @@ def __getstate__(self): try: from rich.progress import Progress + class _Progress(Progress): def get_renderables(self): table = self.make_tasks_table(getattr(self, "_ordered_tasks", [])) yield table + except: pass diff --git a/sklearn/callback/_text_verbose.py b/sklearn/callback/_text_verbose.py index 0064ec97f2052..93f783a297d30 100644 --- a/sklearn/callback/_text_verbose.py +++ b/sklearn/callback/_text_verbose.py @@ -6,7 +6,6 @@ class TextVerbose(BaseCallback): - auto_propagate = True request_stopping_criterion = True diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 888a5649c19bb..f61ffc4077dff 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -21,6 +21,7 @@ def on_fit_iter_end(self, estimator, node, **kwargs): class TestingAutoPropagatedCallback(TestingCallback): auto_propagate = True + class NotValidCallback: def on_fit_begin(self, estimator, *, X=None, y=None): pass @@ -37,7 +38,7 @@ def __init__(self, max_iter=20): self.max_iter = max_iter def fit(self, X, y): - root = self._eval_callbacks_on_fit_begin( + root, X, y, X_val, y_val = self._eval_callbacks_on_fit_begin( levels=[ {"descr": "fit", "max_iter": self.max_iter}, {"descr": "iter", "max_iter": None}, @@ -54,6 +55,7 @@ def fit(self, X, y): self._from_reconstruction_attributes, reconstruction_attributes=lambda: {"n_iter_": i + 1}, ), + data={"X": X, "y": y, "X_val": X_val, "y_val": y_val"}, ): break @@ -63,6 +65,9 @@ def fit(self, X, y): return self + def objective_function(self, X, y=None): + return 0, 0, 0 + class MetaEstimator(BaseEstimator): def __init__( @@ -75,7 +80,7 @@ def __init__( self.prefer = prefer def fit(self, X, y): - root = self._eval_callbacks_on_fit_begin( + root, *_ = self._eval_callbacks_on_fit_begin( levels=[ {"descr": "fit", "max_iter": self.n_outer}, {"descr": "outer", "max_iter": self.n_inner}, @@ -93,7 +98,7 @@ def fit(self, X, y): self._eval_callbacks_on_fit_end() return self - + def _func(self, estimator, X, y, parent_node, i): for j, node in enumerate(parent_node.children): est = clone(estimator) @@ -104,4 +109,4 @@ def _func(self, estimator, X, y, parent_node, i): _eval_callbacks_on_fit_iter_end(estimator=self, node=parent_node) - return \ No newline at end of file + return diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index c77d88b68ce3d..01669a5494dde 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -10,12 +10,13 @@ from sklearn.callback.tests._utils import MetaEstimator -@pytest.mark.parametrize("callbacks", +@pytest.mark.parametrize( + "callbacks", [ TestingCallback(), [TestingCallback()], [TestingCallback(), TestingAutoPropagatedCallback()], - ] + ], ) def test_set_callbacks(callbacks): """Sanity check for the _set_callbacks method""" @@ -49,7 +50,7 @@ def test_propagate_callbacks(): assert hasattr(sub_estimator, "_parent_ct_node") assert not_propagated_callback not in sub_estimator._callbacks - assert propagated_callback in sub_estimator._callbacks + assert propagated_callback in sub_estimator._callbacks def test_propagate_callback_no_callback(): diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py index fb99003eb3b09..aa79503545acb 100644 --- a/sklearn/callback/tests/test_callbacks.py +++ b/sklearn/callback/tests/test_callbacks.py @@ -20,13 +20,22 @@ y = np.zeros(100, dtype=int) -@pytest.mark.parametrize("Callback", [ConvergenceMonitor, EarlyStopping, ProgressBar, Snapshot, TextVerbose,]) +@pytest.mark.parametrize( + "Callback", + [ + ConvergenceMonitor, + EarlyStopping, + ProgressBar, + Snapshot, + TextVerbose, + ], +) def test_callback_doesnt_hold_ref_to_estimator(Callback): callback = Callback() est = Estimator() callback_refcount = sys.getrefcount(callback) est_refcount = sys.getrefcount(est) - + est._set_callbacks(callbacks=callback) est.fit(X, y) # estimator has a ref on the callback but the callback has no ref to the estimator diff --git a/sklearn/callback/tests/test_computation_tree.py b/sklearn/callback/tests/test_computation_tree.py index 902175b71a250..2fe3766eba489 100644 --- a/sklearn/callback/tests/test_computation_tree.py +++ b/sklearn/callback/tests/test_computation_tree.py @@ -87,7 +87,10 @@ def test_get_ancestors(): ancestors = node.get_ancestors(include_ancestor_trees=False) assert ancestors == [ - node, node.parent, node.parent.parent, node.parent.parent.parent + node, + node.parent, + node.parent.parent, + node.parent.parent.parent, ] assert [n.idx for n in ancestors] == expected_node_indices assert computation_tree.root in ancestors diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index 003dbec919033..a3eac9c7e3468 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -36,7 +36,6 @@ ) from ..utils import metadata_routing from ..callback._base import _eval_callbacks_on_fit_iter_end -from ..callback._base import callback_aware EPSILON = np.finfo(np.float32).eps @@ -403,6 +402,7 @@ def _update_coordinate_descent(X, W, Ht, l1_reg, l2_reg, shuffle, random_state): def _fit_coordinate_descent( X, + X_val, W, H, tol=1e-4, @@ -429,6 +429,9 @@ def _fit_coordinate_descent( X : array-like of shape (n_samples, n_features) Constant matrix. + X_val : array-like of shape (n_samples_val, n_features) + Constant validation matrix. + W : array-like of shape (n_samples, n_components) Initial guess for the solution. @@ -469,6 +472,12 @@ def _fit_coordinate_descent( results across multiple function calls. See :term:`Glossary `. + estimator : estimator instance, default=None + The estimator calling this function. Used by callbacks. + + parent_node : ComputationNode instance, default=None + The parent node of the current node. Used by callbacks. + Returns ------- W : ndarray of shape (n_samples, n_components) @@ -490,6 +499,8 @@ def _fit_coordinate_descent( # so W and Ht are both in C order in memory Ht = check_array(H.T, order="C") X = check_array(X, accept_sparse="csr") + if X_val is not None: + X_val = check_array(X_val, accept_sparse="csr") rng = check_random_state(random_state) @@ -527,6 +538,7 @@ def _fit_coordinate_descent( "reconstruction_err_": _beta_divergence(X, W, Ht.T, 2, True), }, ), + data={"X": X, "y": None, "X_val": X_val, "y_val": None}, ): break @@ -748,6 +760,7 @@ def _multiplicative_update_h( def _fit_multiplicative_update( X, + X_val, W, H, beta_loss="frobenius", @@ -773,6 +786,9 @@ def _fit_multiplicative_update( X : array-like of shape (n_samples, n_features) Constant input matrix. + X_val : array-like of shape (n_samples_val, n_features) + Constant validation matrix. + W : array-like of shape (n_samples, n_components) Initial guess for the solution. @@ -813,6 +829,12 @@ def _fit_multiplicative_update( verbose : int, default=0 The verbosity level. + estimator : estimator instance, default=None + The estimator calling this function. Used by callbacks. + + parent_node : ComputationNode instance, default=None + The parent node of the current node. Used by callbacks. + Returns ------- W : ndarray of shape (n_samples, n_components) @@ -909,6 +931,7 @@ def _fit_multiplicative_update( "reconstruction_err_": _beta_divergence(X, W, H, beta_loss, True), }, ), + data={"X": X, "y": None, "X_val": X_val, "y_val": None}, ): break @@ -1340,6 +1363,28 @@ def inverse_transform(self, Xt=None, W=None): check_is_fitted(self) return Xt @ self.components_ + + def objective_function(self, X, y=None, *, W=None, H=None, normalize=False): + if W is None: + W = self.transform(X) + if H is None: + H = self.components_ + + data_fit = _beta_divergence(X, W, H, self._beta_loss) + + l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = self._compute_regularization(X) + penalization = ( + l1_reg_W * W.sum() + + l1_reg_H * H.sum() + + l2_reg_W * (W**2).sum() + + l2_reg_H * (H**2).sum() + ) + + if normalize: + data_fit /= X.shape[0] + penalization /= X.shape[0] + + return data_fit + penalization, data_fit, penalization @property def _n_features_out(self): @@ -1617,7 +1662,6 @@ def _check_params(self, X): return self - @callback_aware @_fit_context(prefer_skip_nested_validation=True) def fit_transform(self, X, y=None, W=None, H=None): """Learn a NMF model for the data X and returns the transformed data. @@ -1650,7 +1694,7 @@ def fit_transform(self, X, y=None, W=None, H=None): X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32] ) - root = self._eval_callbacks_on_fit_begin( + root, X, _, X_val, _ = self._eval_callbacks_on_fit_begin( levels=[ {"descr": "fit", "max_iter": self.max_iter}, {"descr": "iter", "max_iter": None}, @@ -1658,7 +1702,7 @@ def fit_transform(self, X, y=None, W=None, H=None): X=X, ) - W, H, n_iter = self._fit_transform(X, W=W, H=H, parent_node=root) + W, H, n_iter = self._fit_transform(X, X_val, W=W, H=H, parent_node=root) self.reconstruction_err_ = _beta_divergence( X, W, H, self._beta_loss, square_root=True @@ -1672,7 +1716,7 @@ def fit_transform(self, X, y=None, W=None, H=None): return W def _fit_transform( - self, X, y=None, W=None, H=None, update_H=True, parent_node=None + self, X, X_val=None, W=None, H=None, update_H=True, parent_node=None ): """Learn a NMF model for the data X and returns the transformed data. @@ -1733,6 +1777,7 @@ def _fit_transform( if self.solver == "cd": W, H, n_iter = _fit_coordinate_descent( X, + X_val, W, H, self.tol, @@ -1751,6 +1796,7 @@ def _fit_transform( elif self.solver == "mu": W, H, n_iter, *_ = _fit_multiplicative_update( X, + X_val, W, H, self._beta_loss, @@ -2439,28 +2485,6 @@ def partial_fit(self, X, y=None, W=None, H=None): return self - def objective_function(self, X, y=None, *, W=None, H=None, normalize=False): - if W is None: - W = self.transform(X) - if H is None: - H = self.components_ - - data_fit = _beta_divergence(X, W, H, self._beta_loss) - - l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = self._scale_regularization(X) - penalization = ( - l1_reg_W * W.sum() - + l1_reg_H * H.sum() - + l2_reg_W * (W ** 2).sum() - + l2_reg_H * (H ** 2).sum() - ) - - if normalize: - data_fit /= X.shape[0] - penalization /= X.shape[0] - - return data_fit + penalization, data_fit, penalization - @property def _n_features_out(self): """Number of transformed output features.""" diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 732f3c25a93eb..b949a35b28d02 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -449,9 +449,7 @@ def _logistic_regression_path( node = ( None if parent_node is None - else parent_node - if len(Cs) == 1 - else parent_node.children + else parent_node if len(Cs) == 1 else parent_node.children ) if solver == "lbfgs": @@ -1324,7 +1322,9 @@ def fit(self, X, y, sample_weight=None): {"descr": "class", "max_iter": self.max_iter}, {"descr": "iter", "max_iter": None}, ] - root = self._eval_callbacks_on_fit_begin(levels=levels, X=X, y=y) + root, X, y, X_val, y_val = self._eval_callbacks_on_fit_begin( + levels=levels, X=X, y=y + ) # distinguish between multinomial and ovr nodes = [root] if len(classes_) == 1 else root.children diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 9b9450cfee0ec..9ec5ce8414201 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -831,7 +831,9 @@ def fit(self, X, y=None, *, groups=None, **fit_params): all_out = [] all_more_results = defaultdict(list) - def evaluate_candidates(candidate_params, cv=None, more_results=None, parent_node=None): + def evaluate_candidates( + candidate_params, cv=None, more_results=None, parent_node=None + ): cv = cv or cv_orig candidate_params = list(candidate_params) n_candidates = len(candidate_params) @@ -863,8 +865,16 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None, parent_nod caller=self, node=node, ) - for ((cand_idx, parameters), (split_idx, (train, test))), node in zip(product( - enumerate(candidate_params), enumerate(cv.split(X, y, groups))), nodes) + for ( + (cand_idx, parameters), + (split_idx, (train, test)), + ), node in zip( + product( + enumerate(candidate_params), + enumerate(cv.split(X, y, groups)), + ), + nodes, + ) ) if len(out) < 1: @@ -1477,6 +1487,7 @@ def _run_search(self, evaluate_candidates): """Search all candidates in param_grid""" evaluate_candidates(self._param_grid, parent_node=self._computation_tree.root) + class RandomizedSearchCV(BaseSearchCV): """Randomized search on hyper parameters. diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index cb3563723027c..30fc160880d89 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -723,7 +723,7 @@ def _fit_and_score( cloned_parameters[k] = clone(v, safe=False) estimator = estimator.set_params(**cloned_parameters) - + if caller is not None: caller._propagate_callbacks(estimator, parent_node=node) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 9002bfcb0d8ad..0eb02009ecf91 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -34,7 +34,6 @@ from .utils.parallel import delayed, Parallel from .exceptions import NotFittedError from .callback._base import _eval_callbacks_on_fit_iter_end -from .callback._base import callback_aware __all__ = ["Pipeline", "FeatureUnion", "make_pipeline", "make_union"] @@ -357,7 +356,7 @@ def _fit(self, X, y=None, **fit_params_steps): # Setup the memory memory = check_memory(self.memory) - root = self._eval_callbacks_on_fit_begin( + root, *_ = self._eval_callbacks_on_fit_begin( levels=[ {"descr": "fit", "max_iter": len(self.steps)}, {"descr": "step", "max_iter": None}, @@ -405,7 +404,6 @@ def _fit(self, X, y=None, **fit_params_steps): return X - @callback_aware @_fit_context( # estimators in Pipeline.steps are not validated yet prefer_skip_nested_validation=False From d7208facafece078b0c8e687dc066b432eac2cbc Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 29 Jun 2023 15:04:29 +0200 Subject: [PATCH 19/55] wip --- sklearn/base.py | 2 +- sklearn/callback/__init__.py | 4 +- sklearn/callback/_convergence_monitor.py | 126 ------------------ sklearn/callback/_early_stopping.py | 16 +-- sklearn/callback/_monitoring.py | 111 +++++++++++++++ sklearn/callback/_text_verbose.py | 4 - sklearn/callback/tests/_utils.py | 18 +-- .../test_base_estimator_callback_methods.py | 2 +- sklearn/callback/tests/test_callbacks.py | 4 +- sklearn/decomposition/_nmf.py | 2 - .../gradient_boosting.py | 46 +++++++ sklearn/linear_model/_logistic.py | 2 - sklearn/model_selection/_search.py | 2 - sklearn/pipeline.py | 2 - 14 files changed, 181 insertions(+), 160 deletions(-) delete mode 100644 sklearn/callback/_convergence_monitor.py create mode 100644 sklearn/callback/_monitoring.py diff --git a/sklearn/base.py b/sklearn/base.py index 9c802b536f89d..09c76277b986e 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -119,7 +119,7 @@ def _clone_parametrized(estimator, *, safe=True): # copy callbacks if hasattr(estimator, "_callbacks"): - new_object._callbacks = clone(estimator._callbacks, safe=False) + new_object._callbacks = estimator._callbacks # quick sanity check of the parameters of the clone for name in new_object_params: diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py index 9767411b6c934..b74126e1ce327 100644 --- a/sklearn/callback/__init__.py +++ b/sklearn/callback/__init__.py @@ -4,7 +4,7 @@ from ._computation_tree import ComputationNode from ._computation_tree import ComputationTree from ._computation_tree import load_computation_tree -from ._convergence_monitor import ConvergenceMonitor +from ._monitoring import Monitoring from ._early_stopping import EarlyStopping from ._progressbar import ProgressBar from ._snapshot import Snapshot @@ -15,7 +15,7 @@ "ComputationNode", "ComputationTree", "load_computation_tree", - "ConvergenceMonitor", + "Monitoring", "EarlyStopping", "ProgressBar", "Snapshot", diff --git a/sklearn/callback/_convergence_monitor.py b/sklearn/callback/_convergence_monitor.py deleted file mode 100644 index 98fec496d6eb7..0000000000000 --- a/sklearn/callback/_convergence_monitor.py +++ /dev/null @@ -1,126 +0,0 @@ -# License: BSD 3 clause - -from copy import copy -from pathlib import Path -from tempfile import mkdtemp - -import matplotlib.pyplot as plt -import pandas as pd - -from . import BaseCallback - -# import ..metrics as metrics - - -class ConvergenceMonitor(BaseCallback): - """Monitor model convergence. - - Parameters - ---------- - monitor : - - X_val : ndarray, default=None - Validation data - - y_val : ndarray, default=None - Validation target - - Attributes - ---------- - data : pandas.DataFrame - The monitored quantities at each iteration. - """ - - request_reconstruction_attributes = True - - def __init__( - self, - *, - monitor="objective_function", - on="val", - higher_is_better=False, - ): - if monitor == "objective_function": - self._monitor = "objective_function" - else: - self._monitor = getattr(metrics, monitor, None) - if self._monitor is None: - raise ValueError(f"unknown metric {monitor}") - - self._data_file = Path(mkdtemp()) / "convergence_monitor.csv" - - def on_fit_begin(self, estimator, *, X=None, y=None): - self.estimator = estimator - self.X_train = X - self.y_train = y - - def on_fit_iter_end(self, *, estimator, node, **kwargs): - reconstruction_attributes = kwargs.get("reconstruction_attributes", None) - if reconstruction_attributes is None: - return - - new_estimator = copy(estimator) - for key, val in reconstruction_attributes.items(): - setattr(new_estimator, key, val) - - # if self._monitor = - - obj_train, *_ = new_estimator.objective_function( - self.X_train, self.y_train, normalize=True - ) - if self.X_val is not None: - obj_val, *_ = new_estimator.objective_function( - self.X_val, self.y_val, normalize=True - ) - else: - obj_val = None - - ancestors = node.get_ancestors()[:0:-1] - ancestors_desc = [ - f"{n.computation_tree.estimator_name}-{n.description}" for n in ancestors - ] - ancestors_idx = [f"{n.idx}" for n in ancestors] - - if not self._data_file.exists(): - with open(self._data_file, "w") as f: - f.write( - f"{','.join(ancestors_desc)},iteration,time,obj_train,obj_val\n" - ) - - with open(self._data_file, "a") as f: - f.write( - f"{','.join(ancestors_idx)},{node.idx},{curr_time},{obj_train},{obj_val}\n" - ) - - def on_fit_end(self): - pass - - def get_data(self): - if not hasattr(self, "data"): - self.data = pd.read_csv(self._data_file) - return self.data - - def plot(self, x="iteration"): - data = self.get_data() - - # all columns but iteration, time, obj_train, obj_val - group_by_columns = list(data.columns[:-4]) - groups = data.groupby(group_by_columns) - - for key in groups.groups.keys(): - group = groups.get_group(key) - fig, ax = plt.subplots() - - ax.plot(group[x], group["obj_train"], label="obj_train") - if self.X_val is not None: - ax.plot(group[x], group["obj_val"], label="obj_val") - - if x == "iteration": - x_label = "Number of iterations" - elif x == "time": - x_label = "Time (s)" - ax.set_xlabel(x_label) - ax.set_ylabel("objective function") - - ax.legend() - plt.show() diff --git a/sklearn/callback/_early_stopping.py b/sklearn/callback/_early_stopping.py index 6d408dda8c960..1ad9ad8437d37 100644 --- a/sklearn/callback/_early_stopping.py +++ b/sklearn/callback/_early_stopping.py @@ -25,13 +25,14 @@ def __init__( self.max_no_improvement = max_no_improvement self.threshold = threshold - def on_fit_begin(self, estimator, X=None, y=None): self._no_improvement = {} self._last_monitored = {} - self.early_stopped_ = None + + def on_fit_begin(self, estimator, X=None, y=None): + pass def on_fit_iter_end(self, *, estimator, node, **kwargs): - if node.depth != estimator._computation_tree.depth: + if node.depth != node.computation_tree.depth: return reconstructed_estimator = kwargs.pop("from_reconstruction_attributes") @@ -46,8 +47,8 @@ def on_fit_iter_end(self, *, estimator, node, **kwargs): new_monitored = self.monitor(reconstructed_estimator, X, y) elif self.monitor is None or isinstance(self.monitor, str): from ..metrics import check_scoring - scorer = check_scoring(estimator, self.monitor) - new_monitored = scorer(estimator, X, y) + scorer = check_scoring(reconstructed_estimator, self.monitor) + new_monitored = scorer(reconstructed_estimator, X, y) if self._score_improved(node, new_monitored): self._no_improvement[node.parent] = 0 @@ -56,9 +57,8 @@ def on_fit_iter_end(self, *, estimator, node, **kwargs): self._no_improvement[node.parent] += 1 if self._no_improvement[node.parent] >= self.max_no_improvement: - self.early_stopped_ = node.idx return True - + def _score_improved(self, node, new_monitored): if node.parent not in self._last_monitored: return True @@ -74,4 +74,4 @@ def on_fit_end(self): @property def request_validation_split(self): - return self.on == "val" + return self.on == "validation_set" diff --git a/sklearn/callback/_monitoring.py b/sklearn/callback/_monitoring.py new file mode 100644 index 0000000000000..2eb7c8de5885f --- /dev/null +++ b/sklearn/callback/_monitoring.py @@ -0,0 +1,111 @@ +# License: BSD 3 clause + +# import os +from pathlib import Path +from tempfile import NamedTemporaryFile, TemporaryDirectory +from tempfile import mkdtemp + +import matplotlib.pyplot as plt +import pandas as pd + +from . import BaseCallback + + +class Monitoring(BaseCallback): + """Monitor model convergence. + + Parameters + ---------- + monitor : + + X_val : ndarray, default=None + Validation data + + y_val : ndarray, default=None + Validation target + + Attributes + ---------- + data : pandas.DataFrame + The monitored quantities at each iteration. + """ + + request_from_reconstruction_attributes = True + + def __init__( + self, + *, + monitor="objective_function", + on="validation_set", + validation_split="auto", + ): + from ..model_selection import KFold + self.validation_split = validation_split + if validation_split == "auto": + self.validation_split = KFold(n_splits=5, shuffle=True, random_state=42) + self.monitor = monitor + self.on = on + + self._data_dir = TemporaryDirectory() + self._data_files = {} + + if isinstance(self.monitor, str): + self.monitor_name = self.monitor + elif callable(self.monitor): + self.monitor_name = self.monitor.__name__ + + def on_fit_begin(self, estimator, *, X=None, y=None): + fname = Path(self._data_dir.name) / f"{estimator._computation_tree.uid}.csv" + with open(fname, "w") as file: + file.write(f"iteration,{self.monitor_name}_train,{self.monitor_name}_val\n") + self._data_files[estimator._computation_tree] = fname + + def on_fit_iter_end(self, *, estimator, node, from_reconstruction_attributes, data, **kwargs): + if node.depth != node.computation_tree.depth: + return + + new_estimator = from_reconstruction_attributes + + X, y, X_val, y_val = data["X"], data["y"], data["X_val"], data["y_val"] + + if self.monitor == "objective_function": + new_monitored_train, *_ = new_estimator.objective_function(X, y, normalize=True) + if X_val is not None: + new_monitored_val, *_ = new_estimator.objective_function(X_val, y_val, normalize=True) + elif callable(self.monitor): + new_monitored_train = self.monitor(new_estimator, X, y) + if X_val is not None: + new_monitored_val = self.monitor(new_estimator, X_val, y_val) + elif self.monitor is None or isinstance(self.monitor, str): + from ..metrics import check_scoring + scorer = check_scoring(new_estimator, self.monitor) + new_monitored_train = scorer(new_estimator, X, y) + if X_val is not None: + new_monitored_val = scorer(new_estimator, X_val, y_val) + + if X_val is None: + new_monitored_val = None + + with open(self._data_files[node.computation_tree], "a") as f: + f.write(f"{node.idx},{new_monitored_train},{new_monitored_val}\n") + + def on_fit_end(self): + pass + + # @property + # def data(self): + + def plot(self): + data_files = [p for p in Path(self._data_dir.name).iterdir() if p.is_file()] + for f in data_files: + data = pd.read_csv(f) + fig, ax = plt.subplots() + ax.plot(data["iteration"], data[f"{self.monitor_name}_train"], label=f"train set") + if self.on != "train_set": + ax.plot(data["iteration"], data[f"{self.monitor_name}_val"], label=f"validation set") + + ax.set_xlabel("Number of iterations") + ax.set_ylabel(self.monitor_name) + + ax.legend() + plt.show() diff --git a/sklearn/callback/_text_verbose.py b/sklearn/callback/_text_verbose.py index 93f783a297d30..9773f1c8a6f51 100644 --- a/sklearn/callback/_text_verbose.py +++ b/sklearn/callback/_text_verbose.py @@ -9,11 +9,7 @@ class TextVerbose(BaseCallback): auto_propagate = True request_stopping_criterion = True - def __init__(self, min_time_between_calls=0): - self.min_time_between_calls = min_time_between_calls - def on_fit_begin(self, estimator, X=None, y=None): - self.estimator = estimator self._start_time = time.perf_counter() def on_fit_iter_end(self, *, node, **kwargs): diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index f61ffc4077dff..4144ba3ddf3f3 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -2,7 +2,7 @@ from joblib.parallel import Parallel, delayed -from sklearn.base import BaseEstimator, clone +from sklearn.base import BaseEstimator, clone, _fit_context from sklearn.callback import BaseCallback from sklearn.callback._base import _eval_callbacks_on_fit_iter_end @@ -34,9 +34,12 @@ def on_fit_iter_end(self, estimator, node, **kwargs): class Estimator(BaseEstimator): + _parameter_constraints = {} + def __init__(self, max_iter=20): self.max_iter = max_iter + @_fit_context(prefer_skip_nested_validation=False) def fit(self, X, y): root, X, y, X_val, y_val = self._eval_callbacks_on_fit_begin( levels=[ @@ -55,21 +58,21 @@ def fit(self, X, y): self._from_reconstruction_attributes, reconstruction_attributes=lambda: {"n_iter_": i + 1}, ), - data={"X": X, "y": y, "X_val": X_val, "y_val": y_val"}, + data={"X": X, "y": y, "X_val": X_val, "y_val": y_val}, ): break self.n_iter_ = i + 1 - self._eval_callbacks_on_fit_end() - return self - def objective_function(self, X, y=None): + def objective_function(self, X, y=None, normalize=False): return 0, 0, 0 class MetaEstimator(BaseEstimator): + _parameter_constraints = {} + def __init__( self, estimator, n_outer=4, n_inner=3, n_jobs=None, prefer="processes" ): @@ -79,8 +82,9 @@ def __init__( self.n_jobs = n_jobs self.prefer = prefer + @_fit_context(prefer_skip_nested_validation=False) def fit(self, X, y): - root, *_ = self._eval_callbacks_on_fit_begin( + root, X, y, _, _ = self._eval_callbacks_on_fit_begin( levels=[ {"descr": "fit", "max_iter": self.n_outer}, {"descr": "outer", "max_iter": self.n_inner}, @@ -95,8 +99,6 @@ def fit(self, X, y): for i, node in enumerate(root.children) ) - self._eval_callbacks_on_fit_end() - return self def _func(self, estimator, X, y, parent_node, i): diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index 01669a5494dde..2f554101dcfa3 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -88,7 +88,7 @@ def test_eval_callbacks_on_fit_begin(): {"descr": "fit", "max_iter": 10}, {"descr": "iter", "max_iter": None}, ] - ct_root = estimator._eval_callbacks_on_fit_begin(levels=levels) + ct_root, *_ = estimator._eval_callbacks_on_fit_begin(levels=levels) assert hasattr(estimator, "_computation_tree") assert ct_root is estimator._computation_tree.root diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py index aa79503545acb..e453705b637e1 100644 --- a/sklearn/callback/tests/test_callbacks.py +++ b/sklearn/callback/tests/test_callbacks.py @@ -7,7 +7,7 @@ import numpy as np -from sklearn.callback import ConvergenceMonitor +from sklearn.callback import Monitoring from sklearn.callback import EarlyStopping from sklearn.callback import ProgressBar from sklearn.callback import Snapshot @@ -23,7 +23,7 @@ @pytest.mark.parametrize( "Callback", [ - ConvergenceMonitor, + Monitoring, EarlyStopping, ProgressBar, Snapshot, diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index a3eac9c7e3468..8cd485114ac9c 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -1711,8 +1711,6 @@ def fit_transform(self, X, y=None, W=None, H=None): self.components_ = H self.n_iter_ = n_iter - self._eval_callbacks_on_fit_end() - return W def _fit_transform( diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index e44b6428f8f4e..e5df230279b59 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -19,6 +19,7 @@ ) from ...base import BaseEstimator, RegressorMixin, ClassifierMixin, is_classifier from ...base import _fit_context +from ...callback._base import _eval_callbacks_on_fit_iter_end from ...utils import check_random_state, resample, compute_sample_weight from ...utils.validation import ( check_is_fitted, @@ -462,6 +463,17 @@ def fit(self, X, y, sample_weight=None): X_train, y_train, sample_weight_train = X, y, sample_weight X_val = y_val = sample_weight_val = None + begin_at_stage = 0 if not (self._is_fitted() and self.warm_start) else self.n_iter_ + + root, X_train, y_train, X_val, y_val = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.max_iter - begin_at_stage}, + {"descr": "iter", "max_iter": None}, + ], + X=X, + y=y, + ) + # Bin the data # For ease of use of the API, the user-facing GBDT classes accept the # parameter max_bins, which doesn't take into account the bin for @@ -756,6 +768,26 @@ def fit(self, X, y, sample_weight=None): if should_early_stop: break + if _eval_callbacks_on_fit_iter_end( + estimator=self, + node=root.children[iteration - begin_at_stage], + fit_state={}, + from_reconstruction_attributes=partial( + self._from_reconstruction_attributes, + reconstruction_attributes=lambda: { + "train_score_": np.asarray(self.train_score_), + "validation_score_": np.asarray(self.validation_score_), + }, + ), + data={ + "X": X_binned_train, + "y": y_train, + "X_val": X_binned_val, + "y_val": y_val + }, + ): + break + if self.verbose: duration = time() - fit_start_time n_total_leaves = sum( @@ -794,8 +826,22 @@ def fit(self, X, y, sample_weight=None): self.train_score_ = np.asarray(self.train_score_) self.validation_score_ = np.asarray(self.validation_score_) del self._in_fit # hard delete so we're sure it can't be used anymore + return self + def objective_function(self, X, y, *, raw_predictions=None, normalize=False): + if raw_predictions is None: + raw_predictions = self._raw_predict(X) + + loss = self._loss( + y_true=y, + raw_prediction=raw_predictions, + ) + if normalize: + loss /= raw_predictions.shape[0] + + return loss, loss, 0 + def _is_fitted(self): return len(getattr(self, "_predictors", [])) > 0 diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index b949a35b28d02..dcaf6377dbe1f 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -1397,8 +1397,6 @@ def fit(self, X, y, sample_weight=None): else: self.intercept_ = np.zeros(n_classes) - self._eval_callbacks_on_fit_end() - return self def predict_proba(self, X): diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 9ec5ce8414201..c6c718ca0d4fa 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -1481,8 +1481,6 @@ def fit(self, X, y=None, *, groups=None, **fit_params): ) super().fit(X, y=y, groups=groups, **fit_params) - self._eval_callbacks_on_fit_end() - def _run_search(self, evaluate_candidates): """Search all candidates in param_grid""" evaluate_candidates(self._param_grid, parent_node=self._computation_tree.root) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 0eb02009ecf91..0d563cbb10c12 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -447,8 +447,6 @@ def fit(self, X, y=None, **fit_params): _eval_callbacks_on_fit_iter_end(estimator=self, node=node) - self._eval_callbacks_on_fit_end() - return self def _can_fit_transform(self): From b8ac1a5e86aeee791675aebd36758c23831a3efa Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 18 Oct 2023 12:10:21 +0200 Subject: [PATCH 20/55] cln --- sklearn/callback/_progressbar.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index f802cc63b3b9a..f8ed251add34a 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -103,7 +103,7 @@ def get_renderables(self): table = self.make_tasks_table(getattr(self, "_ordered_tasks", [])) yield table -except: +except ImportError: pass @@ -262,7 +262,8 @@ def _format_task_description(self, node, computation_tree, depth): description = f"{computation_tree.estimator_name} - {node.description}" if node.parent is None and computation_tree.parent_node is not None: description = ( - f"{computation_tree.parent_node.description} {computation_tree.parent_node.idx} |" + f"{computation_tree.parent_node.description} " + f"{computation_tree.parent_node.idx} |" f" {description}" ) if node.parent is not None: From b644430691691fb74aa92d1c53ffcd537aa1710f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 25 Oct 2023 17:51:09 +0200 Subject: [PATCH 21/55] wip --- sklearn/base.py | 49 +-- sklearn/callback/__init__.py | 14 +- sklearn/callback/_base.py | 34 +- sklearn/callback/_computation_tree.py | 310 +++------------- sklearn/callback/_early_stopping.py | 81 ---- sklearn/callback/_monitoring.py | 124 ------- sklearn/callback/_progressbar.py | 351 ++++++------------ sklearn/callback/_snapshot.py | 65 ---- sklearn/callback/_text_verbose.py | 40 -- sklearn/callback/tests/_utils.py | 15 +- .../test_base_estimator_callback_methods.py | 40 +- sklearn/callback/tests/test_callbacks.py | 77 ---- 12 files changed, 197 insertions(+), 1003 deletions(-) delete mode 100644 sklearn/callback/_early_stopping.py delete mode 100644 sklearn/callback/_monitoring.py delete mode 100644 sklearn/callback/_snapshot.py delete mode 100644 sklearn/callback/_text_verbose.py delete mode 100644 sklearn/callback/tests/test_callbacks.py diff --git a/sklearn/base.py b/sklearn/base.py index 62c99025b37b3..bc3073527250e 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -11,15 +11,12 @@ import re import warnings from collections import defaultdict -from functools import partial -from shutil import rmtree import numpy as np from . import __version__ from ._config import config_context, get_config -from .callback import BaseCallback, ComputationTree -from .callback._base import CallbackContext +from .callback import BaseCallback, build_computation_tree from .exceptions import InconsistentVersionWarning from .utils import _IS_32BIT from .utils._estimator_html_repr import estimator_html_repr @@ -713,7 +710,7 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): if not propagated_callbacks: return - sub_estimator._parent_ct_node = parent_node + sub_estimator._parent_node = parent_node sub_estimator._set_callbacks( getattr(sub_estimator, "_callbacks", []) + propagated_callbacks @@ -743,41 +740,14 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): root : ComputationNode instance The root of the computation tree. """ - self._computation_tree = ComputationTree( + self._computation_tree = build_computation_tree( estimator_name=self.__class__.__name__, levels=levels, - parent_node=getattr(self, "_parent_ct_node", None), + parent=getattr(self, "_parent_node", None), ) if not hasattr(self, "_callbacks"): - return self._computation_tree.root, None, None, None, None - - X_val, y_val = None, None - - if any(callback.request_validation_split for callback in self._callbacks): - splitter = next( - callback.validation_split - for callback in self._callbacks - if hasattr(callback, "validation_split") - ) - - train, val = next(splitter.split(X)) - if X is not None: - X, X_val = X[train], X[val] - if y is not None: - y, y_val = y[train], y[val] - - # - CallbackContext( - self._callbacks, - finalizer=partial(rmtree, ignore_errors=True), - finalizer_args=self._computation_tree.tree_dir, - ) - - # - file_path = self._computation_tree.tree_dir / "computation_tree.pkl" - with open(file_path, "wb") as f: - pickle.dump(self._computation_tree, f) + return self._computation_tree # Only call the on_fit_begin method of callbacks that are not # propagated from a meta-estimator. @@ -785,18 +755,13 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): if not callback._is_propagated(estimator=self): callback.on_fit_begin(estimator=self, X=X, y=y) - return self._computation_tree.root, X, y, X_val, y_val + return self._computation_tree def _eval_callbacks_on_fit_end(self): """Evaluate the on_fit_end method of the callbacks""" - if not hasattr(self, "_callbacks"): + if not hasattr(self, "_callbacks") or not hasattr(self, "_computation_tree"): return - if not hasattr(self, "_computation_tree"): - return - - self._computation_tree._tree_status[0] = True - # Only call the on_fit_end method of callbacks that are not # propagated from a meta-estimator. for callback in self._callbacks: diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py index 2069ae3a58681..7c6e4a07741e3 100644 --- a/sklearn/callback/__init__.py +++ b/sklearn/callback/__init__.py @@ -1,21 +1,13 @@ # License: BSD 3 clause +# Authors: the scikit-learn developers from ._base import BaseCallback -from ._computation_tree import ComputationNode, ComputationTree, load_computation_tree -from ._early_stopping import EarlyStopping -from ._monitoring import Monitoring +from ._computation_tree import ComputationNode, build_computation_tree from ._progressbar import ProgressBar -from ._snapshot import Snapshot -from ._text_verbose import TextVerbose __all__ = [ "BaseCallback", + "build_computation_tree", "ComputationNode", - "ComputationTree", - "load_computation_tree", - "Monitoring", - "EarlyStopping", "ProgressBar", - "Snapshot", - "TextVerbose", ] diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 42e18b431db52..0bc4023f2a266 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -1,6 +1,6 @@ # License: BSD 3 clause +# Authors: the scikit-learn developers -import weakref from abc import ABC, abstractmethod @@ -9,7 +9,7 @@ def _eval_callbacks_on_fit_iter_end(**kwargs): """Evaluate the on_fit_iter_end method of the callbacks - This function should be called at the end of each computation node. + This function must be called at the end of each computation node. Parameters ---------- @@ -27,8 +27,6 @@ def _eval_callbacks_on_fit_iter_end(**kwargs): if not hasattr(estimator, "_callbacks") or node is None: return False - estimator._computation_tree._tree_status[node.tree_status_idx] = True - # stopping_criterion and reconstruction_attributes can be costly to compute. # They are passed as lambdas for lazy evaluation. We only actually # compute them if a callback requests it. @@ -56,9 +54,11 @@ def on_fit_begin(self, estimator, *, X=None, y=None): ---------- estimator: estimator instance The estimator the callback is set on. + X: ndarray or sparse matrix, default=None The training data. - y: ndarray, default=None + + y: ndarray or sparse matrix, default=None The target. """ pass @@ -103,11 +103,6 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) used by generic callbacks but by a callback designed for a specific estimator instead. - - extra_verbose: dict - Model specific . This is not meant to be - used by generic callbacks but by a callback designed for a specific - estimator instead. - Returns ------- stop : bool or None @@ -131,7 +126,7 @@ def _is_propagated(self, estimator): """Check if this callback attached to estimator has been propagated from a meta-estimator. """ - return self.auto_propagate and hasattr(estimator, "_parent_ct_node") + return self.auto_propagate and hasattr(estimator, "_parent_node") @property def request_stopping_criterion(self): @@ -140,20 +135,3 @@ def request_stopping_criterion(self): @property def request_from_reconstruction_attributes(self): return False - - @property - def request_validation_split(self): - return False - - def _set_context(self, context): - if not hasattr(self, "_callback_contexts"): - self._callback_contexts = [] - - self._callback_contexts.append(context) - - -class CallbackContext: - def __init__(self, callbacks, finalizer, finalizer_args): - for callback in callbacks: - callback._set_context(self) - weakref.finalize(self, finalizer, finalizer_args) diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py index e5be721f60ff6..0add054ea6180 100644 --- a/sklearn/callback/_computation_tree.py +++ b/sklearn/callback/_computation_tree.py @@ -1,21 +1,14 @@ # License: BSD 3 clause - -import os -import pickle -from pathlib import Path -from tempfile import mkdtemp -from uuid import uuid4 - -import numpy as np +# Authors: the scikit-learn developers class ComputationNode: - """A node in a ComputationTree + """A node in a computation tree Parameters ---------- - computation_tree : ComputationTree instance - The computation tree it belongs to. + estimator_name : str + The name of the estimator this computation node belongs to. parent : ComputationNode instance, default=None The parent node. None means this is the root. @@ -26,260 +19,77 @@ class ComputationNode: description : str, default=None A description of this computation node. None means it's a leaf. - tree_status_idx : int, default=0 - The index of the status of this node in the `tree_status` array of its - computation tree. - - idx : int, default=0 - The index of this node in the children list of its parent. - Attributes ---------- children : list The list of its children nodes. For a leaf, it's an empty list - - depth : int - The depth of this node in its computation tree. The root has a depth of 0. """ def __init__( self, - computation_tree, - parent=None, - max_iter=None, + estimator_name, description=None, - tree_status_idx=0, + max_iter=None, idx=0, + parent=None, ): - self.computation_tree = computation_tree + # estimator_name and description are tuples because an estimator can be + # a sub-estimator of a meta-estimator. In that case, the root of the computation + # tree of the sub-estimator and a leaf of the computation tree of the + # meta-estimator correspond to the same computation step. Therefore, both + # nodes are merged into a single node, retaining the information of both. + self.estimator_name = (estimator_name,) + self.description = (description,) + self.parent = parent self.max_iter = max_iter - self.description = description - self.tree_status_idx = tree_status_idx self.idx = idx - self.children = [] - self.depth = 0 if self.parent is None else self.parent.depth + 1 - - def get_ancestors(self, include_ancestor_trees=True): - """Get the list of all nodes in the path from the node to the root - - Parameters - ---------- - include_ancestor_trees : bool, default=True - If True, propagate to the tree of the `parent_node` of this tree if it - exists and so on. - - Returns - ------- - ancestors : list - The list of ancestors of this node (included). - """ - node = self - ancestors = [node] - - while node.parent is not None: - node = node.parent - ancestors.append(node) - - if include_ancestor_trees: - node_parent_tree = node.computation_tree.parent_node - if node_parent_tree is not None: - ancestors.extend(node_parent_tree.get_ancestors()) - - return ancestors - - def __repr__(self): - return ( - f"ComputationNode(description={self.description}, " - f"depth={self.depth}, idx={self.idx})" - ) - - -class ComputationTree: - """Data structure to store the computation tree of an estimator - - Parameters - ---------- - estimator_name : str - The name of the estimator. - - levels : list of dict - A description of the nested levels of computation of the estimator to build the - tree. It's a list of dict with "descr" and "max_iter" keys. - - parent_node : ComputationNode, default=None - The node where the estimator is used in the computation tree of a - meta-estimator. This node is not set to be the parent of the root of this tree. - - Attributes - ---------- - depth : int - The depth of the tree. It corresponds to the depth of its deepest leaf. - - root : ComputationNode instance - The root of the computation tree. - - tree_dir : pathlib.Path instance - The path of the directory where the computation tree is dumped during the fit of - its estimator. If it has a parent tree, this is a sub-directory of the - `tree_dir` of its parent. - - uid : uuid.UUID - Unique indentifier for a ComputationTree instance. - """ - def __init__(self, estimator_name, levels, *, parent_node=None): - self.estimator_name = estimator_name - self.parent_node = parent_node - - self.depth = len(levels) - 1 - self.root, self.n_nodes = self._build_tree(levels) - - self.uid = uuid4() - - parent_tree_dir = ( - None - if self.parent_node is None - else self.parent_node.computation_tree.tree_dir - ) - if parent_tree_dir is None: - self.tree_dir = Path(mkdtemp()) - else: - # This tree has a parent tree. Place it in a subdir of its parent dir - # and give it a name that allows from the parent tree to find the sub dir - # of the sub tree of a given leaf. - self.tree_dir = parent_tree_dir / str(parent_node.tree_status_idx) - self.tree_dir.mkdir() - self._filename = self.tree_dir / "tree_status.memmap" - - self._set_tree_status(mode="w+") - self._tree_status[:] = False - - def _build_tree(self, levels): - """Build the computation tree from the description of the levels""" - root = ComputationNode( - computation_tree=self, - max_iter=levels[0]["max_iter"], - description=levels[0]["descr"], - ) - - n_nodes = self._recursive_build_tree(root, levels) - - return root, n_nodes - - def _recursive_build_tree(self, parent, levels, n_nodes=1): - """Recursively build the tree from the root the leaves""" - if parent.depth == self.depth: - return n_nodes - - for i in range(parent.max_iter): - children_max_iter = levels[parent.depth + 1]["max_iter"] - description = levels[parent.depth + 1]["descr"] + self.children = [] - node = ComputationNode( - computation_tree=self, - parent=parent, - max_iter=children_max_iter, - description=description, - tree_status_idx=n_nodes, - idx=i, + @property + def depth(self): + """The depth of this node in the computation tree""" + return 0 if self.parent is None else self.parent.depth + 1 + + @property + def path(self): + """List of all the nodes in the path from the root to this node""" + return [self] if self.parent is None else self.parent.path + [self] + + def __iter__(self): + """Pre-order depth-first traversal""" + yield self + for node in self.children: + yield from node + + +def build_computation_tree(estimator_name, levels, parent=None, idx=0): + """Build the computation tree from the description of the levels""" + this_level = levels[0] + + node = ComputationNode( + estimator_name=estimator_name, + parent=parent, + max_iter=this_level["max_iter"], + description=this_level["descr"], + idx=idx, + ) + + if parent is not None and parent.max_iter is None: + # parent node is a leaf of the computation tree of an outer estimator. It means + # that this node is the root of the computation tree of this estimator. They + # both correspond the same computation step, so we merge both nodes. + node.description = parent.description + node.description + node.estimator_name = parent.estimator_name + node.estimator_name + node.parent = parent.parent + node.idx = parent.idx + parent.parent.children[node.idx] = node + + if node.max_iter is not None: + for i in range(node.max_iter): + node.children.append( + build_computation_tree(estimator_name, levels[1:], parent=node, idx=i) ) - parent.children.append(node) - - n_nodes = self._recursive_build_tree(node, levels, n_nodes + 1) - - return n_nodes - - def _set_tree_status(self, mode): - """Create a memory-map to the tree_status array stored on the disk""" - # This has to be done each time we unpickle the tree - self._tree_status = np.memmap( - self._filename, dtype=bool, mode=mode, shape=(self.n_nodes,) - ) - - def get_progress(self, node): - """Return the number of finished child nodes of this node""" - if self._tree_status[node.tree_status_idx]: - return node.max_iter - - # Since the children of a node are not ordered (to account for parallel - # execution), we can't rely on the highest index for which the status is True. - return sum( - [self._tree_status[child.tree_status_idx] for child in node.children] - ) - - def get_child_computation_tree_dir(self, node): - if node.children: - raise ValueError("node is not a leaf") - return self.tree_dir / str(node.tree_status_idx) - - def iterate(self, include_leaves=False): - """Return an iterable over the nodes of the computation tree - - Nodes are discovered in a depth first search manner. - - Parameters - ---------- - include_leaves : bool - Whether or not to include the leaves of the tree in the iterable - - Returns - ------- - nodes_list : list - A list of the nodes of the computation tree. - """ - return self._recursive_iterate(include_leaves=include_leaves) - - def _recursive_iterate(self, node=None, include_leaves=False, node_list=None): - """Recursively constructs the iterable""" - # TODO make it an iterator ? - if node is None: - node = self.root - node_list = [] - - if node.children or include_leaves: - node_list.append(node) - - for child in node.children: - self._recursive_iterate(child, include_leaves, node_list) - - return node_list - - def __repr__(self): - res = ( - f"[{self.estimator_name}] {self.root.description} : progress " - f"{self.get_progress(self.root)} / {self.root.max_iter}\n" - ) - for node in self.iterate(include_leaves=False): - if node is not self.root: - res += ( - f"{' ' * node.depth}{node.description} {node.idx}: progress " - f"{self.get_progress(node)} / {node.max_iter}\n" - ) - return res - - -def load_computation_tree(directory): - """load the computation tree of a directory - - Parameters - ---------- - directory : pathlib.Path instance - The directory where the computation tree is dumped - - Returns - ------- - computation_tree : ComputationTree instance - The loaded computation tree - """ - file_path = directory / "computation_tree.pkl" - if not file_path.exists() or not os.path.getsize(file_path) > 0: - # Do not try to load the tree when it's created but not yet written - return - - with open(file_path, "rb") as f: - computation_tree = pickle.load(f) - - computation_tree._set_tree_status(mode="r") - return computation_tree + return node diff --git a/sklearn/callback/_early_stopping.py b/sklearn/callback/_early_stopping.py deleted file mode 100644 index b3137c4ff7812..0000000000000 --- a/sklearn/callback/_early_stopping.py +++ /dev/null @@ -1,81 +0,0 @@ -# License: BSD 3 clause - -from . import BaseCallback - - -class EarlyStopping(BaseCallback): - request_from_reconstruction_attributes = True - - def __init__( - self, - monitor="objective_function", - on="validation_set", - higher_is_better=False, - validation_split="auto", - max_no_improvement=10, - threshold=1e-2, - ): - from ..model_selection import KFold - - self.validation_split = validation_split - if validation_split == "auto": - self.validation_split = KFold(n_splits=5, shuffle=True, random_state=42) - self.monitor = monitor - self.on = on - self.higher_is_better = higher_is_better - self.max_no_improvement = max_no_improvement - self.threshold = threshold - - self._no_improvement = {} - self._last_monitored = {} - - def on_fit_begin(self, estimator, X=None, y=None): - pass - - def on_fit_iter_end(self, *, estimator, node, **kwargs): - if node.depth != node.computation_tree.depth: - return - - reconstructed_estimator = kwargs.pop("from_reconstruction_attributes") - data = kwargs.pop("data") - - X = data["X_val"] if self.on == "validation_set" else data["X"] - y = data["y_val"] if self.on == "validation_set" else data["y"] - - if self.monitor == "objective_function": - new_monitored, *_ = reconstructed_estimator.objective_function( - X, y, normalize=True - ) - elif callable(self.monitor): - new_monitored = self.monitor(reconstructed_estimator, X, y) - elif self.monitor is None or isinstance(self.monitor, str): - from ..metrics import check_scoring - - scorer = check_scoring(reconstructed_estimator, self.monitor) - new_monitored = scorer(reconstructed_estimator, X, y) - - if self._score_improved(node, new_monitored): - self._no_improvement[node.parent] = 0 - self._last_monitored[node.parent] = new_monitored - else: - self._no_improvement[node.parent] += 1 - - if self._no_improvement[node.parent] >= self.max_no_improvement: - return True - - def _score_improved(self, node, new_monitored): - if node.parent not in self._last_monitored: - return True - - last_monitored = self._last_monitored[node.parent] - if self.higher_is_better: - return new_monitored > last_monitored * (1 + self.threshold) - else: - return new_monitored < last_monitored * (1 - self.threshold) - - def on_fit_end(self): - pass - - @property - def request_validation_split(self): - return self.on == "validation_set" diff --git a/sklearn/callback/_monitoring.py b/sklearn/callback/_monitoring.py deleted file mode 100644 index cfff4d1215c3b..0000000000000 --- a/sklearn/callback/_monitoring.py +++ /dev/null @@ -1,124 +0,0 @@ -# License: BSD 3 clause - -# import os -from pathlib import Path -from tempfile import TemporaryDirectory - -import matplotlib.pyplot as plt -import pandas as pd - -from . import BaseCallback - - -class Monitoring(BaseCallback): - """Monitor model convergence. - - Parameters - ---------- - monitor : - - X_val : ndarray, default=None - Validation data - - y_val : ndarray, default=None - Validation target - - Attributes - ---------- - data : pandas.DataFrame - The monitored quantities at each iteration. - """ - - request_from_reconstruction_attributes = True - - def __init__( - self, - *, - monitor="objective_function", - on="validation_set", - validation_split="auto", - ): - from ..model_selection import KFold - - self.validation_split = validation_split - if validation_split == "auto": - self.validation_split = KFold(n_splits=5, shuffle=True, random_state=42) - self.monitor = monitor - self.on = on - - self._data_dir = TemporaryDirectory() - self._data_files = {} - - if isinstance(self.monitor, str): - self.monitor_name = self.monitor - elif callable(self.monitor): - self.monitor_name = self.monitor.__name__ - - def on_fit_begin(self, estimator, *, X=None, y=None): - fname = Path(self._data_dir.name) / f"{estimator._computation_tree.uid}.csv" - with open(fname, "w") as file: - file.write(f"iteration,{self.monitor_name}_train,{self.monitor_name}_val\n") - self._data_files[estimator._computation_tree] = fname - - def on_fit_iter_end( - self, *, estimator, node, from_reconstruction_attributes, data, **kwargs - ): - if node.depth != node.computation_tree.depth: - return - - new_estimator = from_reconstruction_attributes - - X, y, X_val, y_val = data["X"], data["y"], data["X_val"], data["y_val"] - - if self.monitor == "objective_function": - new_monitored_train, *_ = new_estimator.objective_function( - X, y, normalize=True - ) - if X_val is not None: - new_monitored_val, *_ = new_estimator.objective_function( - X_val, y_val, normalize=True - ) - elif callable(self.monitor): - new_monitored_train = self.monitor(new_estimator, X, y) - if X_val is not None: - new_monitored_val = self.monitor(new_estimator, X_val, y_val) - elif self.monitor is None or isinstance(self.monitor, str): - from ..metrics import check_scoring - - scorer = check_scoring(new_estimator, self.monitor) - new_monitored_train = scorer(new_estimator, X, y) - if X_val is not None: - new_monitored_val = scorer(new_estimator, X_val, y_val) - - if X_val is None: - new_monitored_val = None - - with open(self._data_files[node.computation_tree], "a") as f: - f.write(f"{node.idx},{new_monitored_train},{new_monitored_val}\n") - - def on_fit_end(self): - pass - - # @property - # def data(self): - - def plot(self): - data_files = [p for p in Path(self._data_dir.name).iterdir() if p.is_file()] - for f in data_files: - data = pd.read_csv(f) - fig, ax = plt.subplots() - ax.plot( - data["iteration"], data[f"{self.monitor_name}_train"], label="train set" - ) - if self.on != "train_set": - ax.plot( - data["iteration"], - data[f"{self.monitor_name}_val"], - label="validation set", - ) - - ax.set_xlabel("Number of iterations") - ax.set_ylabel(self.monitor_name) - - ax.legend() - plt.show() diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index f8ed251add34a..a99c40f509aaf 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -1,91 +1,38 @@ # License: BSD 3 clause +# Authors: the scikit-learn developers import importlib -from threading import Event, Thread +from multiprocessing import Manager +from threading import Thread -from . import BaseCallback, load_computation_tree - - -def _check_backend_support(backend, caller_name): - """Raise ImportError with detailed error message if backend is not installed. - - Parameters - ---------- - backend : {"rich", "tqdm"} - The requested backend. - - caller_name : str - The name of the caller that requires the backend. - """ - try: - importlib.import_module(backend) # noqa - except ImportError as e: - raise ImportError(f"{caller_name} requires {backend} installed.") from e +from . import BaseCallback class ProgressBar(BaseCallback): - """Callback that displays progress bars for each iterative steps of the estimator - - Parameters - ---------- - backend: {"rich", "tqdm"}, default="rich" - The backend for the progress bars display. - - max_depth_show : int, default=None - The maximum nested level of progress bars to display. - - max_depth_keep : int, default=None - The maximum nested level of progress bars to keep displayed when they are - finished. - """ + """Callback that displays progress bars for each iterative steps of an estimator""" auto_propagate = True - def __init__(self, backend="rich", max_depth_show=None, max_depth_keep=None): - if backend not in ("rich", "tqdm"): - raise ValueError( - f"backend should be 'rich' or 'tqdm', got {self.backend} instead." - ) - _check_backend_support(backend, caller_name="Progressbar") - self.backend = backend - - if max_depth_show is not None and max_depth_show < 0: - raise ValueError("max_depth_show should be >= 0.") - self.max_depth_show = max_depth_show - - if max_depth_keep is not None and max_depth_keep < 0: - raise ValueError("max_depth_keep should be >= 0.") - self.max_depth_keep = max_depth_keep + def __init__(self): + try: + importlib.import_module("rich") # noqa + except ImportError as e: + raise ImportError("ProgressBar requires rich installed.") from e def on_fit_begin(self, estimator, X=None, y=None): - self._stop_event = Event() - - if self.backend == "rich": - self.progress_monitor = _RichProgressMonitor( - estimator=estimator, - event=self._stop_event, - max_depth_show=self.max_depth_show, - max_depth_keep=self.max_depth_keep, - ) - elif self.backend == "tqdm": - self.progress_monitor = _TqdmProgressMonitor( - estimator=estimator, - event=self._stop_event, - ) - + self._queue = Manager().Queue() + self.progress_monitor = _RichProgressMonitor(queue=self._queue) self.progress_monitor.start() def on_fit_iter_end(self, *, estimator, node, **kwargs): - pass + self._queue.put(node) def on_fit_end(self): - self._stop_event.set() + self._queue.put(None) self.progress_monitor.join() def __getstate__(self): state = self.__dict__.copy() - if "_stop_event" in state: - del state["_stop_event"] if "progress_monitor" in state: del state["progress_monitor"] return state @@ -96,7 +43,8 @@ def __getstate__(self): # insert tasks between existing tasks. try: - from rich.progress import Progress + from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn + from rich.style import Style class _Progress(Progress): def get_renderables(self): @@ -111,42 +59,20 @@ class _RichProgressMonitor(Thread): """Thread monitoring the progress of an estimator with rich based display The display is a list of nested rich tasks using rich.Progress. There is one for - each node in the computation tree of the estimator and in the computation trees of - estimators used in the estimator. + each non-leaf node in the computation tree of the estimator. Parameters ---------- - estimator : estimator instance - The estimator to monitor - - event : threading.Event instance - This thread will run until event is set. - - max_depth_show : int, default=None - The maximum nested level of progress bars to display. - - max_depth_keep : int, default=None - The maximum nested level of progress bars to keep displayed when they are - finished. + queue : multiprocessing.Manager.Queue instance + This thread will run until the queue is empty. """ - def __init__(self, estimator, event, max_depth_show=None, max_depth_keep=None): + def __init__(self, *, queue): Thread.__init__(self) - self.computation_tree = estimator._computation_tree - self.event = event - self.max_depth_show = max_depth_show - self.max_depth_keep = max_depth_keep - - # _computation_trees is a dict `directory: tuple` where - # - tuple[0] is the computation tree of the directory - # - tuple[1] is a dict `node.tree_status_idx: task_id` - self._computation_trees = {} + self.queue = queue def run(self): - from rich.progress import BarColumn, TextColumn, TimeRemainingColumn - from rich.style import Style - - with _Progress( + self.progress_ctx = _Progress( TextColumn("[progress.description]{task.description}"), BarColumn( complete_style=Style(color="dark_orange"), @@ -155,155 +81,104 @@ def run(self): TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeRemainingColumn(), auto_refresh=False, - ) as progress_ctx: - self._progress_ctx = progress_ctx - - while not self.event.wait(0.05): - self._recursive_update_tasks() - self._progress_ctx.refresh() - - self._recursive_update_tasks() - self._progress_ctx.refresh() - - def _recursive_update_tasks(self, this_dir=None, depth=0): - """Recursively loop through directories and init or update tasks - - Parameters - ---------- - this_dir : pathlib.Path instance - The directory to - - depth : int - The current depth - """ - if self.max_depth_show is not None and depth > self.max_depth_show: - # Fast exit if this dir is deeper than what we want to show anyway - return - - if this_dir is None: - this_dir = self.computation_tree.tree_dir - # _ordered_tasks holds the list of the tasks in the order we want them to - # be displayed. - self._progress_ctx._ordered_tasks = [] - - if this_dir not in self._computation_trees: - # First time we discover this directory -> store the computation tree - # If the computation tree is not readable yet, skip and try again next time - computation_tree = load_computation_tree(this_dir) - if computation_tree is None: - return - - self._computation_trees[this_dir] = (computation_tree, {}) - - computation_tree, task_ids = self._computation_trees[this_dir] - - for node in computation_tree.iterate(include_leaves=True): - if node.children: - # node is not a leaf, create or update its task - if node.tree_status_idx not in task_ids: - visible = True - if ( - self.max_depth_show is not None - and depth + node.depth > self.max_depth_show - ): - # If this node is deeper than what we want to show, we create - # the task anyway but make it not visible - visible = False - - task_ids[node.tree_status_idx] = self._progress_ctx.add_task( - self._format_task_description(node, computation_tree, depth), - total=node.max_iter, - visible=visible, - ) - - task_id = task_ids[node.tree_status_idx] - task = self._progress_ctx.tasks[task_id] - self._progress_ctx._ordered_tasks.append(task) - - parent_task = self._get_parent_task(node, computation_tree, task_ids) - if parent_task is not None and parent_task.finished: - # If the task of the parent node is finished, make this task - # finished. It can happen if some computations are stopped - # before reaching max_iter. - visible = True - if ( - self.max_depth_keep is not None - and depth + node.depth > self.max_depth_keep - ): - # If this node is deeper than what we want to keep in the output - # make it not visible - visible = False - self._progress_ctx.update( - task_id, completed=node.max_iter, visible=visible, refresh=False - ) - else: - node_progress = computation_tree.get_progress(node) - if node_progress != task.completed: - self._progress_ctx.update( - task_id, completed=node_progress, refresh=False - ) + ) + + # Holds the root of the tree of rich tasks (i.e. progress bars) that will be + # created dynamically as the computation tree of the estimator is traversed. + self.root_task = None + + with self.progress_ctx: + while node:= self.queue.get(): + + self._update_task_tree(node) + self._update_tasks() + self.progress_ctx.refresh() + + def _update_task_tree(self, node): + """Update the tree of tasks from a new node""" + curr_task, parent_task = None, None + + for curr_node in node.path: + + if curr_node.parent is None: # root node + if self.root_task is None: + self.root_task = TaskNode(curr_node, progress_ctx=self.progress_ctx) + curr_task = self.root_task + elif curr_node.idx not in parent_task.children: + curr_task = TaskNode(curr_node, progress_ctx=self.progress_ctx, parent=parent_task) + parent_task.children[curr_node.idx] = curr_task + else: # task already exists + curr_task = parent_task.children[curr_node.idx] + parent_task = curr_task + + # Mark the deepest task as finished (this is the one corresponding the + # computation node that we just get from the queue). + curr_task.finished = True + + def _update_tasks(self): + """Loop through the tasks in their display oder and update their progress""" + self.progress_ctx._ordered_tasks = [] + + for task_node in self.root_task: + task = self.progress_ctx.tasks[task_node.task_id] + + if task_node.parent is not None and task_node.parent.finished: + # If the parent task is finished, then mark the current task as + # finished. It can happen if an estimator doesn't reach its max number + # of iterations (e.g. early stopping). + completed = task.total else: - # node is a leaf, look for tasks of its sub computation tree before - # going to the next node - child_dir = computation_tree.get_child_computation_tree_dir(node) - # child_dir = this_dir / str(node.tree_status_idx) - if child_dir.exists(): - self._recursive_update_tasks( - child_dir, depth + computation_tree.depth - ) - - def _format_task_description(self, node, computation_tree, depth): - """Return a formatted description for the task of the node""" - colors = ["red", "green", "blue", "yellow"] + completed = sum(t.finished for t in task_node.children.values()) - indent = f"{' ' * (depth + node.depth)}" - style = f"[{colors[(depth + node.depth)%len(colors)]}]" + if completed == task.total: + task_node.finished = True - description = f"{computation_tree.estimator_name} - {node.description}" - if node.parent is None and computation_tree.parent_node is not None: - description = ( - f"{computation_tree.parent_node.description} " - f"{computation_tree.parent_node.idx} |" - f" {description}" + self.progress_ctx.update( + task_node.task_id, completed=completed, refresh=False ) - if node.parent is not None: - description = f"{description} {node.idx}" + self.progress_ctx._ordered_tasks.append(task) - return f"{style}{indent}{description}" - def _get_parent_task(self, node, computation_tree, task_ids): - """Get the task of the parent node""" - if node.parent is not None: - # node is not the root, return the task of its parent - task_id = task_ids[node.parent.tree_status_idx] - return self._progress_ctx.tasks[task_id] - if computation_tree.parent_node is not None: - # node is the root, return the task of the parent of the parent_node of - # its computation tree - parent_dir = computation_tree.parent_node.computation_tree.tree_dir - _, parent_tree_task_ids = self._computation_trees[parent_dir] - task_id = parent_tree_task_ids[ - computation_tree.parent_node.parent.tree_status_idx - ] - return self._progress_ctx._tasks[task_id] - return - - -class _TqdmProgressMonitor(Thread): - def __init__(self, estimator, event): - Thread.__init__(self) - self.computation_tree = estimator._computation_tree - self.event = event +class TaskNode: + """A node in the tree of rich tasks + + Parameters + ---------- + node : ComputationNode instance + The computation node this task corresponds to. + + progress_ctx : rich.Progress instance + The progress context to which this task belongs. + + parent : TaskNode instance + The parent of this task. + """ + def __init__(self, node, progress_ctx, parent=None): + self.node_idx = node.idx + self.parent = parent + self.children = {} + self.finished = False - def run(self): - from tqdm import tqdm + if node.max_iter is not None: + description = self._format_task_description(node) + self.task_id = progress_ctx.add_task(description, total=node.max_iter) - root = self.computation_tree.root + def _format_task_description(self, node): + """Return a formatted description for the task of the node""" + colors = ["red", "green", "blue", "yellow"] + + indent = f"{' ' * (node.depth)}" + style = f"[{colors[(node.depth)%len(colors)]}]" - with tqdm(total=len(root.children)) as pbar: - while not self.event.wait(0.05): - node_progress = self.computation_tree.get_progress(root) - if node_progress != pbar.total: - pbar.update(node_progress - pbar.n) + description = f"{node.estimator_name[0]} - {node.description[0]} #{node.idx}" + if len(node.estimator_name) == 2: + description += f" | {node.estimator_name[1]} - {node.description[1]}" + + return f"{style}{indent}{description}" - pbar.update(pbar.total - pbar.n) + def __iter__(self): + """Pre-order depth-first traversal, excluding leaves""" + if self.children: + yield self + for child in self.children.values(): + yield from child diff --git a/sklearn/callback/_snapshot.py b/sklearn/callback/_snapshot.py deleted file mode 100644 index cfb76c5ec1139..0000000000000 --- a/sklearn/callback/_snapshot.py +++ /dev/null @@ -1,65 +0,0 @@ -# License: BSD 3 clause - -import pickle -from datetime import datetime -from pathlib import Path - -from . import BaseCallback - - -class Snapshot(BaseCallback): - """Take regular snapshots of an estimator - - Parameters - ---------- - keep_last_n : int or None, default=1 - Only the last `keep_last_n` snapshots are kept on the disk. None means all - snapshots are kept. - - base_dir : str or pathlib.Path instance, default=None - The directory where the snapshots should be stored. If None, they are stored in - the current directory. - """ - - request_from_reconstruction_attributes = True - - def __init__(self, keep_last_n=1, base_dir=None): - self.keep_last_n = keep_last_n - if keep_last_n is not None and keep_last_n <= 0: - raise ValueError( - "keep_last_n must be a positive integer, got" - f" {self.keep_last_n} instead." - ) - - self.base_dir = Path("." if base_dir is None else base_dir) - - def on_fit_begin(self, estimator, X=None, y=None): - subdir = self._get_subdir(estimator._computation_tree) - subdir.mkdir() - - def on_fit_iter_end(self, *, estimator, node, **kwargs): - new_estimator = kwargs.get("from_reconstruction_attributes", None) - if new_estimator is None: - return - - subdir = self._get_subdir(node.computation_tree) - snapshot_filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}.pkl" - - with open(subdir / snapshot_filename, "wb") as f: - pickle.dump(new_estimator, f) - - if self.keep_last_n is not None: - for snapshot in sorted(subdir.iterdir())[: -self.keep_last_n]: - snapshot.unlink(missing_ok=True) - - def on_fit_end(self): - pass - - def _get_subdir(self, computation_tree): - """Return the sub directory containing the snapshots of the estimator""" - subdir = ( - self.base_dir - / f"snapshots_{computation_tree.estimator_name}_{str(computation_tree.uid)}" - ) - - return subdir diff --git a/sklearn/callback/_text_verbose.py b/sklearn/callback/_text_verbose.py deleted file mode 100644 index 9773f1c8a6f51..0000000000000 --- a/sklearn/callback/_text_verbose.py +++ /dev/null @@ -1,40 +0,0 @@ -# License: BSD 3 clause - -import time - -from . import BaseCallback - - -class TextVerbose(BaseCallback): - auto_propagate = True - request_stopping_criterion = True - - def on_fit_begin(self, estimator, X=None, y=None): - self._start_time = time.perf_counter() - - def on_fit_iter_end(self, *, node, **kwargs): - if node.depth != node.computation_tree.depth: - return - - stopping_criterion = kwargs.get("stopping_criterion", None) - tol = kwargs.get("tol", None) - - current_time = time.perf_counter() - self._start_time - - s = f"{node.description} {node.idx}" - parent = node.parent - while parent is not None and parent.parent is not None: - s = f"{parent.description} {parent.idx} - {s}" - parent = parent.parent - - msg = ( - f"[{parent.computation_tree.estimator_name}] {s} | time {current_time:.5f}s" - ) - - if stopping_criterion is not None and tol is not None: - msg += f" | stopping_criterion={stopping_criterion:.3E} | tol={tol:.3E}" - - print(msg) - - def on_fit_end(self): - pass diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index d867bdcfa77d2..e0839ffd94d46 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -1,4 +1,5 @@ -from functools import partial +# License: BSD 3 clause +# Authors: the scikit-learn developers from joblib.parallel import Parallel, delayed @@ -41,7 +42,7 @@ def __init__(self, max_iter=20): @_fit_context(prefer_skip_nested_validation=False) def fit(self, X, y): - root, X, y, X_val, y_val = self._eval_callbacks_on_fit_begin( + root = self._eval_callbacks_on_fit_begin( levels=[ {"descr": "fit", "max_iter": self.max_iter}, {"descr": "iter", "max_iter": None}, @@ -54,11 +55,6 @@ def fit(self, X, y): if _eval_callbacks_on_fit_iter_end( estimator=self, node=root.children[i], - from_reconstruction_attributes=partial( - self._from_reconstruction_attributes, - reconstruction_attributes=lambda: {"n_iter_": i + 1}, - ), - data={"X": X, "y": y, "X_val": X_val, "y_val": y_val}, ): break @@ -66,9 +62,6 @@ def fit(self, X, y): return self - def objective_function(self, X, y=None, normalize=False): - return 0, 0, 0 - class MetaEstimator(BaseEstimator): _parameter_constraints = {} @@ -84,7 +77,7 @@ def __init__( @_fit_context(prefer_skip_nested_validation=False) def fit(self, X, y): - root, X, y, _, _ = self._eval_callbacks_on_fit_begin( + root = self._eval_callbacks_on_fit_begin( levels=[ {"descr": "fit", "max_iter": self.n_outer}, {"descr": "outer", "max_iter": self.n_inner}, diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index 4bfdc27db7b52..17af6045df49c 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -1,6 +1,5 @@ # License: BSD 3 clause - -from pathlib import Path +# Authors: the scikit-learn developers import pytest @@ -51,7 +50,7 @@ def test_propagate_callbacks(): sub_estimator = Estimator() estimator._propagate_callbacks(sub_estimator, parent_node=None) - assert hasattr(sub_estimator, "_parent_ct_node") + assert hasattr(sub_estimator, "_parent_node") assert not_propagated_callback not in sub_estimator._callbacks assert propagated_callback in sub_estimator._callbacks @@ -83,7 +82,7 @@ def test_auto_propagated_callbacks(): def test_eval_callbacks_on_fit_begin(): - """Check that _eval_callbacks_on_fit_begin creates and dumps the computation tree""" + """Check that _eval_callbacks_on_fit_begin creates the computation tree""" estimator = Estimator()._set_callbacks(TestingCallback()) assert not hasattr(estimator, "_computation_tree") @@ -91,36 +90,5 @@ def test_eval_callbacks_on_fit_begin(): {"descr": "fit", "max_iter": 10}, {"descr": "iter", "max_iter": None}, ] - ct_root, *_ = estimator._eval_callbacks_on_fit_begin(levels=levels) + estimator._eval_callbacks_on_fit_begin(levels=levels) assert hasattr(estimator, "_computation_tree") - assert ct_root is estimator._computation_tree.root - - ct_pickle = Path(estimator._computation_tree.tree_dir) / "computation_tree.pkl" - assert ct_pickle.exists() - - -def test_callback_context_finalize(): - """Check that the folder containing the computation tree of the estimator is - deleted when there are no reference left to its callbacks. - """ - callback = TestingCallback() - - # estimator is not fitted, its computation tree is not built yet - est = Estimator()._set_callbacks(callbacks=callback) - assert not hasattr(est, "_computation_tree") - - # estimator is fitted, a folder has been created to hold its computation tree - est.fit(X=None, y=None) - assert hasattr(est, "_computation_tree") - tree_dir = est._computation_tree.tree_dir - assert tree_dir.is_dir() - - # there is no more reference to the estimator, but there is still a reference to the - # callback which might need to access the computation tree - del est - assert tree_dir.is_dir() - - # there is no more reference to the callback, the computation tree folder must be - # deleted - del callback - assert not tree_dir.is_dir() diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py deleted file mode 100644 index 5adb16a79bef9..0000000000000 --- a/sklearn/callback/tests/test_callbacks.py +++ /dev/null @@ -1,77 +0,0 @@ -# License: BSD 3 clause - -import pickle -import sys -import tempfile - -import numpy as np -import pytest - -from sklearn.callback import ( - EarlyStopping, - Monitoring, - ProgressBar, - Snapshot, - TextVerbose, -) -from sklearn.callback.tests._utils import Estimator, MetaEstimator - -X = np.zeros((100, 3)) -y = np.zeros(100, dtype=int) - - -@pytest.mark.parametrize( - "Callback", - [ - Monitoring, - EarlyStopping, - ProgressBar, - Snapshot, - TextVerbose, - ], -) -def test_callback_doesnt_hold_ref_to_estimator(Callback): - callback = Callback() - est = Estimator() - callback_refcount = sys.getrefcount(callback) - est_refcount = sys.getrefcount(est) - - est._set_callbacks(callbacks=callback) - est.fit(X, y) - # estimator has a ref on the callback but the callback has no ref to the estimator - assert sys.getrefcount(est) == est_refcount - assert sys.getrefcount(callback) == callback_refcount + 1 - - -@pytest.mark.parametrize("n_jobs", (1, 2)) -@pytest.mark.parametrize("prefer", ("threads", "processes")) -def test_snapshot_meta_estimator(n_jobs, prefer): - """Test for the Snapshot callback""" - estimator = Estimator(max_iter=20) - - with tempfile.TemporaryDirectory() as tmp_dir: - keep_last_n = 5 - callback = Snapshot(keep_last_n=keep_last_n, base_dir=tmp_dir) - estimator._set_callbacks(callback) - metaestimator = MetaEstimator( - estimator=estimator, n_outer=4, n_inner=3, n_jobs=n_jobs, prefer=prefer - ) - - metaestimator.fit(X, y) - - # There's a subdir of base_dir for each clone of estimator fitted in - # metaestimator. There are n_outer * n_inner such clones - snapshot_dirs = list(callback.base_dir.iterdir()) - assert len(snapshot_dirs) == metaestimator.n_outer * metaestimator.n_inner - - for snapshot_dir in snapshot_dirs: - snapshots = sorted(snapshot_dir.iterdir()) - assert len(snapshots) == keep_last_n - - for i, snapshot in enumerate(snapshots): - with open(snapshot, "rb") as f: - loaded_estimator = pickle.load(f) - - # We kept last 5 snapshots out of 20 iterations. - # This one is the 16 + i-th. - assert loaded_estimator.n_iter_ == 16 + i From 3ab3d7f0a58278302dc20be524e347c30dcd2199 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 25 Oct 2023 17:58:19 +0200 Subject: [PATCH 22/55] wip --- sklearn/base.py | 22 --- .../callback/tests/test_computation_tree.py | 90 ++++--------- sklearn/decomposition/_nmf.py | 125 +----------------- sklearn/decomposition/tests/test_nmf.py | 30 ----- .../gradient_boosting.py | 48 ------- sklearn/linear_model/_logistic.py | 61 +-------- sklearn/linear_model/_sag.py | 4 - sklearn/linear_model/_sag_fast.pyx.tp | 21 +-- sklearn/model_selection/_search.py | 76 +---------- sklearn/model_selection/_validation.py | 8 -- sklearn/pipeline.py | 29 +--- sklearn/utils/optimize.py | 15 +-- 12 files changed, 43 insertions(+), 486 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index bc3073527250e..293519ee3a8b4 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -768,28 +768,6 @@ def _eval_callbacks_on_fit_end(self): if not callback._is_propagated(estimator=self): callback.on_fit_end() - def _from_reconstruction_attributes(self, *, reconstruction_attributes): - """Return an as if fitted copy of this estimator - - Parameters - ---------- - reconstruction_attributes : callable - A callable that has no arguments and returns the necessary fitted attributes - to create a working fitted estimator from this instance. - - Using a callable allows lazy evaluation of the potentially costly - reconstruction attributes. - - Returns - ------- - fitted_estimator : estimator instance - The fitted copy of this estimator. - """ - new_estimator = copy.copy(self) # XXX deepcopy ? - for key, val in reconstruction_attributes().items(): - setattr(new_estimator, key, val) - return new_estimator - @property def _repr_html_(self): """HTML representation of estimator. diff --git a/sklearn/callback/tests/test_computation_tree.py b/sklearn/callback/tests/test_computation_tree.py index 5a6da95eea469..b3612c1df43d1 100644 --- a/sklearn/callback/tests/test_computation_tree.py +++ b/sklearn/callback/tests/test_computation_tree.py @@ -1,8 +1,9 @@ # License: BSD 3 clause +# Authors: the scikit-learn developers import numpy as np -from sklearn.callback import ComputationTree +from sklearn.callback import ComputationNode, build_computation_tree levels = [ {"descr": "level0", "max_iter": 3}, @@ -14,17 +15,15 @@ def test_computation_tree(): """Check the construction of the computation tree""" - computation_tree = ComputationTree(estimator_name="estimator", levels=levels) - assert computation_tree.estimator_name == "estimator" + computation_tree = build_computation_tree(estimator_name="estimator", levels=levels) + assert computation_tree.estimator_name == ("estimator",) + assert computation_tree.parent is None + assert computation_tree.idx == 0 - root = computation_tree.root - assert root.parent is None - assert root.idx == 0 + assert len(computation_tree.children) == computation_tree.max_iter == 3 + assert [node.idx for node in computation_tree.children] == list(range(3)) - assert len(root.children) == root.max_iter == 3 - assert [node.idx for node in root.children] == list(range(3)) - - for node1 in root.children: + for node1 in computation_tree.children: assert len(node1.children) == 5 assert [n.idx for n in node1.children] == list(range(5)) @@ -40,68 +39,31 @@ def test_n_nodes(): """Check that the number of node in a comutation tree corresponds to what we expect from the level descriptions """ - computation_tree = ComputationTree(estimator_name="", levels=levels) + computation_tree = build_computation_tree(estimator_name="", levels=levels) max_iter_per_level = [level["max_iter"] for level in levels[:-1]] expected_n_nodes = 1 + np.sum(np.cumprod(max_iter_per_level)) - assert computation_tree.n_nodes == expected_n_nodes - assert len(computation_tree.iterate(include_leaves=True)) == expected_n_nodes - assert computation_tree._tree_status.shape == (expected_n_nodes,) - + actual_n_nodes = sum(1 for _ in computation_tree) -def test_tree_status_idx(): - """Check that each node has a unique index in the _tree_status array and that their - order corresponds to the order given by a depth first search. - """ - computation_tree = ComputationTree(estimator_name="", levels=levels) - - indexes = [ - node.tree_status_idx for node in computation_tree.iterate(include_leaves=True) - ] - assert indexes == list(range(computation_tree.n_nodes)) + assert actual_n_nodes == expected_n_nodes -def test_get_ancestors(): - """Check the ancestor search and its propagation to parent trees""" - parent_levels = [ - {"descr": "parent_level0", "max_iter": 2}, - {"descr": "parent_level1", "max_iter": 4}, - {"descr": "parent_level2", "max_iter": None}, - ] +def test_path(): + """Check that the path from the root to a node is correct""" + computation_tree = build_computation_tree(estimator_name="", levels=levels) - parent_computation_tree = ComputationTree( - estimator_name="parent_estimator", levels=parent_levels - ) - parent_node = parent_computation_tree.root.children[0].children[2] - # indices of each node (in its parent children) in this chain are 0, 0, 2. - # (root is always 0). - expected_parent_indices = [2, 0, 0] - - computation_tree = ComputationTree( - estimator_name="estimator", levels=levels, parent_node=parent_node - ) - node = computation_tree.root.children[1].children[3].children[5] - expected_node_indices = [5, 3, 1, 0] - - ancestors = node.get_ancestors(include_ancestor_trees=False) - assert ancestors == [ + assert computation_tree.path == [computation_tree] + + node = computation_tree.children[1].children[2].children[3] + expected_path = [ + computation_tree, + computation_tree.children[1], + computation_tree.children[1].children[2], node, - node.parent, - node.parent.parent, - node.parent.parent.parent, ] - assert [n.idx for n in ancestors] == expected_node_indices - assert computation_tree.root in ancestors + assert node.path == expected_path + + assert all(node.depth == i for i, node in enumerate(expected_path)) + - ancestors = node.get_ancestors(include_ancestor_trees=True) - assert ancestors == [ - node, - node.parent, - node.parent.parent, - node.parent.parent.parent, - parent_node, - parent_node.parent, - parent_node.parent.parent, - ] - assert [n.idx for n in ancestors] == expected_node_indices + expected_parent_indices diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index 46f8645a1f06d..db46540e26708 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -10,7 +10,6 @@ import time import warnings from abc import ABC -from functools import partial from math import sqrt from numbers import Integral, Real @@ -25,7 +24,6 @@ TransformerMixin, _fit_context, ) -from ..callback._base import _eval_callbacks_on_fit_iter_end from ..exceptions import ConvergenceWarning from ..utils import check_array, check_random_state, gen_batches, metadata_routing from ..utils._param_validation import ( @@ -410,7 +408,6 @@ def _update_coordinate_descent(X, W, Ht, l1_reg, l2_reg, shuffle, random_state): def _fit_coordinate_descent( X, - X_val, W, H, tol=1e-4, @@ -423,8 +420,6 @@ def _fit_coordinate_descent( verbose=0, shuffle=False, random_state=None, - estimator=None, - parent_node=None, ): """Compute Non-negative Matrix Factorization (NMF) with Coordinate Descent @@ -437,9 +432,6 @@ def _fit_coordinate_descent( X : array-like of shape (n_samples, n_features) Constant matrix. - X_val : array-like of shape (n_samples_val, n_features) - Constant validation matrix. - W : array-like of shape (n_samples, n_components) Initial guess for the solution. @@ -480,12 +472,6 @@ def _fit_coordinate_descent( results across multiple function calls. See :term:`Glossary `. - estimator : estimator instance, default=None - The estimator calling this function. Used by callbacks. - - parent_node : ComputationNode instance, default=None - The parent node of the current node. Used by callbacks. - Returns ------- W : ndarray of shape (n_samples, n_components) @@ -507,8 +493,6 @@ def _fit_coordinate_descent( # so W and Ht are both in C order in memory Ht = check_array(H.T, order="C") X = check_array(X, accept_sparse="csr") - if X_val is not None: - X_val = check_array(X_val, accept_sparse="csr") rng = check_random_state(random_state) @@ -531,25 +515,6 @@ def _fit_coordinate_descent( if violation_init == 0: break - if _eval_callbacks_on_fit_iter_end( - estimator=estimator, - node=parent_node.children[n_iter - 1] if parent_node is not None else None, - stopping_criterion=lambda: violation / violation_init, - tol=tol, - fit_state={"H": Ht.T, "W": W}, - from_reconstruction_attributes=partial( - estimator._from_reconstruction_attributes, - reconstruction_attributes=lambda: { - "n_components_": Ht.T.shape[0], - "components_": H, - "n_iter_": n_iter, - "reconstruction_err_": _beta_divergence(X, W, Ht.T, 2, True), - }, - ), - data={"X": X, "y": None, "X_val": X_val, "y_val": None}, - ): - break - if verbose: print("violation:", violation / violation_init) @@ -768,7 +733,6 @@ def _multiplicative_update_h( def _fit_multiplicative_update( X, - X_val, W, H, beta_loss="frobenius", @@ -780,8 +744,6 @@ def _fit_multiplicative_update( l2_reg_H=0, update_H=True, verbose=0, - estimator=None, - parent_node=None, ): """Compute Non-negative Matrix Factorization with Multiplicative Update. @@ -794,9 +756,6 @@ def _fit_multiplicative_update( X : array-like of shape (n_samples, n_features) Constant input matrix. - X_val : array-like of shape (n_samples_val, n_features) - Constant validation matrix. - W : array-like of shape (n_samples, n_components) Initial guess for the solution. @@ -837,12 +796,6 @@ def _fit_multiplicative_update( verbose : int, default=0 The verbosity level. - estimator : estimator instance, default=None - The estimator calling this function. Used by callbacks. - - parent_node : ComputationNode instance, default=None - The parent node of the current node. Used by callbacks. - Returns ------- W : ndarray of shape (n_samples, n_components) @@ -918,31 +871,6 @@ def _fit_multiplicative_update( if beta_loss <= 1: H[H < np.finfo(np.float64).eps] = 0.0 - if _eval_callbacks_on_fit_iter_end( - estimator=estimator, - node=parent_node.children[n_iter - 1] if parent_node is not None else None, - stopping_criterion=lambda: ( - ( - previous_error - - _beta_divergence(X, W, H, beta_loss, square_root=True) - ) - / error_at_init - ), - tol=tol, - fit_state={"H": H, "W": W}, - from_reconstruction_attributes=partial( - estimator._from_reconstruction_attributes, - reconstruction_attributes=lambda: { - "n_components_": H.shape[0], - "components_": H, - "n_iter_": n_iter, - "reconstruction_err_": _beta_divergence(X, W, H, beta_loss, True), - }, - ), - data={"X": X, "y": None, "X_val": X_val, "y_val": None}, - ): - break - # test convergence criterion every 10 iterations if tol > 0 and n_iter % 10 == 0: error = _beta_divergence(X, W, H, beta_loss, square_root=True) @@ -1421,28 +1349,6 @@ def inverse_transform(self, Xt=None, W=None): check_is_fitted(self) return Xt @ self.components_ - def objective_function(self, X, y=None, *, W=None, H=None, normalize=False): - if W is None: - W = self.transform(X) - if H is None: - H = self.components_ - - data_fit = _beta_divergence(X, W, H, self._beta_loss) - - l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = self._compute_regularization(X) - penalization = ( - l1_reg_W * W.sum() - + l1_reg_H * H.sum() - + l2_reg_W * (W**2).sum() - + l2_reg_H * (H**2).sum() - ) - - if normalize: - data_fit /= X.shape[0] - penalization /= X.shape[0] - - return data_fit + penalization, data_fit, penalization - @property def _n_features_out(self): """Number of transformed output features.""" @@ -1756,28 +1662,20 @@ def fit_transform(self, X, y=None, W=None, H=None): X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32] ) - root, X, _, X_val, _ = self._eval_callbacks_on_fit_begin( - levels=[ - {"descr": "fit", "max_iter": self.max_iter}, - {"descr": "iter", "max_iter": None}, - ], - X=X, - ) - - W, H, n_iter = self._fit_transform(X, X_val, W=W, H=H, parent_node=root) + with config_context(assume_finite=True): + W, H, n_iter = self._fit_transform(X, W=W, H=H) self.reconstruction_err_ = _beta_divergence( X, W, H, self._beta_loss, square_root=True ) + self.n_components_ = H.shape[0] self.components_ = H self.n_iter_ = n_iter return W - def _fit_transform( - self, X, X_val=None, W=None, H=None, update_H=True, parent_node=None - ): + def _fit_transform(self, X, y=None, W=None, H=None, update_H=True): """Learn a NMF model for the data X and returns the transformed data. Parameters @@ -1837,7 +1735,6 @@ def _fit_transform( if self.solver == "cd": W, H, n_iter = _fit_coordinate_descent( X, - X_val, W, H, self.tol, @@ -1850,13 +1747,10 @@ def _fit_transform( verbose=self.verbose, shuffle=self.shuffle, random_state=self.random_state, - estimator=self, - parent_node=parent_node, ) elif self.solver == "mu": W, H, n_iter, *_ = _fit_multiplicative_update( X, - X_val, W, H, self._beta_loss, @@ -1866,10 +1760,8 @@ def _fit_transform( l1_reg_H, l2_reg_W, l2_reg_H, - update_H=update_H, - verbose=self.verbose, - estimator=self, - parent_node=parent_node, + update_H, + self.verbose, ) else: raise ValueError("Invalid solver parameter '%s'." % self.solver) @@ -2549,8 +2441,3 @@ def partial_fit(self, X, y=None, W=None, H=None): self.n_steps_ += 1 return self - - @property - def _n_features_out(self): - """Number of transformed output features.""" - return self.components_.shape[0] diff --git a/sklearn/decomposition/tests/test_nmf.py b/sklearn/decomposition/tests/test_nmf.py index 573af0a5e7258..ce13d9358b5d8 100644 --- a/sklearn/decomposition/tests/test_nmf.py +++ b/sklearn/decomposition/tests/test_nmf.py @@ -1,7 +1,5 @@ -import pickle import re import sys -import tempfile import warnings from io import StringIO @@ -10,7 +8,6 @@ from scipy import linalg from sklearn.base import clone -from sklearn.callback import Snapshot from sklearn.decomposition import NMF, MiniBatchNMF, non_negative_factorization from sklearn.decomposition import _nmf as nmf # For testing internals from sklearn.exceptions import ConvergenceWarning @@ -1063,30 +1060,3 @@ def test_nmf_custom_init_shape_error(): with pytest.raises(ValueError, match="Array with wrong second dimension passed"): nmf.fit(X, H=H, W=rng.random_sample((6, 3))) - - -@pytest.mark.parametrize("solver, beta_loss", [("mu", 0), ("mu", 2), ("cd", 2)]) -def test_nmf_callback_reconstruction_attributes(solver, beta_loss): - # Check that the reconstruction attributes passed to the callback allow to make - # a new estimator as if the fit ended when the callback is called. - X = np.random.RandomState(0).random_sample((100, 20)) - - nmf = NMF(n_components=5, solver=solver, beta_loss=beta_loss, random_state=0) - nmf.fit(X) - - with tempfile.TemporaryDirectory() as tmp_dir: - callback = Snapshot(base_dir=tmp_dir) - nmf._set_callbacks(callback) - nmf.fit(X) - - # load model from last iteration - snapshot_dir = next(callback.base_dir.iterdir()) - snapshot = sorted(snapshot_dir.iterdir())[-1] - with open(snapshot, "rb") as f: - loaded_nmf = pickle.load(f) - - # The model saved during the last iteration is the same as the original model - assert nmf.n_iter_ == loaded_nmf.n_iter_ - assert_allclose(nmf.components_, loaded_nmf.components_) - assert_allclose(nmf.reconstruction_err_, loaded_nmf.reconstruction_err_) - assert_allclose(nmf.transform(X), loaded_nmf.transform(X)) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 27a2030ca2d08..c3af930654b73 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -25,7 +25,6 @@ _fit_context, is_classifier, ) -from ...callback._base import _eval_callbacks_on_fit_iter_end from ...metrics import check_scoring from ...model_selection import train_test_split from ...preprocessing import LabelEncoder @@ -476,19 +475,6 @@ def fit(self, X, y, sample_weight=None): X_train, y_train, sample_weight_train = X, y, sample_weight X_val = y_val = sample_weight_val = None - begin_at_stage = ( - 0 if not (self._is_fitted() and self.warm_start) else self.n_iter_ - ) - - root, X_train, y_train, X_val, y_val = self._eval_callbacks_on_fit_begin( - levels=[ - {"descr": "fit", "max_iter": self.max_iter - begin_at_stage}, - {"descr": "iter", "max_iter": None}, - ], - X=X, - y=y, - ) - # Bin the data # For ease of use of the API, the user-facing GBDT classes accept the # parameter max_bins, which doesn't take into account the bin for @@ -783,26 +769,6 @@ def fit(self, X, y, sample_weight=None): if should_early_stop: break - if _eval_callbacks_on_fit_iter_end( - estimator=self, - node=root.children[iteration - begin_at_stage], - fit_state={}, - from_reconstruction_attributes=partial( - self._from_reconstruction_attributes, - reconstruction_attributes=lambda: { - "train_score_": np.asarray(self.train_score_), - "validation_score_": np.asarray(self.validation_score_), - }, - ), - data={ - "X": X_binned_train, - "y": y_train, - "X_val": X_binned_val, - "y_val": y_val, - }, - ): - break - if self.verbose: duration = time() - fit_start_time n_total_leaves = sum( @@ -841,22 +807,8 @@ def fit(self, X, y, sample_weight=None): self.train_score_ = np.asarray(self.train_score_) self.validation_score_ = np.asarray(self.validation_score_) del self._in_fit # hard delete so we're sure it can't be used anymore - return self - def objective_function(self, X, y, *, raw_predictions=None, normalize=False): - if raw_predictions is None: - raw_predictions = self._raw_predict(X) - - loss = self._loss( - y_true=y, - raw_prediction=raw_predictions, - ) - if normalize: - loss /= raw_predictions.shape[0] - - return loss, loss, 0 - def _is_fitted(self): return len(getattr(self, "_predictors", [])) > 0 diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 8e3886d38b781..e6ac6ff087945 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -22,7 +22,6 @@ from .._loss.loss import HalfBinomialLoss, HalfMultinomialLoss from ..base import _fit_context -from ..callback._base import _eval_callbacks_on_fit_iter_end from ..metrics import get_scorer from ..model_selection import check_cv from ..preprocessing import LabelBinarizer, LabelEncoder @@ -128,8 +127,6 @@ def _logistic_regression_path( sample_weight=None, l1_ratio=None, n_threads=1, - estimator=None, - parent_node=None, ): """Compute a Logistic Regression model for a list of regularization parameters. @@ -455,19 +452,11 @@ def _logistic_regression_path( coefs = list() n_iter = np.zeros(len(Cs), dtype=np.int32) for i, C in enumerate(Cs): - # Distinguish between LogReg and LogRegCV - node = ( - None - if parent_node is None - else parent_node if len(Cs) == 1 else parent_node.children - ) - if solver == "lbfgs": l2_reg_strength = 1.0 / (C * sw_sum) iprint = [-1, 50, 1, 100, 101][ np.searchsorted(np.array([0, 1, 2, 3]), verbose) ] - children = iter(node.children) if node is not None else None opt_res = optimize.minimize( func, w0, @@ -481,10 +470,6 @@ def _logistic_regression_path( "gtol": tol, "ftol": 64 * np.finfo(float).eps, }, - callback=lambda xk: _eval_callbacks_on_fit_iter_end( - estimator=estimator, - node=next(children) if children is not None else None, - ), ) n_iter_i = _check_optimize_result( solver, @@ -497,15 +482,7 @@ def _logistic_regression_path( l2_reg_strength = 1.0 / (C * sw_sum) args = (X, target, sample_weight, l2_reg_strength, n_threads) w0, n_iter_i = _newton_cg( - hess, - func, - grad, - w0, - args=args, - maxiter=max_iter, - tol=tol, - estimator=estimator, - parent_node=node, + hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol ) elif solver == "newton-cholesky": l2_reg_strength = 1.0 / (C * sw_sum) @@ -580,8 +557,6 @@ def _logistic_regression_path( max_squared_sum, warm_start_sag, is_saga=(solver == "saga"), - estimator=estimator, - parent_node=node, ) else: @@ -602,20 +577,8 @@ def _logistic_regression_path( else: coefs.append(w0.copy()) - if len(Cs) > 1: - _eval_callbacks_on_fit_iter_end( - estimator=estimator, - node=node, - ) - n_iter[i] = n_iter_i - if multi_class == "ovr": - _eval_callbacks_on_fit_iter_end( - estimator=estimator, - node=parent_node, - ) - return np.array(coefs), np.array(Cs), n_iter @@ -1333,24 +1296,6 @@ def fit(self, X, y, sample_weight=None): if warm_start_coef is None: warm_start_coef = [None] * n_classes - if len(classes_) == 1: - levels = [ - {"descr": "fit", "max_iter": self.max_iter}, - {"descr": "iter", "max_iter": None}, - ] - else: - levels = [ - {"descr": "fit", "max_iter": len(classes_)}, - {"descr": "class", "max_iter": self.max_iter}, - {"descr": "iter", "max_iter": None}, - ] - root, X, y, X_val, y_val = self._eval_callbacks_on_fit_begin( - levels=levels, X=X, y=y - ) - - # distinguish between multinomial and ovr - nodes = [root] if len(classes_) == 1 else root.children - path_func = delayed(_logistic_regression_path) # The SAG solver releases the GIL so it's more efficient to use @@ -1395,10 +1340,8 @@ def fit(self, X, y, sample_weight=None): max_squared_sum=max_squared_sum, sample_weight=sample_weight, n_threads=n_threads, - estimator=self, - parent_node=node, ) - for class_, warm_start_coef_, node in zip(classes_, warm_start_coef, nodes) + for class_, warm_start_coef_ in zip(classes_, warm_start_coef) ) fold_coefs_, _, n_iter_ = zip(*fold_coefs_) diff --git a/sklearn/linear_model/_sag.py b/sklearn/linear_model/_sag.py index 88f8f9d50bf0c..2626955ec2a7f 100644 --- a/sklearn/linear_model/_sag.py +++ b/sklearn/linear_model/_sag.py @@ -100,8 +100,6 @@ def sag_solver( max_squared_sum=None, warm_start_mem=None, is_saga=False, - estimator=None, - parent_node=None, ): """SAG solver for Ridge and LogisticRegression. @@ -346,8 +344,6 @@ def sag_solver( intercept_decay, is_saga, verbose, - estimator=estimator, - parent_node=parent_node, ) if n_iter_ == max_iter: diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index a342d95fb9dbb..97bf3020d6602 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -34,7 +34,6 @@ from ._sgd_fast cimport LossFunction from ._sgd_fast cimport Log, SquaredLoss from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64 -from ..callback._base import _eval_callbacks_on_fit_iter_end from libc.stdio cimport printf @@ -218,9 +217,7 @@ def sag{{name_suffix}}( {{c_type}}[::1] intercept_sum_gradient_init, double intercept_decay, bint saga, - bint verbose, - estimator, - parent_node, + bint verbose ): """Stochastic Average Gradient (SAG) and SAGA solvers. @@ -541,22 +538,6 @@ def sag{{name_suffix}}( max_weight = fmax{{name_suffix}}(max_weight, fabs(weights[idx])) max_change = fmax{{name_suffix}}(max_change, fabs(weights[idx] - previous_weights[idx])) previous_weights[idx] = weights[idx] - - with gil: - if _eval_callbacks_on_fit_iter_end( - estimator=estimator, - node=parent_node.children[n_iter] if parent_node is not None else None, - stopping_criterion = ( - lambda: max_change / max_weight - if max_weight != 0 - else 0 - if max_weight == max_change == 0 - else np.inf - ), - tol=tol, - ): - break - if ((max_weight != 0 and max_change / max_weight <= tol) or max_weight == 0 and max_change == 0): if verbose: diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 7b140dc732464..9de03c2c663ec 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -18,7 +18,7 @@ from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial, reduce -from itertools import cycle, product +from itertools import product import numpy as np from numpy.ma import MaskedArray @@ -900,9 +900,7 @@ def fit(self, X, y=None, **params): all_out = [] all_more_results = defaultdict(list) - def evaluate_candidates( - candidate_params, cv=None, more_results=None, parent_node=None - ): + def evaluate_candidates(candidate_params, cv=None, more_results=None): cv = cv or cv_orig candidate_params = list(candidate_params) n_candidates = len(candidate_params) @@ -915,11 +913,6 @@ def evaluate_candidates( ) ) - if parent_node is not None: - nodes = parent_node.children - else: - nodes = cycle([None]) - out = parallel( delayed(_fit_and_score)( clone(base_estimator), @@ -931,18 +924,10 @@ def evaluate_candidates( split_progress=(split_idx, n_splits), candidate_progress=(cand_idx, n_candidates), **fit_and_score_kwargs, - caller=self, - node=node, ) - for ( - (cand_idx, parameters), - (split_idx, (train, test)), - ), node in zip( - product( - enumerate(candidate_params), - enumerate(cv.split(X, y, **routed_params.splitter.split)), - ), - nodes, + for (cand_idx, parameters), (split_idx, (train, test)) in product( + enumerate(candidate_params), + enumerate(cv.split(X, y, **routed_params.splitter.split)), ) ) @@ -1537,58 +1522,9 @@ def __init__( ) self.param_grid = param_grid - def fit(self, X, y=None, *, groups=None, **fit_params): - """Run fit with all sets of parameters. - - Parameters - ---------- - - X : array-like of shape (n_samples, n_features) - Training vector, where `n_samples` is the number of samples and - `n_features` is the number of features. - - y : array-like of shape (n_samples, n_output) or (n_samples,), default=None - Target relative to X for classification or regression; - None for unsupervised learning. - - groups : array-like of shape (n_samples,), default=None - Group labels for the samples used while splitting the dataset into - train/test set. Only used in conjunction with a "Group" :term:`cv` - instance (e.g., :class:`~sklearn.model_selection.GroupKFold`). - - **fit_params : dict of str -> object - Parameters passed to the `fit` method of the estimator. - - If a fit parameter is an array-like whose length is equal to - `num_samples` then it will be split across CV groups along with `X` - and `y`. For example, the :term:`sample_weight` parameter is split - because `len(sample_weights) = len(X)`. - - Returns - ------- - self : object - Instance of fitted estimator. - """ - self._param_grid = ParameterGrid(self.param_grid) - - self._checked_cv_orig = check_cv( - self.cv, y, classifier=is_classifier(self.estimator) - ) - n_splits = self._checked_cv_orig.get_n_splits(X, y, groups) - - self._eval_callbacks_on_fit_begin( - levels=[ - {"descr": "fit", "max_iter": len(self._param_grid) * n_splits}, - {"descr": "param - fold", "max_iter": None}, - ], - X=X, - y=y, - ) - super().fit(X, y=y, groups=groups, **fit_params) - def _run_search(self, evaluate_candidates): """Search all candidates in param_grid""" - evaluate_candidates(self._param_grid, parent_node=self._computation_tree.root) + evaluate_candidates(ParameterGrid(self.param_grid)) class RandomizedSearchCV(BaseSearchCV): diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index e2e23500a00a2..f3c8735043408 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -25,7 +25,6 @@ from joblib import logger from ..base import clone, is_classifier -from ..callback._base import _eval_callbacks_on_fit_iter_end from ..exceptions import FitFailedWarning, UnsetMetadataPassedError from ..metrics import check_scoring, get_scorer_names from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer @@ -749,8 +748,6 @@ def _fit_and_score( split_progress=None, candidate_progress=None, error_score=np.nan, - caller=None, - node=None, ): """Fit estimator and compute scores for a given dataset split. @@ -880,9 +877,6 @@ def _fit_and_score( # ref: https://github.com/scikit-learn/scikit-learn/pull/26786 estimator = estimator.set_params(**clone(parameters, safe=False)) - if caller is not None: - caller._propagate_callbacks(estimator, parent_node=node) - start_time = time.time() X_train, y_train = _safe_split(estimator, X, y, train) @@ -949,8 +943,6 @@ def _fit_and_score( end_msg += result_msg print(end_msg) - _eval_callbacks_on_fit_iter_end(estimator=caller, node=node) - result["test_scores"] = test_scores if return_train_score: result["train_scores"] = train_scores diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 6169084f5ee6d..713c7d6116f53 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -16,14 +16,9 @@ from scipy import sparse from .base import TransformerMixin, _fit_context, clone -from .callback._base import _eval_callbacks_on_fit_iter_end from .exceptions import NotFittedError from .preprocessing import FunctionTransformer -from .utils import ( - Bunch, - _print_elapsed_time, - check_pandas_support, -) +from .utils import Bunch, _print_elapsed_time, check_pandas_support from .utils._estimator_html_repr import _VisualBlock from .utils._metadata_requests import METHODS from .utils._param_validation import HasMethods, Hidden @@ -387,23 +382,12 @@ def _fit(self, X, y=None, routed_params=None): # Setup the memory memory = check_memory(self.memory) - root, *_ = self._eval_callbacks_on_fit_begin( - levels=[ - {"descr": "fit", "max_iter": len(self.steps)}, - {"descr": "step", "max_iter": None}, - ], - X=X, - y=y, - ) - fit_transform_one_cached = memory.cache(_fit_transform_one) for step_idx, name, transformer in self._iter( with_final=False, filter_passthrough=False ): - node = root.children[step_idx] if transformer is None or transformer == "passthrough": - _eval_callbacks_on_fit_iter_end(estimator=self, node=node) with _print_elapsed_time("Pipeline", self._log_message(step_idx)): continue @@ -413,9 +397,6 @@ def _fit(self, X, y=None, routed_params=None): cloned_transformer = transformer else: cloned_transformer = clone(transformer) - - self._propagate_callbacks(cloned_transformer, parent_node=node) - # Fit or load from cache the current transformer X, fitted_transformer = fit_transform_one_cached( cloned_transformer, @@ -430,9 +411,6 @@ def _fit(self, X, y=None, routed_params=None): # transformer. This is necessary when loading the transformer # from the cache. self.steps[step_idx] = (name, fitted_transformer) - - _eval_callbacks_on_fit_iter_end(estimator=self, node=node) - return X @_fit_context( @@ -486,14 +464,9 @@ def fit(self, X, y=None, **params): Xt = self._fit(X, y, routed_params) with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if self._final_estimator != "passthrough": - node = self._computation_tree.root.children[-1] - self._propagate_callbacks(self._final_estimator, parent_node=node) - last_step_params = routed_params[self.steps[-1][0]] self._final_estimator.fit(Xt, y, **last_step_params["fit"]) - _eval_callbacks_on_fit_iter_end(estimator=self, node=node) - return self def _can_fit_transform(self): diff --git a/sklearn/utils/optimize.py b/sklearn/utils/optimize.py index 807c78810df37..024b0bcaf95ee 100644 --- a/sklearn/utils/optimize.py +++ b/sklearn/utils/optimize.py @@ -18,7 +18,6 @@ import numpy as np import scipy -from ..callback._base import _eval_callbacks_on_fit_iter_end from ..exceptions import ConvergenceWarning from .fixes import line_search_wolfe1, line_search_wolfe2 @@ -157,8 +156,6 @@ def _newton_cg( maxinner=200, line_search=True, warn=True, - estimator=None, - parent_node=None, ): """ Minimization of scalar function of one or more variables using the @@ -220,17 +217,7 @@ def _newton_cg( fgrad, fhess_p = grad_hess(xk, *args) absgrad = np.abs(fgrad) - max_absgrad = np.max(absgrad) - - if _eval_callbacks_on_fit_iter_end( - estimator=estimator, - node=None if parent_node is None else parent_node.children[k], - stopping_criterion=lambda: max_absgrad, - tol=tol, - ): - break - - if max_absgrad <= tol: + if np.max(absgrad) <= tol: break maggrad = np.sum(absgrad) From 39c04cc499589b338ddca6dbd98f01b509859ee5 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 25 Oct 2023 18:32:58 +0200 Subject: [PATCH 23/55] wip --- sklearn/base.py | 1 - sklearn/callback/_base.py | 8 ++-- sklearn/callback/_computation_tree.py | 42 +++++++++++++++---- sklearn/callback/_progressbar.py | 15 +++---- sklearn/callback/tests/__init__.py | 0 .../callback/tests/test_computation_tree.py | 6 +-- 6 files changed, 49 insertions(+), 23 deletions(-) create mode 100644 sklearn/callback/tests/__init__.py diff --git a/sklearn/base.py b/sklearn/base.py index 293519ee3a8b4..64637fbad9da4 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -6,7 +6,6 @@ import copy import functools import inspect -import pickle import platform import re import warnings diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 0bc4023f2a266..aa99303ef220d 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -52,13 +52,13 @@ def on_fit_begin(self, estimator, *, X=None, y=None): Parameters ---------- - estimator: estimator instance + estimator : estimator instance The estimator the callback is set on. - X: ndarray or sparse matrix, default=None + X : ndarray or sparse matrix, default=None The training data. - y: ndarray or sparse matrix, default=None + y : ndarray or sparse matrix, default=None The target. """ pass @@ -81,7 +81,7 @@ def on_fit_iter_end(self, estimator, node, **kwargs): node : ComputationNode instance The caller computation node. - kwargs : dict + **kwargs : dict arguments passed to the callback. Possible keys are - stopping_criterion: float diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py index 0add054ea6180..fccea345134f1 100644 --- a/sklearn/callback/_computation_tree.py +++ b/sklearn/callback/_computation_tree.py @@ -10,14 +10,17 @@ class ComputationNode: estimator_name : str The name of the estimator this computation node belongs to. - parent : ComputationNode instance, default=None - The parent node. None means this is the root. + description : str, default=None + A description of this computation node. None means it's a leaf. max_iter : int, default=None The number of its children. None means it's a leaf. - description : str, default=None - A description of this computation node. None means it's a leaf. + idx : int, default=0 + The index of this node among its siblings. + + parent : ComputationNode instance, default=None + The parent node. None means this is the root. Attributes ---------- @@ -51,7 +54,7 @@ def __init__( def depth(self): """The depth of this node in the computation tree""" return 0 if self.parent is None else self.parent.depth + 1 - + @property def path(self): """List of all the nodes in the path from the root to this node""" @@ -65,7 +68,32 @@ def __iter__(self): def build_computation_tree(estimator_name, levels, parent=None, idx=0): - """Build the computation tree from the description of the levels""" + """Build the computation tree from the description of the levels. + + Parameters + ---------- + estimator_name : str + The name of the estimator this computation tree belongs to. + + levels : list of dict + The description of the levels of the computation tree. Each dict must have + the following keys: + - descr: str + A description of the level + - max_iter: int or None + The number of its children. None means it's a leaf. + + parent : ComputationNode instance, default=None + The parent node. None means this is the root. + + idx : int, default=0 + The index of this node among its siblings. + + Returns + ------- + computation_tree : ComputationNode instance + The root of the computation tree. + """ this_level = levels[0] node = ComputationNode( @@ -77,7 +105,7 @@ def build_computation_tree(estimator_name, levels, parent=None, idx=0): ) if parent is not None and parent.max_iter is None: - # parent node is a leaf of the computation tree of an outer estimator. It means + # parent node is a leaf of the computation tree of an outer estimator. It means # that this node is the root of the computation tree of this estimator. They # both correspond the same computation step, so we merge both nodes. node.description = parent.description + node.description diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index a99c40f509aaf..6440b1856d695 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -43,7 +43,7 @@ def __getstate__(self): # insert tasks between existing tasks. try: - from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn + from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn from rich.style import Style class _Progress(Progress): @@ -88,8 +88,7 @@ def run(self): self.root_task = None with self.progress_ctx: - while node:= self.queue.get(): - + while node := self.queue.get(): self._update_task_tree(node) self._update_tasks() self.progress_ctx.refresh() @@ -99,13 +98,14 @@ def _update_task_tree(self, node): curr_task, parent_task = None, None for curr_node in node.path: - if curr_node.parent is None: # root node if self.root_task is None: self.root_task = TaskNode(curr_node, progress_ctx=self.progress_ctx) curr_task = self.root_task elif curr_node.idx not in parent_task.children: - curr_task = TaskNode(curr_node, progress_ctx=self.progress_ctx, parent=parent_task) + curr_task = TaskNode( + curr_node, progress_ctx=self.progress_ctx, parent=parent_task + ) parent_task.children[curr_node.idx] = curr_task else: # task already exists curr_task = parent_task.children[curr_node.idx] @@ -141,18 +141,19 @@ def _update_tasks(self): class TaskNode: """A node in the tree of rich tasks - + Parameters ---------- node : ComputationNode instance The computation node this task corresponds to. - + progress_ctx : rich.Progress instance The progress context to which this task belongs. parent : TaskNode instance The parent of this task. """ + def __init__(self, node, progress_ctx, parent=None): self.node_idx = node.idx self.parent = parent diff --git a/sklearn/callback/tests/__init__.py b/sklearn/callback/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/callback/tests/test_computation_tree.py b/sklearn/callback/tests/test_computation_tree.py index b3612c1df43d1..7258e0833f4a4 100644 --- a/sklearn/callback/tests/test_computation_tree.py +++ b/sklearn/callback/tests/test_computation_tree.py @@ -3,7 +3,7 @@ import numpy as np -from sklearn.callback import ComputationNode, build_computation_tree +from sklearn.callback import build_computation_tree levels = [ {"descr": "level0", "max_iter": 3}, @@ -54,7 +54,7 @@ def test_path(): computation_tree = build_computation_tree(estimator_name="", levels=levels) assert computation_tree.path == [computation_tree] - + node = computation_tree.children[1].children[2].children[3] expected_path = [ computation_tree, @@ -65,5 +65,3 @@ def test_path(): assert node.path == expected_path assert all(node.depth == i for i, node in enumerate(expected_path)) - - From 73ecb31f9e8daf000091dfa7396900d3cb1640f2 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 25 Oct 2023 18:33:30 +0200 Subject: [PATCH 24/55] wip --- sklearn/callback/_computation_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py index fccea345134f1..9705b4f70e3dc 100644 --- a/sklearn/callback/_computation_tree.py +++ b/sklearn/callback/_computation_tree.py @@ -69,7 +69,7 @@ def __iter__(self): def build_computation_tree(estimator_name, levels, parent=None, idx=0): """Build the computation tree from the description of the levels. - + Parameters ---------- estimator_name : str From 90589197f14a3de48e773589092f23364fbf5bc8 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 27 Oct 2023 10:25:10 +0200 Subject: [PATCH 25/55] mypy --- sklearn/callback/tests/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index e0839ffd94d46..78e66a478e8e3 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -35,7 +35,7 @@ def on_fit_iter_end(self, estimator, node, **kwargs): class Estimator(BaseEstimator): - _parameter_constraints = {} + _parameter_constraints: dict = {} def __init__(self, max_iter=20): self.max_iter = max_iter @@ -64,7 +64,7 @@ def fit(self, X, y): class MetaEstimator(BaseEstimator): - _parameter_constraints = {} + _parameter_constraints: dict = {} def __init__( self, estimator, n_outer=4, n_inner=3, n_jobs=None, prefer="processes" From 309f755937d8f15e7ed83ba18fc556b5b67bedf5 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 27 Oct 2023 14:58:35 +0200 Subject: [PATCH 26/55] add test for progressbars --- sklearn/callback/_base.py | 3 -- sklearn/callback/_progressbar.py | 8 ++-- sklearn/callback/tests/test_progressbar.py | 43 ++++++++++++++++++++++ 3 files changed, 48 insertions(+), 6 deletions(-) create mode 100644 sklearn/callback/tests/test_progressbar.py diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index aa99303ef220d..496a06d6799ba 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -61,12 +61,10 @@ def on_fit_begin(self, estimator, *, X=None, y=None): y : ndarray or sparse matrix, default=None The target. """ - pass @abstractmethod def on_fit_end(self): """Method called at the end of the fit method of the estimator""" - pass @abstractmethod def on_fit_iter_end(self, estimator, node, **kwargs): @@ -108,7 +106,6 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) stop : bool or None Whether or not to stop the current level of iterations at this node. """ - pass @property def auto_propagate(self): diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index 6440b1856d695..f6636c4e48182 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -78,7 +78,7 @@ def run(self): complete_style=Style(color="dark_orange"), finished_style=Style(color="cyan"), ), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TextColumn("[bright_magenta]{task.percentage:>3.0f}%"), TimeRemainingColumn(), auto_refresh=False, ) @@ -166,12 +166,14 @@ def __init__(self, node, progress_ctx, parent=None): def _format_task_description(self, node): """Return a formatted description for the task of the node""" - colors = ["red", "green", "blue", "yellow"] + colors = ["bright_magenta", "cyan", "dark_orange"] indent = f"{' ' * (node.depth)}" style = f"[{colors[(node.depth)%len(colors)]}]" - description = f"{node.estimator_name[0]} - {node.description[0]} #{node.idx}" + description = f"{node.estimator_name[0]} - {node.description[0]}" + if node.parent is not None: + description += f" #{node.idx}" if len(node.estimator_name) == 2: description += f" | {node.estimator_name[1]} - {node.description[1]}" diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py new file mode 100644 index 0000000000000..552535f07c44f --- /dev/null +++ b/sklearn/callback/tests/test_progressbar.py @@ -0,0 +1,43 @@ +import textwrap + +import pytest + +from sklearn.callback import ProgressBar + +from ._utils import Estimator, MetaEstimator + + +@pytest.mark.parametrize("n_jobs", [1, 2]) +@pytest.mark.parametrize("prefer", ["threads", "processes"]) +def test_progressbar(n_jobs, prefer, capsys): + """Check the output of the progress bars and their completion""" + pytest.importorskip("rich") + + est = Estimator() + meta_est = MetaEstimator(est, n_jobs=n_jobs, prefer=prefer) + meta_est._set_callbacks(ProgressBar()) + meta_est.fit(None, None) + + captured = capsys.readouterr() + + expected_output = """\ + MetaEstimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - outer #0 ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #0 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #1 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #2 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - outer #1 ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #0 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #1 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #2 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - outer #2 ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #0 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #1 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #2 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - outer #3 ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #0 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #1 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + MetaEstimator - inner #2 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + """ + + assert captured.out == textwrap.dedent(expected_output) From 3569329acafd345e7e29d9e6b9d807489c5e420a Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 27 Oct 2023 15:07:59 +0200 Subject: [PATCH 27/55] can't guarantee same order of tasks --- sklearn/callback/_base.py | 28 +++++++++++-------- sklearn/callback/tests/test_progressbar.py | 32 +++++++--------------- 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 496a06d6799ba..c61cc1c1eb92a 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -30,13 +30,15 @@ def _eval_callbacks_on_fit_iter_end(**kwargs): # stopping_criterion and reconstruction_attributes can be costly to compute. # They are passed as lambdas for lazy evaluation. We only actually # compute them if a callback requests it. - if any(cb.request_stopping_criterion for cb in estimator._callbacks): - kwarg = kwargs.pop("stopping_criterion", lambda: None)() - kwargs["stopping_criterion"] = kwarg + # TODO: This is not used yet but will be necessary for next callbacks + # Uncomment when needed + # if any(cb.request_stopping_criterion for cb in estimator._callbacks): + # kwarg = kwargs.pop("stopping_criterion", lambda: None)() + # kwargs["stopping_criterion"] = kwarg - if any(cb.request_from_reconstruction_attributes for cb in estimator._callbacks): - kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() - kwargs["from_reconstruction_attributes"] = kwarg + # if any(cb.request_from_reconstruction_attributes for cb in estimator._callbacks): + # kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() + # kwargs["from_reconstruction_attributes"] = kwarg return any(callback.on_fit_iter_end(**kwargs) for callback in estimator._callbacks) @@ -125,10 +127,12 @@ def _is_propagated(self, estimator): """ return self.auto_propagate and hasattr(estimator, "_parent_node") - @property - def request_stopping_criterion(self): - return False + # TODO: This is not used yet but will be necessary for next callbacks + # Uncomment when needed + # @property + # def request_stopping_criterion(self): + # return False - @property - def request_from_reconstruction_attributes(self): - return False + # @property + # def request_from_reconstruction_attributes(self): + # return False diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py index 552535f07c44f..33ee5d42ec0bb 100644 --- a/sklearn/callback/tests/test_progressbar.py +++ b/sklearn/callback/tests/test_progressbar.py @@ -1,4 +1,4 @@ -import textwrap +import re import pytest @@ -20,24 +20,12 @@ def test_progressbar(n_jobs, prefer, capsys): captured = capsys.readouterr() - expected_output = """\ - MetaEstimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - outer #0 ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #0 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #1 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #2 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - outer #1 ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #0 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #1 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #2 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - outer #2 ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #0 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #1 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #2 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - outer #3 ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #0 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #1 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - MetaEstimator - inner #2 | Estimator - fit ━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - """ - - assert captured.out == textwrap.dedent(expected_output) + assert re.search(r"MetaEstimator - fit", captured.out) + for i in range(4): + assert re.search(rf"MetaEstimator - outer #{i}", captured.out) + for i in range(3): + assert re.search(rf"MetaEstimator - inner #{i} | Estimator - fit", captured.out) + + # Check that all bars are 100% complete + assert re.search(r"100%", captured.out) + assert not re.search(r"[1-9]%", captured.out) From 2e28e4ad5a1b70ff367380076aa6a31564afb3f8 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 27 Oct 2023 15:33:11 +0200 Subject: [PATCH 28/55] cln --- sklearn/callback/tests/test_progressbar.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py index 33ee5d42ec0bb..a22970d65571a 100644 --- a/sklearn/callback/tests/test_progressbar.py +++ b/sklearn/callback/tests/test_progressbar.py @@ -1,3 +1,6 @@ +# License: BSD 3 clause +# Authors: the scikit-learn developers + import re import pytest From 5270bad35ac5fd8c43b39d2bfaa8965b0d07189e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 21 Nov 2023 09:39:57 +0100 Subject: [PATCH 29/55] address nitpicks --- sklearn/base.py | 10 +++--- sklearn/callback/__init__.py | 5 +++ sklearn/callback/_base.py | 4 +-- sklearn/callback/_computation_tree.py | 6 ++-- sklearn/callback/_progressbar.py | 33 +++++++++---------- sklearn/callback/tests/_utils.py | 24 +++++++------- .../test_base_estimator_callback_methods.py | 15 +++++---- .../callback/tests/test_computation_tree.py | 18 +++++----- sklearn/callback/tests/test_progressbar.py | 2 +- 9 files changed, 62 insertions(+), 55 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 8e17444112c41..336e15cfc0b52 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -116,7 +116,7 @@ def _clone_parametrized(estimator, *, safe=True): params_set = new_object.get_params(deep=False) - # copy callbacks + # attach callbacks to the new estimator if hasattr(estimator, "_callbacks"): new_object._callbacks = estimator._callbacks @@ -672,7 +672,7 @@ def _set_callbacks(self, callbacks): # XXX should be a method of MetaEstimatorMixin but this mixin can't handle all # meta-estimators. def _propagate_callbacks(self, sub_estimator, *, parent_node): - """Propagate the auto-propagated callbacks to a sub-estimator + """Propagate the auto-propagated callbacks to a sub-estimator. Parameters ---------- @@ -680,7 +680,7 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): The sub-estimator to propagate the callbacks to. parent_node : ComputationNode instance - The computation node in this estimator to set as parent_node to the + The computation node in this estimator to set as `parent_node` to the computation tree of the sub-estimator. It must be the node where the fit method of the sub-estimator is called. """ @@ -716,7 +716,7 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): ) def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): - """Evaluate the on_fit_begin method of the callbacks + """Evaluate the `on_fit_begin` method of the callbacks. The computation tree is also built at this point. @@ -757,7 +757,7 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): return self._computation_tree def _eval_callbacks_on_fit_end(self): - """Evaluate the on_fit_end method of the callbacks""" + """Evaluate the `on_fit_end` method of the callbacks.""" if not hasattr(self, "_callbacks") or not hasattr(self, "_computation_tree"): return diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py index 7c6e4a07741e3..c614c38c58b43 100644 --- a/sklearn/callback/__init__.py +++ b/sklearn/callback/__init__.py @@ -1,3 +1,8 @@ +""" +The :mod:`sklearn.callback` module implements the framework and off the shelf +callbacks for scikit-learn estimators. +""" + # License: BSD 3 clause # Authors: the scikit-learn developers diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index c61cc1c1eb92a..43ec4390de92a 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -7,14 +7,14 @@ # Not a method of BaseEstimator because it might not be directly called from fit but # by a non-method function called by fit def _eval_callbacks_on_fit_iter_end(**kwargs): - """Evaluate the on_fit_iter_end method of the callbacks + """Evaluate the `on_fit_iter_end` method of the callbacks. This function must be called at the end of each computation node. Parameters ---------- kwargs : dict - arguments passed to the callback. + Arguments passed to the callback. Returns ------- diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py index 9705b4f70e3dc..2a7bbece6b90e 100644 --- a/sklearn/callback/_computation_tree.py +++ b/sklearn/callback/_computation_tree.py @@ -3,7 +3,7 @@ class ComputationNode: - """A node in a computation tree + """A node in a computation tree. Parameters ---------- @@ -52,12 +52,12 @@ def __init__( @property def depth(self): - """The depth of this node in the computation tree""" + """The depth of this node in the computation tree.""" return 0 if self.parent is None else self.parent.depth + 1 @property def path(self): - """List of all the nodes in the path from the root to this node""" + """List of all the nodes in the path from the root to this node.""" return [self] if self.parent is None else self.parent.path + [self] def __iter__(self): diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index f6636c4e48182..48a1d9882456d 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -9,7 +9,7 @@ class ProgressBar(BaseCallback): - """Callback that displays progress bars for each iterative steps of an estimator""" + """Callback that displays progress bars for each iterative steps of an estimator.""" auto_propagate = True @@ -34,19 +34,18 @@ def on_fit_end(self): def __getstate__(self): state = self.__dict__.copy() if "progress_monitor" in state: - del state["progress_monitor"] + del state["progress_monitor"] # a thread is not picklable return state -# Custom Progress class to allow showing the tasks in a given order (given by setting -# the _ordered_tasks attribute). In particular it allows to dynamically create and -# insert tasks between existing tasks. - try: from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn from rich.style import Style class _Progress(Progress): + # Custom Progress class to allow showing the tasks in a given order (given by + # setting the _ordered_tasks attribute). In particular it allows to dynamically + # create and insert tasks between existing tasks. def get_renderables(self): table = self.make_tasks_table(getattr(self, "_ordered_tasks", [])) yield table @@ -56,14 +55,14 @@ def get_renderables(self): class _RichProgressMonitor(Thread): - """Thread monitoring the progress of an estimator with rich based display + """Thread monitoring the progress of an estimator with rich based display. - The display is a list of nested rich tasks using rich.Progress. There is one for + The display is a list of nested rich tasks using `rich.Progress`. There is one for each non-leaf node in the computation tree of the estimator. Parameters ---------- - queue : multiprocessing.Manager.Queue instance + queue : `multiprocessing.Manager.Queue` instance This thread will run until the queue is empty. """ @@ -94,7 +93,7 @@ def run(self): self.progress_ctx.refresh() def _update_task_tree(self, node): - """Update the tree of tasks from a new node""" + """Update the tree of tasks from a new node.""" curr_task, parent_task = None, None for curr_node in node.path: @@ -116,7 +115,7 @@ def _update_task_tree(self, node): curr_task.finished = True def _update_tasks(self): - """Loop through the tasks in their display oder and update their progress""" + """Loop through the tasks in their display order and update their progress.""" self.progress_ctx._ordered_tasks = [] for task_node in self.root_task: @@ -140,17 +139,17 @@ def _update_tasks(self): class TaskNode: - """A node in the tree of rich tasks + """A node in the tree of rich tasks. Parameters ---------- - node : ComputationNode instance + node : `ComputationNode` instance The computation node this task corresponds to. - progress_ctx : rich.Progress instance + progress_ctx : `rich.Progress` instance The progress context to which this task belongs. - parent : TaskNode instance + parent : `TaskNode` instance The parent of this task. """ @@ -165,7 +164,7 @@ def __init__(self, node, progress_ctx, parent=None): self.task_id = progress_ctx.add_task(description, total=node.max_iter) def _format_task_description(self, node): - """Return a formatted description for the task of the node""" + """Return a formatted description for the task of the node.""" colors = ["bright_magenta", "cyan", "dark_orange"] indent = f"{' ' * (node.depth)}" @@ -180,7 +179,7 @@ def _format_task_description(self, node): return f"{style}{indent}{description}" def __iter__(self): - """Pre-order depth-first traversal, excluding leaves""" + """Pre-order depth-first traversal, excluding leaves.""" if self.children: yield self for child in self.children.values(): diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 78e66a478e8e3..6a7662f1434cd 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -1,11 +1,10 @@ # License: BSD 3 clause # Authors: the scikit-learn developers -from joblib.parallel import Parallel, delayed - from sklearn.base import BaseEstimator, _fit_context, clone from sklearn.callback import BaseCallback from sklearn.callback._base import _eval_callbacks_on_fit_iter_end +from sklearn.utils.parallel import Parallel, delayed class TestingCallback(BaseCallback): @@ -24,6 +23,8 @@ class TestingAutoPropagatedCallback(TestingCallback): class NotValidCallback: + """Unvalid callback since it does not inherit from `BaseCallback`.""" + def on_fit_begin(self, estimator, *, X=None, y=None): pass @@ -88,20 +89,19 @@ def fit(self, X, y): ) Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( - delayed(self._func)(self.estimator, X, y, node, i) - for i, node in enumerate(root.children) + delayed(_func)(self, self.estimator, X, y, node) + for _, node in enumerate(root.children) ) return self - def _func(self, estimator, X, y, parent_node, i): - for j, node in enumerate(parent_node.children): - est = clone(estimator) - self._propagate_callbacks(est, parent_node=node) - est.fit(X, y) - _eval_callbacks_on_fit_iter_end(estimator=self, node=node) +def _func(meta_estimator, inner_estimator, X, y, parent_node): + for _, node in enumerate(parent_node.children): + est = clone(inner_estimator) + meta_estimator._propagate_callbacks(est, parent_node=node) + est.fit(X, y) - _eval_callbacks_on_fit_iter_end(estimator=self, node=parent_node) + _eval_callbacks_on_fit_iter_end(estimator=meta_estimator, node=node) - return + _eval_callbacks_on_fit_iter_end(estimator=meta_estimator, node=parent_node) diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index 17af6045df49c..1d5e5d1678083 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -21,18 +21,21 @@ ], ) def test_set_callbacks(callbacks): - """Sanity check for the _set_callbacks method""" + """Sanity check for the `_set_callbacks` method.""" estimator = Estimator() set_callbacks_return = estimator._set_callbacks(callbacks) assert hasattr(estimator, "_callbacks") - assert estimator._callbacks in (callbacks, [callbacks]) + + expected_callbacks = [callbacks] if not isinstance(callbacks, list) else callbacks + assert estimator._callbacks == expected_callbacks + assert set_callbacks_return is estimator @pytest.mark.parametrize("callbacks", [None, NotValidCallback()]) def test_set_callbacks_error(callbacks): - """Check the error message when not passing a valid callback to _set_callbacks""" + """Check the error message when not passing a valid callback to `_set_callbacks`.""" estimator = Estimator() with pytest.raises(TypeError, match="callbacks must be subclasses of BaseCallback"): @@ -40,7 +43,7 @@ def test_set_callbacks_error(callbacks): def test_propagate_callbacks(): - """Sanity check for the _propagate_callbacks method""" + """Sanity check for the `_propagate_callbacks` method.""" not_propagated_callback = TestingCallback() propagated_callback = TestingAutoPropagatedCallback() @@ -56,7 +59,7 @@ def test_propagate_callbacks(): def test_propagate_callback_no_callback(): - """Check that no callback is propagated if there's no callback""" + """Check that no callback is propagated if there's no callback.""" estimator = Estimator() sub_estimator = Estimator() estimator._propagate_callbacks(sub_estimator, parent_node=None) @@ -82,7 +85,7 @@ def test_auto_propagated_callbacks(): def test_eval_callbacks_on_fit_begin(): - """Check that _eval_callbacks_on_fit_begin creates the computation tree""" + """Check that `_eval_callbacks_on_fit_begin` creates the computation tree.""" estimator = Estimator()._set_callbacks(TestingCallback()) assert not hasattr(estimator, "_computation_tree") diff --git a/sklearn/callback/tests/test_computation_tree.py b/sklearn/callback/tests/test_computation_tree.py index 7258e0833f4a4..ee2e3fa971fdb 100644 --- a/sklearn/callback/tests/test_computation_tree.py +++ b/sklearn/callback/tests/test_computation_tree.py @@ -5,7 +5,7 @@ from sklearn.callback import build_computation_tree -levels = [ +LEVELS = [ {"descr": "level0", "max_iter": 3}, {"descr": "level1", "max_iter": 5}, {"descr": "level2", "max_iter": 7}, @@ -14,8 +14,8 @@ def test_computation_tree(): - """Check the construction of the computation tree""" - computation_tree = build_computation_tree(estimator_name="estimator", levels=levels) + """Check the construction of the computation tree.""" + computation_tree = build_computation_tree(estimator_name="estimator", levels=LEVELS) assert computation_tree.estimator_name == ("estimator",) assert computation_tree.parent is None assert computation_tree.idx == 0 @@ -36,12 +36,12 @@ def test_computation_tree(): def test_n_nodes(): - """Check that the number of node in a comutation tree corresponds to what we expect - from the level descriptions + """Check that the number of node in a computation tree corresponds to what we expect + from the level descriptions. """ - computation_tree = build_computation_tree(estimator_name="", levels=levels) + computation_tree = build_computation_tree(estimator_name="", levels=LEVELS) - max_iter_per_level = [level["max_iter"] for level in levels[:-1]] + max_iter_per_level = [level["max_iter"] for level in LEVELS[:-1]] expected_n_nodes = 1 + np.sum(np.cumprod(max_iter_per_level)) actual_n_nodes = sum(1 for _ in computation_tree) @@ -50,8 +50,8 @@ def test_n_nodes(): def test_path(): - """Check that the path from the root to a node is correct""" - computation_tree = build_computation_tree(estimator_name="", levels=levels) + """Check that the path from the root to a node is correct.""" + computation_tree = build_computation_tree(estimator_name="", levels=LEVELS) assert computation_tree.path == [computation_tree] diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py index a22970d65571a..33f92eebb2013 100644 --- a/sklearn/callback/tests/test_progressbar.py +++ b/sklearn/callback/tests/test_progressbar.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize("n_jobs", [1, 2]) @pytest.mark.parametrize("prefer", ["threads", "processes"]) def test_progressbar(n_jobs, prefer, capsys): - """Check the output of the progress bars and their completion""" + """Check the output of the progress bars and their completion.""" pytest.importorskip("rich") est = Estimator() From ae5faccf43ca941ae031c45237a9d0ac2ecdc1e0 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 22 Nov 2023 17:56:05 +0100 Subject: [PATCH 30/55] make rich soft dependency --- build_tools/azure/debian_atlas_32bit_lock.txt | 2 +- .../azure/debian_atlas_32bit_requirements.txt | 2 +- ...38_conda_defaults_openblas_environment.yml | 1 + ...onda_defaults_openblas_linux-64_conda.lock | 35 ++-- .../py38_conda_forge_mkl_win-64_conda.lock | 73 ++++---- ...forge_openblas_ubuntu_2204_environment.yml | 1 + ...e_openblas_ubuntu_2204_linux-64_conda.lock | 91 +++++----- ...latest_conda_forge_mkl_linux-64_conda.lock | 79 ++++---- ...t_conda_forge_mkl_linux-64_environment.yml | 1 + ...onda_forge_mkl_no_coverage_environment.yml | 1 + ..._forge_mkl_no_coverage_linux-64_conda.lock | 89 +++++----- ...pylatest_conda_forge_mkl_osx-64_conda.lock | 111 ++++++------ ...est_conda_forge_mkl_osx-64_environment.yml | 1 + ...latest_conda_mkl_no_openmp_environment.yml | 1 + ...test_conda_mkl_no_openmp_osx-64_conda.lock | 22 ++- ...latest_pip_openblas_pandas_environment.yml | 1 + ...st_pip_openblas_pandas_linux-64_conda.lock | 43 ++--- ...pylatest_pip_scipy_dev_linux-64_conda.lock | 20 +-- build_tools/azure/pypy3_linux-64_conda.lock | 64 +++---- build_tools/azure/ubuntu_atlas_lock.txt | 6 +- .../azure/ubuntu_atlas_requirements.txt | 2 +- build_tools/circle/doc_environment.yml | 1 + build_tools/circle/doc_linux-64_conda.lock | 168 +++++++++--------- .../doc_min_dependencies_environment.yml | 1 + .../doc_min_dependencies_linux-64_conda.lock | 82 ++++----- .../py39_conda_forge_linux-aarch64_conda.lock | 60 +++---- .../update_environments_and_lock_files.py | 11 +- sklearn/_min_dependencies.py | 1 + sklearn/callback/_progressbar.py | 7 +- sklearn/utils/__init__.py | 16 ++ 30 files changed, 520 insertions(+), 473 deletions(-) diff --git a/build_tools/azure/debian_atlas_32bit_lock.txt b/build_tools/azure/debian_atlas_32bit_lock.txt index 3734664c28110..02577734f5601 100644 --- a/build_tools/azure/debian_atlas_32bit_lock.txt +++ b/build_tools/azure/debian_atlas_32bit_lock.txt @@ -12,7 +12,7 @@ cython==0.29.33 # via -r build_tools/azure/debian_atlas_32bit_requirements.txt iniconfig==2.0.0 # via pytest -joblib==1.1.1 +joblib==1.2.0 # via -r build_tools/azure/debian_atlas_32bit_requirements.txt packaging==23.2 # via pytest diff --git a/build_tools/azure/debian_atlas_32bit_requirements.txt b/build_tools/azure/debian_atlas_32bit_requirements.txt index 6cddc77610e5c..52f7aeaac577f 100644 --- a/build_tools/azure/debian_atlas_32bit_requirements.txt +++ b/build_tools/azure/debian_atlas_32bit_requirements.txt @@ -2,7 +2,7 @@ # following script to centralize the configuration for CI builds: # build_tools/update_environments_and_lock_files.py cython==0.29.33 # min -joblib==1.1.1 # min +joblib==1.2.0 # min threadpoolctl==2.2.0 pytest==7.1.2 # min pytest-cov==2.9.0 # min diff --git a/build_tools/azure/py38_conda_defaults_openblas_environment.yml b/build_tools/azure/py38_conda_defaults_openblas_environment.yml index f0e5c653cb2de..9f69928464d45 100644 --- a/build_tools/azure/py38_conda_defaults_openblas_environment.yml +++ b/build_tools/azure/py38_conda_defaults_openblas_environment.yml @@ -13,6 +13,7 @@ dependencies: - threadpoolctl=2.2.0 - matplotlib=3.3.4 # min - pandas + - rich - pyamg - pytest - pytest-xdist=2.5.0 diff --git a/build_tools/azure/py38_conda_defaults_openblas_linux-64_conda.lock b/build_tools/azure/py38_conda_defaults_openblas_linux-64_conda.lock index d0fa05f74e541..132c8e7bf8ffc 100644 --- a/build_tools/azure/py38_conda_defaults_openblas_linux-64_conda.lock +++ b/build_tools/azure/py38_conda_defaults_openblas_linux-64_conda.lock @@ -1,6 +1,6 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 59b748d4b41a3e69462c0c657961aebaa5b15bc3caad670dff038296fa151c6e +# input_hash: 7876783d709e65ea6ff0d3fffd31775a638a6bed643e8f330ef09c5529720a09 @EXPLICIT https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda#c3473ff8bdb3d124ed5ff11ec380d6f9 https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-openblas.conda#9ddfcaef10d79366c90128f5dc444be8 @@ -14,7 +14,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/_openmp_mutex-5.1-1_gnu.conda#71d28 https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-11.2.0-h1234567_1.conda#a87728dabf3151fb9cfa990bd2eb0464 https://repo.anaconda.com/pkgs/main/linux-64/expat-2.5.0-h6a678d5_0.conda#9a21d99d49a0a556cf9590430dec8ec0 https://repo.anaconda.com/pkgs/main/linux-64/giflib-5.2.1-h5eee18b_3.conda#aa7d64adb3cd8a75d398167f8c29afc3 -https://repo.anaconda.com/pkgs/main/linux-64/icu-58.2-he6710b0_3.conda#48cc14d5ad1a9bcd8dac17211a8deb8b +https://repo.anaconda.com/pkgs/main/linux-64/icu-73.1-h6a678d5_0.conda#6d09df641fc23f7d277a04dc7ea32dd4 https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9e-h5eee18b_1.conda#ac373800fda872108412d1ccfe3fa572 https://repo.anaconda.com/pkgs/main/linux-64/lerc-3.0-h295c915_0.conda#b97309770412f10bed8d9448f6f98f87 https://repo.anaconda.com/pkgs/main/linux-64/libdeflate-1.17-h5eee18b_1.conda#82831ef0b6c9595382d74e0c281f6742 @@ -25,8 +25,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/libwebp-base-1.3.2-h5eee18b_0.conda https://repo.anaconda.com/pkgs/main/linux-64/libxcb-1.15-h7f8727e_0.conda#ada518dcadd6aaee9aae47ba9a671553 https://repo.anaconda.com/pkgs/main/linux-64/lz4-c-1.9.4-h6a678d5_0.conda#53915e9402180a7f22ea619c41089520 https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.4-h6a678d5_0.conda#5558eec6e2191741a92f832ea826251c -https://repo.anaconda.com/pkgs/main/linux-64/nspr-4.35-h6a678d5_0.conda#208fff5d60133bcff6998a70c9f5203b -https://repo.anaconda.com/pkgs/main/linux-64/openssl-3.0.11-h7f8727e_2.conda#6cad6f2dcde73f8625d729c6db1272d0 +https://repo.anaconda.com/pkgs/main/linux-64/openssl-3.0.12-h7f8727e_0.conda#48caaebab690276acf1bc1f3b56febf4 https://repo.anaconda.com/pkgs/main/linux-64/pcre-8.45-h295c915_0.conda#b32ccc24d1d9808618c1e898da60f68d https://repo.anaconda.com/pkgs/main/linux-64/xz-5.4.2-h5eee18b_0.conda#bcd31de48a0dcb44bc5b99675800c5cc https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.13-h5eee18b_0.conda#333e31fbfbb5057c92fa845ad6adef93 @@ -34,10 +33,9 @@ https://repo.anaconda.com/pkgs/main/linux-64/ccache-3.7.9-hfe4627d_0.conda#bef6f https://repo.anaconda.com/pkgs/main/linux-64/glib-2.69.1-he621ea3_2.conda#51cf1899782b3f3744aedd143fbc07f3 https://repo.anaconda.com/pkgs/main/linux-64/libcups-2.4.2-h2d74bed_1.conda#3f265c2172a9e8c90a74037b6fa13685 https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20221030-h5eee18b_0.conda#7c724a17739aceaf9d1633ff06962137 -https://repo.anaconda.com/pkgs/main/linux-64/libevent-2.1.12-hdbd6064_1.conda#99312bf9d90f1ea14534b40afb61ce63 https://repo.anaconda.com/pkgs/main/linux-64/libllvm14-14.0.6-hdb19cb5_3.conda#aefea2b45cf32f12b4f1ffaa70aa3201 https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.39-h5eee18b_0.conda#f6aee38184512eb05b06c2e94d39ab22 -https://repo.anaconda.com/pkgs/main/linux-64/libxml2-2.10.4-hcbfbd50_0.conda#c42cffdb0bc28d37a4eb33aed114f554 +https://repo.anaconda.com/pkgs/main/linux-64/libxml2-2.10.4-hf1b16e4_1.conda#e87849ce513f9968794f20bba620e6a4 https://repo.anaconda.com/pkgs/main/linux-64/readline-8.2-h5eee18b_0.conda#be42180685cce6e6b0329201d9f48efb https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.12-h1ccaba5_0.conda#fa10ff4aa631fa4aa090a6234d7770b9 https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.5.5-hc292b87_0.conda#0f59d57dc21f585f4c282d60dfb46505 @@ -48,7 +46,6 @@ https://repo.anaconda.com/pkgs/main/linux-64/krb5-1.20.1-h143b758_1.conda#cf1acc https://repo.anaconda.com/pkgs/main/linux-64/libclang13-14.0.6-default_he11475f_1.conda#44890feda1cf51639d9c94afbacce011 https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.5.1-h6a678d5_0.conda#235a671f74f0c4ecad9f9b3b107e3566 https://repo.anaconda.com/pkgs/main/linux-64/libxkbcommon-1.0.1-h5eee18b_1.conda#888b2e8f1bbf21017c503826e2d24b50 -https://repo.anaconda.com/pkgs/main/linux-64/libxslt-1.1.37-h2085143_0.conda#680f9676bf55bdafd276eaa12fbb0f28 https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.41.2-h5eee18b_0.conda#c7086c9ceb6cfe1c4c729a774a2d88a5 https://repo.anaconda.com/pkgs/main/linux-64/cyrus-sasl-2.1.28-h52b45da_1.conda#d634af1577e4008f9228ae96ce671c44 https://repo.anaconda.com/pkgs/main/linux-64/fontconfig-2.14.1-h4c34cd2_2.conda#f0b472f5b544f8d57beb09ed4a2932e1 @@ -57,10 +54,9 @@ https://repo.anaconda.com/pkgs/main/linux-64/lcms2-2.12-h3be6417_0.conda#719db47 https://repo.anaconda.com/pkgs/main/linux-64/libclang-14.0.6-default_hc6dbbc7_1.conda#8f12583c4027b2861cff470f6b8837c4 https://repo.anaconda.com/pkgs/main/linux-64/libpq-12.15-hdbd6064_1.conda#218227d255f6056b6f49f52dd0d1731f https://repo.anaconda.com/pkgs/main/linux-64/libwebp-1.3.2-h11a3e52_0.conda#9e0d6c9abdd97b076c66d4cf488589ee -https://repo.anaconda.com/pkgs/main/linux-64/nss-3.89.1-h6a678d5_0.conda#4d9d28fc3a0ca4916f281d2f5429ac50 https://repo.anaconda.com/pkgs/main/linux-64/openjpeg-2.4.0-h3ad879b_0.conda#86baecb47ecaa7f7ff2657a1f03b90c9 https://repo.anaconda.com/pkgs/main/linux-64/python-3.8.18-h955ad1f_0.conda#fa35c1028f48db26df051ee75dd9422f -https://repo.anaconda.com/pkgs/main/linux-64/certifi-2023.7.22-py38h06a4308_0.conda#59416ad8979a654bb8f5184b62d8a9e7 +https://repo.anaconda.com/pkgs/main/linux-64/certifi-2023.11.17-py38h06a4308_0.conda#3c4c381d8521859fcfde56ef2e3e5c40 https://repo.anaconda.com/pkgs/main/noarch/cycler-0.11.0-pyhd3eb1b0_0.conda#f5e365d2cdb66d547eb8c3ab93843aab https://repo.anaconda.com/pkgs/main/linux-64/cython-0.29.33-py38h6a678d5_0.conda#eb105388ba8bcf5ce82cf4cd5deeb5f9 https://repo.anaconda.com/pkgs/main/linux-64/exceptiongroup-1.0.4-py38h06a4308_0.conda#db954e73dca6076c64a1004d71b45784 @@ -68,6 +64,7 @@ https://repo.anaconda.com/pkgs/main/noarch/execnet-1.9.0-pyhd3eb1b0_0.conda#f895 https://repo.anaconda.com/pkgs/main/noarch/iniconfig-1.1.1-pyhd3eb1b0_0.tar.bz2#e40edff2c5708f342cef43c7f280c507 https://repo.anaconda.com/pkgs/main/linux-64/joblib-1.2.0-py38h06a4308_0.conda#ee7f1f50ae15650057e5d5301900ae34 https://repo.anaconda.com/pkgs/main/linux-64/kiwisolver-1.4.4-py38h6a678d5_0.conda#7424aa335d22974192800ec19a68486e +https://repo.anaconda.com/pkgs/main/linux-64/mdurl-0.1.0-py38h06a4308_0.conda#69312410e814f0bff66c43199238b373 https://repo.anaconda.com/pkgs/main/linux-64/mysql-5.7.24-h721c034_2.conda#dfc19ca2466d275c4c1f73b62c57f37b https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.17.3-py38h2f8d375_0.conda#40edbb76ecacefb1e6ab639b514822b1 https://repo.anaconda.com/pkgs/main/linux-64/packaging-23.1-py38h06a4308_0.conda#9ec9b6ee22dad7f49806c51218befd5b @@ -75,7 +72,9 @@ https://repo.anaconda.com/pkgs/main/linux-64/pillow-10.0.1-py38ha6cbd5a_0.conda# https://repo.anaconda.com/pkgs/main/linux-64/pluggy-1.0.0-py38h06a4308_1.conda#87bb1d3f6cf3e409a1dac38cee99918e https://repo.anaconda.com/pkgs/main/linux-64/ply-3.11-py38_0.conda#d6a69c576c6e4d19e3074eaae3d149f2 https://repo.anaconda.com/pkgs/main/noarch/py-1.11.0-pyhd3eb1b0_0.conda#7205a898ed2abbf6e9b903dff6abe08e +https://repo.anaconda.com/pkgs/main/linux-64/pygments-2.15.1-py38h06a4308_1.conda#79e8654fed904cfd833c44ef0d1307a2 https://repo.anaconda.com/pkgs/main/linux-64/pyparsing-3.0.9-py38h06a4308_0.conda#becbbf51d2b05de228eed968e20f963d +https://repo.anaconda.com/pkgs/main/linux-64/pyqt5-sip-12.13.0-py38h5eee18b_0.conda#0ebb310c44968880835aefbf9fbbfa2c https://repo.anaconda.com/pkgs/main/linux-64/pytz-2023.3.post1-py38h06a4308_0.conda#351d59ddfed216ab9b05481d3bb63106 https://repo.anaconda.com/pkgs/main/linux-64/setuptools-68.0.0-py38h06a4308_0.conda#24f9c895455f3992d6b04957fd0e7546 https://repo.anaconda.com/pkgs/main/noarch/six-1.16.0-pyhd3eb1b0_1.conda#34586824d411d36af2fa40e799c172d0 @@ -83,21 +82,21 @@ https://repo.anaconda.com/pkgs/main/noarch/threadpoolctl-2.2.0-pyh0d69192_0.cond https://repo.anaconda.com/pkgs/main/noarch/toml-0.10.2-pyhd3eb1b0_0.conda#cda05f5f6d8509529d1a2743288d197a https://repo.anaconda.com/pkgs/main/linux-64/tomli-2.0.1-py38h06a4308_0.conda#791cce9de9913e9587b0a85cd8419123 https://repo.anaconda.com/pkgs/main/linux-64/tornado-6.3.3-py38h5eee18b_0.conda#8030fb73590f8370a558f783b4f9f030 +https://repo.anaconda.com/pkgs/main/linux-64/typing_extensions-4.7.1-py38h06a4308_0.conda#cd44242195553e21e029f66627a43387 https://repo.anaconda.com/pkgs/main/linux-64/coverage-7.2.2-py38h5eee18b_0.conda#a05c1732d4e67102d2aa8d7e56de778b +https://repo.anaconda.com/pkgs/main/linux-64/markdown-it-py-2.2.0-py38h06a4308_1.conda#e2ace6e5d56b948f53e11e9452f49f99 https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.17.3-py38h7e8d029_0.conda#5f2b196b515f8fe6b37e3d224650577d https://repo.anaconda.com/pkgs/main/linux-64/pytest-7.4.0-py38h06a4308_0.conda#ba6c58ef1c6ba5247ccc17d41fdd71e5 https://repo.anaconda.com/pkgs/main/noarch/python-dateutil-2.8.2-pyhd3eb1b0_0.conda#211ee00320b08a1ac9fea6677649f6c9 -https://repo.anaconda.com/pkgs/main/linux-64/qt-main-5.15.2-h7358343_9.conda#d3eac069d7e4e93b866a07c2274c9ee7 -https://repo.anaconda.com/pkgs/main/linux-64/sip-6.6.2-py38h6a678d5_0.conda#cb3f0d10f7f79870945f4dbbe0000f92 +https://repo.anaconda.com/pkgs/main/linux-64/qt-main-5.15.2-h53bd1ea_10.conda#bd0c79e82df6323f638bdcb871891b61 +https://repo.anaconda.com/pkgs/main/linux-64/sip-6.7.12-py38h6a678d5_0.conda#3a940732bb7fcf43ec398ce06be29eb4 https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-base-3.3.4-py38h62a2d02_0.conda#7156fafe3362d0b6a2de43e0002febb3 https://repo.anaconda.com/pkgs/main/linux-64/pandas-1.2.4-py38ha9443f7_0.conda#5bd3fd807a294f387feabc65821b75d0 -https://repo.anaconda.com/pkgs/main/linux-64/pyqt5-sip-12.11.0-py38h6a678d5_1.conda#7bc403c7d55f1465e922964d293d2186 -https://repo.anaconda.com/pkgs/main/linux-64/pytest-cov-4.1.0-py38h06a4308_0.conda#ef981a8b88a9ecf7a84bf50516211e0c -https://repo.anaconda.com/pkgs/main/noarch/pytest-forked-1.3.0-pyhd3eb1b0_0.tar.bz2#07970bffdc78f417d7f8f1c7e620f5c4 -https://repo.anaconda.com/pkgs/main/linux-64/qt-webengine-5.15.9-h9ab4d14_7.conda#907aa480f11eabd16bd6c72c81720ef2 +https://repo.anaconda.com/pkgs/main/linux-64/pyqt-5.15.10-py38h6a678d5_0.conda#5251f84010c75d82f672974e69c67cd6 +https://repo.anaconda.com/pkgs/main/linux-64/pytest-cov-4.1.0-py38h06a4308_1.conda#6b5a671f724b1520b19f48988ad99083 +https://repo.anaconda.com/pkgs/main/linux-64/pytest-forked-1.6.0-py38h06a4308_0.conda#aff806e2ad3b684150eeaceaf9be72c4 +https://repo.anaconda.com/pkgs/main/linux-64/rich-13.3.5-py38h06a4308_0.conda#d6cadc35a9501abb7392d72e25ee86e1 https://repo.anaconda.com/pkgs/main/linux-64/scipy-1.5.0-py38habc2bb6_0.conda#a27a97fc2377ab74cbd33ce22d3c3353 +https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-3.3.4-py38h06a4308_0.conda#96033fd3465abc467ae394c6852930de https://repo.anaconda.com/pkgs/main/linux-64/pyamg-4.2.3-py38h79cecc1_0.conda#6e7f4f94000b244396de8bf4e6ae8dc4 https://repo.anaconda.com/pkgs/main/noarch/pytest-xdist-2.5.0-pyhd3eb1b0_0.conda#d15cdc4207bcf8ca920822597f1d138d -https://repo.anaconda.com/pkgs/main/linux-64/qtwebkit-5.212-h3fafdc1_5.conda#e811bbc0456e3d3a02cab199492153ee -https://repo.anaconda.com/pkgs/main/linux-64/pyqt-5.15.7-py38h6a678d5_1.conda#62232dc285be8e7e85ae9596d89b3b95 -https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-3.3.4-py38h06a4308_0.conda#96033fd3465abc467ae394c6852930de diff --git a/build_tools/azure/py38_conda_forge_mkl_win-64_conda.lock b/build_tools/azure/py38_conda_forge_mkl_win-64_conda.lock index 6c58c70b3e79c..ac89cbd89fced 100644 --- a/build_tools/azure/py38_conda_forge_mkl_win-64_conda.lock +++ b/build_tools/azure/py38_conda_forge_mkl_win-64_conda.lock @@ -1,10 +1,10 @@ # Generated by conda-lock. # platform: win-64 -# input_hash: 4ac1abe3eccdd48c0d50af8de11dd3c144459b84f500eae8f575232e0be3a07d +# input_hash: e3af9571d95aff7d02e118db6e2ccbce90cd3cf3c663b4ed8a5e8c3fef5b1318 @EXPLICIT -https://conda.anaconda.org/conda-forge/win-64/ca-certificates-2023.7.22-h56e8100_0.conda#b1c2327b36f1a25d96f2039b0d3e3739 -https://conda.anaconda.org/conda-forge/win-64/intel-openmp-2023.2.0-h57928b3_50496.conda#519f9c42672f1e8a334ec9471e93f4fe -https://conda.anaconda.org/conda-forge/win-64/mkl-include-2023.2.0-h6a75c08_50496.conda#6d5b648b65d7d4fd9feb95be42c11011 +https://conda.anaconda.org/conda-forge/win-64/ca-certificates-2023.11.17-h56e8100_0.conda#1163114b483f26761f993c709e65271f +https://conda.anaconda.org/conda-forge/win-64/intel-openmp-2023.2.0-h57928b3_50497.conda#a401f3cae152deb75bbed766a90a6312 +https://conda.anaconda.org/conda-forge/win-64/mkl-include-2023.2.0-h6a75c08_50497.conda#02fd1f15c56cc902aeaf3df3497cf266 https://conda.anaconda.org/conda-forge/win-64/msys2-conda-epoch-20160418-1.tar.bz2#b0309b72560df66f71a9d5e34a5efdfa https://conda.anaconda.org/conda-forge/win-64/python_abi-3.8-4_cp38.conda#b1059de1664cef9a785dda079a50f1ed https://conda.anaconda.org/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_0.tar.bz2#72608f6cd3e5898229c3ea16deb1ac43 @@ -14,7 +14,7 @@ https://conda.anaconda.org/conda-forge/win-64/vc14_runtime-14.36.32532-hdcecf7f_ https://conda.anaconda.org/conda-forge/win-64/m2w64-gcc-libs-core-5.3.0-7.tar.bz2#4289d80fb4d272f1f3b56cfe87ac90bd https://conda.anaconda.org/conda-forge/win-64/vc-14.3-h64f974e_17.conda#67ff6791f235bb606659bf2a5c169191 https://conda.anaconda.org/conda-forge/win-64/vs2015_runtime-14.36.32532-h05e6639_17.conda#4618046c39f7c81861e53ded842e738a -https://conda.anaconda.org/conda-forge/win-64/bzip2-1.0.8-h8ffe710_4.tar.bz2#7c03c66026944073040cb19a4f3ec3c9 +https://conda.anaconda.org/conda-forge/win-64/bzip2-1.0.8-hcfcfb64_5.conda#26eb8ca6ea332b675e11704cce84a3be https://conda.anaconda.org/conda-forge/win-64/icu-73.2-h63175ca_0.conda#0f47d9e3192d9e09ae300da0d28e0f56 https://conda.anaconda.org/conda-forge/win-64/lerc-4.0.0-h63175ca_0.tar.bz2#1900cb3cab5055833cfddb0ba233b074 https://conda.anaconda.org/conda-forge/win-64/libbrotlicommon-1.1.0-hcfcfb64_1.conda#f77f319fb82980166569e1280d5b2864 @@ -23,13 +23,13 @@ https://conda.anaconda.org/conda-forge/win-64/libffi-3.4.2-h8ffe710_5.tar.bz2#2c https://conda.anaconda.org/conda-forge/win-64/libiconv-1.17-h8ffe710_0.tar.bz2#050119977a86e4856f0416e2edcf81bb https://conda.anaconda.org/conda-forge/win-64/libjpeg-turbo-3.0.0-hcfcfb64_1.conda#3f1b948619c45b1ca714d60c7389092c https://conda.anaconda.org/conda-forge/win-64/libogg-1.3.4-h8ffe710_1.tar.bz2#04286d905a0dcb7f7d4a12bdfe02516d -https://conda.anaconda.org/conda-forge/win-64/libsqlite-3.43.2-hcfcfb64_0.conda#a4a81906f6ce911113f672973777f305 +https://conda.anaconda.org/conda-forge/win-64/libsqlite-3.44.0-hcfcfb64_0.conda#446fb1973cfeb8b32de4add3c9ac1057 https://conda.anaconda.org/conda-forge/win-64/libwebp-base-1.3.2-hcfcfb64_0.conda#dcde8820959e64378d4e06147ffecfdd https://conda.anaconda.org/conda-forge/win-64/libzlib-1.2.13-hcfcfb64_5.conda#5fdb9c6a113b6b6cb5e517fd972d5f41 https://conda.anaconda.org/conda-forge/win-64/m2w64-gcc-libgfortran-5.3.0-6.tar.bz2#066552ac6b907ec6d72c0ddab29050dc -https://conda.anaconda.org/conda-forge/win-64/openssl-3.1.3-hcfcfb64_0.conda#16b2c80ad196f18acd31b588ef28cb9a +https://conda.anaconda.org/conda-forge/win-64/openssl-3.1.4-hcfcfb64_0.conda#2eebbc64373a1c6db62ad23304e9678e https://conda.anaconda.org/conda-forge/win-64/pthreads-win32-2.9.1-hfa6e2cd_3.tar.bz2#e2da8758d7d51ff6aa78a14dfb9dbed4 -https://conda.anaconda.org/conda-forge/win-64/tk-8.6.13-hcfcfb64_0.conda#74405f2ccbb40af409fee1a71ce70dc6 +https://conda.anaconda.org/conda-forge/win-64/tk-8.6.13-h5226925_1.conda#fc048363eb8f03cd1737600a5d08aafe https://conda.anaconda.org/conda-forge/win-64/xz-5.2.6-h8d14728_0.tar.bz2#515d77642eaa3639413c6b1bc3f94219 https://conda.anaconda.org/conda-forge/win-64/gettext-0.21.1-h5728263_0.tar.bz2#299d4fd6798a45337042ff5a48219e5f https://conda.anaconda.org/conda-forge/win-64/krb5-1.21.2-heb0366b_0.conda#6e8b0f22b4eef3b3cb3849bb4c3d47f9 @@ -38,26 +38,26 @@ https://conda.anaconda.org/conda-forge/win-64/libbrotlienc-1.1.0-hcfcfb64_1.cond https://conda.anaconda.org/conda-forge/win-64/libclang13-15.0.7-default_h77d9078_3.conda#ba26634d038b91466bb4242c8b5e0cfa https://conda.anaconda.org/conda-forge/win-64/libpng-1.6.39-h19919ed_0.conda#ab6febdb2dbd9c00803609079db4de71 https://conda.anaconda.org/conda-forge/win-64/libvorbis-1.3.7-h0e60522_0.tar.bz2#e1a22282de0169c93e4ffe6ce6acc212 -https://conda.anaconda.org/conda-forge/win-64/libxml2-2.11.5-hc3477c8_1.conda#27974f880a010b1441093d9f737a949f +https://conda.anaconda.org/conda-forge/win-64/libxml2-2.11.6-hc3477c8_0.conda#08ffbb4c22dd3622e122058368f8b708 https://conda.anaconda.org/conda-forge/win-64/m2w64-gcc-libs-5.3.0-7.tar.bz2#fe759119b8b3bfa720b8762c6fdc35de -https://conda.anaconda.org/conda-forge/win-64/pcre2-10.40-h17e33f8_0.tar.bz2#2519de0d9620dc2bc7e19caf6867136d +https://conda.anaconda.org/conda-forge/win-64/pcre2-10.42-h17e33f8_0.conda#59610c61da3af020289a806ec9c6a7fd https://conda.anaconda.org/conda-forge/win-64/python-3.8.18-h4de0772_0_cpython.conda#d261509b6d608edf6027143f205cf19b https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.5-h12be248_0.conda#792bb5da68bf0a6cac6a6072ecb8dbeb https://conda.anaconda.org/conda-forge/win-64/brotli-bin-1.1.0-hcfcfb64_1.conda#0105229d7c5fabaa840043a86c10ec64 https://conda.anaconda.org/conda-forge/win-64/brotli-python-1.1.0-py38hd3f51b4_1.conda#72708ea626a2530148ea49eb743576f4 -https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda#7f3dbc9179b4dde7da98dfb151d0ad22 -https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.0-pyhd8ed1ab_0.conda#fef8ef5f0a54546b9efee39468229917 +https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789 +https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.2-pyhd8ed1ab_0.conda#7f4a9e3fcff3f6356ae99244a014da6a https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/win-64/cython-3.0.4-py38hd3f51b4_0.conda#50a79eef478496ae81b1912c62d2df3e -https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.1.3-pyhd8ed1ab_0.conda#e6518222753f519e911e83136d2158d9 +https://conda.anaconda.org/conda-forge/win-64/cython-3.0.5-py38hd3f51b4_0.conda#ffa28c4a25a89ba45f015db6f8b9f26b +https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363 https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 https://conda.anaconda.org/conda-forge/win-64/freetype-2.12.1-hdaf720e_2.conda#3761b23693f768dc75a8fd0a73ca053f https://conda.anaconda.org/conda-forge/noarch/idna-3.4-pyhd8ed1ab_0.tar.bz2#34272b248891bddccc64479f9a7fffed https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 https://conda.anaconda.org/conda-forge/win-64/kiwisolver-1.4.5-py38hb1fd069_1.conda#19a5ecd89c16b22db1d1830e93392aab https://conda.anaconda.org/conda-forge/win-64/libclang-15.0.7-default_h77d9078_3.conda#71c8b6249c9e9e18b3aec705e95c1040 -https://conda.anaconda.org/conda-forge/win-64/libglib-2.78.0-he8f3873_0.conda#25f5b3502a82ac425c72c3bc0efbecb5 +https://conda.anaconda.org/conda-forge/win-64/libglib-2.78.1-h16e383f_1.conda#092b567b75f9f699e8d1fbaf37064b8e https://conda.anaconda.org/conda-forge/win-64/libhwloc-2.9.3-default_haede6df_1009.conda#87da045f6d26ce9fe20ad76a18f6a18a https://conda.anaconda.org/conda-forge/win-64/libtiff-4.6.0-h6e2ebb7_2.conda#08d653b74ee2dec0131ad4259ffbb126 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 @@ -75,52 +75,51 @@ https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5 https://conda.anaconda.org/conda-forge/win-64/tornado-6.3.3-py38h91455d4_1.conda#1daea9d484de0ed524b80c9772484102 https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.8.0-pyha770c72_0.conda#5b1be40a26d10a06f6d4f1f9e19fa0c7 https://conda.anaconda.org/conda-forge/win-64/unicodedata2-15.1.0-py38h91455d4_0.conda#556fb89abee3970c76556144bdab3263 -https://conda.anaconda.org/conda-forge/noarch/wheel-0.41.2-pyhd8ed1ab_0.conda#1ccd092478b3e0ee10d7a891adbf8a4f +https://conda.anaconda.org/conda-forge/noarch/wheel-0.41.3-pyhd8ed1ab_0.conda#3fc026b9c87d091c4b34a6c997324ae8 https://conda.anaconda.org/conda-forge/noarch/win_inet_pton-1.1.0-pyhd8ed1ab_6.tar.bz2#30878ecc4bd36e8deeea1e3c151b2e0b https://conda.anaconda.org/conda-forge/win-64/xorg-libxau-1.0.11-hcd874cb_0.conda#c46ba8712093cb0114404ae8a7582e1a https://conda.anaconda.org/conda-forge/win-64/xorg-libxdmcp-1.1.3-hcd874cb_0.tar.bz2#46878ebb6b9cbd8afcf8088d7ef00ece https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a https://conda.anaconda.org/conda-forge/win-64/brotli-1.1.0-hcfcfb64_1.conda#f47f6db2528e38321fb00ae31674c133 https://conda.anaconda.org/conda-forge/win-64/coverage-7.3.2-py38h91455d4_0.conda#6d4fd016918358448d9055caa59cb616 -https://conda.anaconda.org/conda-forge/win-64/glib-tools-2.78.0-h12be248_0.conda#466538fb59949a3c015b55671dc7e52c -https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.0-pyhd8ed1ab_0.conda#48b0d98e0c0ec810d3ccc2a0926c8c0e +https://conda.anaconda.org/conda-forge/win-64/glib-tools-2.78.1-h12be248_1.conda#8a3af479aa812a2a2cb0a4ab2be52dc9 +https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.1-pyhd8ed1ab_0.conda#3d5fa25cf42f3f32a12b2d874ace8574 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc https://conda.anaconda.org/conda-forge/win-64/lcms2-2.15-h67d730c_3.conda#f92e86636451e3f6cea03e395346fa90 https://conda.anaconda.org/conda-forge/win-64/libxcb-1.15-hcd874cb_0.conda#090d91b69396f14afef450c285f9758c https://conda.anaconda.org/conda-forge/win-64/openjpeg-2.5.0-h3d672ee_3.conda#45a9628a04efb6fc326fff0a8f47b799 -https://conda.anaconda.org/conda-forge/noarch/pip-23.3-pyhd8ed1ab_0.conda#a06f102f59c8e3bb8b3e46e71c384709 +https://conda.anaconda.org/conda-forge/noarch/pip-23.3.1-pyhd8ed1ab_0.conda#2400c0b86889f43aa52067161e1fb108 +https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.0.0-pyhd8ed1ab_0.conda#6bb4ee32cd435deaeac72776c001e7ac https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyh0701188_6.tar.bz2#56cd9fe388baac0e90c7149cfac95b60 -https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.2-pyhd8ed1ab_0.conda#6dd662ff5ac9a783e5c940ce9f3fe649 +https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.3-pyhd8ed1ab_0.conda#5bdca0aca30b0ee62bb84854e027eae0 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 https://conda.anaconda.org/conda-forge/win-64/sip-6.7.12-py38hd3f51b4_0.conda#8234c36685a08c47f11865ffc7ed36a9 https://conda.anaconda.org/conda-forge/win-64/tbb-2021.10.0-h91493d7_2.conda#5b8c97cf8f0e81d6c22c0bda9978790d -https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.8.0-hd8ed1ab_0.conda#384462e63262a527bda564fa2d9126c0 -https://conda.anaconda.org/conda-forge/win-64/fonttools-4.43.1-py38h91455d4_0.conda#d960b3ead49389adc21a688a75cd30ef -https://conda.anaconda.org/conda-forge/win-64/glib-2.78.0-h12be248_0.conda#1ed98e4da48693079f2fe83298c5b0ac -https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.0-pyhd8ed1ab_0.conda#6a62c2cc25376a0d050b3d1d221c3ee9 -https://conda.anaconda.org/conda-forge/win-64/mkl-2023.2.0-h6a75c08_50496.conda#03da367d935ecf4d3e4005cf705d0e21 +https://conda.anaconda.org/conda-forge/win-64/fonttools-4.45.0-py38h91455d4_0.conda#024466f195a98b53fe3667f9468eb30b +https://conda.anaconda.org/conda-forge/win-64/glib-2.78.1-h12be248_1.conda#247e1bc91e6698e1b9846c4d4df509fa +https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.1-pyhd8ed1ab_0.conda#d04bd1b5bed9177dd7c3cef15e2b6710 +https://conda.anaconda.org/conda-forge/win-64/mkl-2023.2.0-h6a75c08_50497.conda#064cea9f45531e7b53584acf4bd8b044 https://conda.anaconda.org/conda-forge/win-64/pillow-10.1.0-py38hc375fad_0.conda#d671ae9247896e544d8b2df9feaf1f89 -https://conda.anaconda.org/conda-forge/noarch/platformdirs-3.11.0-pyhd8ed1ab_0.conda#8f567c0a74aa44cf732f15773b4083b0 https://conda.anaconda.org/conda-forge/win-64/pyqt5-sip-12.12.2-py38hd3f51b4_5.conda#32974507018705cbe32a392473cd6ec1 https://conda.anaconda.org/conda-forge/noarch/pytest-cov-4.1.0-pyhd8ed1ab_0.conda#06eb685a3a0b146347a58dda979485da https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 -https://conda.anaconda.org/conda-forge/noarch/urllib3-2.0.7-pyhd8ed1ab_0.conda#270e71c14d37074b1d066ee21cf0c4a6 -https://conda.anaconda.org/conda-forge/win-64/gstreamer-1.22.6-hb4038d2_2.conda#e6d2009457a1e5d9653fd06873a7a367 -https://conda.anaconda.org/conda-forge/win-64/libblas-3.9.0-19_win64_mkl.conda#4f8a1a63cfbf74bc7b2813d9c6c205be -https://conda.anaconda.org/conda-forge/win-64/mkl-devel-2023.2.0-h57928b3_50496.conda#381330681b4506191e1a71699ea9e6fc +https://conda.anaconda.org/conda-forge/noarch/urllib3-2.1.0-pyhd8ed1ab_0.conda#f8ced8ee63830dec7ecc1be048d1470a +https://conda.anaconda.org/conda-forge/win-64/gstreamer-1.22.7-hb4038d2_0.conda#9b2f6622276ed34d20eb36e6a4ce2f50 +https://conda.anaconda.org/conda-forge/win-64/libblas-3.9.0-20_win64_mkl.conda#6cad6cd2fbdeef4d651b8f752a4da960 +https://conda.anaconda.org/conda-forge/win-64/mkl-devel-2023.2.0-h57928b3_50497.conda#0d52cfab24361c77268b54920c11903c https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda#a30144e4156cdbb236f99ebb49828f8b -https://conda.anaconda.org/conda-forge/win-64/gst-plugins-base-1.22.6-h001b923_2.conda#20e57b894392cb792cdf5c501b35a8f6 -https://conda.anaconda.org/conda-forge/win-64/libcblas-3.9.0-19_win64_mkl.conda#1b9ede5cff953aa1a5f4d9f8ec644972 -https://conda.anaconda.org/conda-forge/win-64/liblapack-3.9.0-19_win64_mkl.conda#574e6e8bcc85df2885eb2a87d31ae005 -https://conda.anaconda.org/conda-forge/noarch/pooch-1.7.0-pyhd8ed1ab_4.conda#3cdaf7af08850933662b1e228bc6b5bc -https://conda.anaconda.org/conda-forge/win-64/liblapacke-3.9.0-19_win64_mkl.conda#c77175c01902a5a8eb9e5598bd9e7756 +https://conda.anaconda.org/conda-forge/win-64/gst-plugins-base-1.22.7-h001b923_0.conda#e4b56ad6c21e861456f32bfc79b43c4b +https://conda.anaconda.org/conda-forge/win-64/libcblas-3.9.0-20_win64_mkl.conda#e6d36cfcb2f2dff0f659d2aa0813eb2d +https://conda.anaconda.org/conda-forge/win-64/liblapack-3.9.0-20_win64_mkl.conda#9510d07424d70fcac553d86b3e4a7c14 +https://conda.anaconda.org/conda-forge/noarch/pooch-1.8.0-pyhd8ed1ab_0.conda#134b2b57b7865d2316a7cce1915a51ed +https://conda.anaconda.org/conda-forge/win-64/liblapacke-3.9.0-20_win64_mkl.conda#960008cd6e9827a5c9b68e77fdf3d29f https://conda.anaconda.org/conda-forge/win-64/numpy-1.24.4-py38h1d91fd2_0.conda#bb13551a7913ff4de74df687f03ba14e https://conda.anaconda.org/conda-forge/win-64/qt-main-5.15.8-h9e85ed6_17.conda#568b134e26f3e2a44ff24028c27b8c0e -https://conda.anaconda.org/conda-forge/win-64/blas-devel-3.9.0-19_win64_mkl.conda#fa4fbc5210838e16ccdad9fff235d1ff +https://conda.anaconda.org/conda-forge/win-64/blas-devel-3.9.0-20_win64_mkl.conda#40f21d1e894795983dec1036847e7460 https://conda.anaconda.org/conda-forge/win-64/contourpy-1.1.1-py38hb1fd069_1.conda#13df3a01683e407c2745cc0b6aa6beca https://conda.anaconda.org/conda-forge/win-64/pyqt-5.15.9-py38hd6c051e_5.conda#7d7f5b99c3929f02566314f252f9ef53 https://conda.anaconda.org/conda-forge/win-64/scipy-1.10.1-py38h1aea9ed_3.conda#1ed766b46170f86ead2ae6b9b8151191 -https://conda.anaconda.org/conda-forge/win-64/blas-2.119-mkl.conda#ae72b73d43b1b758c0f56933ba304050 +https://conda.anaconda.org/conda-forge/win-64/blas-2.120-mkl.conda#169d630727008b4356a138a3a0f595d4 https://conda.anaconda.org/conda-forge/win-64/matplotlib-base-3.7.3-py38h2724991_0.conda#80ee24705fa140b2febf66a1f9fb9b39 https://conda.anaconda.org/conda-forge/win-64/matplotlib-3.7.3-py38haa244fe_0.conda#30c703c4b30df6b261308086e5171a9d diff --git a/build_tools/azure/py38_conda_forge_openblas_ubuntu_2204_environment.yml b/build_tools/azure/py38_conda_forge_openblas_ubuntu_2204_environment.yml index bbbb3bb4cef6c..3297f4ba131cd 100644 --- a/build_tools/azure/py38_conda_forge_openblas_ubuntu_2204_environment.yml +++ b/build_tools/azure/py38_conda_forge_openblas_ubuntu_2204_environment.yml @@ -13,6 +13,7 @@ dependencies: - threadpoolctl - matplotlib - pandas + - rich - pyamg - pytest - pytest-xdist=2.5.0 diff --git a/build_tools/azure/py38_conda_forge_openblas_ubuntu_2204_linux-64_conda.lock b/build_tools/azure/py38_conda_forge_openblas_ubuntu_2204_linux-64_conda.lock index d7b262a895e76..775028f7554fa 100644 --- a/build_tools/azure/py38_conda_forge_openblas_ubuntu_2204_linux-64_conda.lock +++ b/build_tools/azure/py38_conda_forge_openblas_ubuntu_2204_linux-64_conda.lock @@ -1,23 +1,23 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: c5e4221207552c628ae5d6fd5d1a639c5fa48e17df39e521800953029c61bb2a +# input_hash: 2ae0a4d59a28951631554a8905e9c64cefd33da5f2c2f8e29d1d5056387cd522 @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 -https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.7.22-hbcca054_0.conda#a73ecd2988327ad4c8f2c331482917f2 +https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.11.17-hbcca054_0.conda#01ffc8d36f9eba0ce0b3c1955fa780ee https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2#0c96522c6bdaed4b1566d11387caaf45 https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2#34893075a5c9e55cdafac56607368fc6 https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2#4d59c254e01d9cde7957100457e2d5fb https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-hab24e00_0.tar.bz2#19410c3df09dfb12d1206132a1d357c5 https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda#7aca3059a1729aa76c597603f10b0dd3 -https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_2.conda#9172c297304f2a20134fc56c97fbe229 +https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_3.conda#937eaed008f6bf2191c5fe76f87755e9 https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.8-4_cp38.conda#ea6b353536f42246cd130c7fef1285cf https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2#f766549260d6815b0c52253f1fb1bb29 https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2#fee5683a3f04bd15cbd8318b096a27ab https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2#562b26ba2e19059551a811e72ab7f793 -https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_2.conda#c28003b0be0494f9a7664389146716ff +https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_3.conda#23fdf1fef05baeb7eadc2aed5fb0011f https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.10-hd590300_0.conda#75dae9a4201732aa78a530b826ee5fe0 https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2#d9c69a24ad678ffce24c6543a0176b00 -https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h7f98852_4.tar.bz2#a1fd65c7ccbf10880423d82bca54eb54 +https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda#69b8b6202a07720f448be700e300ccf4 https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2#14947d8770185e5153fdd04d4673ed37 https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2#8c54672728e8ec6aa6db90cf2806d220 https://conda.anaconda.org/conda-forge/linux-64/icu-73.2-h59595ed_0.conda#cc47e1facc155f91abd89b11e48e72ff @@ -28,7 +28,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.1.0-hd590300_1 https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.19-hd590300_0.conda#1635570038840ee3f9c71d22aa5b8b6d https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda#6305a3dd2752c76335295da4e581f2fd https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2#d645c6d2ac96843a2bfaccd2d62b3ac3 -https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_2.conda#78fdab09d9138851dde2b5fe2a11019e +https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_3.conda#c714d905cdfa0e70200f68b80cc04764 https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-h166bdaf_0.tar.bz2#b62b52da46c39ee2bc3c162ac7f1804d https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-3.0.0-hd590300_1.conda#ea25936bb4080d843790b586850f82b8 https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda#30fd6e37fe21f86f4bd26d6ee73eeec7 @@ -39,9 +39,9 @@ https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.co https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda#f36c115f1ee199da648e0597ec2047ad https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda#318b08df404f9c9be5712aaa5a6f0bb0 https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.3-h59595ed_0.conda#bdadff838d5437aea83607ced8b37f75 -https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-hcb278e6_0.conda#681105bccc2a3f7f1a837d47d39c9179 +https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda#7dbaa197d7ba6032caf7ae7f32c1efa0 https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda#da0ec11a6454ae19bff5b02ed881a2b1 -https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.3-hd590300_0.conda#7bb88ce04c8deb9f7d763ae04a1da72f +https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.4-hd590300_0.conda#412ba6938c3e2abaca8b1129ea82e238 https://conda.anaconda.org/conda-forge/linux-64/pixman-0.42.2-h59595ed_0.conda#700edd63ccd5fc66b70b1c028cea9a68 https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2#22dad4df6e8630e8dff2428f6f6a7036 https://conda.anaconda.org/conda-forge/linux-64/xorg-kbproto-1.0.7-h7f98852_1002.tar.bz2#4b230e8381279d76131116660f5a241a @@ -60,32 +60,32 @@ https://conda.anaconda.org/conda-forge/linux-64/libcap-2.69-h0f662aa_0.conda#25c https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20191231-he28a2e2_2.tar.bz2#4d331e44109e3f0e19b4cb8f9b82f3e1 https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda#a1cfcc585f0c42bf8d5546bb1dfb668d https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.3-h59595ed_0.conda#ee48bf17cc83a00f59ca1494d5646869 -https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_2.conda#e75a75a6eaf6f318dae2631158c46575 +https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_3.conda#73031c79546ad06f1fe62e57fdd021bc https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.47-h71f35ed_0.conda#c2097d0b46367996f09b4e8e4920384a https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.39-h753d276_0.conda#e1c890aebdebbfbf87e2c917187b4416 -https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.43.2-h2797004_0.conda#4b441a1ee22397d5a27dc1126b849edd +https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.0-h2797004_0.conda#b58e6816d137f3aabf77d341dd5d732b https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2#309dec04b70a3cc0f1e84a4013683bc0 https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda#33277193f5b92bad9fdd230eb700929c -https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.5-h232c23b_1.conda#f3858448893839820d4bcfb14ad3ecdf -https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_5.conda#1e8ef4090ca4f0d66404a7441e1dbf3c -https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.40-hc3806b6_0.tar.bz2#69e2c796349cd9b273890bee0febfe1b +https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.6-h232c23b_0.conda#427a3e59d66cb5d145020bd9c6493334 +https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_6.conda#80bf3b277c120dd294b51d404b931a75 +https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.42-hcad00b1_0.conda#679c8961826aa4b50653bce17ee52abe https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4 -https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-h2797004_0.conda#513336054f884f95d9fd925748f41ef3 +https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda#d453b98d9c83e71da0741bb0ff4d76bc https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda#93ee23f12bc2e684548181256edd2cf6 https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda#68c34ec6149623be41a1933ab996a209 https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda#04b88013080254850d6c01ed54810589 https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.1.0-hd590300_1.conda#39f910d205726805a958da408ca194ba https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda#9ae35c3d96db2c94ce0cef86efdfa2cb https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 -https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.1-h166bdaf_0.tar.bz2#f967fc95089cd247ceed56eda31de3a9 -https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.0-hebfc3b9_0.conda#e618003da3547216310088478e475945 +https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.2-hd590300_0.conda#3d7d5e5cebf8af5aadb040732860f1b6 +https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.1-h783c2da_1.conda#70052d6c1e84643e30ffefb21ab6950f https://conda.anaconda.org/conda-forge/linux-64/libhiredis-1.0.2-h2cc385e_0.tar.bz2#b34907d3a81a3cd8095ee83d174c074a https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-h5cf9203_3.conda#9efe82d44b76a7529a1d702e5a37752e -https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.24-pthreads_h413a1c8_0.conda#6e4ef6ca28655124dcde9bd500e44c32 +https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.25-pthreads_h413a1c8_0.conda#d172b34a443b95f86089e8229ddc9a17 https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda#ef1910918dd895516a769ed36b5b3a4e https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda#55ed21669b2015f77c180feb1dd41930 -https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.3-h4dfa4b3_0.conda#1a82298c57b609a31ab6f2342a307b69 -https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_5.conda#b72f016c910ff9295b1377d3e17da3f2 +https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.5-h4dfa4b3_0.conda#799291c22ec87a0c86c0a4fc0e22b1c5 +https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_6.conda#e87530d1b12dd7f4e0f856dc07358d60 https://conda.anaconda.org/conda-forge/linux-64/nss-3.94-h1d7d5a4_0.conda#7caef74bbfa730e014b20f0852068509 https://conda.anaconda.org/conda-forge/linux-64/python-3.8.18-hd12c33a_0_cpython.conda#334cb629e10d209f1c17630f653168b1 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-hd590300_1.conda#9bfac7ccd94d54fd21a0501296d60424 @@ -96,32 +96,34 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.con https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hd590300_1.conda#f27a24d46e3ea7b70a1f98e50c62508f https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.1.0-py38h17151c0_1.conda#7a5a699c8992fc51ef25e980f4502c2a https://conda.anaconda.org/conda-forge/linux-64/ccache-4.8.1-h1fcd64f_0.conda#fd37a0c47d8b3667b73af0549037ce83 -https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda#7f3dbc9179b4dde7da98dfb151d0ad22 -https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.0-pyhd8ed1ab_0.conda#fef8ef5f0a54546b9efee39468229917 +https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789 +https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.2-pyhd8ed1ab_0.conda#7f4a9e3fcff3f6356ae99244a014da6a https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.4-py38h17151c0_0.conda#d2a8d935c0af114551c4e527bd83d446 +https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.5-py38h17151c0_0.conda#06a221160ec7bd58e49f324de146742a https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2#ecfff944ba3960ecb334b9a2663d708d -https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.1.3-pyhd8ed1ab_0.conda#e6518222753f519e911e83136d2158d9 +https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363 https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda#0f69b688f52ff6da70bccb7ff7001d1d -https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.0-hfc55251_0.conda#e10134de3558dd95abda6987b5548f4f +https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.1-hfc55251_1.conda#a50918d10114a0bf80fb46c7cc692058 https://conda.anaconda.org/conda-forge/noarch/idna-3.4-pyhd8ed1ab_0.tar.bz2#34272b248891bddccc64479f9a7fffed https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py38h7f3f72f_1.conda#b66dcd4f710628fc5563ad56f02ca89b https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.15-hb7c19ff_3.conda#e96637dd92c5f340215c753a5c9a22d7 -https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-19_linux64_openblas.conda#420f4e9be59d0dc9133a0f43f7bab3f3 +https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-20_linux64_openblas.conda#2b7bb4f7562c8cf334fc2e20c2d28abc https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_h9986a30_3.conda#1720df000b48e31842500323cb7be18c https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h4637d8d_4.conda#d4529f4dff3057982a7617c7ac58fde3 -https://conda.anaconda.org/conda-forge/linux-64/libpq-16.0-hfc447b1_1.conda#e4a9a5ba40123477db33e02a78dffb01 +https://conda.anaconda.org/conda-forge/linux-64/libpq-16.1-hfc447b1_0.conda#2b7f1893cf40b4ccdc0230bcd94d5ed9 https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-254-h3516f8a_0.conda#df4b1cd0c91b4234fb02b5701a4cdddc +https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.0-pyhd8ed1ab_0.tar.bz2#f8dab71fdc13b1bf29a01248b156d268 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 -https://conda.anaconda.org/conda-forge/linux-64/openblas-0.3.24-pthreads_h7a3da1a_0.conda#ebe8e905b06dfc5b4b40642d34b1d2f3 +https://conda.anaconda.org/conda-forge/linux-64/openblas-0.3.25-pthreads_h7a3da1a_0.conda#87661673941b5e702275fdf0fc095ad0 https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.0-h488ebb8_3.conda#128c25b7fe6a25286a48f3a6a9b5b6f3 https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.3.0-pyhd8ed1ab_0.conda#2390bd10bed1f3fdc7a537fb5a447d8d https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2#7205635cd71531943440fbfe3b6b5727 https://conda.anaconda.org/conda-forge/noarch/py-1.11.0-pyh6c4a22f_0.tar.bz2#b4613d7e7a493916d867842a6a148054 +https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2#2a7de29fb590ca14b5243c4c812c8025 https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2023.3-pyhd8ed1ab_0.conda#2590495f608a63625e165915fb4e2e34 @@ -140,37 +142,38 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.4-h0b41bf4_2.co https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.11-hd590300_0.conda#ed67c36f215b310412b2af935bf3e530 https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda#f907bb958910dc404647326ca80c263e -https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.43.1-py38h01eb140_0.conda#13ee70161cad7093ce2f9f83b300ed06 -https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.0-hfc55251_0.conda#2f55a36b549f51a7e0c2b1e3c3f0ccd4 -https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.0-pyhd8ed1ab_0.conda#48b0d98e0c0ec810d3ccc2a0926c8c0e +https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.45.0-py38h01eb140_0.conda#7e8df1d608701e620f9ee87b76108be4 +https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.1-hfc55251_1.conda#8d7242302bb3d03b9a690b6dda872603 +https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.1-pyhd8ed1ab_0.conda#3d5fa25cf42f3f32a12b2d874ace8574 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc -https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-19_linux64_openblas.conda#d12374af44575413fbbd4a217d46ea33 +https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-20_linux64_openblas.conda#36d486d72ab64ffea932329a1d3729a3 https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_h7634d5b_3.conda#0922208521c0463e690bbaebba7eb551 -https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-19_linux64_openblas.conda#9f100edf65436e3eabc2a51fc00b2c37 +https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-20_linux64_openblas.conda#6fabc51f5e647d09cc010c40061557e0 https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.6.0-h5d7e998_0.conda#d8edd0e29db6fb6b6988e1a28d35d994 +https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda#93a8e71256479c62074356ef6ebf501b https://conda.anaconda.org/conda-forge/linux-64/pillow-10.1.0-py38ha43c96d_0.conda#67ca17c651f86159a3b8ed1132d97c12 +https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.0.0-pyhd8ed1ab_0.conda#6bb4ee32cd435deaeac72776c001e7ac https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-hb77b528_5.conda#ac902ff3c1c6d750dd0dfc93a974ab74 -https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.2-pyhd8ed1ab_0.conda#6dd662ff5ac9a783e5c940ce9f3fe649 +https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.3-pyhd8ed1ab_0.conda#5bdca0aca30b0ee62bb84854e027eae0 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.12-py38h17151c0_0.conda#ae2edf79b63f97071aea203b22a6774a -https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.8.0-hd8ed1ab_0.conda#384462e63262a527bda564fa2d9126c0 -https://conda.anaconda.org/conda-forge/noarch/urllib3-2.0.7-pyhd8ed1ab_0.conda#270e71c14d37074b1d066ee21cf0c4a6 -https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.6-h98fc4e7_2.conda#1c95f7c612f9121353c4ef764678113e -https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.2.1-h3d44ed6_0.conda#98db5f8813f45e2b29766aff0e4a499c -https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.0-pyhd8ed1ab_0.conda#6a62c2cc25376a0d050b3d1d221c3ee9 -https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-19_linux64_openblas.conda#685e99d3214f5ac9d1ec6b37983985a6 +https://conda.anaconda.org/conda-forge/noarch/urllib3-2.1.0-pyhd8ed1ab_0.conda#f8ced8ee63830dec7ecc1be048d1470a +https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.7-h98fc4e7_0.conda#6c919bafe5e03428a8e2ef319d7ef990 +https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.3.0-h3d44ed6_0.conda#5a6f6c00ef982a9bc83558d9ac8f64a0 +https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.1-pyhd8ed1ab_0.conda#d04bd1b5bed9177dd7c3cef15e2b6710 +https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-20_linux64_openblas.conda#05c5862c7dc25e65ba6c471d96429dae https://conda.anaconda.org/conda-forge/linux-64/numpy-1.24.4-py38h59b608b_0.conda#8c3e050afeeb2b32575bdb8955cc67b2 -https://conda.anaconda.org/conda-forge/noarch/platformdirs-3.11.0-pyhd8ed1ab_0.conda#8f567c0a74aa44cf732f15773b4083b0 https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.12.2-py38h17151c0_5.conda#3d66f5c4a0af2713f60ec11bf1230136 https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda#a30144e4156cdbb236f99ebb49828f8b -https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-19_linux64_openblas.conda#96bca12f1b7c48298dd1abf3e11121af +https://conda.anaconda.org/conda-forge/noarch/rich-13.7.0-pyhd8ed1ab_0.conda#d7a11d4f3024b2f4a6e0ae7377dd61e9 +https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-20_linux64_openblas.conda#9932a1d4e9ecf2d35fb19475446e361e https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.1.1-py38h7f3f72f_1.conda#18ae206b2d413e5cc8d2bb8ab48aa165 -https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.6-h8e1006c_2.conda#3d8e98279bad55287f2ef9047996f33c +https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.7-h8e1006c_0.conda#065e2c1d49afa3fdc1a01f1dacd6ab09 https://conda.anaconda.org/conda-forge/linux-64/pandas-2.0.3-py38h01efb38_1.conda#01a2b6144e65631e2fe24e569d0738ee -https://conda.anaconda.org/conda-forge/noarch/pooch-1.7.0-pyhd8ed1ab_4.conda#3cdaf7af08850933662b1e228bc6b5bc +https://conda.anaconda.org/conda-forge/noarch/pooch-1.8.0-pyhd8ed1ab_0.conda#134b2b57b7865d2316a7cce1915a51ed https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e -https://conda.anaconda.org/conda-forge/linux-64/blas-2.119-openblas.conda#f536a14a54da8b2aedd5a967d1e407c9 +https://conda.anaconda.org/conda-forge/linux-64/blas-2.120-openblas.conda#c8f6916a81a340650078171b1d852574 https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.7.3-py38h58ed7fa_0.conda#d8db25d58823182ce93233964f307a47 https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h82b777d_17.conda#4f01e33dbb406085a16a2813ab067e95 https://conda.anaconda.org/conda-forge/linux-64/scipy-1.10.1-py38h59b608b_3.conda#2f2a57462fcfbc67dfdbb0de6f7484c2 diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock index c5fac7dea82dd..9575b4985fac3 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock @@ -1,26 +1,27 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 7aa55d66dfbd0f6267a9aff8c750d1e9f42cd339726c8f9c4d1299341b064849 +# input_hash: 3b28d20370f88920dc1e7cdc408d6f9d5dd33d46147313a4cca95c7ed8eac48f @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 -https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.7.22-hbcca054_0.conda#a73ecd2988327ad4c8f2c331482917f2 +https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.11.17-hbcca054_0.conda#01ffc8d36f9eba0ce0b3c1955fa780ee https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2#0c96522c6bdaed4b1566d11387caaf45 https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2#34893075a5c9e55cdafac56607368fc6 https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2#4d59c254e01d9cde7957100457e2d5fb https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-hab24e00_0.tar.bz2#19410c3df09dfb12d1206132a1d357c5 https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda#7aca3059a1729aa76c597603f10b0dd3 -https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_2.conda#9172c297304f2a20134fc56c97fbe229 +https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_3.conda#937eaed008f6bf2191c5fe76f87755e9 https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.11-4_cp311.conda#d786502c97404c94d7d58d258a445a65 https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda#939e3e74d8be4dac89ce83b20de2492a https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2#f766549260d6815b0c52253f1fb1bb29 https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2#fee5683a3f04bd15cbd8318b096a27ab https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2#562b26ba2e19059551a811e72ab7f793 -https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_2.conda#c28003b0be0494f9a7664389146716ff +https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_3.conda#23fdf1fef05baeb7eadc2aed5fb0011f https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.10-hd590300_0.conda#75dae9a4201732aa78a530b826ee5fe0 https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2#d9c69a24ad678ffce24c6543a0176b00 https://conda.anaconda.org/conda-forge/linux-64/aws-c-common-0.9.0-hd590300_0.conda#71b89db63b5b504e7afc8ad901172e1e -https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h7f98852_4.tar.bz2#a1fd65c7ccbf10880423d82bca54eb54 -https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.20.1-hd590300_0.conda#6642e4faa4804be3a0e7edfefbd16595 +https://conda.anaconda.org/conda-forge/linux-64/brotli-1.0.9-h9c3ff4c_4.tar.bz2#f4f75dc7038aaeb6eaae16a5ef5350b3 +https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda#69b8b6202a07720f448be700e300ccf4 +https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.22.1-hd590300_0.conda#8430bd266c7b2cfbda403f7585d5ee86 https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2#14947d8770185e5153fdd04d4673ed37 https://conda.anaconda.org/conda-forge/linux-64/gflags-2.2.2-he1b5a44_1004.tar.bz2#cddaf2c63ea4a5901cf09524c490ecdc https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2#8c54672728e8ec6aa6db90cf2806d220 @@ -35,7 +36,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.19-hd590300_0.conda https://conda.anaconda.org/conda-forge/linux-64/libev-4.33-h516909a_1.tar.bz2#6f8720dff19e17ce5d48cfe7f3d2f0a3 https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda#6305a3dd2752c76335295da4e581f2fd https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2#d645c6d2ac96843a2bfaccd2d62b3ac3 -https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_2.conda#78fdab09d9138851dde2b5fe2a11019e +https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_3.conda#c714d905cdfa0e70200f68b80cc04764 https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-h166bdaf_0.tar.bz2#b62b52da46c39ee2bc3c162ac7f1804d https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-3.0.0-hd590300_1.conda#ea25936bb4080d843790b586850f82b8 https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda#30fd6e37fe21f86f4bd26d6ee73eeec7 @@ -48,9 +49,9 @@ https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.co https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda#f36c115f1ee199da648e0597ec2047ad https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda#318b08df404f9c9be5712aaa5a6f0bb0 https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.3-h59595ed_0.conda#bdadff838d5437aea83607ced8b37f75 -https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-hcb278e6_0.conda#681105bccc2a3f7f1a837d47d39c9179 +https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda#7dbaa197d7ba6032caf7ae7f32c1efa0 https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda#da0ec11a6454ae19bff5b02ed881a2b1 -https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.3-hd590300_0.conda#7bb88ce04c8deb9f7d763ae04a1da72f +https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.4-hd590300_0.conda#412ba6938c3e2abaca8b1129ea82e238 https://conda.anaconda.org/conda-forge/linux-64/pixman-0.42.2-h59595ed_0.conda#700edd63ccd5fc66b70b1c028cea9a68 https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2#22dad4df6e8630e8dff2428f6f6a7036 https://conda.anaconda.org/conda-forge/linux-64/rdma-core-28.9-h59595ed_1.conda#aeffb7c06b5f65e55e6c637408dc4100 @@ -78,31 +79,30 @@ https://conda.anaconda.org/conda-forge/linux-64/libcap-2.69-h0f662aa_0.conda#25c https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20191231-he28a2e2_2.tar.bz2#4d331e44109e3f0e19b4cb8f9b82f3e1 https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda#a1cfcc585f0c42bf8d5546bb1dfb668d https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.3-h59595ed_0.conda#ee48bf17cc83a00f59ca1494d5646869 -https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_2.conda#e75a75a6eaf6f318dae2631158c46575 +https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_3.conda#73031c79546ad06f1fe62e57fdd021bc https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.47-h71f35ed_0.conda#c2097d0b46367996f09b4e8e4920384a -https://conda.anaconda.org/conda-forge/linux-64/libnghttp2-1.52.0-h61bc06f_0.conda#613955a50485812985c059e7b269f42e +https://conda.anaconda.org/conda-forge/linux-64/libnghttp2-1.58.0-h47da74e_0.conda#9b13d5ee90fc9f09d54fd403247342b4 https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.39-h753d276_0.conda#e1c890aebdebbfbf87e2c917187b4416 https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-3.21.12-hfc55251_2.conda#e3a7d4ba09b8dc939b98fef55f539220 -https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.43.2-h2797004_0.conda#4b441a1ee22397d5a27dc1126b849edd +https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.0-h2797004_0.conda#b58e6816d137f3aabf77d341dd5d732b https://conda.anaconda.org/conda-forge/linux-64/libssh2-1.11.0-h0841786_0.conda#1f5a58e686b13bcfde88b93f547d23fe https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2#309dec04b70a3cc0f1e84a4013683bc0 https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda#33277193f5b92bad9fdd230eb700929c -https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.5-h232c23b_1.conda#f3858448893839820d4bcfb14ad3ecdf -https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_5.conda#1e8ef4090ca4f0d66404a7441e1dbf3c -https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.40-hc3806b6_0.tar.bz2#69e2c796349cd9b273890bee0febfe1b +https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.6-h232c23b_0.conda#427a3e59d66cb5d145020bd9c6493334 +https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_6.conda#80bf3b277c120dd294b51d404b931a75 +https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.42-hcad00b1_0.conda#679c8961826aa4b50653bce17ee52abe https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4 https://conda.anaconda.org/conda-forge/linux-64/s2n-1.3.49-h06160fa_0.conda#1d78349eb26366ecc034a4afe70a8534 -https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-h2797004_0.conda#513336054f884f95d9fd925748f41ef3 +https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda#d453b98d9c83e71da0741bb0ff4d76bc https://conda.anaconda.org/conda-forge/linux-64/ucx-1.14.1-h64cca9d_5.conda#39aa3b356d10d7e5add0c540945a0944 https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda#93ee23f12bc2e684548181256edd2cf6 https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda#68c34ec6149623be41a1933ab996a209 https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda#04b88013080254850d6c01ed54810589 https://conda.anaconda.org/conda-forge/linux-64/aws-c-io-0.13.32-he9a53bd_1.conda#8a24e5820f4a0ffd2ed9c4722cd5d7ca -https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.0.9-h166bdaf_9.conda#d47dee1856d9cb955b8076eeff304a5b https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda#9ae35c3d96db2c94ce0cef86efdfa2cb https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 -https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.1-h166bdaf_0.tar.bz2#f967fc95089cd247ceed56eda31de3a9 -https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.0-hebfc3b9_0.conda#e618003da3547216310088478e475945 +https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.2-hd590300_0.conda#3d7d5e5cebf8af5aadb040732860f1b6 +https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.1-h783c2da_1.conda#70052d6c1e84643e30ffefb21ab6950f https://conda.anaconda.org/conda-forge/linux-64/libgrpc-1.54.3-hb20ce57_0.conda#7af7c59ab24db007dfd82e0a3a343f66 https://conda.anaconda.org/conda-forge/linux-64/libhiredis-1.0.2-h2cc385e_0.tar.bz2#b34907d3a81a3cd8095ee83d174c074a https://conda.anaconda.org/conda-forge/linux-64/libhwloc-2.9.3-default_h554bfaf_1009.conda#f36ddc11ca46958197a45effdd286e45 @@ -110,8 +110,8 @@ https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-h5cf9203_3.cond https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda#ef1910918dd895516a769ed36b5b3a4e https://conda.anaconda.org/conda-forge/linux-64/libthrift-0.18.1-h8fd135c_2.conda#bbf65f7688512872f063810623b755dc https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda#55ed21669b2015f77c180feb1dd41930 -https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.3-h4dfa4b3_0.conda#1a82298c57b609a31ab6f2342a307b69 -https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_5.conda#b72f016c910ff9295b1377d3e17da3f2 +https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.5-h4dfa4b3_0.conda#799291c22ec87a0c86c0a4fc0e22b1c5 +https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_6.conda#e87530d1b12dd7f4e0f856dc07358d60 https://conda.anaconda.org/conda-forge/linux-64/nss-3.94-h1d7d5a4_0.conda#7caef74bbfa730e014b20f0852068509 https://conda.anaconda.org/conda-forge/linux-64/orc-1.9.0-h2f23424_1.conda#9571eb3eb0f7fe8b59956a7786babbcd https://conda.anaconda.org/conda-forge/linux-64/python-3.11.6-hab00c5b_0_cpython.conda#b0dfbe2fcbfdb097d321bfd50ecddab1 @@ -123,31 +123,32 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.con https://conda.anaconda.org/conda-forge/noarch/array-api-compat-1.4-pyhd8ed1ab_0.conda#074126c948c25ddcb8298ec8685a7f3d https://conda.anaconda.org/conda-forge/linux-64/aws-c-event-stream-0.3.1-h2e3709c_4.conda#2cf21b1cbc1c096a28ffa2892257a2c1 https://conda.anaconda.org/conda-forge/linux-64/aws-c-http-0.7.11-h00aa349_4.conda#cb932dff7328ff620ce8059c9968b095 -https://conda.anaconda.org/conda-forge/linux-64/brotli-1.0.9-h166bdaf_9.conda#4601544b4982ba1861fa9b9c607b2c06 https://conda.anaconda.org/conda-forge/linux-64/ccache-4.8.1-h1fcd64f_0.conda#fd37a0c47d8b3667b73af0549037ce83 -https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda#7f3dbc9179b4dde7da98dfb151d0ad22 +https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789 https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.4-py311hb755f60_0.conda#6dc0be74820d94a0606595130218981c +https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.5-py311hb755f60_0.conda#25b42509a68f96e612534af3fe2cf033 https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2#ecfff944ba3960ecb334b9a2663d708d -https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.1.3-pyhd8ed1ab_0.conda#e6518222753f519e911e83136d2158d9 +https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363 https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda#0f69b688f52ff6da70bccb7ff7001d1d -https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.0-hfc55251_0.conda#e10134de3558dd95abda6987b5548f4f +https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.1-hfc55251_1.conda#a50918d10114a0bf80fb46c7cc692058 https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py311h9547e67_1.conda#2c65bdf442b0d37aad080c8a4e0d452f https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.15-hb7c19ff_3.conda#e96637dd92c5f340215c753a5c9a22d7 https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_h9986a30_3.conda#1720df000b48e31842500323cb7be18c https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h4637d8d_4.conda#d4529f4dff3057982a7617c7ac58fde3 https://conda.anaconda.org/conda-forge/linux-64/libcurl-8.4.0-hca28451_0.conda#1158ac1d2613b28685644931f11ee807 -https://conda.anaconda.org/conda-forge/linux-64/libpq-16.0-hfc447b1_1.conda#e4a9a5ba40123477db33e02a78dffb01 +https://conda.anaconda.org/conda-forge/linux-64/libpq-16.1-hfc447b1_0.conda#2b7f1893cf40b4ccdc0230bcd94d5ed9 https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-254-h3516f8a_0.conda#df4b1cd0c91b4234fb02b5701a4cdddc +https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.0-pyhd8ed1ab_0.tar.bz2#f8dab71fdc13b1bf29a01248b156d268 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.0-h488ebb8_3.conda#128c25b7fe6a25286a48f3a6a9b5b6f3 https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.3.0-pyhd8ed1ab_0.conda#2390bd10bed1f3fdc7a537fb5a447d8d https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2#7205635cd71531943440fbfe3b6b5727 https://conda.anaconda.org/conda-forge/noarch/py-1.11.0-pyh6c4a22f_0.tar.bz2#b4613d7e7a493916d867842a6a148054 +https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2023.3-pyhd8ed1ab_0.conda#2590495f608a63625e165915fb4e2e34 https://conda.anaconda.org/conda-forge/noarch/pytz-2023.3.post1-pyhd8ed1ab_0.conda#c93346b446cd08c169d843ae5fc0da97 @@ -167,43 +168,45 @@ https://conda.anaconda.org/conda-forge/linux-64/aws-c-auth-0.7.3-h28f7589_1.cond https://conda.anaconda.org/conda-forge/linux-64/aws-c-mqtt-0.9.3-hb447be9_1.conda#c520669eb0be9269a5f0d8ef62531882 https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda#f907bb958910dc404647326ca80c263e https://conda.anaconda.org/conda-forge/linux-64/coverage-7.3.2-py311h459d7ec_0.conda#7b3145fed7adc7c63a0e08f6f29f5480 -https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.43.1-py311h459d7ec_0.conda#ac995b680de3bdce2531c553b27dfe7e -https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.0-hfc55251_0.conda#2f55a36b549f51a7e0c2b1e3c3f0ccd4 +https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.45.0-py311h459d7ec_0.conda#316e188c8068e6d9c55d23bbead5d6c3 +https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.1-hfc55251_1.conda#8d7242302bb3d03b9a690b6dda872603 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_h7634d5b_3.conda#0922208521c0463e690bbaebba7eb551 https://conda.anaconda.org/conda-forge/linux-64/libgoogle-cloud-2.12.0-hac9eb74_1.conda#0dee716254497604762957076ac76540 https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.6.0-h5d7e998_0.conda#d8edd0e29db6fb6b6988e1a28d35d994 +https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda#93a8e71256479c62074356ef6ebf501b https://conda.anaconda.org/conda-forge/linux-64/mkl-2022.2.1-h84fe81f_16997.conda#a7ce56d5757f5b57e7daabe703ade5bb https://conda.anaconda.org/conda-forge/linux-64/pillow-10.1.0-py311ha6c5da5_0.conda#83a988daf5c49e57f7d2086fb6781fe8 https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-hb77b528_5.conda#ac902ff3c1c6d750dd0dfc93a974ab74 -https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.2-pyhd8ed1ab_0.conda#6dd662ff5ac9a783e5c940ce9f3fe649 +https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.3-pyhd8ed1ab_0.conda#5bdca0aca30b0ee62bb84854e027eae0 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.12-py311hb755f60_0.conda#02336abab4cb5dd794010ef53c54bd09 https://conda.anaconda.org/conda-forge/linux-64/aws-c-s3-0.3.14-hf3aad02_1.conda#a968ffa7e9fe0c257628033d393e512f https://conda.anaconda.org/conda-forge/linux-64/blas-1.0-mkl.tar.bz2#349aef876b1d8c9dccae01de20d5b385 -https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.6-h98fc4e7_2.conda#1c95f7c612f9121353c4ef764678113e -https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.2.1-h3d44ed6_0.conda#98db5f8813f45e2b29766aff0e4a499c +https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.7-h98fc4e7_0.conda#6c919bafe5e03428a8e2ef319d7ef990 +https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.3.0-h3d44ed6_0.conda#5a6f6c00ef982a9bc83558d9ac8f64a0 https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-16_linux64_mkl.tar.bz2#85f61af03fd291dae33150ffe89dc09a https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.12.2-py311hb755f60_5.conda#e4d262cc3600e70b505a6761d29f6207 https://conda.anaconda.org/conda-forge/noarch/pytest-cov-4.1.0-pyhd8ed1ab_0.conda#06eb685a3a0b146347a58dda979485da https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 +https://conda.anaconda.org/conda-forge/noarch/rich-13.7.0-pyhd8ed1ab_0.conda#d7a11d4f3024b2f4a6e0ae7377dd61e9 https://conda.anaconda.org/conda-forge/linux-64/aws-crt-cpp-0.21.0-hb942446_5.conda#07d92ed5403ad7b5c66ffd7d5b8f7e57 -https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.6-h8e1006c_2.conda#3d8e98279bad55287f2ef9047996f33c +https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.7-h8e1006c_0.conda#065e2c1d49afa3fdc1a01f1dacd6ab09 https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-16_linux64_mkl.tar.bz2#361bf757b95488de76c4f123805742d3 https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-16_linux64_mkl.tar.bz2#a2f166748917d6d6e4707841ca1f519e https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e https://conda.anaconda.org/conda-forge/linux-64/aws-sdk-cpp-1.10.57-h85b1a90_19.conda#0605d3d60857fc07bd6a11e878fe0f08 https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.0-py311h64a7726_0.conda#bf16a9f625126e378302f08e7ed67517 https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h82b777d_17.conda#4f01e33dbb406085a16a2813ab067e95 -https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.1.1-py311h9547e67_1.conda#52d3de443952d33c5cee6b24b172ce96 +https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py311h9547e67_0.conda#40828c5b36ef52433e21f89943e09f33 https://conda.anaconda.org/conda-forge/linux-64/libarrow-12.0.1-hb87d912_8_cpu.conda#3f3b11398fe79b578e3c44dd00a44e4a -https://conda.anaconda.org/conda-forge/linux-64/pandas-2.1.1-py311h320fe9a_1.conda#a4371a95a8ae703a22949af28467b93d -https://conda.anaconda.org/conda-forge/linux-64/polars-0.19.9-py311hf926cbc_0.conda#ca8b9ed276ef4acfb993cf6e523eebb8 +https://conda.anaconda.org/conda-forge/linux-64/pandas-2.1.3-py311h320fe9a_0.conda#3ea3486e16d559dfcb539070ed330a1e +https://conda.anaconda.org/conda-forge/linux-64/polars-0.19.15-py311hf926cbc_0.conda#ffcf4f1b60fd589aa5b035bf67783c41 https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.9-py311hf0fb5b6_5.conda#ec7e45bc76d9d0b69a74a2075932b8e8 https://conda.anaconda.org/conda-forge/linux-64/pytorch-1.13.1-cpu_py311h410fd25_1.conda#ddd2fadddf89e3dc3d541a2537fce010 https://conda.anaconda.org/conda-forge/linux-64/scipy-1.11.3-py311h64a7726_1.conda#e4b4d3b764e2d029477d0db88248a8b5 -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.0-py311h54ef318_2.conda#5655371cc61b8c31c369a7e709acb294 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.2-py311h54ef318_0.conda#9f80753bc008bfc9b95f39d9ff9f1694 https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.0.1-py311h92ebd52_1.conda#586ea5aa4a4ce2e7dbecb6c7416fc8ac https://conda.anaconda.org/conda-forge/linux-64/pyarrow-12.0.1-py311h39c9aba_8_cpu.conda#587370a25bb2c50cce90909ce20d38b8 https://conda.anaconda.org/conda-forge/linux-64/pytorch-cpu-1.13.1-cpu_py311hdb170b5_1.conda#a805d5f103e493f207613283d8acbbe1 -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.0-py311h38be061_2.conda#0289918d4a09bbd0b85fd23ddf1c3ac1 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.2-py311h38be061_0.conda#ecffdcca48fcf288c2d9554e749be7ec diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml index 07ec7bb7ff206..d04526a784582 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml @@ -13,6 +13,7 @@ dependencies: - threadpoolctl - matplotlib - pandas + - rich - pyamg - pytest - pytest-xdist=2.5.0 diff --git a/build_tools/azure/pylatest_conda_forge_mkl_no_coverage_environment.yml b/build_tools/azure/pylatest_conda_forge_mkl_no_coverage_environment.yml index 02392a4e05aa8..c7e51147850a7 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_no_coverage_environment.yml +++ b/build_tools/azure/pylatest_conda_forge_mkl_no_coverage_environment.yml @@ -13,6 +13,7 @@ dependencies: - threadpoolctl - matplotlib - pandas + - rich - pyamg - pytest - pytest-xdist=2.5.0 diff --git a/build_tools/azure/pylatest_conda_forge_mkl_no_coverage_linux-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_no_coverage_linux-64_conda.lock index 9fba6d514f5a8..3473e6f4dd428 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_no_coverage_linux-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_no_coverage_linux-64_conda.lock @@ -1,25 +1,25 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 223cf367742008b437f38ff4642c0e70494f665cf9434d4da5c6483c757397fd +# input_hash: 1599ddf513acb123ca46ba08a74656deffa48a0250a9cee59d81b7c6baaaccee @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 -https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.7.22-hbcca054_0.conda#a73ecd2988327ad4c8f2c331482917f2 +https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.11.17-hbcca054_0.conda#01ffc8d36f9eba0ce0b3c1955fa780ee https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2#0c96522c6bdaed4b1566d11387caaf45 https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2#34893075a5c9e55cdafac56607368fc6 https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2#4d59c254e01d9cde7957100457e2d5fb https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-hab24e00_0.tar.bz2#19410c3df09dfb12d1206132a1d357c5 https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda#7aca3059a1729aa76c597603f10b0dd3 -https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_2.conda#9172c297304f2a20134fc56c97fbe229 -https://conda.anaconda.org/conda-forge/linux-64/mkl-include-2023.2.0-h84fe81f_50495.conda#099fd5afb2dfb20ab1c37d4b1074476c +https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_3.conda#937eaed008f6bf2191c5fe76f87755e9 +https://conda.anaconda.org/conda-forge/linux-64/mkl-include-2023.2.0-h84fe81f_50496.conda#7af9fd0b2d7219f4a4200a34561340f6 https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.12-4_cp312.conda#dccc2d142812964fcc6abdc97b672dff https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda#939e3e74d8be4dac89ce83b20de2492a https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2#f766549260d6815b0c52253f1fb1bb29 https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2#fee5683a3f04bd15cbd8318b096a27ab https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2#562b26ba2e19059551a811e72ab7f793 -https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_2.conda#c28003b0be0494f9a7664389146716ff +https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_3.conda#23fdf1fef05baeb7eadc2aed5fb0011f https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.10-hd590300_0.conda#75dae9a4201732aa78a530b826ee5fe0 https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2#d9c69a24ad678ffce24c6543a0176b00 -https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h7f98852_4.tar.bz2#a1fd65c7ccbf10880423d82bca54eb54 +https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda#69b8b6202a07720f448be700e300ccf4 https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2#14947d8770185e5153fdd04d4673ed37 https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2#8c54672728e8ec6aa6db90cf2806d220 https://conda.anaconda.org/conda-forge/linux-64/icu-73.2-h59595ed_0.conda#cc47e1facc155f91abd89b11e48e72ff @@ -30,7 +30,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.1.0-hd590300_1 https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.19-hd590300_0.conda#1635570038840ee3f9c71d22aa5b8b6d https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda#6305a3dd2752c76335295da4e581f2fd https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2#d645c6d2ac96843a2bfaccd2d62b3ac3 -https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_2.conda#78fdab09d9138851dde2b5fe2a11019e +https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_3.conda#c714d905cdfa0e70200f68b80cc04764 https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-h166bdaf_0.tar.bz2#b62b52da46c39ee2bc3c162ac7f1804d https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-3.0.0-hd590300_1.conda#ea25936bb4080d843790b586850f82b8 https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda#30fd6e37fe21f86f4bd26d6ee73eeec7 @@ -41,9 +41,9 @@ https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.co https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda#f36c115f1ee199da648e0597ec2047ad https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda#318b08df404f9c9be5712aaa5a6f0bb0 https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.3-h59595ed_0.conda#bdadff838d5437aea83607ced8b37f75 -https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-hcb278e6_0.conda#681105bccc2a3f7f1a837d47d39c9179 +https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda#7dbaa197d7ba6032caf7ae7f32c1efa0 https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda#da0ec11a6454ae19bff5b02ed881a2b1 -https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.3-hd590300_0.conda#7bb88ce04c8deb9f7d763ae04a1da72f +https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.4-hd590300_0.conda#412ba6938c3e2abaca8b1129ea82e238 https://conda.anaconda.org/conda-forge/linux-64/pixman-0.42.2-h59595ed_0.conda#700edd63ccd5fc66b70b1c028cea9a68 https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2#22dad4df6e8630e8dff2428f6f6a7036 https://conda.anaconda.org/conda-forge/linux-64/xorg-kbproto-1.0.7-h7f98852_1002.tar.bz2#4b230e8381279d76131116660f5a241a @@ -62,32 +62,32 @@ https://conda.anaconda.org/conda-forge/linux-64/libcap-2.69-h0f662aa_0.conda#25c https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20191231-he28a2e2_2.tar.bz2#4d331e44109e3f0e19b4cb8f9b82f3e1 https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda#a1cfcc585f0c42bf8d5546bb1dfb668d https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.3-h59595ed_0.conda#ee48bf17cc83a00f59ca1494d5646869 -https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_2.conda#e75a75a6eaf6f318dae2631158c46575 +https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_3.conda#73031c79546ad06f1fe62e57fdd021bc https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.47-h71f35ed_0.conda#c2097d0b46367996f09b4e8e4920384a https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.39-h753d276_0.conda#e1c890aebdebbfbf87e2c917187b4416 -https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.43.2-h2797004_0.conda#4b441a1ee22397d5a27dc1126b849edd +https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.0-h2797004_0.conda#b58e6816d137f3aabf77d341dd5d732b https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2#309dec04b70a3cc0f1e84a4013683bc0 https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda#33277193f5b92bad9fdd230eb700929c -https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.5-h232c23b_1.conda#f3858448893839820d4bcfb14ad3ecdf -https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_5.conda#1e8ef4090ca4f0d66404a7441e1dbf3c -https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.40-hc3806b6_0.tar.bz2#69e2c796349cd9b273890bee0febfe1b +https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.6-h232c23b_0.conda#427a3e59d66cb5d145020bd9c6493334 +https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_6.conda#80bf3b277c120dd294b51d404b931a75 +https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.42-hcad00b1_0.conda#679c8961826aa4b50653bce17ee52abe https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4 -https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-h2797004_0.conda#513336054f884f95d9fd925748f41ef3 +https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda#d453b98d9c83e71da0741bb0ff4d76bc https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda#93ee23f12bc2e684548181256edd2cf6 https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda#68c34ec6149623be41a1933ab996a209 https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda#04b88013080254850d6c01ed54810589 https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.1.0-hd590300_1.conda#39f910d205726805a958da408ca194ba https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda#9ae35c3d96db2c94ce0cef86efdfa2cb https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 -https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.1-h166bdaf_0.tar.bz2#f967fc95089cd247ceed56eda31de3a9 -https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.0-hebfc3b9_0.conda#e618003da3547216310088478e475945 +https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.2-hd590300_0.conda#3d7d5e5cebf8af5aadb040732860f1b6 +https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.1-h783c2da_1.conda#70052d6c1e84643e30ffefb21ab6950f https://conda.anaconda.org/conda-forge/linux-64/libhiredis-1.0.2-h2cc385e_0.tar.bz2#b34907d3a81a3cd8095ee83d174c074a https://conda.anaconda.org/conda-forge/linux-64/libhwloc-2.9.3-default_h554bfaf_1009.conda#f36ddc11ca46958197a45effdd286e45 https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-h5cf9203_3.conda#9efe82d44b76a7529a1d702e5a37752e https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda#ef1910918dd895516a769ed36b5b3a4e https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda#55ed21669b2015f77c180feb1dd41930 -https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.3-h4dfa4b3_0.conda#1a82298c57b609a31ab6f2342a307b69 -https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_5.conda#b72f016c910ff9295b1377d3e17da3f2 +https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.5-h4dfa4b3_0.conda#799291c22ec87a0c86c0a4fc0e22b1c5 +https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_6.conda#e87530d1b12dd7f4e0f856dc07358d60 https://conda.anaconda.org/conda-forge/linux-64/nss-3.94-h1d7d5a4_0.conda#7caef74bbfa730e014b20f0852068509 https://conda.anaconda.org/conda-forge/linux-64/python-3.12.0-hab00c5b_0_cpython.conda#7f97faab5bebcc2580f4f299285323da https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-hd590300_1.conda#9bfac7ccd94d54fd21a0501296d60424 @@ -97,28 +97,30 @@ https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.1-h8ee46fc_1.con https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.conda#49e482d882669206653b095f5206c05b https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hd590300_1.conda#f27a24d46e3ea7b70a1f98e50c62508f https://conda.anaconda.org/conda-forge/linux-64/ccache-4.8.1-h1fcd64f_0.conda#fd37a0c47d8b3667b73af0549037ce83 -https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda#7f3dbc9179b4dde7da98dfb151d0ad22 +https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789 https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.4-py312h30efb56_0.conda#1630f67636e5277b195c845a0ead7602 +https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.5-py312h30efb56_0.conda#d89575b5415d88a9bf5ba7cb54aa41c7 https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2#ecfff944ba3960ecb334b9a2663d708d -https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.1.3-pyhd8ed1ab_0.conda#e6518222753f519e911e83136d2158d9 +https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363 https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda#0f69b688f52ff6da70bccb7ff7001d1d -https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.0-hfc55251_0.conda#e10134de3558dd95abda6987b5548f4f +https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.1-hfc55251_1.conda#a50918d10114a0bf80fb46c7cc692058 https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py312h8572e83_1.conda#c1e71f2bc05d8e8e033aefac2c490d05 https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.15-hb7c19ff_3.conda#e96637dd92c5f340215c753a5c9a22d7 https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_h9986a30_3.conda#1720df000b48e31842500323cb7be18c https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h4637d8d_4.conda#d4529f4dff3057982a7617c7ac58fde3 -https://conda.anaconda.org/conda-forge/linux-64/libpq-16.0-hfc447b1_1.conda#e4a9a5ba40123477db33e02a78dffb01 +https://conda.anaconda.org/conda-forge/linux-64/libpq-16.1-hfc447b1_0.conda#2b7f1893cf40b4ccdc0230bcd94d5ed9 https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-254-h3516f8a_0.conda#df4b1cd0c91b4234fb02b5701a4cdddc +https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.0-pyhd8ed1ab_0.tar.bz2#f8dab71fdc13b1bf29a01248b156d268 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.0-h488ebb8_3.conda#128c25b7fe6a25286a48f3a6a9b5b6f3 https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.3.0-pyhd8ed1ab_0.conda#2390bd10bed1f3fdc7a537fb5a447d8d https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2#7205635cd71531943440fbfe3b6b5727 https://conda.anaconda.org/conda-forge/noarch/py-1.11.0-pyh6c4a22f_0.tar.bz2#b4613d7e7a493916d867842a6a148054 +https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2023.3-pyhd8ed1ab_0.conda#2590495f608a63625e165915fb4e2e34 https://conda.anaconda.org/conda-forge/noarch/pytz-2023.3.post1-pyhd8ed1ab_0.conda#c93346b446cd08c169d843ae5fc0da97 @@ -129,41 +131,44 @@ https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.2.0-pyha21a80b_0.c https://conda.anaconda.org/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2#f832c45a477c78bebd107098db465095 https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5844808ffab9ebdb694585b50ba02a96 https://conda.anaconda.org/conda-forge/linux-64/tornado-6.3.3-py312h98912ed_1.conda#5bd63a3bf512694536cee3e48463a47c +https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.8.0-pyha770c72_0.conda#5b1be40a26d10a06f6d4f1f9e19fa0c7 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-image-0.4.0-h8ee46fc_1.conda#9d7bcddf49cbf727730af10e71022c73 https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.40-hd590300_0.conda#07c15d846a2e4d673da22cbd85fdb6d2 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.4-h0b41bf4_2.conda#82b6df12252e6f32402b96dacc656fec https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.11-hd590300_0.conda#ed67c36f215b310412b2af935bf3e530 https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda#f907bb958910dc404647326ca80c263e -https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.43.1-py312h98912ed_0.conda#80e254c2ceb3e171895efcbefd85db38 -https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.0-hfc55251_0.conda#2f55a36b549f51a7e0c2b1e3c3f0ccd4 +https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.45.0-py312h98912ed_0.conda#8842dbf1d1956a18ca8aeafb038cad88 +https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.1-hfc55251_1.conda#8d7242302bb3d03b9a690b6dda872603 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_h7634d5b_3.conda#0922208521c0463e690bbaebba7eb551 https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.6.0-h5d7e998_0.conda#d8edd0e29db6fb6b6988e1a28d35d994 -https://conda.anaconda.org/conda-forge/linux-64/mkl-2023.2.0-h84fe81f_50495.conda#f89293539ccddb83eed23b2b14b4d906 +https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda#93a8e71256479c62074356ef6ebf501b +https://conda.anaconda.org/conda-forge/linux-64/mkl-2023.2.0-h84fe81f_50496.conda#81d4a1a57d618adf0152db973d93b2ad https://conda.anaconda.org/conda-forge/linux-64/pillow-10.1.0-py312hf3581a9_0.conda#c04d3de9d831a69a5fdfab1413ec2fb6 https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-hb77b528_5.conda#ac902ff3c1c6d750dd0dfc93a974ab74 -https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.2-pyhd8ed1ab_0.conda#6dd662ff5ac9a783e5c940ce9f3fe649 +https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.3-pyhd8ed1ab_0.conda#5bdca0aca30b0ee62bb84854e027eae0 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.12-py312h30efb56_0.conda#32633871002ee9902f747d2236e0d122 -https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.6-h98fc4e7_2.conda#1c95f7c612f9121353c4ef764678113e -https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.2.1-h3d44ed6_0.conda#98db5f8813f45e2b29766aff0e4a499c -https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-19_linux64_mkl.conda#ec166f71f3d4c92ef1a714717b9b22eb -https://conda.anaconda.org/conda-forge/linux-64/mkl-devel-2023.2.0-ha770c72_50495.conda#bceef28ba896a368dd10cc6216fea12f +https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.7-h98fc4e7_0.conda#6c919bafe5e03428a8e2ef319d7ef990 +https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.3.0-h3d44ed6_0.conda#5a6f6c00ef982a9bc83558d9ac8f64a0 +https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-20_linux64_mkl.conda#8bf521f6007b0b0eb91515a1165b5d85 +https://conda.anaconda.org/conda-forge/linux-64/mkl-devel-2023.2.0-ha770c72_50496.conda#3b4c50e31ff098b18a450e4f5f860adf https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.12.2-py312h30efb56_5.conda#8a2a122dc4fe14d8cff38f1cf426381f https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 -https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.6-h8e1006c_2.conda#3d8e98279bad55287f2ef9047996f33c -https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-19_linux64_mkl.conda#2468764de45bdcd1b2baf35a93312ca8 -https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-19_linux64_mkl.conda#aaa1703ee4c30735dbfeabc8287ce81e +https://conda.anaconda.org/conda-forge/noarch/rich-13.7.0-pyhd8ed1ab_0.conda#d7a11d4f3024b2f4a6e0ae7377dd61e9 +https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.7-h8e1006c_0.conda#065e2c1d49afa3fdc1a01f1dacd6ab09 +https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-20_linux64_mkl.conda#7a2972758a03adc92d856072c71c9170 +https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-20_linux64_mkl.conda#4db0cd03efcdab535f6f066aca4cddbb https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e -https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-19_linux64_mkl.conda#b595484789fe86969f0b5662c4473df5 +https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-20_linux64_mkl.conda#3dea5e9be386b963d7f4368966e238b3 https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.0-py312heda63a1_0.conda#9b4f35e4d83e2a8b17868b65beb438d9 https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h82b777d_17.conda#4f01e33dbb406085a16a2813ab067e95 -https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-19_linux64_mkl.conda#7a04ef5f2294b05fcece16e4a1f04d7a -https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.1.1-py312h8572e83_1.conda#111d116d3b299ff9a74c518127da8386 -https://conda.anaconda.org/conda-forge/linux-64/pandas-2.1.1-py312hfb8ada1_1.conda#77ebdc18c9a54c929a332d59965f197f +https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-20_linux64_mkl.conda#079d50df2338a3d47522d7e84c3dfbf6 +https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py312h8572e83_0.conda#b6249daaaf4577e6f72d95fc4ab767c6 +https://conda.anaconda.org/conda-forge/linux-64/pandas-2.1.3-py312hfb8ada1_0.conda#ef74af58f348d62a35c58e82aef5f868 https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.9-py312h949fe66_5.conda#f6548a564e2d01b2a42020259503945b https://conda.anaconda.org/conda-forge/linux-64/scipy-1.11.3-py312heda63a1_1.conda#c89108be4deb842ced096623aa932fd0 -https://conda.anaconda.org/conda-forge/linux-64/blas-2.119-mkl.conda#923c56d369fe68c2da0d554d3e5edc2c -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.0-py312he5832f3_2.conda#f8763aa397e9989dde2e3ff9d5be8f2a +https://conda.anaconda.org/conda-forge/linux-64/blas-2.120-mkl.conda#9444330235a4828878cbe9c897ba0aa3 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.2-py312he5832f3_0.conda#1bf345f8df6896b5a8016f16188946ba https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.0.1-py312hfb10629_1.conda#79ec33a3b3e9e6858e40e6f253b174ab -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.0-py312h7900ff3_2.conda#ec445ebb50fab44f00bf2d9ea98878e5 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.2-py312h7900ff3_0.conda#b409beb1dc6ebb34b767b7fb8fc70b9d diff --git a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock index c98cb2728e40f..9d23849f62048 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock @@ -1,9 +1,9 @@ # Generated by conda-lock. # platform: osx-64 -# input_hash: 02abef27514db5e5119c3cdc253e84a06374c1b308495298b46bdb14dcc52ae9 +# input_hash: 576ee8ecb3e0fc6bd4381c4c9defb95845bca7e35121f16b0253c2b57e36a252 @EXPLICIT -https://conda.anaconda.org/conda-forge/osx-64/bzip2-1.0.8-h0d85af4_4.tar.bz2#37edc4e6304ca87316e160f5ca0bd1b5 -https://conda.anaconda.org/conda-forge/osx-64/ca-certificates-2023.7.22-h8857fd0_0.conda#bf2c54c18997bf3542af074c10191771 +https://conda.anaconda.org/conda-forge/osx-64/bzip2-1.0.8-h10d778d_5.conda#6097a6ca9ada32699b5fc4312dd6ef18 +https://conda.anaconda.org/conda-forge/osx-64/ca-certificates-2023.11.17-h8857fd0_0.conda#c687e9d14c49e3d3946d50a413cdbf16 https://conda.anaconda.org/conda-forge/osx-64/icu-73.2-hf5e326d_0.conda#5cc301d759ec03f28328428e28f65591 https://conda.anaconda.org/conda-forge/osx-64/libbrotlicommon-1.1.0-h0dc2134_1.conda#9e6c31441c9aa24e41ace40d6151aab6 https://conda.anaconda.org/conda-forge/osx-64/libcxx-16.0.6-hd57cbcb_0.conda#7d6972792161077908b62971802f289a @@ -15,9 +15,8 @@ https://conda.anaconda.org/conda-forge/osx-64/libiconv-1.17-hac89ed1_0.tar.bz2#6 https://conda.anaconda.org/conda-forge/osx-64/libjpeg-turbo-3.0.0-h0dc2134_1.conda#72507f8e3961bc968af17435060b6dd6 https://conda.anaconda.org/conda-forge/osx-64/libwebp-base-1.3.2-h0dc2134_0.conda#4e7e9d244e87d66c18d36894fd6a8ae5 https://conda.anaconda.org/conda-forge/osx-64/libzlib-1.2.13-h8a1eda9_5.conda#4a3ad23f6e16f99c04e166767193d700 -https://conda.anaconda.org/conda-forge/osx-64/llvm-openmp-17.0.3-hb6ac08f_0.conda#b70adc70bc7527a207c81c2e6b43532c -https://conda.anaconda.org/conda-forge/osx-64/mkl-include-2023.2.0-h6bab518_50499.conda#0279312581ec159355692cf4dbd951c4 -https://conda.anaconda.org/conda-forge/osx-64/ncurses-6.4-hf0c8a7f_0.conda#c3dbae2411164d9b02c69090a9a91857 +https://conda.anaconda.org/conda-forge/osx-64/llvm-openmp-17.0.5-hb6ac08f_0.conda#8ca3784280b7cb54163a46e8a918fb43 +https://conda.anaconda.org/conda-forge/osx-64/mkl-include-2023.2.0-h6bab518_50500.conda#835abb8ded5e26f23ea6996259c7972e https://conda.anaconda.org/conda-forge/osx-64/pthread-stubs-0.4-hc929b4f_1001.tar.bz2#addd19059de62181cd11ae8f4ef26084 https://conda.anaconda.org/conda-forge/osx-64/python_abi-3.12-4_cp312.conda#87201ac4314b911b74197e588cca3639 https://conda.anaconda.org/conda-forge/osx-64/tbb-2021.10.0-h1c7c39f_2.conda#73434bcf87082942e938352afae9b0fa @@ -25,55 +24,66 @@ https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda#939e https://conda.anaconda.org/conda-forge/osx-64/xorg-libxau-1.0.11-h0dc2134_0.conda#9566b4c29274125b0266d0177b5eb97b https://conda.anaconda.org/conda-forge/osx-64/xorg-libxdmcp-1.1.3-h35c211d_0.tar.bz2#86ac76d6bf1cbb9621943eb3bd9ae36e https://conda.anaconda.org/conda-forge/osx-64/xz-5.2.6-h775f41a_0.tar.bz2#a72f9d4ea13d55d745ff1ed594747f10 -https://conda.anaconda.org/conda-forge/osx-64/gmp-6.2.1-h2e338ed_0.tar.bz2#dedc96914428dae572a39e69ee2a392f +https://conda.anaconda.org/conda-forge/osx-64/gmp-6.3.0-h93d8f39_0.conda#a4ffd4bfd88659cbecbd36b61594bf0d https://conda.anaconda.org/conda-forge/osx-64/isl-0.25-hb486fe8_0.tar.bz2#45a9a46c78c0ea5c275b535f7923bde3 https://conda.anaconda.org/conda-forge/osx-64/lerc-4.0.0-hb486fe8_0.tar.bz2#f9d6a4c82889d5ecedec1d90eb673c55 https://conda.anaconda.org/conda-forge/osx-64/libbrotlidec-1.1.0-h0dc2134_1.conda#9ee0bab91b2ca579e10353738be36063 https://conda.anaconda.org/conda-forge/osx-64/libbrotlienc-1.1.0-h0dc2134_1.conda#8a421fe09c6187f0eb5e2338a8a8be6d https://conda.anaconda.org/conda-forge/osx-64/libgfortran5-13.2.0-h2873a65_1.conda#3af564516b5163cd8cc08820413854bc https://conda.anaconda.org/conda-forge/osx-64/libpng-1.6.39-ha978bb4_0.conda#35e4928794c5391aec14ffdf1deaaee5 -https://conda.anaconda.org/conda-forge/osx-64/libsqlite-3.43.2-h92b6c6a_0.conda#61b88c5f99f1537ed30b34758bd54d54 +https://conda.anaconda.org/conda-forge/osx-64/libsqlite-3.44.0-h92b6c6a_0.conda#5dd5e957ebfee02720c30e0e2d127bbe https://conda.anaconda.org/conda-forge/osx-64/libxcb-1.15-hb7f2c08_0.conda#5513f57e0238c87c12dffedbcc9c1a4a -https://conda.anaconda.org/conda-forge/osx-64/libxml2-2.11.5-h3346baf_1.conda#7584dee6af7de378aed0ae49aebedb8a -https://conda.anaconda.org/conda-forge/osx-64/mkl-2023.2.0-hfe8bc8c_50499.conda#c0e68227c6292c21d84aee7f017a8e66 -https://conda.anaconda.org/conda-forge/osx-64/openssl-3.1.3-h8a1eda9_0.conda#26f9b58f905547e658e9587f8e8cfe43 -https://conda.anaconda.org/conda-forge/osx-64/readline-8.2-h9e318b2_1.conda#f17f77f2acf4d344734bda76829ce14e +https://conda.anaconda.org/conda-forge/osx-64/libxml2-2.11.6-hc0ae0f7_0.conda#2b6ec8c6366ea74db4b910469addad1d +https://conda.anaconda.org/conda-forge/osx-64/mkl-2023.2.0-h54c2260_50500.conda#0a342ccdc79e4fcd359245ac51941e7b +https://conda.anaconda.org/conda-forge/osx-64/ncurses-6.4-h93d8f39_2.conda#e58f366bd4d767e9ab97ab8b272e7670 +https://conda.anaconda.org/conda-forge/osx-64/openssl-3.1.4-hd75f5a5_0.conda#bc9201da6eb1e0df4107901df5371347 https://conda.anaconda.org/conda-forge/osx-64/tapi-1100.0.11-h9ce4665_0.tar.bz2#f9ff42ccf809a21ba6f8607f8de36108 -https://conda.anaconda.org/conda-forge/osx-64/tk-8.6.13-hef22860_0.conda#0c25eedcc888b6d765948ab62a18c03e +https://conda.anaconda.org/conda-forge/osx-64/tk-8.6.13-h1abcd95_1.conda#bf830ba5afc507c6232d4ef0fb1a882d https://conda.anaconda.org/conda-forge/osx-64/zlib-1.2.13-h8a1eda9_5.conda#75a8a98b1c4671c5d2897975731da42d https://conda.anaconda.org/conda-forge/osx-64/zstd-1.5.5-h829000d_0.conda#80abc41d0c48b82fe0f04e7f42f5cb7e https://conda.anaconda.org/conda-forge/osx-64/brotli-bin-1.1.0-h0dc2134_1.conda#ece565c215adcc47fc1db4e651ee094b https://conda.anaconda.org/conda-forge/osx-64/freetype-2.12.1-h60636b9_2.conda#25152fce119320c980e5470e64834b50 -https://conda.anaconda.org/conda-forge/osx-64/libblas-3.9.0-19_osx64_mkl.conda#d1f8be60e9f3e45c329099c897e4433a +https://conda.anaconda.org/conda-forge/osx-64/libblas-3.9.0-20_osx64_mkl.conda#160fdc97a51d66d51dc782fb67d35205 https://conda.anaconda.org/conda-forge/osx-64/libgfortran-5.0.0-13_2_0_h97931a8_1.conda#b55fd11ab6318a6e67ac191309701d5a https://conda.anaconda.org/conda-forge/osx-64/libllvm15-15.0.7-he4b1e75_3.conda#ecc6df80c4b0445ac0de9cabae189db3 https://conda.anaconda.org/conda-forge/osx-64/libtiff-4.6.0-h684deea_2.conda#2ca10a325063e000ad6d2a5900061e0d -https://conda.anaconda.org/conda-forge/osx-64/mkl-devel-2023.2.0-h694c41f_50499.conda#cb7ef3645802dd874e87c03150d1cc89 -https://conda.anaconda.org/conda-forge/osx-64/mpfr-4.2.0-h4f9bd69_0.conda#f48a2f4515be334c5cfeed82517b96e0 -https://conda.anaconda.org/conda-forge/osx-64/python-3.12.0-h30d4d87_0_cpython.conda#d11dc8f4551011fb6baa2865f1ead48f +https://conda.anaconda.org/conda-forge/osx-64/mkl-devel-2023.2.0-h694c41f_50500.conda#1b4d0235ef253a1e19459351badf4f9f +https://conda.anaconda.org/conda-forge/osx-64/mpfr-4.2.1-h0c69b56_0.conda#d545aecded064848432bc994075dfccf +https://conda.anaconda.org/conda-forge/osx-64/readline-8.2-h9e318b2_1.conda#f17f77f2acf4d344734bda76829ce14e https://conda.anaconda.org/conda-forge/osx-64/sigtool-0.1.3-h88f4db0_0.tar.bz2#fbfb84b9de9a6939cb165c02c69b1865 https://conda.anaconda.org/conda-forge/osx-64/brotli-1.1.0-h0dc2134_1.conda#9272dd3b19c4e8212f8542cefd5c3d67 -https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda#7f3dbc9179b4dde7da98dfb151d0ad22 -https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 -https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/osx-64/cython-3.0.4-py312h444b7ae_0.conda#43db6578d18532fb960309dde65a742b -https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.1.3-pyhd8ed1ab_0.conda#e6518222753f519e911e83136d2158d9 -https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 -https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 -https://conda.anaconda.org/conda-forge/osx-64/kiwisolver-1.4.5-py312h49ebfd2_1.conda#21f174a5cfb5964069c374171a979157 https://conda.anaconda.org/conda-forge/osx-64/lcms2-2.15-hd6ba6f3_3.conda#8059507d52f477fbd4b81841e085e25b https://conda.anaconda.org/conda-forge/osx-64/ld64_osx-64-609-h0fd476b_15.conda#f98d11f8e568521e1e3f88cbe5a4d53c -https://conda.anaconda.org/conda-forge/osx-64/libcblas-3.9.0-19_osx64_mkl.conda#7ffbeb8b259cf8dfb88e02339c2a329f +https://conda.anaconda.org/conda-forge/osx-64/libcblas-3.9.0-20_osx64_mkl.conda#51089a4865eb4aec2bc5c7468bd07f9f https://conda.anaconda.org/conda-forge/osx-64/libclang-cpp15-15.0.7-default_hdb78580_3.conda#73639154fe4a7ca500d1361eef58fb65 https://conda.anaconda.org/conda-forge/osx-64/libhiredis-1.0.2-h2beb688_0.tar.bz2#524282b2c46c9dedf051b3bc2ae05494 -https://conda.anaconda.org/conda-forge/osx-64/liblapack-3.9.0-19_osx64_mkl.conda#9e3476408faadc41bc4f5b33d1919e8a +https://conda.anaconda.org/conda-forge/osx-64/liblapack-3.9.0-20_osx64_mkl.conda#58f08e12ad487fac4a08f90ff0b87aec https://conda.anaconda.org/conda-forge/osx-64/llvm-tools-15.0.7-he4b1e75_3.conda#7177e9334a86af1b1581f14607ced61c https://conda.anaconda.org/conda-forge/osx-64/mpc-1.3.1-h81bd1dd_0.conda#c752c0eb6c250919559172c011e5f65b -https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 https://conda.anaconda.org/conda-forge/osx-64/openjpeg-2.5.0-ha4da562_3.conda#40a36f8e9a6fdf6a78c6428ee6c44188 +https://conda.anaconda.org/conda-forge/osx-64/python-3.12.0-h30d4d87_0_cpython.conda#d11dc8f4551011fb6baa2865f1ead48f +https://conda.anaconda.org/conda-forge/osx-64/ccache-4.8.1-h28e096f_0.conda#dcc8cc97fdab7a5fad9e1a6bbad9ed0e +https://conda.anaconda.org/conda-forge/osx-64/cctools_osx-64-973.0.1-habff3f6_15.conda#015a19b3900b9fce20a48551e66f7699 +https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789 +https://conda.anaconda.org/conda-forge/osx-64/clang-15-15.0.7-default_hdb78580_3.conda#688d6b9e178cb7786a07e3cfca2a8f09 +https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 +https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 +https://conda.anaconda.org/conda-forge/osx-64/cython-3.0.5-py312h444b7ae_0.conda#e81230af852d3e935cea94d66f5ea7a0 +https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363 +https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 +https://conda.anaconda.org/conda-forge/osx-64/gfortran_impl_osx-64-12.3.0-h54fd467_1.conda#5f4d40236e204c6e62cd0a316244f316 +https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 +https://conda.anaconda.org/conda-forge/osx-64/kiwisolver-1.4.5-py312h49ebfd2_1.conda#21f174a5cfb5964069c374171a979157 +https://conda.anaconda.org/conda-forge/osx-64/ld64-609-ha91a046_15.conda#3c27099076bcf8fdde53785efb46dbf1 +https://conda.anaconda.org/conda-forge/osx-64/liblapacke-3.9.0-20_osx64_mkl.conda#124ae8e384268a8da66f1d64114a1eda +https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.0-pyhd8ed1ab_0.tar.bz2#f8dab71fdc13b1bf29a01248b156d268 +https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 +https://conda.anaconda.org/conda-forge/osx-64/numpy-1.26.0-py312h5df92dc_0.conda#aec65f472e6633e23844b555bf640a3f https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 +https://conda.anaconda.org/conda-forge/osx-64/pillow-10.1.0-py312h0c70c2f_0.conda#50fc3446a464ff986aa4496e1eebf60b https://conda.anaconda.org/conda-forge/noarch/pluggy-1.3.0-pyhd8ed1ab_0.conda#2390bd10bed1f3fdc7a537fb5a447d8d https://conda.anaconda.org/conda-forge/noarch/py-1.11.0-pyh6c4a22f_0.tar.bz2#b4613d7e7a493916d867842a6a148054 +https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2023.3-pyhd8ed1ab_0.conda#2590495f608a63625e165915fb4e2e34 https://conda.anaconda.org/conda-forge/noarch/pytz-2023.3.post1-pyhd8ed1ab_0.conda#c93346b446cd08c169d843ae5fc0da97 @@ -83,42 +93,35 @@ https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.2.0-pyha21a80b_0.c https://conda.anaconda.org/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2#f832c45a477c78bebd107098db465095 https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5844808ffab9ebdb694585b50ba02a96 https://conda.anaconda.org/conda-forge/osx-64/tornado-6.3.3-py312h104f124_1.conda#6835d4940d6fbd41e1a32d58dfae8f06 -https://conda.anaconda.org/conda-forge/osx-64/ccache-4.8.1-h28e096f_0.conda#dcc8cc97fdab7a5fad9e1a6bbad9ed0e -https://conda.anaconda.org/conda-forge/osx-64/cctools_osx-64-973.0.1-habff3f6_15.conda#015a19b3900b9fce20a48551e66f7699 -https://conda.anaconda.org/conda-forge/osx-64/clang-15-15.0.7-default_hdb78580_3.conda#688d6b9e178cb7786a07e3cfca2a8f09 +https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.8.0-pyha770c72_0.conda#5b1be40a26d10a06f6d4f1f9e19fa0c7 +https://conda.anaconda.org/conda-forge/osx-64/blas-devel-3.9.0-20_osx64_mkl.conda#cc3260179093918b801e373c6e888e02 +https://conda.anaconda.org/conda-forge/osx-64/cctools-973.0.1-hd9ad811_15.conda#00374d616829d1e4d2266e571147848a +https://conda.anaconda.org/conda-forge/osx-64/clang-15.0.7-h694c41f_3.conda#8a48d466e519b8db7dda7c5d27cc1d31 +https://conda.anaconda.org/conda-forge/osx-64/contourpy-1.2.0-py312hbf0bb39_0.conda#74190e06053cda7139a0cb71f3e618fd https://conda.anaconda.org/conda-forge/osx-64/coverage-7.3.2-py312h104f124_0.conda#1e98139a6dc6e29569dff47a1895a40c -https://conda.anaconda.org/conda-forge/osx-64/fonttools-4.43.1-py312h41838bb_0.conda#f7aba11c6253c8069fc8eab0344b5100 -https://conda.anaconda.org/conda-forge/osx-64/gfortran_impl_osx-64-12.3.0-h54fd467_1.conda#5f4d40236e204c6e62cd0a316244f316 +https://conda.anaconda.org/conda-forge/osx-64/fonttools-4.45.0-py312h41838bb_0.conda#8af369cf887521ff6a068715aed7be15 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc -https://conda.anaconda.org/conda-forge/osx-64/ld64-609-ha91a046_15.conda#3c27099076bcf8fdde53785efb46dbf1 -https://conda.anaconda.org/conda-forge/osx-64/liblapacke-3.9.0-19_osx64_mkl.conda#943eeb9302476d460b4cbc9627cb7fa1 -https://conda.anaconda.org/conda-forge/osx-64/numpy-1.26.0-py312h5df92dc_0.conda#aec65f472e6633e23844b555bf640a3f -https://conda.anaconda.org/conda-forge/osx-64/pillow-10.1.0-py312h0c70c2f_0.conda#50fc3446a464ff986aa4496e1eebf60b -https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.2-pyhd8ed1ab_0.conda#6dd662ff5ac9a783e5c940ce9f3fe649 +https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda#93a8e71256479c62074356ef6ebf501b +https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.3-pyhd8ed1ab_0.conda#5bdca0aca30b0ee62bb84854e027eae0 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 -https://conda.anaconda.org/conda-forge/osx-64/blas-devel-3.9.0-19_osx64_mkl.conda#9b935938179cb8fdc559cb17b89371bc -https://conda.anaconda.org/conda-forge/osx-64/cctools-973.0.1-hd9ad811_15.conda#00374d616829d1e4d2266e571147848a -https://conda.anaconda.org/conda-forge/osx-64/clang-15.0.7-h694c41f_3.conda#8a48d466e519b8db7dda7c5d27cc1d31 -https://conda.anaconda.org/conda-forge/osx-64/contourpy-1.1.1-py312h49ebfd2_1.conda#04a065595201b53fa694da2caf8a791c -https://conda.anaconda.org/conda-forge/osx-64/pandas-2.1.1-py312hb3462e8_1.conda#760ed0214d47448a7e130a6a31a0ed1f -https://conda.anaconda.org/conda-forge/noarch/pytest-cov-4.1.0-pyhd8ed1ab_0.conda#06eb685a3a0b146347a58dda979485da -https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 https://conda.anaconda.org/conda-forge/osx-64/scipy-1.11.3-py312h2c2f0bb_1.conda#0671c0db4a12dd2c176db00022a223fb -https://conda.anaconda.org/conda-forge/osx-64/blas-2.119-mkl.conda#eabf5c1e4f1e9e2eee06e533618a09b0 +https://conda.anaconda.org/conda-forge/osx-64/blas-2.120-mkl.conda#b041a7677a412f3d925d8208936cb1e2 https://conda.anaconda.org/conda-forge/osx-64/clangxx-15.0.7-default_hdb78580_3.conda#58df9ff86fefc7684670be729b41412f -https://conda.anaconda.org/conda-forge/osx-64/matplotlib-base-3.8.0-py312h1fe5000_2.conda#843921e02619d401fb3f7ca6ca810ab7 +https://conda.anaconda.org/conda-forge/osx-64/matplotlib-base-3.8.2-py312h302682c_0.conda#6a3b7c29d663a9cda13afb8f2638cc46 +https://conda.anaconda.org/conda-forge/osx-64/pandas-2.1.3-py312haf8ecfc_0.conda#d96a4b2b3dc4ae11f7fc8b736a12c3fb https://conda.anaconda.org/conda-forge/osx-64/pyamg-5.0.1-py312h674694f_1.conda#e5b9c0f8b5c367467425ff34353ef761 -https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e +https://conda.anaconda.org/conda-forge/noarch/pytest-cov-4.1.0-pyhd8ed1ab_0.conda#06eb685a3a0b146347a58dda979485da +https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 +https://conda.anaconda.org/conda-forge/noarch/rich-13.7.0-pyhd8ed1ab_0.conda#d7a11d4f3024b2f4a6e0ae7377dd61e9 https://conda.anaconda.org/conda-forge/noarch/compiler-rt_osx-64-15.0.7-he1888fc_1.conda#e1f93ea86259a549f2dcbfd245bf0422 -https://conda.anaconda.org/conda-forge/osx-64/matplotlib-3.8.0-py312hb401068_2.conda#9848bbc974514541827e80edb9c184e2 +https://conda.anaconda.org/conda-forge/osx-64/matplotlib-3.8.2-py312hb401068_0.conda#926f479dcab7d6d26bba7fe39f67e3b2 +https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e https://conda.anaconda.org/conda-forge/osx-64/compiler-rt-15.0.7-he1888fc_1.conda#8ec296a4b097aeb2d85eafaf745c770a -https://conda.anaconda.org/conda-forge/osx-64/clang_impl_osx-64-15.0.7-h03d6864_5.conda#f80669234aa12eb3d1e85ebfaf490c25 -https://conda.anaconda.org/conda-forge/osx-64/clang_osx-64-15.0.7-hb91bd55_5.conda#83046189f3b7c66452e5fcb9e260a47a +https://conda.anaconda.org/conda-forge/osx-64/clang_osx-64-15.0.7-h03d6864_4.conda#c7d519705112545d61798032be53188c https://conda.anaconda.org/conda-forge/osx-64/c-compiler-1.6.0-h63c33a9_0.conda#d7f3b8d3a85b4e7eded31adb611bb665 -https://conda.anaconda.org/conda-forge/osx-64/clangxx_impl_osx-64-15.0.7-h3d2d1bf_5.conda#f3fc3518dd41811fc7cf4f60e31042b4 +https://conda.anaconda.org/conda-forge/osx-64/clangxx_osx-64-15.0.7-h3d2d1bf_4.conda#55c56aeb34837cd2b249cc210fe83044 https://conda.anaconda.org/conda-forge/osx-64/gfortran_osx-64-12.3.0-h18f7dce_1.conda#436af2384c47aedb94af78a128e174f1 -https://conda.anaconda.org/conda-forge/osx-64/clangxx_osx-64-15.0.7-hb91bd55_5.conda#79968c368f2dc38efc996f9d4a426a5e -https://conda.anaconda.org/conda-forge/osx-64/gfortran-12.3.0-h2c809b3_1.conda#c48adbaa8944234b80ef287c37e329b0 https://conda.anaconda.org/conda-forge/osx-64/cxx-compiler-1.6.0-h1c7c39f_0.conda#9adaf7c9d4e1e15e70a8dd46befbbab2 +https://conda.anaconda.org/conda-forge/osx-64/gfortran-12.3.0-h2c809b3_1.conda#c48adbaa8944234b80ef287c37e329b0 https://conda.anaconda.org/conda-forge/osx-64/fortran-compiler-1.6.0-h932d759_0.conda#d2bc049eae716dd6879079ddd209ffc3 https://conda.anaconda.org/conda-forge/osx-64/compilers-1.6.0-h694c41f_0.conda#d4c66ca84aa87a6c63f4c8a6498052d9 diff --git a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_environment.yml b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_environment.yml index 4ddb80c7cae3d..8bb38cb1a4567 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_environment.yml +++ b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_environment.yml @@ -13,6 +13,7 @@ dependencies: - threadpoolctl - matplotlib - pandas + - rich - pyamg - pytest - pytest-xdist=2.5.0 diff --git a/build_tools/azure/pylatest_conda_mkl_no_openmp_environment.yml b/build_tools/azure/pylatest_conda_mkl_no_openmp_environment.yml index 64a33fe7d7522..43ccc08a7dd84 100644 --- a/build_tools/azure/pylatest_conda_mkl_no_openmp_environment.yml +++ b/build_tools/azure/pylatest_conda_mkl_no_openmp_environment.yml @@ -13,6 +13,7 @@ dependencies: - threadpoolctl - matplotlib - pandas + - rich - pyamg - pytest - pytest-xdist=2.5.0 diff --git a/build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock b/build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock index cf75fc6e69cb1..0ff809481f3ea 100644 --- a/build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock +++ b/build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock @@ -1,6 +1,6 @@ # Generated by conda-lock. # platform: osx-64 -# input_hash: 03f7604aefb9752d2367c457bdf4e4923158be96db35ac0dd1d5dc60a9981cd1 +# input_hash: 3365b0ce9ef91b9728635b3ed85970db67f3cbe714516eb897b29e4c8a1bdecc @EXPLICIT https://repo.anaconda.com/pkgs/main/osx-64/blas-1.0-mkl.conda#cb2c87e85ac8e0ceae776d26d4214c8a https://repo.anaconda.com/pkgs/main/osx-64/bzip2-1.0.8-h1de35cc_0.conda#19fcb113b170fe2a0be96b47801fed7d @@ -18,21 +18,21 @@ https://repo.anaconda.com/pkgs/main/noarch/tzdata-2023c-h04d1e81_0.conda#29db02a https://repo.anaconda.com/pkgs/main/osx-64/xz-5.4.2-h6c40b1e_0.conda#5e546d3c9765b4441e511804d58f6e3f https://repo.anaconda.com/pkgs/main/osx-64/zlib-1.2.13-h4dc903c_0.conda#d0202dd912bfb45d3422786531717882 https://repo.anaconda.com/pkgs/main/osx-64/ccache-3.7.9-hf120daa_0.conda#a01515a32e721c51d631283f991bc8ea -https://repo.anaconda.com/pkgs/main/osx-64/intel-openmp-2023.1.0-ha357a0b_43547.conda#aa6031369dd8c8cc6b2f393a0b2d9f0c +https://repo.anaconda.com/pkgs/main/osx-64/intel-openmp-2023.1.0-ha357a0b_43548.conda#ba8a89ffe593eb88e4c01334753c40c3 https://repo.anaconda.com/pkgs/main/osx-64/lerc-3.0-he9d5cce_0.conda#aec2c3dbef836849c9260f05be04f3db https://repo.anaconda.com/pkgs/main/osx-64/libbrotlidec-1.0.9-hca72f7f_7.conda#b85983951745cc666d9a1b42894210b2 https://repo.anaconda.com/pkgs/main/osx-64/libbrotlienc-1.0.9-hca72f7f_7.conda#e306d7a1599202a7c95762443f110832 https://repo.anaconda.com/pkgs/main/osx-64/libgfortran5-11.3.0-h9dfd629_28.conda#1fa1a27ee100b1918c3021dbfa3895a3 https://repo.anaconda.com/pkgs/main/osx-64/libpng-1.6.39-h6c40b1e_0.conda#a3c824835f53ad27aeb86d2b55e47804 https://repo.anaconda.com/pkgs/main/osx-64/lz4-c-1.9.4-hcec6c5f_0.conda#44291e9e6920cfff30caf1299f48db38 -https://repo.anaconda.com/pkgs/main/osx-64/openssl-3.0.11-hca72f7f_2.conda#c418222bdf3216cccc95ddc45a9e2b61 +https://repo.anaconda.com/pkgs/main/osx-64/openssl-3.0.12-hca72f7f_0.conda#0c97cb1bc867408ada7ea18e440ea3c8 https://repo.anaconda.com/pkgs/main/osx-64/readline-8.2-hca72f7f_0.conda#971667436260e523f6f7355fdfa238bf https://repo.anaconda.com/pkgs/main/osx-64/tbb-2021.8.0-ha357a0b_0.conda#fb48530a3eea681c11dafb95b3387c0f https://repo.anaconda.com/pkgs/main/osx-64/tk-8.6.12-h5d9f67b_0.conda#047f0af5486d19163e37fd7f8ae3d29f https://repo.anaconda.com/pkgs/main/osx-64/brotli-bin-1.0.9-hca72f7f_7.conda#110bdca1a20710820e61f7fa3047f737 https://repo.anaconda.com/pkgs/main/osx-64/freetype-2.12.1-hd8bbffd_0.conda#1f276af321375ee7fe8056843044fa76 https://repo.anaconda.com/pkgs/main/osx-64/libgfortran-5.0.0-11_3_0_hecd8cb5_28.conda#2eb13b680803f1064e53873ae0aaafb3 -https://repo.anaconda.com/pkgs/main/osx-64/mkl-2023.1.0-h8e150cf_43559.conda#f5a09d45a003f817d5c43935e20ca0c8 +https://repo.anaconda.com/pkgs/main/osx-64/mkl-2023.1.0-h8e150cf_43560.conda#85d0f3431dd5c6ae44f8725fdd3d3e59 https://repo.anaconda.com/pkgs/main/osx-64/sqlite-3.41.2-h6c40b1e_0.conda#6947a501943529c7536b7e4ba53802c1 https://repo.anaconda.com/pkgs/main/osx-64/zstd-1.5.5-hc035e20_0.conda#5e0b7ddb1b7dc6b630e1f9a03499c19c https://repo.anaconda.com/pkgs/main/osx-64/brotli-1.0.9-hca72f7f_7.conda#68e54d12ec67591deb2ffd70348fb00f @@ -47,12 +47,14 @@ https://repo.anaconda.com/pkgs/main/osx-64/joblib-1.2.0-py311hecd8cb5_0.conda#af https://repo.anaconda.com/pkgs/main/osx-64/kiwisolver-1.4.4-py311hcec6c5f_0.conda#f2cf31e2a762f071fd6bc4d74ea2bfc8 https://repo.anaconda.com/pkgs/main/osx-64/lcms2-2.12-hf1fd2bf_0.conda#697aba7a3308226df7a93ccfeae16ffa https://repo.anaconda.com/pkgs/main/osx-64/libwebp-1.3.2-hf6ce154_0.conda#91790dd6960d374adbdfda5ddaa44b4b +https://repo.anaconda.com/pkgs/main/osx-64/mdurl-0.1.0-py311hecd8cb5_0.conda#18c799813d3fdd2ea68c807d54038820 https://repo.anaconda.com/pkgs/main/osx-64/mkl-service-2.4.0-py311h6c40b1e_1.conda#f709b80c57a0fcc577319920d1b7228b https://repo.anaconda.com/pkgs/main/noarch/munkres-1.1.4-py_0.conda#148362ba07f92abab76999a680c80084 https://repo.anaconda.com/pkgs/main/osx-64/openjpeg-2.4.0-h66ea3da_0.conda#882833bd7befc5e60e6fba9c518c1b79 https://repo.anaconda.com/pkgs/main/osx-64/packaging-23.1-py311hecd8cb5_0.conda#4f5c491cd2de9d61f61c0ea3340ab46a https://repo.anaconda.com/pkgs/main/osx-64/pluggy-1.0.0-py311hecd8cb5_1.conda#98e4da64cd934965a0caf4136280ff35 https://repo.anaconda.com/pkgs/main/noarch/py-1.11.0-pyhd3eb1b0_0.conda#7205a898ed2abbf6e9b903dff6abe08e +https://repo.anaconda.com/pkgs/main/osx-64/pygments-2.15.1-py311hecd8cb5_1.conda#9e0f1e7667af6f469dfba22aa87dc6e2 https://repo.anaconda.com/pkgs/main/osx-64/pyparsing-3.0.9-py311hecd8cb5_0.conda#a4262f849ecc82af69f58da0cbcaaf04 https://repo.anaconda.com/pkgs/main/noarch/python-tzdata-2023.3-pyhd3eb1b0_0.conda#479c037de0186d114b9911158427624e https://repo.anaconda.com/pkgs/main/osx-64/pytz-2023.3.post1-py311hecd8cb5_0.conda#32d107281d133e3935dfb6935153e438 @@ -62,17 +64,19 @@ https://repo.anaconda.com/pkgs/main/noarch/threadpoolctl-2.2.0-pyh0d69192_0.cond https://repo.anaconda.com/pkgs/main/noarch/toml-0.10.2-pyhd3eb1b0_0.conda#cda05f5f6d8509529d1a2743288d197a https://repo.anaconda.com/pkgs/main/osx-64/tornado-6.3.3-py311h6c40b1e_0.conda#e98809ea222b3da0ebeae40bc73dfdb0 https://repo.anaconda.com/pkgs/main/noarch/fonttools-4.25.0-pyhd3eb1b0_0.conda#bb9c5b5a6d892fca5efe4bf0203b6a48 +https://repo.anaconda.com/pkgs/main/osx-64/markdown-it-py-2.2.0-py311hecd8cb5_1.conda#fcb5ce9cc6f0157d39022a57d4319729 https://repo.anaconda.com/pkgs/main/osx-64/numpy-base-1.24.3-py311h53bf9ac_1.conda#1b1957e3823208a006d0699999335c7d https://repo.anaconda.com/pkgs/main/osx-64/pillow-10.0.1-py311h7d39338_0.conda#0caf29bc5e73a3f5ca2f299c0b50a404 https://repo.anaconda.com/pkgs/main/osx-64/pytest-7.4.0-py311hecd8cb5_0.conda#8c5496a4a1f36160ac5556495faa4a24 https://repo.anaconda.com/pkgs/main/noarch/python-dateutil-2.8.2-pyhd3eb1b0_0.conda#211ee00320b08a1ac9fea6677649f6c9 -https://repo.anaconda.com/pkgs/main/osx-64/pytest-cov-4.1.0-py311hecd8cb5_0.conda#61d021237f3fce3f15344ddfb8ca2fea -https://repo.anaconda.com/pkgs/main/noarch/pytest-forked-1.3.0-pyhd3eb1b0_0.tar.bz2#07970bffdc78f417d7f8f1c7e620f5c4 +https://repo.anaconda.com/pkgs/main/osx-64/pytest-cov-4.1.0-py311hecd8cb5_1.conda#b1e41a8eda3f119b39b13f3a4d0c5bf5 +https://repo.anaconda.com/pkgs/main/osx-64/pytest-forked-1.6.0-py311hecd8cb5_0.conda#b1154a9887bee381b3405ec37f8b13f3 +https://repo.anaconda.com/pkgs/main/osx-64/rich-13.3.5-py311hecd8cb5_0.conda#6360bb4bc83d7dfe34cd5c04e25690fa https://repo.anaconda.com/pkgs/main/noarch/pytest-xdist-2.5.0-pyhd3eb1b0_0.conda#d15cdc4207bcf8ca920822597f1d138d https://repo.anaconda.com/pkgs/main/osx-64/bottleneck-1.3.5-py311hb9e55a9_0.conda#5aa1b58b421d4608b16184f8468253ef -https://repo.anaconda.com/pkgs/main/osx-64/contourpy-1.0.5-py311ha357a0b_0.conda#a130f83ba4b5d008e0c134c73e10b8fb -https://repo.anaconda.com/pkgs/main/osx-64/matplotlib-3.7.2-py311hecd8cb5_0.conda#fb32402e331974eb0d7117a597be3702 -https://repo.anaconda.com/pkgs/main/osx-64/matplotlib-base-3.7.2-py311h8251f7d_0.conda#d2d777472551de1f9fcac29db016d17e +https://repo.anaconda.com/pkgs/main/osx-64/contourpy-1.2.0-py311ha357a0b_0.conda#c9189b40e5b4be360aef22be336a4838 +https://repo.anaconda.com/pkgs/main/osx-64/matplotlib-3.8.0-py311hecd8cb5_0.conda#f720f09a9d1bb976aa92a13180cf7133 +https://repo.anaconda.com/pkgs/main/osx-64/matplotlib-base-3.8.0-py311h41a4f6b_0.conda#da5175158820055096f25520004fb9b3 https://repo.anaconda.com/pkgs/main/osx-64/mkl_fft-1.3.8-py311h6c40b1e_0.conda#7e70133e3cf6151d2826da7ae3af609f https://repo.anaconda.com/pkgs/main/osx-64/mkl_random-1.2.4-py311ha357a0b_0.conda#b363dccbb0219bb2f810a05b9bde92fb https://repo.anaconda.com/pkgs/main/osx-64/numpy-1.24.3-py311h728a8a3_1.conda#68069c79ebb0cdd2561026a909a57183 diff --git a/build_tools/azure/pylatest_pip_openblas_pandas_environment.yml b/build_tools/azure/pylatest_pip_openblas_pandas_environment.yml index ddbc75c1d9110..8872b7f0207bd 100644 --- a/build_tools/azure/pylatest_pip_openblas_pandas_environment.yml +++ b/build_tools/azure/pylatest_pip_openblas_pandas_environment.yml @@ -15,6 +15,7 @@ dependencies: - threadpoolctl - matplotlib - pandas + - rich - pyamg - pytest - pytest-xdist==2.5.0 diff --git a/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock b/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock index 734156f07c01f..c125bba517054 100644 --- a/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock +++ b/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock @@ -1,6 +1,6 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: d01d23bd27bcd50d2b3643492f966c8e390822d72b69f31bf66c2fe98a265a4c +# input_hash: 28b441bc163dfcfe14d91eab7f60d6797ff276c7eb53721b4b45c60cb306ef3c @EXPLICIT https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda#c3473ff8bdb3d124ed5ff11ec380d6f9 https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2023.08.22-h06a4308_0.conda#243d5065a09a3e85ab888c05f5b6445a @@ -12,7 +12,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/_openmp_mutex-5.1-1_gnu.conda#71d28 https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-11.2.0-h1234567_1.conda#a87728dabf3151fb9cfa990bd2eb0464 https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.4.4-h6a678d5_0.conda#06e288f9250abef59b9a367d151fc339 https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.4-h6a678d5_0.conda#5558eec6e2191741a92f832ea826251c -https://repo.anaconda.com/pkgs/main/linux-64/openssl-3.0.11-h7f8727e_2.conda#6cad6f2dcde73f8625d729c6db1272d0 +https://repo.anaconda.com/pkgs/main/linux-64/openssl-3.0.12-h7f8727e_0.conda#48caaebab690276acf1bc1f3b56febf4 https://repo.anaconda.com/pkgs/main/linux-64/xz-5.4.2-h5eee18b_0.conda#bcd31de48a0dcb44bc5b99675800c5cc https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.13-h5eee18b_0.conda#333e31fbfbb5057c92fa845ad6adef93 https://repo.anaconda.com/pkgs/main/linux-64/ccache-3.7.9-hfe4627d_0.conda#bef6fc681c273bb7bd0c67d1a591365e @@ -24,15 +24,15 @@ https://repo.anaconda.com/pkgs/main/linux-64/setuptools-68.0.0-py39h06a4308_0.co https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.41.2-py39h06a4308_0.conda#ec1b8213c3585defaa6042ed2f95861d https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3-py39h06a4308_0.conda#25664fd7e573de5af347e27cdb036023 # pip alabaster @ https://files.pythonhosted.org/packages/64/88/c7083fc61120ab661c5d0b82cb77079fc1429d3f913a456c1c82cf4658f7/alabaster-0.7.13-py3-none-any.whl#sha256=1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3 -# pip babel @ https://files.pythonhosted.org/packages/ff/37/b0241795c3a320a3def948cd0d06daf70310e7fea1d8fda312629bc22ea9/Babel-2.13.0-py3-none-any.whl#sha256=fbfcae1575ff78e26c7449136f1abbefc3c13ce542eeb13d43d50d8b047216ec -# pip certifi @ https://files.pythonhosted.org/packages/4c/dd/2234eab22353ffc7d94e8d13177aaa050113286e93e7b40eae01fbf7c3d9/certifi-2023.7.22-py3-none-any.whl#sha256=92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 -# pip charset-normalizer @ https://files.pythonhosted.org/packages/9e/45/824835a9c165eae015eb7b4a875a581918b9fc96439f8d9a5ca0868f0b7d/charset_normalizer-3.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=619d1c96099be5823db34fe89e2582b336b5b074a7f47f819d6b3a57ff7bdb86 +# pip babel @ https://files.pythonhosted.org/packages/86/14/5dc2eb02b7cc87b2f95930310a2cc5229198414919a116b564832c747bc1/Babel-2.13.1-py3-none-any.whl#sha256=7077a4984b02b6727ac10f1f7294484f737443d7e2e66c5e4380e41a3ae0b4ed +# pip certifi @ https://files.pythonhosted.org/packages/64/62/428ef076be88fa93716b576e4a01f919d25968913e817077a386fcbe4f42/certifi-2023.11.17-py3-none-any.whl#sha256=e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474 +# pip charset-normalizer @ https://files.pythonhosted.org/packages/98/69/5d8751b4b670d623aa7a47bef061d69c279e9f922f6705147983aa76c3ce/charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 # pip cycler @ https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl#sha256=85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 -# pip cython @ https://files.pythonhosted.org/packages/f8/7c/a94e50c8bf3cbfd0209319264c535c9cdf65e4d015954d6cb4efcd21fe09/Cython-3.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=36299ffd5663203c25d3a76980f077e23b6d4f574d142f0f43943f57be445639 +# pip cython @ https://files.pythonhosted.org/packages/40/19/46acdb9d0cad7cc2a80f721d6132dc50f7864d1598d9298e83b78c7d6bf9/Cython-3.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=db21997270e943aee9cb7694112d24a4702fbe1977fbe53b3cb4db3d02be73d9 # pip docutils @ https://files.pythonhosted.org/packages/26/87/f238c0670b94533ac0353a4e2a1a771a0cc73277b88bff23d3ae35a256c1/docutils-0.20.1-py3-none-any.whl#sha256=96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6 -# pip exceptiongroup @ https://files.pythonhosted.org/packages/ad/83/b71e58666f156a39fb29417e4c8ca4bc7400c0dd4ed9e8842ab54dc8c344/exceptiongroup-1.1.3-py3-none-any.whl#sha256=343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3 +# pip exceptiongroup @ https://files.pythonhosted.org/packages/b8/9a/5028fd52db10e600f1c4674441b968cf2ea4959085bfb5b99fb1250e5f68/exceptiongroup-1.2.0-py3-none-any.whl#sha256=4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14 # pip execnet @ https://files.pythonhosted.org/packages/e8/9c/a079946da30fac4924d92dbc617e5367d454954494cf1e71567bcc4e00ee/execnet-2.0.2-py3-none-any.whl#sha256=88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41 -# pip fonttools @ https://files.pythonhosted.org/packages/06/19/05e6d60206300d030948c2e09c28ad7f6e3c6e2299b9a5beb01261b38f0a/fonttools-4.43.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=2062542a7565091cea4cc14dd99feff473268b5b8afdee564f7067dd9fff5860 +# pip fonttools @ https://files.pythonhosted.org/packages/a0/bb/27f4c18f0830c0d79783a04bac121f519718e6a468a879c5d94d51271423/fonttools-4.45.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=dc991712aaea9d545b13ec480aaf2ebd12ccdea180fce864dd9863f5134f5a06 # pip idna @ https://files.pythonhosted.org/packages/fc/34/3030de6f1370931b9dbb4dad48f6ab1015ab1d32447850b9fc94e60097be/idna-3.4-py3-none-any.whl#sha256=90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 # pip imagesize @ https://files.pythonhosted.org/packages/ff/62/85c4c919272577931d407be5ba5d71c20f0b616d31a0befe0ae45bb79abd/imagesize-1.4.1-py2.py3-none-any.whl#sha256=0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b # pip iniconfig @ https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl#sha256=b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 @@ -40,13 +40,14 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3-py39h06a4308_0.conda#25664 # pip kiwisolver @ https://files.pythonhosted.org/packages/c0/a8/841594f11d0b88d8aeb26991bc4dac38baa909dc58d0c4262a4f7893bcbf/kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl#sha256=6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff # pip lazy-loader @ https://files.pythonhosted.org/packages/a1/c3/65b3814e155836acacf720e5be3b5757130346670ac454fee29d3eda1381/lazy_loader-0.3-py3-none-any.whl#sha256=1e9e76ee8631e264c62ce10006718e80b2cfc74340d17d1031e0f84af7478554 # pip markupsafe @ https://files.pythonhosted.org/packages/de/63/cb7e71984e9159ec5f45b5e81e896c8bdd0e45fe3fc6ce02ab497f0d790e/MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e -# pip networkx @ https://files.pythonhosted.org/packages/f6/eb/5585c96636bbb2755865c31d83a19dd220ef88e716df4659dacb86e009cc/networkx-3.2-py3-none-any.whl#sha256=8b25f564bd28f94ac821c58b04ae1a3109e73b001a7d476e4bb0d00d63706bf8 -# pip numpy @ https://files.pythonhosted.org/packages/89/ac/53100546dcd9aa400a73c7770b13cad9a3b18bf83433499e36b5efe9850f/numpy-1.26.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=a5b411040beead47a228bde3b2241100454a6abde9df139ed087bd73fc0a4908 +# pip mdurl @ https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl#sha256=84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 +# pip networkx @ https://files.pythonhosted.org/packages/d5/f0/8fbc882ca80cf077f1b246c0e3c3465f7f415439bdea6b899f6b19f61f70/networkx-3.2.1-py3-none-any.whl#sha256=f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2 +# pip numpy @ https://files.pythonhosted.org/packages/2f/75/f007cc0e6a373207818bef17f463d3305e9dd380a70db0e523e7660bf21f/numpy-1.26.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=baf8aab04a2c0e859da118f0b38617e5ee65d75b83795055fb66c0d5e9e9b818 # pip packaging @ https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl#sha256=8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 # pip pillow @ https://files.pythonhosted.org/packages/9f/3a/ada56d489446dbb7679d242bfd7bb159cee8a7989c34dd34045103d5280d/Pillow-10.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=1a8413794b4ad9719346cd9306118450b7b00d9a15846451549314a58ac42219 # pip pluggy @ https://files.pythonhosted.org/packages/05/b8/42ed91898d4784546c5f06c60506400548db3f7a4b3fb441cba4e5c17952/pluggy-1.3.0-py3-none-any.whl#sha256=d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7 # pip py @ https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl#sha256=607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378 -# pip pygments @ https://files.pythonhosted.org/packages/43/88/29adf0b44ba6ac85045e63734ae0997d3c58d8b1a91c914d240828d0d73d/Pygments-2.16.1-py3-none-any.whl#sha256=13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 +# pip pygments @ https://files.pythonhosted.org/packages/97/9c/372fef8377a6e340b1704768d20daaded98bf13282b5327beb2e2fe2c7ef/pygments-2.17.2-py3-none-any.whl#sha256=b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c # pip pyparsing @ https://files.pythonhosted.org/packages/39/92/8486ede85fcc088f1b3dba4ce92dd29d126fd96b0008ea213167940a2475/pyparsing-3.1.1-py3-none-any.whl#sha256=32c7c0b711493c72ff18a981d24f28aaf9c1fb7ed5e9667c9e84e3db623bdbfb # pip pytz @ https://files.pythonhosted.org/packages/32/4d/aaf7eff5deb402fd9a24a1449a8119f00d74ae9c2efa79f8ef9994261fc2/pytz-2023.3.post1-py2.py3-none-any.whl#sha256=ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7 # pip six @ https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl#sha256=8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -55,28 +56,28 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3-py39h06a4308_0.conda#25664 # pip tabulate @ https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl#sha256=024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f # pip threadpoolctl @ https://files.pythonhosted.org/packages/81/12/fd4dea011af9d69e1cad05c75f3f7202cdcbeac9b712eea58ca779a72865/threadpoolctl-3.2.0-py3-none-any.whl#sha256=2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032 # pip tomli @ https://files.pythonhosted.org/packages/97/75/10a9ebee3fd790d20926a90a2547f0bf78f371b2f13aa822c759680ca7b9/tomli-2.0.1-py3-none-any.whl#sha256=939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc -# pip typing-extensions @ https://files.pythonhosted.org/packages/24/21/7d397a4b7934ff4028987914ac1044d3b7d52712f30e2ac7a2ae5bc86dd0/typing_extensions-4.8.0-py3-none-any.whl#sha256=8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 # pip tzdata @ https://files.pythonhosted.org/packages/d5/fb/a79efcab32b8a1f1ddca7f35109a50e4a80d42ac1c9187ab46522b2407d7/tzdata-2023.3-py2.py3-none-any.whl#sha256=7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda -# pip urllib3 @ https://files.pythonhosted.org/packages/d2/b2/b157855192a68541a91ba7b2bbcb91f1b4faa51f8bae38d8005c034be524/urllib3-2.0.7-py3-none-any.whl#sha256=fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +# pip urllib3 @ https://files.pythonhosted.org/packages/96/94/c31f58c7a7f470d5665935262ebd7455c7e4c7782eb525658d3dbf4b9403/urllib3-2.1.0-py3-none-any.whl#sha256=55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3 # pip zipp @ https://files.pythonhosted.org/packages/d9/66/48866fc6b158c81cc2bfecc04c480f105c6040e8b077bc54c634b4a67926/zipp-3.17.0-py3-none-any.whl#sha256=0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31 -# pip contourpy @ https://files.pythonhosted.org/packages/2b/c0/24c34c41a180f875419b536125799c61e2330b997d77a5a818a3bc3e08cd/contourpy-1.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=efe0fab26d598e1ec07d72cf03eaeeba8e42b4ecf6b9ccb5a356fde60ff08b85 +# pip contourpy @ https://files.pythonhosted.org/packages/a9/ba/d8fd1380876f1e9114157606302e3644c85f6d116aeba354c212ee13edc7/contourpy-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=11f8d2554e52f459918f7b8e6aa20ec2a3bce35ce95c1f0ef4ba36fbda306df5 # pip coverage @ https://files.pythonhosted.org/packages/f1/e7/6d778d717d178c8c73103e2c467f3c8d8ebc9cacb825ebe3f3cf05e7c6df/coverage-7.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=149de1d2401ae4655c436a3dced6dd153f4c3309f599c3d4bd97ab172eaf02d9 -# pip imageio @ https://files.pythonhosted.org/packages/f6/37/e21e6f38b93878ba80302e95b8ccd4718d80f0c53055ccae343e606b1e2d/imageio-2.31.5-py3-none-any.whl#sha256=97f68e12ba676f2f4b541684ed81f7f3370dc347e8321bc68ee34d37b2dbac9f +# pip imageio @ https://files.pythonhosted.org/packages/fa/04/9abe71dfe8c77f5ee58e8c50df3b562884f7494b56c318b867bd2dcb6ec8/imageio-2.33.0-py3-none-any.whl#sha256=d580d6576d0ae39c459a444a23f6f61fe72123a3df2264f5fce8c87784a4be2e # pip importlib-metadata @ https://files.pythonhosted.org/packages/cc/37/db7ba97e676af155f5fcb1a35466f446eadc9104e25b83366e8088c9c926/importlib_metadata-6.8.0-py3-none-any.whl#sha256=3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb -# pip importlib-resources @ https://files.pythonhosted.org/packages/65/6e/09d8816b5cb7a4006ef8ad1717a2703ad9f331dae9717d9f22488a2d6469/importlib_resources-6.1.0-py3-none-any.whl#sha256=aa50258bbfa56d4e33fbd8aa3ef48ded10d1735f11532b8df95388cc6bdb7e83 +# pip importlib-resources @ https://files.pythonhosted.org/packages/93/e8/facde510585869b5ec694e8e0363ffe4eba067cb357a8398a55f6a1f8023/importlib_resources-6.1.1-py3-none-any.whl#sha256=e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6 # pip jinja2 @ https://files.pythonhosted.org/packages/bc/c3/f068337a370801f372f2f8f6bad74a5c140f6fda3d9de154052708dd3c65/Jinja2-3.1.2-py3-none-any.whl#sha256=6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61 -# pip pytest @ https://files.pythonhosted.org/packages/df/d0/e192c4275aecabf74faa1aacd75ef700091913236ec78b1a98f62a2412ee/pytest-7.4.2-py3-none-any.whl#sha256=1d881c6124e08ff0a1bb75ba3ec0bfd8b5354a01c194ddd5a0a870a48d99b002 +# pip markdown-it-py @ https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl#sha256=355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 +# pip pytest @ https://files.pythonhosted.org/packages/f3/8c/f16efd81ca8e293b2cc78f111190a79ee539d0d5d36ccd49975cb3beac60/pytest-7.4.3-py3-none-any.whl#sha256=0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac # pip python-dateutil @ https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl#sha256=961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9 # pip requests @ https://files.pythonhosted.org/packages/70/8e/0e2d847013cb52cd35b38c009bb167a1a26b2ce6cd6965bf26b47bc0bf44/requests-2.31.0-py3-none-any.whl#sha256=58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f -# pip scipy @ https://files.pythonhosted.org/packages/88/8c/9d1f74196c296046af1f20e6d3fc7fbb27387282315e1643f450bba14329/scipy-1.11.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=c77da50c9a91e23beb63c2a711ef9e9ca9a2060442757dffee34ea41847d8156 -# pip setuptools-scm @ https://files.pythonhosted.org/packages/0e/a3/b9a8b0adfe672bf0df5901707aa929d30a97ee390ba651910186776746d2/setuptools_scm-8.0.4-py3-none-any.whl#sha256=b47844cd2a84b83b3187a5782c71128c28b4c94cad8bfb871da2784a5cb54c4f +# pip scipy @ https://files.pythonhosted.org/packages/db/86/bf3f01f003224c00dd94d9443d676023ed65d63ea2e34356888dc7fa8f48/scipy-1.11.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=91af76a68eeae0064887a48e25c4e616fa519fa0d38602eda7e0f97d65d57937 # pip tifffile @ https://files.pythonhosted.org/packages/f5/72/68ea763b5f3e3d9871492683059ed4724fd700dbe54aa03cdda7a9692129/tifffile-2023.9.26-py3-none-any.whl#sha256=1de47fa945fddaade256e25ad4f375ae65547f3c1354063aded881c32a64cf89 # pip lightgbm @ https://files.pythonhosted.org/packages/98/a9/01f50aee85949ba713733b69a3f0b42d39719a414a0e29bdf2a9f05ecc53/lightgbm-4.1.0.tar.gz#sha256=bee59dd269a93b093f2c610d4a6683a7ea87c63d3ea35c622123ce2c020b2abc -# pip matplotlib @ https://files.pythonhosted.org/packages/e0/8b/b62bc50b01bb2d4af96bc0045c39d60209e2701e172789ceace20a0866b2/matplotlib-3.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=5de39dc61ca35342cf409e031f70f18219f2c48380d3886c1cf5ad9f17898e06 -# pip pandas @ https://files.pythonhosted.org/packages/bc/7e/a9e11bd272e3135108892b6230a115568f477864276181eada3a35d03237/pandas-2.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=a0dbfea0dd3901ad4ce2306575c54348d98499c95be01b8d885a2737fe4d7a98 +# pip matplotlib @ https://files.pythonhosted.org/packages/53/1f/653d60d2ec81a6095fa3e571cf2de57742bab8a51a5c01de26730ce3dc53/matplotlib-3.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=5864bdd7da445e4e5e011b199bb67168cdad10b501750367c496420f2ad00843 +# pip pandas @ https://files.pythonhosted.org/packages/4e/7b/6c251522fd21ad2a51f26df677582ed917650cb8dff286e17625e7a6531b/pandas-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=1329dbe93a880a3d7893149979caa82d6ba64a25e471682637f846d9dbc10dd2 # pip pyamg @ https://files.pythonhosted.org/packages/35/1c/8b2aa6fbb2bae258ab6cdb35b09635bf50865ac2bcdaf220db3d972cc0d8/pyamg-5.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=1332acec6d5ede9440c8ced0ef20952f5b766387116f254b79880ce29fdecee7 # pip pytest-cov @ https://files.pythonhosted.org/packages/a7/4b/8b78d126e275efa2379b1c2e09dc52cf70df16fc3b90613ef82531499d73/pytest_cov-4.1.0-py3-none-any.whl#sha256=6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a # pip pytest-forked @ https://files.pythonhosted.org/packages/f4/af/9c0bda43e486a3c9bf1e0f876d0f241bc3f229d7d65d09331a0868db9629/pytest_forked-1.6.0-py3-none-any.whl#sha256=810958f66a91afb1a1e2ae83089d8dc1cd2437ac96b12963042fbb9fb4d16af0 +# pip rich @ https://files.pythonhosted.org/packages/be/be/1520178fa01eabe014b16e72a952b9f900631142ccd03dc36cf93e30c1ce/rich-13.7.0-py3-none-any.whl#sha256=6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235 # pip scikit-image @ https://files.pythonhosted.org/packages/a3/7e/4cd853a855ac34b4ef3ef6a5c3d1c2e96eaca1154fc6be75db55ffa87393/scikit_image-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=3b7a6c89e8d6252332121b58f50e1625c35f7d6a85489c0b6b7ee4f5155d547a # pip pytest-xdist @ https://files.pythonhosted.org/packages/21/08/b1945d4b4986eb1aa10cf84efc5293bba39da80a2f95db3573dd90678408/pytest_xdist-2.5.0-py3-none-any.whl#sha256=6fe5c74fec98906deb8f2d2b616b5c782022744978e7bd4695d39c8f42d0ce65 # pip numpydoc @ https://files.pythonhosted.org/packages/9c/94/09c437fd4a5fb5adf0468c0865c781dbc11d399544b55f1163d5d4414afb/numpydoc-1.6.0-py3-none-any.whl#sha256=b6ddaa654a52bdf967763c1e773be41f1c3ae3da39ee0de973f2680048acafaa diff --git a/build_tools/azure/pylatest_pip_scipy_dev_linux-64_conda.lock b/build_tools/azure/pylatest_pip_scipy_dev_linux-64_conda.lock index 505502ab4ea4b..67839fe964441 100644 --- a/build_tools/azure/pylatest_pip_scipy_dev_linux-64_conda.lock +++ b/build_tools/azure/pylatest_pip_scipy_dev_linux-64_conda.lock @@ -1,6 +1,6 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 28ec764eefc982520846833c9ea571cf6ea5a0593dee76d7a7560b34e341e35b +# input_hash: 3c1cccbd7d97038317ff1f757471e00fa9433f59f0d7f6f9b9a1ac9e062bfd88 @EXPLICIT https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda#c3473ff8bdb3d124ed5ff11ec380d6f9 https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2023.08.22-h06a4308_0.conda#243d5065a09a3e85ab888c05f5b6445a @@ -14,7 +14,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/bzip2-1.0.8-h7b6447c_0.conda#9303f4 https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.4.4-h6a678d5_0.conda#06e288f9250abef59b9a367d151fc339 https://repo.anaconda.com/pkgs/main/linux-64/libuuid-1.41.5-h5eee18b_0.conda#4a6a2354414c9080327274aa514e5299 https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.4-h6a678d5_0.conda#5558eec6e2191741a92f832ea826251c -https://repo.anaconda.com/pkgs/main/linux-64/openssl-3.0.11-h7f8727e_2.conda#6cad6f2dcde73f8625d729c6db1272d0 +https://repo.anaconda.com/pkgs/main/linux-64/openssl-3.0.12-h7f8727e_0.conda#48caaebab690276acf1bc1f3b56febf4 https://repo.anaconda.com/pkgs/main/linux-64/xz-5.4.2-h5eee18b_0.conda#bcd31de48a0dcb44bc5b99675800c5cc https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.13-h5eee18b_0.conda#333e31fbfbb5057c92fa845ad6adef93 https://repo.anaconda.com/pkgs/main/linux-64/ccache-3.7.9-hfe4627d_0.conda#bef6fc681c273bb7bd0c67d1a591365e @@ -26,9 +26,9 @@ https://repo.anaconda.com/pkgs/main/linux-64/setuptools-68.0.0-py311h06a4308_0.c https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.41.2-py311h06a4308_0.conda#2d4ff85d3dfb7749ae0485ee148d4ea5 https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3-py311h06a4308_0.conda#36ce82a87d13dce74e2a722a8a836115 # pip alabaster @ https://files.pythonhosted.org/packages/64/88/c7083fc61120ab661c5d0b82cb77079fc1429d3f913a456c1c82cf4658f7/alabaster-0.7.13-py3-none-any.whl#sha256=1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3 -# pip babel @ https://files.pythonhosted.org/packages/ff/37/b0241795c3a320a3def948cd0d06daf70310e7fea1d8fda312629bc22ea9/Babel-2.13.0-py3-none-any.whl#sha256=fbfcae1575ff78e26c7449136f1abbefc3c13ce542eeb13d43d50d8b047216ec -# pip certifi @ https://files.pythonhosted.org/packages/4c/dd/2234eab22353ffc7d94e8d13177aaa050113286e93e7b40eae01fbf7c3d9/certifi-2023.7.22-py3-none-any.whl#sha256=92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 -# pip charset-normalizer @ https://files.pythonhosted.org/packages/ff/b6/9222090f396f33cd58aa5b08b9bbf8871416b746a0c7b412a41a973674a5/charset_normalizer-3.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=f0d1e3732768fecb052d90d62b220af62ead5748ac51ef61e7b32c266cac9293 +# pip babel @ https://files.pythonhosted.org/packages/86/14/5dc2eb02b7cc87b2f95930310a2cc5229198414919a116b564832c747bc1/Babel-2.13.1-py3-none-any.whl#sha256=7077a4984b02b6727ac10f1f7294484f737443d7e2e66c5e4380e41a3ae0b4ed +# pip certifi @ https://files.pythonhosted.org/packages/64/62/428ef076be88fa93716b576e4a01f919d25968913e817077a386fcbe4f42/certifi-2023.11.17-py3-none-any.whl#sha256=e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474 +# pip charset-normalizer @ https://files.pythonhosted.org/packages/40/26/f35951c45070edc957ba40a5b1db3cf60a9dbb1b350c2d5bef03e01e61de/charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 # pip coverage @ https://files.pythonhosted.org/packages/bc/01/bf243cf5c926681b35d0c6aa9a3b33da35ab65323c4a593d386b08a0249e/coverage-7.3.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=4175e10cc8dda0265653e8714b3174430b07c1dca8957f4966cbd6c2b1b8065a # pip docutils @ https://files.pythonhosted.org/packages/26/87/f238c0670b94533ac0353a4e2a1a771a0cc73277b88bff23d3ae35a256c1/docutils-0.20.1-py3-none-any.whl#sha256=96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6 # pip execnet @ https://files.pythonhosted.org/packages/e8/9c/a079946da30fac4924d92dbc617e5367d454954494cf1e71567bcc4e00ee/execnet-2.0.2-py3-none-any.whl#sha256=88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41 @@ -37,21 +37,21 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3-py311h06a4308_0.conda#36ce # pip iniconfig @ https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl#sha256=b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # pip markupsafe @ https://files.pythonhosted.org/packages/fe/21/2eff1de472ca6c99ec3993eab11308787b9879af9ca8bbceb4868cf4f2ca/MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 # pip packaging @ https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl#sha256=8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 -# pip platformdirs @ https://files.pythonhosted.org/packages/56/29/3ec311dc18804409ecf0d2b09caa976f3ae6215559306b5b530004e11156/platformdirs-3.11.0-py3-none-any.whl#sha256=e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e +# pip platformdirs @ https://files.pythonhosted.org/packages/31/16/70be3b725073035aa5fc3229321d06e22e73e3e09f6af78dcfdf16c7636c/platformdirs-4.0.0-py3-none-any.whl#sha256=118c954d7e949b35437270383a3f2531e99dd93cf7ce4dc8340d3356d30f173b # pip pluggy @ https://files.pythonhosted.org/packages/05/b8/42ed91898d4784546c5f06c60506400548db3f7a4b3fb441cba4e5c17952/pluggy-1.3.0-py3-none-any.whl#sha256=d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7 # pip py @ https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl#sha256=607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378 -# pip pygments @ https://files.pythonhosted.org/packages/43/88/29adf0b44ba6ac85045e63734ae0997d3c58d8b1a91c914d240828d0d73d/Pygments-2.16.1-py3-none-any.whl#sha256=13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 +# pip pygments @ https://files.pythonhosted.org/packages/97/9c/372fef8377a6e340b1704768d20daaded98bf13282b5327beb2e2fe2c7ef/pygments-2.17.2-py3-none-any.whl#sha256=b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c # pip six @ https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl#sha256=8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # pip snowballstemmer @ https://files.pythonhosted.org/packages/ed/dc/c02e01294f7265e63a7315fe086dd1df7dacb9f840a804da846b96d01b96/snowballstemmer-2.2.0-py2.py3-none-any.whl#sha256=c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a # pip sphinxcontrib-jsmath @ https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl#sha256=2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178 # pip tabulate @ https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl#sha256=024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f # pip threadpoolctl @ https://files.pythonhosted.org/packages/81/12/fd4dea011af9d69e1cad05c75f3f7202cdcbeac9b712eea58ca779a72865/threadpoolctl-3.2.0-py3-none-any.whl#sha256=2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032 -# pip urllib3 @ https://files.pythonhosted.org/packages/d2/b2/b157855192a68541a91ba7b2bbcb91f1b4faa51f8bae38d8005c034be524/urllib3-2.0.7-py3-none-any.whl#sha256=fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +# pip urllib3 @ https://files.pythonhosted.org/packages/96/94/c31f58c7a7f470d5665935262ebd7455c7e4c7782eb525658d3dbf4b9403/urllib3-2.1.0-py3-none-any.whl#sha256=55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3 # pip jinja2 @ https://files.pythonhosted.org/packages/bc/c3/f068337a370801f372f2f8f6bad74a5c140f6fda3d9de154052708dd3c65/Jinja2-3.1.2-py3-none-any.whl#sha256=6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61 -# pip pytest @ https://files.pythonhosted.org/packages/df/d0/e192c4275aecabf74faa1aacd75ef700091913236ec78b1a98f62a2412ee/pytest-7.4.2-py3-none-any.whl#sha256=1d881c6124e08ff0a1bb75ba3ec0bfd8b5354a01c194ddd5a0a870a48d99b002 +# pip pytest @ https://files.pythonhosted.org/packages/f3/8c/f16efd81ca8e293b2cc78f111190a79ee539d0d5d36ccd49975cb3beac60/pytest-7.4.3-py3-none-any.whl#sha256=0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac # pip python-dateutil @ https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl#sha256=961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9 # pip requests @ https://files.pythonhosted.org/packages/70/8e/0e2d847013cb52cd35b38c009bb167a1a26b2ce6cd6965bf26b47bc0bf44/requests-2.31.0-py3-none-any.whl#sha256=58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f -# pip pooch @ https://files.pythonhosted.org/packages/84/8c/4da580db7fb4cfce8f5ed78e7d2aa542e6f201edd69d3d8a96917a8ff63c/pooch-1.7.0-py3-none-any.whl#sha256=74258224fc33d58f53113cf955e8d51bf01386b91492927d0d1b6b341a765ad7 +# pip pooch @ https://files.pythonhosted.org/packages/1a/a5/5174dac3957ac412e80a00f30b6507031fcab7000afc9ea0ac413bddcff2/pooch-1.8.0-py3-none-any.whl#sha256=1bfba436d9e2ad5199ccad3583cca8c241b8736b5bb23fe67c213d52650dbb66 # pip pytest-cov @ https://files.pythonhosted.org/packages/a7/4b/8b78d126e275efa2379b1c2e09dc52cf70df16fc3b90613ef82531499d73/pytest_cov-4.1.0-py3-none-any.whl#sha256=6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a # pip pytest-forked @ https://files.pythonhosted.org/packages/f4/af/9c0bda43e486a3c9bf1e0f876d0f241bc3f229d7d65d09331a0868db9629/pytest_forked-1.6.0-py3-none-any.whl#sha256=810958f66a91afb1a1e2ae83089d8dc1cd2437ac96b12963042fbb9fb4d16af0 # pip pytest-xdist @ https://files.pythonhosted.org/packages/21/08/b1945d4b4986eb1aa10cf84efc5293bba39da80a2f95db3573dd90678408/pytest_xdist-2.5.0-py3-none-any.whl#sha256=6fe5c74fec98906deb8f2d2b616b5c782022744978e7bd4695d39c8f42d0ce65 diff --git a/build_tools/azure/pypy3_linux-64_conda.lock b/build_tools/azure/pypy3_linux-64_conda.lock index 409a935b942fe..0c5b790f599bc 100644 --- a/build_tools/azure/pypy3_linux-64_conda.lock +++ b/build_tools/azure/pypy3_linux-64_conda.lock @@ -1,26 +1,26 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 296e0e62aa19cfbc6aa6d615c86db2d06be56b4b5f76bf148152aff936fcddf5 +# input_hash: 35e4a4f1db15219fa4cb71af7b54acc24ec7c3b3610c479f979c6c44cbd93db7 @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 -https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.7.22-hbcca054_0.conda#a73ecd2988327ad4c8f2c331482917f2 -https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_2.conda#9172c297304f2a20134fc56c97fbe229 +https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.11.17-hbcca054_0.conda#01ffc8d36f9eba0ce0b3c1955fa780ee +https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_3.conda#937eaed008f6bf2191c5fe76f87755e9 https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.9-4_pypy39_pp73.conda#c1b2f29111681a4036ed21eaa3f44620 https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda#939e3e74d8be4dac89ce83b20de2492a https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2#562b26ba2e19059551a811e72ab7f793 -https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_2.conda#c28003b0be0494f9a7664389146716ff -https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h7f98852_4.tar.bz2#a1fd65c7ccbf10880423d82bca54eb54 +https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_3.conda#23fdf1fef05baeb7eadc2aed5fb0011f +https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda#69b8b6202a07720f448be700e300ccf4 https://conda.anaconda.org/conda-forge/linux-64/lerc-4.0.0-h27087fc_0.tar.bz2#76bbff344f0134279f225174e9064c8f https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.1.0-hd590300_1.conda#aec6c91c7371c26392a06708a73c70e5 https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.19-hd590300_0.conda#1635570038840ee3f9c71d22aa5b8b6d https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda#6305a3dd2752c76335295da4e581f2fd https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2#d645c6d2ac96843a2bfaccd2d62b3ac3 -https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_2.conda#78fdab09d9138851dde2b5fe2a11019e +https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_3.conda#c714d905cdfa0e70200f68b80cc04764 https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-3.0.0-hd590300_1.conda#ea25936bb4080d843790b586850f82b8 https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.conda#30de3fd9b3b602f7473f30e684eeea8c https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda#f36c115f1ee199da648e0597ec2047ad -https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-hcb278e6_0.conda#681105bccc2a3f7f1a837d47d39c9179 -https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.3-hd590300_0.conda#7bb88ce04c8deb9f7d763ae04a1da72f +https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda#7dbaa197d7ba6032caf7ae7f32c1efa0 +https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.4-hd590300_0.conda#412ba6938c3e2abaca8b1129ea82e238 https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2#22dad4df6e8630e8dff2428f6f6a7036 https://conda.anaconda.org/conda-forge/linux-64/xorg-kbproto-1.0.7-h7f98852_1002.tar.bz2#4b230e8381279d76131116660f5a241a https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.11-hd590300_0.conda#2c80dc38fface310c9bd81b17037fee5 @@ -31,42 +31,42 @@ https://conda.anaconda.org/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2#2161 https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda#8b9b5aca60558d02ddaa09d599e55920 https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.1.0-hd590300_1.conda#f07002e225d7a60a694d42a7bf5ff53f https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.1.0-hd590300_1.conda#5fc11c6020d421960607d821310fcd4d -https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_2.conda#e75a75a6eaf6f318dae2631158c46575 +https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_3.conda#73031c79546ad06f1fe62e57fdd021bc https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.39-h753d276_0.conda#e1c890aebdebbfbf87e2c917187b4416 -https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.43.2-h2797004_0.conda#4b441a1ee22397d5a27dc1126b849edd +https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.0-h2797004_0.conda#b58e6816d137f3aabf77d341dd5d732b https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda#33277193f5b92bad9fdd230eb700929c https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4 -https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-h2797004_0.conda#513336054f884f95d9fd925748f41ef3 +https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda#d453b98d9c83e71da0741bb0ff4d76bc https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda#68c34ec6149623be41a1933ab996a209 https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda#04b88013080254850d6c01ed54810589 https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.1.0-hd590300_1.conda#39f910d205726805a958da408ca194ba https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda#9ae35c3d96db2c94ce0cef86efdfa2cb https://conda.anaconda.org/conda-forge/linux-64/gdbm-1.18-h0a1914f_2.tar.bz2#b77bc399b07a19c00fe12fdc95ee0297 https://conda.anaconda.org/conda-forge/linux-64/libhiredis-1.0.2-h2cc385e_0.tar.bz2#b34907d3a81a3cd8095ee83d174c074a -https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.24-pthreads_h413a1c8_0.conda#6e4ef6ca28655124dcde9bd500e44c32 +https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.25-pthreads_h413a1c8_0.conda#d172b34a443b95f86089e8229ddc9a17 https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda#55ed21669b2015f77c180feb1dd41930 -https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.3-h4dfa4b3_0.conda#1a82298c57b609a31ab6f2342a307b69 -https://conda.anaconda.org/conda-forge/linux-64/sqlite-3.43.2-h2c6b66d_0.conda#c37b95bcd6c6833dacfd5df0ae2f4303 +https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.5-h4dfa4b3_0.conda#799291c22ec87a0c86c0a4fc0e22b1c5 +https://conda.anaconda.org/conda-forge/linux-64/sqlite-3.44.0-h2c6b66d_0.conda#df56c636df4a98990462d66ac7be2330 https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.conda#49e482d882669206653b095f5206c05b https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hd590300_1.conda#f27a24d46e3ea7b70a1f98e50c62508f https://conda.anaconda.org/conda-forge/linux-64/ccache-4.8.1-h1fcd64f_0.conda#fd37a0c47d8b3667b73af0549037ce83 https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.15-hb7c19ff_3.conda#e96637dd92c5f340215c753a5c9a22d7 -https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-19_linux64_openblas.conda#420f4e9be59d0dc9133a0f43f7bab3f3 -https://conda.anaconda.org/conda-forge/linux-64/openblas-0.3.24-pthreads_h7a3da1a_0.conda#ebe8e905b06dfc5b4b40642d34b1d2f3 +https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-20_linux64_openblas.conda#2b7bb4f7562c8cf334fc2e20c2d28abc +https://conda.anaconda.org/conda-forge/linux-64/openblas-0.3.25-pthreads_h7a3da1a_0.conda#87661673941b5e702275fdf0fc095ad0 https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.0-h488ebb8_3.conda#128c25b7fe6a25286a48f3a6a9b5b6f3 -https://conda.anaconda.org/conda-forge/linux-64/pypy3.9-7.3.13-h9557127_0.conda#96a942d305b5076755335a6fe2bb7cac -https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-19_linux64_openblas.conda#d12374af44575413fbbd4a217d46ea33 -https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-19_linux64_openblas.conda#9f100edf65436e3eabc2a51fc00b2c37 +https://conda.anaconda.org/conda-forge/linux-64/pypy3.9-7.3.13-h9557127_1.conda#39a062cd784476b77f3899d210ff6abc +https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-20_linux64_openblas.conda#36d486d72ab64ffea932329a1d3729a3 +https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-20_linux64_openblas.conda#6fabc51f5e647d09cc010c40061557e0 https://conda.anaconda.org/conda-forge/linux-64/python-3.9.18-0_73_pypy.conda#aaa2e3d23f19ed80f19970585d8e02ec -https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda#7f3dbc9179b4dde7da98dfb151d0ad22 +https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789 https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.4-py39hc10206b_0.conda#662bf4ed6f0d1cf28dd6de14566ad9bf -https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.1.3-pyhd8ed1ab_0.conda#e6518222753f519e911e83136d2158d9 +https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.5-py39hc10206b_0.conda#92f8e0504db2b47d7f9f8fcee8965503 +https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363 https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py39ha90811c_1.conda#25edffabcb0760fc1821597c4ce920db -https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-19_linux64_openblas.conda#685e99d3214f5ac9d1ec6b37983985a6 +https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-20_linux64_openblas.conda#05c5862c7dc25e65ba6c471d96429dae https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.0-py39h6dedee3_0.conda#0479148a00457b130bbf89ff60ecad72 https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 @@ -82,18 +82,18 @@ https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5 https://conda.anaconda.org/conda-forge/linux-64/tornado-6.3.3-py39hf860d4a_1.conda#ed9f2e116805d111f969b78e71203eef https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-15.1.0-py39hf860d4a_0.conda#f699157518d28d00c87542b4ec1273be https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a -https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-19_linux64_openblas.conda#96bca12f1b7c48298dd1abf3e11121af -https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.1.1-py39ha90811c_1.conda#a6c42b5114eb7d8792e12ab3e0d27ebc -https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.43.1-py39hf860d4a_0.conda#30758f7f0597bf9a14f920453638f1df -https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.0-pyhd8ed1ab_0.conda#48b0d98e0c0ec810d3ccc2a0926c8c0e +https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-20_linux64_openblas.conda#9932a1d4e9ecf2d35fb19475446e361e +https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py39ha90811c_0.conda#f3b2afc64bf0cbe901a9b00d44611c61 +https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.45.0-py39hf860d4a_0.conda#8a1b616fcd4add269e7d5bf2fca87cdf +https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.1-pyhd8ed1ab_0.conda#3d5fa25cf42f3f32a12b2d874ace8574 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc -https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.2-pyhd8ed1ab_0.conda#6dd662ff5ac9a783e5c940ce9f3fe649 +https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.3-pyhd8ed1ab_0.conda#5bdca0aca30b0ee62bb84854e027eae0 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 https://conda.anaconda.org/conda-forge/linux-64/scipy-1.11.3-py39h6dedee3_1.conda#595610a3cd404ad02ce81308b6b344ba -https://conda.anaconda.org/conda-forge/linux-64/blas-2.119-openblas.conda#f536a14a54da8b2aedd5a967d1e407c9 -https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.0-pyhd8ed1ab_0.conda#6a62c2cc25376a0d050b3d1d221c3ee9 +https://conda.anaconda.org/conda-forge/linux-64/blas-2.120-openblas.conda#c8f6916a81a340650078171b1d852574 +https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.1-pyhd8ed1ab_0.conda#d04bd1b5bed9177dd7c3cef15e2b6710 https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.0.1-py39h5fd064f_1.conda#e364cfb3ffb590ccef24b5a92389e751 https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.0-py39h4e7d633_2.conda#36130e852c31e1122d9cd45694dd6b0e +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.2-py39h4e7d633_0.conda#a60f8c577d2db485f0b92bef480d6277 https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.0-py39h4162558_2.conda#63e41cd37208e82246704ad415833c6a +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.2-py39h4162558_0.conda#24444011be733e7bde8617eb8fe725e1 diff --git a/build_tools/azure/ubuntu_atlas_lock.txt b/build_tools/azure/ubuntu_atlas_lock.txt index 4b2968352b422..65ffa33bd87ff 100644 --- a/build_tools/azure/ubuntu_atlas_lock.txt +++ b/build_tools/azure/ubuntu_atlas_lock.txt @@ -6,13 +6,13 @@ # cython==0.29.33 # via -r build_tools/azure/ubuntu_atlas_requirements.txt -exceptiongroup==1.1.3 +exceptiongroup==1.2.0 # via pytest execnet==2.0.2 # via pytest-xdist iniconfig==2.0.0 # via pytest -joblib==1.1.1 +joblib==1.2.0 # via -r build_tools/azure/ubuntu_atlas_requirements.txt packaging==23.2 # via pytest @@ -20,7 +20,7 @@ pluggy==1.3.0 # via pytest py==1.11.0 # via pytest-forked -pytest==7.4.2 +pytest==7.4.3 # via # -r build_tools/azure/ubuntu_atlas_requirements.txt # pytest-forked diff --git a/build_tools/azure/ubuntu_atlas_requirements.txt b/build_tools/azure/ubuntu_atlas_requirements.txt index dcbb6285cd212..b4fad825466a7 100644 --- a/build_tools/azure/ubuntu_atlas_requirements.txt +++ b/build_tools/azure/ubuntu_atlas_requirements.txt @@ -2,7 +2,7 @@ # following script to centralize the configuration for CI builds: # build_tools/update_environments_and_lock_files.py cython==0.29.33 # min -joblib==1.1.1 # min +joblib==1.2.0 # min threadpoolctl==2.0.0 # min pytest pytest-xdist==2.5.0 diff --git a/build_tools/circle/doc_environment.yml b/build_tools/circle/doc_environment.yml index b12ebf12f254d..0a23818988761 100644 --- a/build_tools/circle/doc_environment.yml +++ b/build_tools/circle/doc_environment.yml @@ -13,6 +13,7 @@ dependencies: - threadpoolctl - matplotlib - pandas + - rich - pyamg - pytest - pytest-xdist=2.5.0 diff --git a/build_tools/circle/doc_linux-64_conda.lock b/build_tools/circle/doc_linux-64_conda.lock index d531db979b788..c6fba6a781323 100644 --- a/build_tools/circle/doc_linux-64_conda.lock +++ b/build_tools/circle/doc_linux-64_conda.lock @@ -1,33 +1,33 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: a1599a69c6966abc8fac90f7e4ec1ab0bdd0ca5c1fe5dbbea18d9cbe0b5582ee +# input_hash: fafda966dd645ecc15fa03e13b78402cd6102bf1ce290240ce35e541ebc07b7d @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 -https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.7.22-hbcca054_0.conda#a73ecd2988327ad4c8f2c331482917f2 +https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.11.17-hbcca054_0.conda#01ffc8d36f9eba0ce0b3c1955fa780ee https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2#0c96522c6bdaed4b1566d11387caaf45 https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2#34893075a5c9e55cdafac56607368fc6 https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2#4d59c254e01d9cde7957100457e2d5fb https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-hab24e00_0.tar.bz2#19410c3df09dfb12d1206132a1d357c5 https://conda.anaconda.org/conda-forge/noarch/kernel-headers_linux-64-2.6.32-he073ed8_16.conda#7ca122655873935e02c91279c5b03c8c https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda#7aca3059a1729aa76c597603f10b0dd3 -https://conda.anaconda.org/conda-forge/linux-64/libgcc-devel_linux-64-12.3.0-h8bca6fd_2.conda#ed613582de7b8569fdc53ca141be176a -https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-devel_linux-64-12.3.0-h8bca6fd_2.conda#7268a17e56eb099d1b8869bbbf46de4c -https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_2.conda#9172c297304f2a20134fc56c97fbe229 +https://conda.anaconda.org/conda-forge/noarch/libgcc-devel_linux-64-12.3.0-h8bca6fd_103.conda#1d7f6d1825bd6bf21ee04336ec87a777 +https://conda.anaconda.org/conda-forge/noarch/libstdcxx-devel_linux-64-12.3.0-h8bca6fd_103.conda#3f784d2c059e960156d1ab3858cbf200 +https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_3.conda#937eaed008f6bf2191c5fe76f87755e9 https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.9-4_cp39.conda#bfe4b3259a8ac6cdf0037752904da6a7 https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda#939e3e74d8be4dac89ce83b20de2492a https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2#f766549260d6815b0c52253f1fb1bb29 -https://conda.anaconda.org/conda-forge/linux-64/libgomp-13.2.0-h807b86a_2.conda#e2042154faafe61969556f28bade94b9 +https://conda.anaconda.org/conda-forge/linux-64/libgomp-13.2.0-h807b86a_3.conda#7124cbb46b13d395bdde68f2d215c989 https://conda.anaconda.org/conda-forge/noarch/sysroot_linux-64-2.12-he073ed8_16.conda#071ea8dceff4d30ac511f4a2f8437cd1 https://conda.anaconda.org/conda-forge/linux-64/binutils_impl_linux-64-2.40-hf600244_0.conda#33084421a8c0af6aef1b439707f7662a https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2#fee5683a3f04bd15cbd8318b096a27ab https://conda.anaconda.org/conda-forge/linux-64/binutils-2.40-hdd6e379_0.conda#ccc940fddbc3fcd3d79cd4c654c4b5c4 https://conda.anaconda.org/conda-forge/linux-64/binutils_linux-64-2.40-hbdbef99_2.conda#adfebae9fdc63a598495dfe3b006973a https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2#562b26ba2e19059551a811e72ab7f793 -https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_2.conda#c28003b0be0494f9a7664389146716ff +https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_3.conda#23fdf1fef05baeb7eadc2aed5fb0011f https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.10-hd590300_0.conda#75dae9a4201732aa78a530b826ee5fe0 -https://conda.anaconda.org/conda-forge/linux-64/aom-3.6.1-h59595ed_0.conda#8457db6d1175ee86c8e077f6ac60ff55 +https://conda.anaconda.org/conda-forge/linux-64/aom-3.7.1-h59595ed_0.conda#504e70332b8322cda93b7bceb5925fca https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2#d9c69a24ad678ffce24c6543a0176b00 -https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h7f98852_4.tar.bz2#a1fd65c7ccbf10880423d82bca54eb54 +https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda#69b8b6202a07720f448be700e300ccf4 https://conda.anaconda.org/conda-forge/linux-64/charls-2.4.2-h59595ed_0.conda#4336bd67920dd504cd8c6761d6a99645 https://conda.anaconda.org/conda-forge/linux-64/dav1d-1.2.1-hd590300_0.conda#418c6ca5929a611cbd69204907a83995 https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2#14947d8770185e5153fdd04d4673ed37 @@ -43,22 +43,22 @@ https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.1.0-hd590300_1 https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.19-hd590300_0.conda#1635570038840ee3f9c71d22aa5b8b6d https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda#6305a3dd2752c76335295da4e581f2fd https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2#d645c6d2ac96843a2bfaccd2d62b3ac3 -https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_2.conda#78fdab09d9138851dde2b5fe2a11019e +https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_3.conda#c714d905cdfa0e70200f68b80cc04764 https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-h166bdaf_0.tar.bz2#b62b52da46c39ee2bc3c162ac7f1804d https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-3.0.0-hd590300_1.conda#ea25936bb4080d843790b586850f82b8 https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda#30fd6e37fe21f86f4bd26d6ee73eeec7 https://conda.anaconda.org/conda-forge/linux-64/libogg-1.3.4-h7f98852_1.tar.bz2#6e8cc2173440d77708196c5b93771680 https://conda.anaconda.org/conda-forge/linux-64/libopus-1.3.1-h7f98852_1.tar.bz2#15345e56d527b330e1cacbdf58676e8f -https://conda.anaconda.org/conda-forge/linux-64/libsanitizer-12.3.0-h0f45ef3_2.conda#4655db64eca78a6fcc4fb654fc1f8d57 +https://conda.anaconda.org/conda-forge/linux-64/libsanitizer-12.3.0-h0f45ef3_3.conda#eda05ab0db8f8490945fd99244183e3a https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda#40b61aab5c7ba9ff276c41cfffe6b80b https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.conda#30de3fd9b3b602f7473f30e684eeea8c https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda#f36c115f1ee199da648e0597ec2047ad https://conda.anaconda.org/conda-forge/linux-64/libzopfli-1.0.3-h9c3ff4c_0.tar.bz2#c66fe2d123249af7651ebde8984c51c2 https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda#318b08df404f9c9be5712aaa5a6f0bb0 https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.3-h59595ed_0.conda#bdadff838d5437aea83607ced8b37f75 -https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-hcb278e6_0.conda#681105bccc2a3f7f1a837d47d39c9179 +https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda#7dbaa197d7ba6032caf7ae7f32c1efa0 https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda#da0ec11a6454ae19bff5b02ed881a2b1 -https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.3-hd590300_0.conda#7bb88ce04c8deb9f7d763ae04a1da72f +https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.4-hd590300_0.conda#412ba6938c3e2abaca8b1129ea82e238 https://conda.anaconda.org/conda-forge/linux-64/pixman-0.42.2-h59595ed_0.conda#700edd63ccd5fc66b70b1c028cea9a68 https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2#22dad4df6e8630e8dff2428f6f6a7036 https://conda.anaconda.org/conda-forge/linux-64/rav1e-0.6.6-he8a937b_2.conda#77d9955b4abddb811cb8ab1aa7d743e4 @@ -76,45 +76,45 @@ https://conda.anaconda.org/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2#2161 https://conda.anaconda.org/conda-forge/linux-64/zfp-1.0.0-h59595ed_4.conda#9cfbafab420f42b572f3c032ad59da85 https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.0.7-h0b41bf4_0.conda#49e8329110001f04923fe7e864990b0c https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda#8b9b5aca60558d02ddaa09d599e55920 -https://conda.anaconda.org/conda-forge/linux-64/gcc_impl_linux-64-12.3.0-he2b93b0_2.conda#2f4d8677dc7dd87f93e9abfb2ce86808 -https://conda.anaconda.org/conda-forge/linux-64/libavif16-1.0.1-h87da1f6_2.conda#0281e5f0887a512d7cc2a843173ca243 +https://conda.anaconda.org/conda-forge/linux-64/gcc_impl_linux-64-12.3.0-he2b93b0_3.conda#71c68ea75afe6ac7a9c62c08f5d67a5a +https://conda.anaconda.org/conda-forge/linux-64/libavif16-1.0.2-hed45d22_0.conda#ad3e851b008cbf5bfb0d229b6a776842 https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.1.0-hd590300_1.conda#f07002e225d7a60a694d42a7bf5ff53f https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.1.0-hd590300_1.conda#5fc11c6020d421960607d821310fcd4d https://conda.anaconda.org/conda-forge/linux-64/libcap-2.69-h0f662aa_0.conda#25cb5999faa414e5ccb2c1388f62d3d5 https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20191231-he28a2e2_2.tar.bz2#4d331e44109e3f0e19b4cb8f9b82f3e1 https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda#a1cfcc585f0c42bf8d5546bb1dfb668d https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.3-h59595ed_0.conda#ee48bf17cc83a00f59ca1494d5646869 -https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_2.conda#e75a75a6eaf6f318dae2631158c46575 +https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_3.conda#73031c79546ad06f1fe62e57fdd021bc https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.47-h71f35ed_0.conda#c2097d0b46367996f09b4e8e4920384a https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.39-h753d276_0.conda#e1c890aebdebbfbf87e2c917187b4416 -https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.43.2-h2797004_0.conda#4b441a1ee22397d5a27dc1126b849edd +https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.0-h2797004_0.conda#b58e6816d137f3aabf77d341dd5d732b https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2#309dec04b70a3cc0f1e84a4013683bc0 https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda#33277193f5b92bad9fdd230eb700929c -https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.5-h232c23b_1.conda#f3858448893839820d4bcfb14ad3ecdf -https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_5.conda#1e8ef4090ca4f0d66404a7441e1dbf3c -https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.40-hc3806b6_0.tar.bz2#69e2c796349cd9b273890bee0febfe1b +https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.6-h232c23b_0.conda#427a3e59d66cb5d145020bd9c6493334 +https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_6.conda#80bf3b277c120dd294b51d404b931a75 +https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.42-hcad00b1_0.conda#679c8961826aa4b50653bce17ee52abe https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4 -https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-h2797004_0.conda#513336054f884f95d9fd925748f41ef3 +https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda#d453b98d9c83e71da0741bb0ff4d76bc https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda#93ee23f12bc2e684548181256edd2cf6 https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda#68c34ec6149623be41a1933ab996a209 https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda#04b88013080254850d6c01ed54810589 https://conda.anaconda.org/conda-forge/linux-64/blosc-1.21.5-h0f2a231_0.conda#009521b7ed97cca25f8f997f9e745976 https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.1.0-hd590300_1.conda#39f910d205726805a958da408ca194ba -https://conda.anaconda.org/conda-forge/linux-64/c-blosc2-2.10.5-hb4ffafa_0.conda#831974f6788127baf502fca9001a5fb4 +https://conda.anaconda.org/conda-forge/linux-64/c-blosc2-2.11.2-hb4ffafa_0.conda#aa776e4716e54633d1279cf7599c3711 https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda#9ae35c3d96db2c94ce0cef86efdfa2cb https://conda.anaconda.org/conda-forge/linux-64/gcc-12.3.0-h8d2909c_2.conda#e2f2f81f367e14ca1f77a870bda2fe59 https://conda.anaconda.org/conda-forge/linux-64/gcc_linux-64-12.3.0-h76fc315_2.conda#11517e7b5c910c5b5d6985c0c7eb7f50 -https://conda.anaconda.org/conda-forge/linux-64/gfortran_impl_linux-64-12.3.0-hfcedea8_2.conda#09d48cadff6669068c3bf7ae7dc8ea4a -https://conda.anaconda.org/conda-forge/linux-64/gxx_impl_linux-64-12.3.0-he2b93b0_2.conda#f89b9916afc36fc5562fbfc11330a8a2 +https://conda.anaconda.org/conda-forge/linux-64/gfortran_impl_linux-64-12.3.0-hfcedea8_3.conda#929fbb7d28a3727e96170e613253d2f4 +https://conda.anaconda.org/conda-forge/linux-64/gxx_impl_linux-64-12.3.0-he2b93b0_3.conda#b6ce9868fc6c65a18c22fd983e2d7e6f https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 -https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.1-h166bdaf_0.tar.bz2#f967fc95089cd247ceed56eda31de3a9 -https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.0-hebfc3b9_0.conda#e618003da3547216310088478e475945 +https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.2-hd590300_0.conda#3d7d5e5cebf8af5aadb040732860f1b6 +https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.1-h783c2da_1.conda#70052d6c1e84643e30ffefb21ab6950f https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-h5cf9203_3.conda#9efe82d44b76a7529a1d702e5a37752e -https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.24-pthreads_h413a1c8_0.conda#6e4ef6ca28655124dcde9bd500e44c32 +https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.25-pthreads_h413a1c8_0.conda#d172b34a443b95f86089e8229ddc9a17 https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda#ef1910918dd895516a769ed36b5b3a4e https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda#55ed21669b2015f77c180feb1dd41930 -https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.3-h4dfa4b3_0.conda#1a82298c57b609a31ab6f2342a307b69 -https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_5.conda#b72f016c910ff9295b1377d3e17da3f2 +https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.5-h4dfa4b3_0.conda#799291c22ec87a0c86c0a4fc0e22b1c5 +https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_6.conda#e87530d1b12dd7f4e0f856dc07358d60 https://conda.anaconda.org/conda-forge/linux-64/nss-3.94-h1d7d5a4_0.conda#7caef74bbfa730e014b20f0852068509 https://conda.anaconda.org/conda-forge/linux-64/python-3.9.18-h0755675_0_cpython.conda#3ede353bc605068d9677e700b1847382 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-hd590300_1.conda#9bfac7ccd94d54fd21a0501296d60424 @@ -126,19 +126,19 @@ https://conda.anaconda.org/conda-forge/noarch/alabaster-0.7.13-pyhd8ed1ab_0.cond https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hd590300_1.conda#f27a24d46e3ea7b70a1f98e50c62508f https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.1.0-py39h3d6467e_1.conda#c48418c8b35f1d59ae9ae1174812b40a https://conda.anaconda.org/conda-forge/linux-64/c-compiler-1.6.0-hd590300_0.conda#ea6c792f792bdd7ae6e7e2dee32f0a48 -https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda#7f3dbc9179b4dde7da98dfb151d0ad22 -https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.1-pyhd8ed1ab_0.conda#985378f74689fccce52f158027bd9acd +https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789 +https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.2-pyhd8ed1ab_0.conda#7f4a9e3fcff3f6356ae99244a014da6a https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.4-py39h3d6467e_0.conda#993291d01bb80e1e559c55d2b55972b4 +https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.5-py39h3d6467e_0.conda#8a666e66408ec097bf7b6d44353d6294 https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2#ecfff944ba3960ecb334b9a2663d708d https://conda.anaconda.org/conda-forge/linux-64/docutils-0.20.1-py39hf3d152e_2.conda#8effc3913cfe3a29f2a89cda29bbff04 -https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.1.3-pyhd8ed1ab_0.conda#e6518222753f519e911e83136d2158d9 +https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363 https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda#0f69b688f52ff6da70bccb7ff7001d1d https://conda.anaconda.org/conda-forge/linux-64/gfortran-12.3.0-h499e0f7_2.conda#0558a8c44eb7a18e6682bd3a8ae6dcab https://conda.anaconda.org/conda-forge/linux-64/gfortran_linux-64-12.3.0-h7fe76b4_2.conda#3a749210487c0358b6f135a648cbbf60 -https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.0-hfc55251_0.conda#e10134de3558dd95abda6987b5548f4f +https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.1-hfc55251_1.conda#a50918d10114a0bf80fb46c7cc692058 https://conda.anaconda.org/conda-forge/linux-64/gxx-12.3.0-h8d2909c_2.conda#673bac341be6b90ef9e8abae7e52ca46 https://conda.anaconda.org/conda-forge/linux-64/gxx_linux-64-12.3.0-h8a814eb_2.conda#f517b1525e9783849bd56a5dc45a9960 https://conda.anaconda.org/conda-forge/noarch/idna-3.4-pyhd8ed1ab_0.tar.bz2#34272b248891bddccc64479f9a7fffed @@ -147,22 +147,23 @@ https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py39h7633fee_1.conda#c9f74d717e5a2847a9f8b779c54130f2 https://conda.anaconda.org/conda-forge/noarch/lazy_loader-0.3-pyhd8ed1ab_0.conda#69ea1d0fa7ab33b48c88394ad1dead65 https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.15-hb7c19ff_3.conda#e96637dd92c5f340215c753a5c9a22d7 -https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-19_linux64_openblas.conda#420f4e9be59d0dc9133a0f43f7bab3f3 +https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-20_linux64_openblas.conda#2b7bb4f7562c8cf334fc2e20c2d28abc https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_h9986a30_3.conda#1720df000b48e31842500323cb7be18c https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h4637d8d_4.conda#d4529f4dff3057982a7617c7ac58fde3 -https://conda.anaconda.org/conda-forge/linux-64/libpq-16.0-hfc447b1_1.conda#e4a9a5ba40123477db33e02a78dffb01 +https://conda.anaconda.org/conda-forge/linux-64/libpq-16.1-hfc447b1_0.conda#2b7f1893cf40b4ccdc0230bcd94d5ed9 https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-254-h3516f8a_0.conda#df4b1cd0c91b4234fb02b5701a4cdddc https://conda.anaconda.org/conda-forge/linux-64/markupsafe-2.1.3-py39hd1e30aa_1.conda#ee2b4665b852ec6ff2758f3c1b91233d +https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.0-pyhd8ed1ab_0.tar.bz2#f8dab71fdc13b1bf29a01248b156d268 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 -https://conda.anaconda.org/conda-forge/noarch/networkx-3.2-pyhd8ed1ab_1.conda#522039fb968d6d0a10e872e6f3856f53 -https://conda.anaconda.org/conda-forge/linux-64/openblas-0.3.24-pthreads_h7a3da1a_0.conda#ebe8e905b06dfc5b4b40642d34b1d2f3 +https://conda.anaconda.org/conda-forge/noarch/networkx-3.2.1-pyhd8ed1ab_0.conda#425fce3b531bed6ec3c74fab3e5f0a1c +https://conda.anaconda.org/conda-forge/linux-64/openblas-0.3.25-pthreads_h7a3da1a_0.conda#87661673941b5e702275fdf0fc095ad0 https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.0-h488ebb8_3.conda#128c25b7fe6a25286a48f3a6a9b5b6f3 https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.3.0-pyhd8ed1ab_0.conda#2390bd10bed1f3fdc7a537fb5a447d8d https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2#7205635cd71531943440fbfe3b6b5727 https://conda.anaconda.org/conda-forge/linux-64/psutil-5.9.5-py39hd1e30aa_1.conda#c2e412b0f11e5983bcfc35d9beb91ecb https://conda.anaconda.org/conda-forge/noarch/py-1.11.0-pyh6c4a22f_0.tar.bz2#b4613d7e7a493916d867842a6a148054 -https://conda.anaconda.org/conda-forge/noarch/pygments-2.16.1-pyhd8ed1ab_0.conda#40e5cb18165466773619e5c963f00a7b +https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2#2a7de29fb590ca14b5243c4c812c8025 https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2023.3-pyhd8ed1ab_0.conda#2590495f608a63625e165915fb4e2e34 @@ -178,59 +179,60 @@ https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5 https://conda.anaconda.org/conda-forge/linux-64/tornado-6.3.3-py39hd1e30aa_1.conda#cbe186eefb0bcd91e8f47c3908489874 https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.8.0-pyha770c72_0.conda#5b1be40a26d10a06f6d4f1f9e19fa0c7 https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-15.1.0-py39hd1e30aa_0.conda#1da984bbb6e765743e13388ba7b7b2c8 -https://conda.anaconda.org/conda-forge/noarch/wheel-0.41.2-pyhd8ed1ab_0.conda#1ccd092478b3e0ee10d7a891adbf8a4f +https://conda.anaconda.org/conda-forge/noarch/wheel-0.41.3-pyhd8ed1ab_0.conda#3fc026b9c87d091c4b34a6c997324ae8 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-image-0.4.0-h8ee46fc_1.conda#9d7bcddf49cbf727730af10e71022c73 https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.40-hd590300_0.conda#07c15d846a2e4d673da22cbd85fdb6d2 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.4-h0b41bf4_2.conda#82b6df12252e6f32402b96dacc656fec https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.11-hd590300_0.conda#ed67c36f215b310412b2af935bf3e530 https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a -https://conda.anaconda.org/conda-forge/noarch/babel-2.13.0-pyhd8ed1ab_0.conda#22541af7a9eb59fc6afcadb7ecdf9219 +https://conda.anaconda.org/conda-forge/noarch/babel-2.13.1-pyhd8ed1ab_0.conda#3ccff479c246692468f604df9c85ef26 https://conda.anaconda.org/conda-forge/linux-64/brunsli-0.1-h9c3ff4c_0.tar.bz2#c1ac6229d0bfd14f8354ff9ad2a26cad https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda#f907bb958910dc404647326ca80c263e https://conda.anaconda.org/conda-forge/linux-64/cxx-compiler-1.6.0-h00ab1b0_0.conda#364c6ae36c4e36fcbd4d273cf4db78af -https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.43.1-py39hd1e30aa_0.conda#74b032179f7782051800908cb2250132 +https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.45.0-py39hd1e30aa_0.conda#8f71ddc27b34b892d1116c6ccec272ef https://conda.anaconda.org/conda-forge/linux-64/fortran-compiler-1.6.0-heb67821_0.conda#b65c49dda97ae497abcbdf3a8ba0018f -https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.0-hfc55251_0.conda#2f55a36b549f51a7e0c2b1e3c3f0ccd4 +https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.1-hfc55251_1.conda#8d7242302bb3d03b9a690b6dda872603 https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-6.8.0-pyha770c72_0.conda#4e9f59a060c3be52bc4ddc46ee9b6946 -https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.0-pyhd8ed1ab_0.conda#48b0d98e0c0ec810d3ccc2a0926c8c0e +https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.1-pyhd8ed1ab_0.conda#3d5fa25cf42f3f32a12b2d874ace8574 https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.2-pyhd8ed1ab_1.tar.bz2#c8490ed5c70966d232fdd389d0dbed37 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc -https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-19_linux64_openblas.conda#d12374af44575413fbbd4a217d46ea33 +https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-20_linux64_openblas.conda#36d486d72ab64ffea932329a1d3729a3 https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_h7634d5b_3.conda#0922208521c0463e690bbaebba7eb551 -https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-19_linux64_openblas.conda#9f100edf65436e3eabc2a51fc00b2c37 +https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-20_linux64_openblas.conda#6fabc51f5e647d09cc010c40061557e0 https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.6.0-h5d7e998_0.conda#d8edd0e29db6fb6b6988e1a28d35d994 +https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda#93a8e71256479c62074356ef6ebf501b https://conda.anaconda.org/conda-forge/noarch/memory_profiler-0.61.0-pyhd8ed1ab_0.tar.bz2#8b45f9f2b2f7a98b0ec179c8991a4a9b https://conda.anaconda.org/conda-forge/linux-64/pillow-10.1.0-py39had0adad_0.conda#eeaa413fddccecb2ab7f747bdb55b07f https://conda.anaconda.org/conda-forge/noarch/pip-23.3.1-pyhd8ed1ab_0.conda#2400c0b86889f43aa52067161e1fb108 -https://conda.anaconda.org/conda-forge/noarch/plotly-5.17.0-pyhd8ed1ab_0.conda#76a0b213abcd3ffc1e8fa78804b69dc0 +https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.0.0-pyhd8ed1ab_0.conda#6bb4ee32cd435deaeac72776c001e7ac +https://conda.anaconda.org/conda-forge/noarch/plotly-5.18.0-pyhd8ed1ab_0.conda#9f6a8664f1fe752f79473eeb9bf33a60 https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-hb77b528_5.conda#ac902ff3c1c6d750dd0dfc93a974ab74 -https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.2-pyhd8ed1ab_0.conda#6dd662ff5ac9a783e5c940ce9f3fe649 +https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.3-pyhd8ed1ab_0.conda#5bdca0aca30b0ee62bb84854e027eae0 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.12-py39h3d6467e_0.conda#e667a3ab0df62c54e60e1843d2e6defb -https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.8.0-hd8ed1ab_0.conda#384462e63262a527bda564fa2d9126c0 -https://conda.anaconda.org/conda-forge/noarch/urllib3-2.0.7-pyhd8ed1ab_0.conda#270e71c14d37074b1d066ee21cf0c4a6 +https://conda.anaconda.org/conda-forge/noarch/urllib3-2.1.0-pyhd8ed1ab_0.conda#f8ced8ee63830dec7ecc1be048d1470a https://conda.anaconda.org/conda-forge/linux-64/compilers-1.6.0-ha770c72_0.conda#e2259de4640a51a28c21931ae98e4975 -https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.6-h98fc4e7_2.conda#1c95f7c612f9121353c4ef764678113e -https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.2.1-h3d44ed6_0.conda#98db5f8813f45e2b29766aff0e4a499c -https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.0-pyhd8ed1ab_0.conda#6a62c2cc25376a0d050b3d1d221c3ee9 -https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-19_linux64_openblas.conda#685e99d3214f5ac9d1ec6b37983985a6 +https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.7-h98fc4e7_0.conda#6c919bafe5e03428a8e2ef319d7ef990 +https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.3.0-h3d44ed6_0.conda#5a6f6c00ef982a9bc83558d9ac8f64a0 +https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.1-pyhd8ed1ab_0.conda#d04bd1b5bed9177dd7c3cef15e2b6710 +https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-20_linux64_openblas.conda#05c5862c7dc25e65ba6c471d96429dae https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.0-py39h474f0d3_0.conda#62f1d2e05327bf62728afa448f2a9261 -https://conda.anaconda.org/conda-forge/noarch/platformdirs-3.11.0-pyhd8ed1ab_0.conda#8f567c0a74aa44cf732f15773b4083b0 https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.12.2-py39h3d6467e_5.conda#93aff412f3e49fdb43361c0215cbd72d https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda#a30144e4156cdbb236f99ebb49828f8b -https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-19_linux64_openblas.conda#96bca12f1b7c48298dd1abf3e11121af -https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.1.1-py39h7633fee_1.conda#33afb3357cd0d120ecb26778d37579e4 -https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.6-h8e1006c_2.conda#3d8e98279bad55287f2ef9047996f33c +https://conda.anaconda.org/conda-forge/noarch/rich-13.7.0-pyhd8ed1ab_0.conda#d7a11d4f3024b2f4a6e0ae7377dd61e9 +https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-20_linux64_openblas.conda#9932a1d4e9ecf2d35fb19475446e361e +https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py39h7633fee_0.conda#ed71ad3e30eb03da363fb797419cce98 +https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.7-h8e1006c_0.conda#065e2c1d49afa3fdc1a01f1dacd6ab09 https://conda.anaconda.org/conda-forge/linux-64/imagecodecs-2023.9.18-py39hf9b8f0e_2.conda#38f576a701ea508ed210087c711a06ee https://conda.anaconda.org/conda-forge/noarch/imageio-2.31.5-pyh8c1a49c_0.conda#6820ccf6a3a27df348f18c85dd89014a -https://conda.anaconda.org/conda-forge/linux-64/pandas-2.1.1-py39hddac248_1.conda#f32809db710b8aac48fbc14c13058530 -https://conda.anaconda.org/conda-forge/noarch/pooch-1.7.0-pyhd8ed1ab_4.conda#3cdaf7af08850933662b1e228bc6b5bc +https://conda.anaconda.org/conda-forge/linux-64/pandas-2.1.3-py39hddac248_0.conda#961b398d8c421a3752e26f01f2dcbdac +https://conda.anaconda.org/conda-forge/noarch/pooch-1.8.0-pyhd8ed1ab_0.conda#134b2b57b7865d2316a7cce1915a51ed https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e https://conda.anaconda.org/conda-forge/linux-64/pywavelets-1.4.1-py39h44dd56e_1.conda#d037c20e3da2e85f03ebd20ad480c359 https://conda.anaconda.org/conda-forge/linux-64/scipy-1.11.3-py39h474f0d3_1.conda#55441724fedb3042d38ffa5220f00804 -https://conda.anaconda.org/conda-forge/linux-64/blas-2.119-openblas.conda#f536a14a54da8b2aedd5a967d1e407c9 -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.0-py39he9076e7_2.conda#404144d0628ebbbbd56d161c677cc71b +https://conda.anaconda.org/conda-forge/linux-64/blas-2.120-openblas.conda#c8f6916a81a340650078171b1d852574 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.2-py39he9076e7_0.conda#6085411aa2f0b2b801d3b46e1d3b83c5 https://conda.anaconda.org/conda-forge/noarch/patsy-0.5.3-pyhd8ed1ab_0.tar.bz2#50ef6b29b1fb0768ca82c5aeb4fb2d96 https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.0.1-py39hda80f44_1.conda#6df47699edb4d8d3365de2d189a456bc https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h82b777d_17.conda#4f01e33dbb406085a16a2813ab067e95 @@ -239,11 +241,11 @@ https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.9-py39h52134e7_5.conda https://conda.anaconda.org/conda-forge/linux-64/scikit-image-0.22.0-py39hddac248_2.conda#8d502a4d2cbe5a45ff35ca8af8cbec0a https://conda.anaconda.org/conda-forge/noarch/seaborn-base-0.13.0-pyhd8ed1ab_0.conda#082666331726b2438986cfe33ae9a8ee https://conda.anaconda.org/conda-forge/linux-64/statsmodels-0.14.0-py39h44dd56e_2.conda#a00daa168ddb25c6bb30952374c011e8 -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.0-py39hf3d152e_2.conda#ffe5ae58957da676064e2ce5d039d259 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.2-py39hf3d152e_0.conda#18d40a5ada9a801cabaf5d47c15c6282 https://conda.anaconda.org/conda-forge/noarch/seaborn-0.13.0-hd8ed1ab_0.conda#ebd31a95a7008b7e164dad9dbbb5bb5a https://conda.anaconda.org/conda-forge/noarch/numpydoc-1.5.0-pyhd8ed1ab_0.tar.bz2#3c275d7168a6a135329f4acb364c229a https://conda.anaconda.org/conda-forge/noarch/sphinx-copybutton-0.5.2-pyhd8ed1ab_0.conda#ac832cc43adc79118cf6e23f1f9b8995 -https://conda.anaconda.org/conda-forge/noarch/sphinx-gallery-0.14.0-pyhd8ed1ab_0.conda#b3788794f88c9512393032e448428261 +https://conda.anaconda.org/conda-forge/noarch/sphinx-gallery-0.15.0-pyhd8ed1ab_0.conda#1a49ca9515ef9a96edff2eea06143dc6 https://conda.anaconda.org/conda-forge/noarch/sphinx-prompt-1.4.0-pyhd8ed1ab_0.tar.bz2#88ee91e8679603f2a5bd036d52919cc2 https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-applehelp-1.0.7-pyhd8ed1ab_0.conda#aebfabcb60c33a89c1f9290cab49bc93 https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-devhelp-1.0.5-pyhd8ed1ab_0.conda#ebf08f5184d8eaa486697bc060031953 @@ -251,11 +253,11 @@ https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-htmlhelp-2.0.4-pyhd8 https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-qthelp-1.0.6-pyhd8ed1ab_0.conda#cf5c9649272c677a964a7313279e3a9b https://conda.anaconda.org/conda-forge/noarch/sphinx-7.2.6-pyhd8ed1ab_0.conda#bbfd1120d1824d2d073bc65935f0e4c0 https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.9-pyhd8ed1ab_0.conda#0612e497d7860728f2cda421ea2aec09 -https://conda.anaconda.org/conda-forge/noarch/sphinxext-opengraph-0.8.2-pyhd8ed1ab_0.conda#7f330c6004309c83cc63aed39b70d711 +https://conda.anaconda.org/conda-forge/noarch/sphinxext-opengraph-0.9.0-pyhd8ed1ab_0.conda#7eeb01dc276a321b3ed02319d461db79 # pip attrs @ https://files.pythonhosted.org/packages/f0/eb/fcb708c7bf5056045e9e98f62b93bd7467eb718b0202e7698eb11d66416c/attrs-23.1.0-py3-none-any.whl#sha256=1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04 # pip cloudpickle @ https://files.pythonhosted.org/packages/96/43/dae06432d0c4b1dc9e9149ad37b4ca8384cf6eb7700cd9215b177b914f0a/cloudpickle-3.0.0-py3-none-any.whl#sha256=246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 # pip defusedxml @ https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl#sha256=a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61 -# pip fastjsonschema @ https://files.pythonhosted.org/packages/7f/1a/8aad366cf1779351741e5c791ae76dc8b293f72e9448c689cc2e730f06cb/fastjsonschema-2.18.1-py3-none-any.whl#sha256=aec6a19e9f66e9810ab371cc913ad5f4e9e479b63a7072a2cd060a9369e329a8 +# pip fastjsonschema @ https://files.pythonhosted.org/packages/63/e9/d3dca06ea6b8e58e65716973bc7d9bee9bc39ce233595aa04d04e89a1089/fastjsonschema-2.19.0-py3-none-any.whl#sha256=b9fd1a2dd6971dbc7fee280a95bd199ae0dd9ce22beb91cc75e9c1c528a5170e # pip fqdn @ https://files.pythonhosted.org/packages/cf/58/8acf1b3e91c58313ce5cb67df61001fc9dcd21be4fadb76c1a2d540e09ed/fqdn-1.5.1-py3-none-any.whl#sha256=3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014 # pip json5 @ https://files.pythonhosted.org/packages/70/ba/fa37123a86ae8287d6678535a944f9c3377d8165e536310ed6f6cb0f0c0e/json5-0.9.14-py2.py3-none-any.whl#sha256=740c7f1b9e584a468dbb2939d8d458db3427f2c93ae2139d05f47e453eae964f # pip jsonpointer @ https://files.pythonhosted.org/packages/12/f6/0232cc0c617e195f06f810534d00b74d2f348fe71b2118009ad8ad31f878/jsonpointer-2.4-py2.py3-none-any.whl#sha256=15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a @@ -264,17 +266,17 @@ https://conda.anaconda.org/conda-forge/noarch/sphinxext-opengraph-0.8.2-pyhd8ed1 # pip overrides @ https://files.pythonhosted.org/packages/da/28/3fa6ef8297302fc7b3844980b6c5dbc71cdbd4b61e9b2591234214d5ab39/overrides-7.4.0-py3-none-any.whl#sha256=3ad24583f86d6d7a49049695efe9933e67ba62f0c7625d53c59fa832ce4b8b7d # pip pandocfilters @ https://files.pythonhosted.org/packages/5e/a8/878258cffd53202a6cc1903c226cf09e58ae3df6b09f8ddfa98033286637/pandocfilters-1.5.0-py2.py3-none-any.whl#sha256=33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f # pip pkginfo @ https://files.pythonhosted.org/packages/b3/f2/6e95c86a23a30fa205ea6303a524b20cbae27fbee69216377e3d95266406/pkginfo-1.9.6-py3-none-any.whl#sha256=4b7a555a6d5a22169fcc9cf7bfd78d296b0361adad412a346c1226849af5e546 -# pip prometheus-client @ https://files.pythonhosted.org/packages/ad/b3/6e18c89bf6bd120590ea538a62cae16dc763ff2745b18377c4be5495c4aa/prometheus_client-0.17.1-py3-none-any.whl#sha256=e537f37160f6807b8202a6fc4764cdd19bac5480ddd3e0d463c3002b34462101 +# pip prometheus-client @ https://files.pythonhosted.org/packages/bb/9f/ad934418c48d01269fc2af02229ff64bcf793fd5d7f8f82dc5e7ea7ef149/prometheus_client-0.19.0-py3-none-any.whl#sha256=c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92 # pip ptyprocess @ https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35 # pip pycparser @ https://files.pythonhosted.org/packages/62/d5/5f610ebe421e85889f2e55e33b7f9a6795bd982198517d912eb1c76e1a53/pycparser-2.21-py2.py3-none-any.whl#sha256=8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9 # pip python-json-logger @ https://files.pythonhosted.org/packages/35/a6/145655273568ee78a581e734cf35beb9e33a370b29c5d3c8fee3744de29f/python_json_logger-2.0.7-py3-none-any.whl#sha256=f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd # pip pyyaml @ https://files.pythonhosted.org/packages/7d/39/472f2554a0f1e825bd7c5afc11c817cd7a2f3657460f7159f691fbb37c51/PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c # pip rfc3986-validator @ https://files.pythonhosted.org/packages/9e/51/17023c0f8f1869d8806b979a2bffa3f861f26a3f1a66b094288323fba52f/rfc3986_validator-0.1.1-py2.py3-none-any.whl#sha256=2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9 -# pip rpds-py @ https://files.pythonhosted.org/packages/e8/3c/01949c5c1c4ae6d2811a7eadbdfbb205cb20e68bf9d683ad87e7e3f1522f/rpds_py-0.10.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=b2039f8d545f20c4e52713eea51a275e62153ee96c8035a32b2abb772b6fc9e5 +# pip rpds-py @ https://files.pythonhosted.org/packages/61/a8/e2337cd6df296529290bfae41e9f72c278b9b3d156628b43d47e62fb64a3/rpds_py-0.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=80080972e1d000ad0341c7cc58b6855c80bd887675f92871221451d13a975072 # pip send2trash @ https://files.pythonhosted.org/packages/a9/78/e4df1e080ed790acf3a704edf521006dd96b9841bd2e2a462c0d255e0565/Send2Trash-1.8.2-py3-none-any.whl#sha256=a384719d99c07ce1eefd6905d2decb6f8b7ed054025bb0e618919f945de4f679 # pip sniffio @ https://files.pythonhosted.org/packages/c3/a0/5dba8ed157b0136607c7f2151db695885606968d1fae123dc3391e0cfdbf/sniffio-1.3.0-py3-none-any.whl#sha256=eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384 # pip soupsieve @ https://files.pythonhosted.org/packages/4c/f3/038b302fdfbe3be7da016777069f26ceefe11a681055ea1f7817546508e3/soupsieve-2.5-py3-none-any.whl#sha256=eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7 -# pip traitlets @ https://files.pythonhosted.org/packages/85/e9/d82415708306eb348fb16988c4697076119dfbfa266f17f74e514a23a723/traitlets-5.11.2-py3-none-any.whl#sha256=98277f247f18b2c5cabaf4af369187754f4fb0e85911d473f72329db8a7f4fae +# pip traitlets @ https://files.pythonhosted.org/packages/ed/fd/cfc0d27ca11f3dd12b2a90d06875d8bfb532ef40ce67be4066d10807f4aa/traitlets-5.13.0-py3-none-any.whl#sha256=baf991e61542da48fe8aef8b779a9ea0aa38d8a54166ee250d5af5ecf4486619 # pip types-python-dateutil @ https://files.pythonhosted.org/packages/1c/af/5af2e2a02bc464c1c7818c260606343020b96c0d5b64f637d9e91aee24fe/types_python_dateutil-2.8.19.14-py3-none-any.whl#sha256=f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9 # pip uri-template @ https://files.pythonhosted.org/packages/e7/00/3fca040d7cf8a32776d3d81a00c8ee7457e00f80c649f1e4a863c8321ae9/uri_template-1.3.0-py3-none-any.whl#sha256=a44a133ea12d44a0c0f06d7d42a52d71282e77e2f937d8abd5655b8d56fc1363 # pip webcolors @ https://files.pythonhosted.org/packages/d5/e1/3e9013159b4cbb71df9bd7611cbf90dc2c621c8aeeb677fc41dad72f2261/webcolors-1.13-py3-none-any.whl#sha256=29bc7e8752c0a1bd4a1f03c14d6e6a72e93d82193738fa860cbff59d0fcc11bf @@ -286,25 +288,25 @@ https://conda.anaconda.org/conda-forge/noarch/sphinxext-opengraph-0.8.2-pyhd8ed1 # pip bleach @ https://files.pythonhosted.org/packages/ea/63/da7237f805089ecc28a3f36bca6a21c31fcbc2eb380f3b8f1be3312abd14/bleach-6.1.0-py3-none-any.whl#sha256=3225f354cfc436b9789c66c4ee030194bee0568fbf9cbdad3bc8b5c26c5f12b6 # pip cffi @ https://files.pythonhosted.org/packages/ea/ac/e9e77bc385729035143e54cc8c4785bd480eaca9df17565963556b0b7a93/cffi-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098 # pip doit @ https://files.pythonhosted.org/packages/44/83/a2960d2c975836daa629a73995134fd86520c101412578c57da3d2aa71ee/doit-0.36.0-py3-none-any.whl#sha256=ebc285f6666871b5300091c26eafdff3de968a6bd60ea35dd1e3fc6f2e32479a -# pip jupyter-core @ https://files.pythonhosted.org/packages/ac/92/bec527b68e2b56d0b1a30db19ce8370cba69fb68d34c981f4549564ca551/jupyter_core-5.4.0-py3-none-any.whl#sha256=66e252f675ac04dcf2feb6ed4afb3cd7f68cf92f483607522dc251f32d471571 -# pip referencing @ https://files.pythonhosted.org/packages/be/8e/56d6f1e2d591f4d6cbcba446cac4a1b0dc4f584537e2071d9bcee8eeab6b/referencing-0.30.2-py3-none-any.whl#sha256=449b6669b6121a9e96a7f9e410b245d471e8d48964c67113ce9afe50c8dd7bdf +# pip jupyter-core @ https://files.pythonhosted.org/packages/ab/ea/af6508f71d2bcbf4db538940120cc3d3f10287f62105e756bd315aa345b5/jupyter_core-5.5.0-py3-none-any.whl#sha256=e11e02cd8ae0a9de5c6c44abf5727df9f2581055afe00b22183f621ba3585805 +# pip referencing @ https://files.pythonhosted.org/packages/29/c1/69342fbc8efd1aac5cda853cea771763b95d92325c4f8f83b499c07bc698/referencing-0.31.0-py3-none-any.whl#sha256=381b11e53dd93babb55696c71cf42aef2d36b8a150c49bf0bc301e36d536c882 # pip rfc3339-validator @ https://files.pythonhosted.org/packages/7b/44/4e421b96b67b2daff264473f7465db72fbdf36a07e05494f50300cc7b0c6/rfc3339_validator-0.1.4-py2.py3-none-any.whl#sha256=24f6ec1eda14ef823da9e36ec7113124b39c04d50a4d3d3a3c2859577e7791fa -# pip terminado @ https://files.pythonhosted.org/packages/84/a7/c7628d79651b8c8c775d27b374315a825141b5783512e82026fb210dd639/terminado-0.17.1-py3-none-any.whl#sha256=8650d44334eba354dd591129ca3124a6ba42c3d5b70df5051b6921d506fdaeae +# pip terminado @ https://files.pythonhosted.org/packages/69/df/deebc9fb14a49062a3330f673e80b100e665b54d998163b3f62620b6240c/terminado-0.18.0-py3-none-any.whl#sha256=87b0d96642d0fe5f5abd7783857b9cab167f221a39ff98e3b9619a788a3c0f2e # pip tinycss2 @ https://files.pythonhosted.org/packages/da/99/fd23634d6962c2791fb8cb6ccae1f05dcbfc39bce36bba8b1c9a8d92eae8/tinycss2-1.2.1-py3-none-any.whl#sha256=2b80a96d41e7c3914b8cda8bc7f705a4d9c49275616e886103dd839dfc847847 # pip argon2-cffi-bindings @ https://files.pythonhosted.org/packages/ec/f7/378254e6dd7ae6f31fe40c8649eea7d4832a42243acaf0f1fff9083b2bed/argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=b746dba803a79238e925d9046a63aa26bf86ab2a2fe74ce6b009a1c3f5c8f2ae # pip isoduration @ https://files.pythonhosted.org/packages/7b/55/e5326141505c5d5e34c5e0935d2908a74e4561eca44108fbfb9c13d2911a/isoduration-20.11.0-py3-none-any.whl#sha256=b2904c2a4228c3d44f409c8ae8e2370eb21a26f7ac2ec5446df141dde3452042 -# pip jsonschema-specifications @ https://files.pythonhosted.org/packages/1c/24/83349ac2189cc2435e84da3f69ba3c97314d3c0622628e55171c6798ed80/jsonschema_specifications-2023.7.1-py3-none-any.whl#sha256=05adf340b659828a004220a9613be00fa3f223f2b82002e273dee62fd50524b1 +# pip jsonschema-specifications @ https://files.pythonhosted.org/packages/20/a9/384ec45013ab883d7c2bf120f2988682986fdead973decf0bae28a4523e7/jsonschema_specifications-2023.11.1-py3-none-any.whl#sha256=f596778ab612b3fd29f72ea0d990393d0540a5aab18bf0407a46632eab540779 # pip jupyter-server-terminals @ https://files.pythonhosted.org/packages/ea/7f/36db12bdb90f5237766dcbf59892198daab7260acbcf03fc75e2a2a82672/jupyter_server_terminals-0.4.4-py3-none-any.whl#sha256=75779164661cec02a8758a5311e18bb8eb70c4e86c6b699403100f1585a12a36 -# pip jupyterlite-core @ https://files.pythonhosted.org/packages/2e/0c/3715d2e2b125c770241a1189fe2836899f99fbe5d8b7cef49144a414b6ed/jupyterlite_core-0.1.3-py3-none-any.whl#sha256=f900a2fc359ac0660214144efc7420ddd737f36aac36601928e8e1ecb5013bb9 +# pip jupyterlite-core @ https://files.pythonhosted.org/packages/b4/5d/9708684e65d244493ff4c970ea882b508da8d46e59a4cc99076991c16732/jupyterlite_core-0.2.0-py3-none-any.whl#sha256=255e8272941d0e950d05cfcfc28bde244c0404d2d5990da1b8b3485c44fe1718 # pip pyzmq @ https://files.pythonhosted.org/packages/a2/e0/08605421a2ede5d87adbde9685599fa7e6af1df700c657759a1892ced942/pyzmq-25.1.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl#sha256=d457aed310f2670f59cc5b57dcfced452aeeed77f9da2b9763616bd57e4dbaae # pip argon2-cffi @ https://files.pythonhosted.org/packages/a4/6a/e8a041599e78b6b3752da48000b14c8d1e8a04ded09c88c714ba047f34f5/argon2_cffi-23.1.0-py3-none-any.whl#sha256=c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea -# pip jsonschema @ https://files.pythonhosted.org/packages/0f/bf/a84bc75f069f4f156e1c0d9892fb7325945106c6ecaad9f29d24360872af/jsonschema-4.19.1-py3-none-any.whl#sha256=cd5f1f9ed9444e554b38ba003af06c0a8c2868131e56bfbef0550fb450c0330e -# pip jupyter-client @ https://files.pythonhosted.org/packages/dc/05/e91a1a935a25ca1b46c78260def39125b2cfca96c2adbc285d365af23e3f/jupyter_client-8.4.0-py3-none-any.whl#sha256=6a2a950ec23a8f62f9e4c66acec7f0ea6c7d1f80ba0992e747b10c56ce2e6dbe -# pip jupyterlite-pyodide-kernel @ https://files.pythonhosted.org/packages/22/5d/48eff296cd38723bb42cfe51a2bb3fe3c7ed3d9b2eb20b1e9ca0c4b521d2/jupyterlite_pyodide_kernel-0.1.3-py3-none-any.whl#sha256=302f9052f61097118fdbbd252b447fe3394652db007b8f99965a2bae9561d0d4 -# pip jupyter-events @ https://files.pythonhosted.org/packages/47/47/cd46c2d3e409bed27338aec1610dfa13da67f64c671f739b7eff0954c14d/jupyter_events-0.8.0-py3-none-any.whl#sha256=81f07375c7673ff298bfb9302b4a981864ec64edaed75ca0fe6f850b9b045525 +# pip jsonschema @ https://files.pythonhosted.org/packages/0f/ed/0058234d8dd2b1fc6beeea8eab945191a05e9d391a63202f49fe23327586/jsonschema-4.20.0-py3-none-any.whl#sha256=ed6231f0429ecf966f5bc8dfef245998220549cbbcf140f913b7464c52c3b6b3 +# pip jupyter-client @ https://files.pythonhosted.org/packages/43/ae/5f4f72980765e2e5e02b260f9c53bcc706cefa7ac9c8d7240225c55788d4/jupyter_client-8.6.0-py3-none-any.whl#sha256=909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99 +# pip jupyterlite-pyodide-kernel @ https://files.pythonhosted.org/packages/1c/a8/d4c30081747f4c5d3d75c1e77251ef64f2c3b927023f2796168a83aa65e2/jupyterlite_pyodide_kernel-0.2.0-py3-none-any.whl#sha256=17d713f0eeb3f778c4d51129834096d364c16f05ac06e10292383c43c0eb5bd9 +# pip jupyter-events @ https://files.pythonhosted.org/packages/e3/55/0c1aa72f4317e826a471dc4adc3036acd11d496ded68c4bbac2a88551519/jupyter_events-0.9.0-py3-none-any.whl#sha256=d853b3c10273ff9bc8bb8b30076d65e2c9685579db736873de6c2232dde148bf # pip nbformat @ https://files.pythonhosted.org/packages/f4/e7/ef30a90b70eba39e675689b9eaaa92530a71d7435ab8f9cae520814e0caf/nbformat-5.9.2-py3-none-any.whl#sha256=1c5172d786a41b82bcfd0c23f9e6b6f072e8fb49c39250219e4acfff1efe89e9 -# pip nbclient @ https://files.pythonhosted.org/packages/ac/5a/d670ca51e6c3d98574b9647599821590efcd811d71f58e9c89fc59a17685/nbclient-0.8.0-py3-none-any.whl#sha256=25e861299e5303a0477568557c4045eccc7a34c17fc08e7959558707b9ebe548 -# pip nbconvert @ https://files.pythonhosted.org/packages/5b/08/6af17268360385f767c7a53dd2b71b9718c61911464fb34f5453c80cfe48/nbconvert-7.9.2-py3-none-any.whl#sha256=39fe4b8bdd1b0104fdd86fc8a43a9077ba64c720bda4c6132690d917a0a154ee -# pip jupyter-server @ https://files.pythonhosted.org/packages/ab/88/d47e331c0041a4ba10b38f1669662e3456c7c2801e3b3250c27159cf164a/jupyter_server-2.8.0-py3-none-any.whl#sha256=c57270faa6530393ae69783a2d2f1874c718b9f109080581ea076b05713249fa -# pip jupyterlab-server @ https://files.pythonhosted.org/packages/96/cd/cdabe44549d60e0967904f0bdd9e3756b521112317612a3997eb2fda9181/jupyterlab_server-2.25.0-py3-none-any.whl#sha256=c9f67a98b295c5dee87f41551b0558374e45d449f3edca153dd722140630dcb2 -# pip jupyterlite-sphinx @ https://files.pythonhosted.org/packages/38/c9/5f1142c005cf8d75830b10029e53f074324bc85cfca1f1d0f22a207b771c/jupyterlite_sphinx-0.9.3-py3-none-any.whl#sha256=be6332d16490ea2fa90b78187a2c5e1c357195966a25741d60b1790346571041 +# pip nbclient @ https://files.pythonhosted.org/packages/6b/3a/607149974149f847125c38a62b9ea2b8267eb74823bbf8d8c54ae0212a00/nbclient-0.9.0-py3-none-any.whl#sha256=a3a1ddfb34d4a9d17fc744d655962714a866639acd30130e9be84191cd97cd15 +# pip nbconvert @ https://files.pythonhosted.org/packages/84/61/460af4b68b3c681d1f82d48646cf2acb8f6d29edf9a8366dc37ae69e902a/nbconvert-7.11.0-py3-none-any.whl#sha256=d1d417b7f34a4e38887f8da5bdfd12372adf3b80f995d57556cb0972c68909fe +# pip jupyter-server @ https://files.pythonhosted.org/packages/6e/79/178c7a551d50734a779b1fb7688089e46c6141d8b108b2c9cbb028c27437/jupyter_server-2.11.0-py3-none-any.whl#sha256=c9bd6e6d71dc5a2a25df167dc323422997f14682b008bfecb5d7920a55020ea7 +# pip jupyterlab-server @ https://files.pythonhosted.org/packages/a2/97/abbbe35fc67b6f9423309988f2e411f7cb117b08321866d3d8b720f4c0d4/jupyterlab_server-2.25.2-py3-none-any.whl#sha256=5b1798c9cc6a44f65c757de9f97fc06fc3d42535afbf47d2ace5e964ab447aaf +# pip jupyterlite-sphinx @ https://files.pythonhosted.org/packages/fa/f9/ad6d7164eca7ab9d523fc9b8c8a4a5508b424ee051f44a01797be224aeaa/jupyterlite_sphinx-0.10.0-py3-none-any.whl#sha256=72f332bf2748902802b719fbce598234e27facfcdc9aec020bf8cf025b12ba62 diff --git a/build_tools/circle/doc_min_dependencies_environment.yml b/build_tools/circle/doc_min_dependencies_environment.yml index b9e2c19bf3737..dbeb1748bd0c7 100644 --- a/build_tools/circle/doc_min_dependencies_environment.yml +++ b/build_tools/circle/doc_min_dependencies_environment.yml @@ -13,6 +13,7 @@ dependencies: - threadpoolctl - matplotlib=3.3.4 # min - pandas=1.0.5 # min + - rich - pyamg - pytest - pytest-xdist=2.5.0 diff --git a/build_tools/circle/doc_min_dependencies_linux-64_conda.lock b/build_tools/circle/doc_min_dependencies_linux-64_conda.lock index 5063621f08491..59e18f9c4c6d6 100644 --- a/build_tools/circle/doc_min_dependencies_linux-64_conda.lock +++ b/build_tools/circle/doc_min_dependencies_linux-64_conda.lock @@ -1,34 +1,32 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 89bff8490bf2ceb18906f6ddd4607d9246575b235172ddd0f03716ae6cb21a3d +# input_hash: f370bb493c9b53c07e2c969481f9f4d777016bf5eb3e1b86785da2ac8d930ef7 @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 -https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.7.22-hbcca054_0.conda#a73ecd2988327ad4c8f2c331482917f2 +https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.11.17-hbcca054_0.conda#01ffc8d36f9eba0ce0b3c1955fa780ee https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2#0c96522c6bdaed4b1566d11387caaf45 https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2#34893075a5c9e55cdafac56607368fc6 https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2#4d59c254e01d9cde7957100457e2d5fb https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-hab24e00_0.tar.bz2#19410c3df09dfb12d1206132a1d357c5 https://conda.anaconda.org/conda-forge/noarch/kernel-headers_linux-64-2.6.32-he073ed8_16.conda#7ca122655873935e02c91279c5b03c8c https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.36.1-hea4e1c9_2.tar.bz2#bd4f2e711b39af170e7ff15163fe87ee -https://conda.anaconda.org/conda-forge/linux-64/libgcc-devel_linux-64-7.5.0-hda03d7c_20.tar.bz2#2146b25eb2a762a44fab709338a7b6d9 https://conda.anaconda.org/conda-forge/linux-64/libgfortran4-7.5.0-h14aa051_20.tar.bz2#a072eab836c3a9578ce72b5640ce592d -https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-devel_linux-64-7.5.0-hb016644_20.tar.bz2#31d5500f621954679ee41d7f5d1089fb -https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_2.conda#9172c297304f2a20134fc56c97fbe229 +https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_3.conda#937eaed008f6bf2191c5fe76f87755e9 https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.8-4_cp38.conda#ea6b353536f42246cd130c7fef1285cf https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2#f766549260d6815b0c52253f1fb1bb29 https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-7.5.0-h14aa051_20.tar.bz2#c3b2ad091c043c08689e64b10741484b -https://conda.anaconda.org/conda-forge/linux-64/libgomp-13.2.0-h807b86a_2.conda#e2042154faafe61969556f28bade94b9 +https://conda.anaconda.org/conda-forge/linux-64/libgomp-13.2.0-h807b86a_3.conda#7124cbb46b13d395bdde68f2d215c989 https://conda.anaconda.org/conda-forge/noarch/sysroot_linux-64-2.12-he073ed8_16.conda#071ea8dceff4d30ac511f4a2f8437cd1 https://conda.anaconda.org/conda-forge/linux-64/binutils_impl_linux-64-2.36.1-h193b22a_2.tar.bz2#32aae4265554a47ea77f7c09f86aeb3b https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2#fee5683a3f04bd15cbd8318b096a27ab https://conda.anaconda.org/conda-forge/linux-64/binutils-2.36.1-hdd6e379_2.tar.bz2#3111f86041b5b6863545ca49130cca95 https://conda.anaconda.org/conda-forge/linux-64/binutils_linux-64-2.36-hf3e587d_33.tar.bz2#72b245322c589284f1b92a5c971e5cb6 https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2#562b26ba2e19059551a811e72ab7f793 -https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_2.conda#c28003b0be0494f9a7664389146716ff +https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_3.conda#23fdf1fef05baeb7eadc2aed5fb0011f https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.10-hd590300_0.conda#75dae9a4201732aa78a530b826ee5fe0 https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2#d9c69a24ad678ffce24c6543a0176b00 -https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h7f98852_4.tar.bz2#a1fd65c7ccbf10880423d82bca54eb54 -https://conda.anaconda.org/conda-forge/linux-64/gcc_impl_linux-64-7.5.0-habd7529_20.tar.bz2#42140612518a7ce78f571d64b6a50ba3 +https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda#69b8b6202a07720f448be700e300ccf4 +https://conda.anaconda.org/conda-forge/linux-64/gcc_impl_linux-64-7.5.0-hda68d29_13.tar.bz2#fe83fa08b9fe7ceccfd0068571b92827 https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2#14947d8770185e5153fdd04d4673ed37 https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2#8c54672728e8ec6aa6db90cf2806d220 https://conda.anaconda.org/conda-forge/linux-64/icu-73.2-h59595ed_0.conda#cc47e1facc155f91abd89b11e48e72ff @@ -48,9 +46,9 @@ https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.co https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda#f36c115f1ee199da648e0597ec2047ad https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda#318b08df404f9c9be5712aaa5a6f0bb0 https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.3-h59595ed_0.conda#bdadff838d5437aea83607ced8b37f75 -https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-hcb278e6_0.conda#681105bccc2a3f7f1a837d47d39c9179 +https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda#7dbaa197d7ba6032caf7ae7f32c1efa0 https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda#da0ec11a6454ae19bff5b02ed881a2b1 -https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.3-hd590300_0.conda#7bb88ce04c8deb9f7d763ae04a1da72f +https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.4-hd590300_0.conda#412ba6938c3e2abaca8b1129ea82e238 https://conda.anaconda.org/conda-forge/linux-64/pixman-0.42.2-h59595ed_0.conda#700edd63ccd5fc66b70b1c028cea9a68 https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2#22dad4df6e8630e8dff2428f6f6a7036 https://conda.anaconda.org/conda-forge/linux-64/xorg-kbproto-1.0.7-h7f98852_1002.tar.bz2#4b230e8381279d76131116660f5a241a @@ -66,21 +64,21 @@ https://conda.anaconda.org/conda-forge/linux-64/yaml-0.2.5-h7f98852_2.tar.bz2#4c https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda#8b9b5aca60558d02ddaa09d599e55920 https://conda.anaconda.org/conda-forge/linux-64/gcc_linux-64-7.5.0-h47867f9_33.tar.bz2#3a31c3f430a31184a5d07e67d3b24e2c https://conda.anaconda.org/conda-forge/linux-64/gfortran_impl_linux-64-7.5.0-h56cb351_20.tar.bz2#8f897b30195bd3a2251b4c51c3cc91cf -https://conda.anaconda.org/conda-forge/linux-64/gxx_impl_linux-64-7.5.0-hd0bb8aa_20.tar.bz2#dbe78fc5fb9c339f8e55426559e12f7b +https://conda.anaconda.org/conda-forge/linux-64/gxx_impl_linux-64-7.5.0-h64c220c_13.tar.bz2#4000478f05b985c952ad8e100d1c13fd https://conda.anaconda.org/conda-forge/linux-64/libcap-2.69-h0f662aa_0.conda#25cb5999faa414e5ccb2c1388f62d3d5 https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20191231-he28a2e2_2.tar.bz2#4d331e44109e3f0e19b4cb8f9b82f3e1 https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda#a1cfcc585f0c42bf8d5546bb1dfb668d https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.3-h59595ed_0.conda#ee48bf17cc83a00f59ca1494d5646869 https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.47-h71f35ed_0.conda#c2097d0b46367996f09b4e8e4920384a https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.39-h753d276_0.conda#e1c890aebdebbfbf87e2c917187b4416 -https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.43.2-h2797004_0.conda#4b441a1ee22397d5a27dc1126b849edd +https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.0-h2797004_0.conda#b58e6816d137f3aabf77d341dd5d732b https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2#309dec04b70a3cc0f1e84a4013683bc0 https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda#33277193f5b92bad9fdd230eb700929c -https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.5-h232c23b_1.conda#f3858448893839820d4bcfb14ad3ecdf -https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_5.conda#1e8ef4090ca4f0d66404a7441e1dbf3c -https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.40-hc3806b6_0.tar.bz2#69e2c796349cd9b273890bee0febfe1b +https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.11.6-h232c23b_0.conda#427a3e59d66cb5d145020bd9c6493334 +https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_6.conda#80bf3b277c120dd294b51d404b931a75 +https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.42-hcad00b1_0.conda#679c8961826aa4b50653bce17ee52abe https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4 -https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-h2797004_0.conda#513336054f884f95d9fd925748f41ef3 +https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda#d453b98d9c83e71da0741bb0ff4d76bc https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda#93ee23f12bc2e684548181256edd2cf6 https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda#68c34ec6149623be41a1933ab996a209 https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda#04b88013080254850d6c01ed54810589 @@ -89,13 +87,13 @@ https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda https://conda.anaconda.org/conda-forge/linux-64/gfortran_linux-64-7.5.0-h78c8a43_33.tar.bz2#b2879010fb369f4012040f7a27657cd8 https://conda.anaconda.org/conda-forge/linux-64/gxx_linux-64-7.5.0-h555fc39_33.tar.bz2#5cf979793d2c5130a012cb6480867adc https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 -https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.1-h166bdaf_0.tar.bz2#f967fc95089cd247ceed56eda31de3a9 -https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.0-hebfc3b9_0.conda#e618003da3547216310088478e475945 +https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.2-hd590300_0.conda#3d7d5e5cebf8af5aadb040732860f1b6 +https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.1-h783c2da_1.conda#70052d6c1e84643e30ffefb21ab6950f https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-h5cf9203_3.conda#9efe82d44b76a7529a1d702e5a37752e https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda#ef1910918dd895516a769ed36b5b3a4e https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda#55ed21669b2015f77c180feb1dd41930 -https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.3-h4dfa4b3_0.conda#1a82298c57b609a31ab6f2342a307b69 -https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_5.conda#b72f016c910ff9295b1377d3e17da3f2 +https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.5-h4dfa4b3_0.conda#799291c22ec87a0c86c0a4fc0e22b1c5 +https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_6.conda#e87530d1b12dd7f4e0f856dc07358d60 https://conda.anaconda.org/conda-forge/linux-64/nss-3.94-h1d7d5a4_0.conda#7caef74bbfa730e014b20f0852068509 https://conda.anaconda.org/conda-forge/linux-64/python-3.8.18-hd12c33a_0_cpython.conda#334cb629e10d209f1c17630f653168b1 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-hd590300_1.conda#9bfac7ccd94d54fd21a0501296d60424 @@ -105,8 +103,8 @@ https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.1-h8ee46fc_1.con https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.conda#49e482d882669206653b095f5206c05b https://conda.anaconda.org/conda-forge/noarch/alabaster-0.7.13-pyhd8ed1ab_0.conda#06006184e203b61d3525f90de394471e https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.1.0-py38h17151c0_1.conda#7a5a699c8992fc51ef25e980f4502c2a -https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda#7f3dbc9179b4dde7da98dfb151d0ad22 -https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.0-pyhd8ed1ab_0.conda#fef8ef5f0a54546b9efee39468229917 +https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789 +https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.2-pyhd8ed1ab_0.conda#7f4a9e3fcff3f6356ae99244a014da6a https://conda.anaconda.org/conda-forge/noarch/click-8.1.7-unix_pyh707e725_0.conda#f3ad426304898027fc619827ff428eca https://conda.anaconda.org/conda-forge/noarch/cloudpickle-3.0.0-pyhd8ed1ab_0.conda#753d29fe41bb881e4b9c004f0abf973f https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 @@ -115,12 +113,12 @@ https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5 https://conda.anaconda.org/conda-forge/linux-64/cython-0.29.33-py38h8dc9893_0.conda#5d50cd654981f0ccc7c878ac297afaa7 https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2#ecfff944ba3960ecb334b9a2663d708d https://conda.anaconda.org/conda-forge/linux-64/docutils-0.19-py38h578d9bd_1.tar.bz2#3746b24949251f1a00ae0d616d4cdc1b -https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.1.3-pyhd8ed1ab_0.conda#e6518222753f519e911e83136d2158d9 +https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363 https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda#0f69b688f52ff6da70bccb7ff7001d1d https://conda.anaconda.org/conda-forge/linux-64/fortran-compiler-1.1.1-he991be0_0.tar.bz2#e38ac82cc517b9e245c1ae99f9f140da -https://conda.anaconda.org/conda-forge/noarch/fsspec-2023.9.2-pyh1a96a4e_0.conda#9d15cd3a0e944594ab528da37dc72ecc -https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.0-hfc55251_0.conda#e10134de3558dd95abda6987b5548f4f +https://conda.anaconda.org/conda-forge/noarch/fsspec-2023.10.0-pyhca7485f_0.conda#5b86cf1ceaaa9be2ec4627377e538db1 +https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.1-hfc55251_1.conda#a50918d10114a0bf80fb46c7cc692058 https://conda.anaconda.org/conda-forge/noarch/idna-3.4-pyhd8ed1ab_0.tar.bz2#34272b248891bddccc64479f9a7fffed https://conda.anaconda.org/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2#7de5386c8fea29e76b303f37dde4c352 https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 @@ -128,19 +126,20 @@ https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py38h7f3f72f_1. https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.15-hb7c19ff_3.conda#e96637dd92c5f340215c753a5c9a22d7 https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_h9986a30_3.conda#1720df000b48e31842500323cb7be18c https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h4637d8d_4.conda#d4529f4dff3057982a7617c7ac58fde3 -https://conda.anaconda.org/conda-forge/linux-64/libpq-16.0-hfc447b1_1.conda#e4a9a5ba40123477db33e02a78dffb01 +https://conda.anaconda.org/conda-forge/linux-64/libpq-16.1-hfc447b1_0.conda#2b7f1893cf40b4ccdc0230bcd94d5ed9 https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-254-h3516f8a_0.conda#df4b1cd0c91b4234fb02b5701a4cdddc https://conda.anaconda.org/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2#91e27ef3d05cc772ce627e51cff111c4 https://conda.anaconda.org/conda-forge/linux-64/markupsafe-2.1.3-py38h01eb140_1.conda#2dabf287937cd631e292096cc6d0867e +https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.0-pyhd8ed1ab_0.tar.bz2#f8dab71fdc13b1bf29a01248b156d268 https://conda.anaconda.org/conda-forge/linux-64/mkl-2020.4-h726a3e6_304.tar.bz2#b9b35a50e5377b19da6ec0709ae77fc3 -https://conda.anaconda.org/conda-forge/noarch/networkx-3.2-pyhd8ed1ab_0.conda#cec8cc498664cc00a070676aa89e69a7 +https://conda.anaconda.org/conda-forge/noarch/networkx-3.1-pyhd8ed1ab_0.conda#254f787d5068bc89f578bf63893ce8b4 https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.0-h488ebb8_3.conda#128c25b7fe6a25286a48f3a6a9b5b6f3 https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.3.0-pyhd8ed1ab_0.conda#2390bd10bed1f3fdc7a537fb5a447d8d https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2#7205635cd71531943440fbfe3b6b5727 https://conda.anaconda.org/conda-forge/linux-64/psutil-5.9.5-py38h01eb140_1.conda#89cb08bb523adf12fed3829558638d84 https://conda.anaconda.org/conda-forge/noarch/py-1.11.0-pyh6c4a22f_0.tar.bz2#b4613d7e7a493916d867842a6a148054 -https://conda.anaconda.org/conda-forge/noarch/pygments-2.16.1-pyhd8ed1ab_0.conda#40e5cb18165466773619e5c963f00a7b +https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2#2a7de29fb590ca14b5243c4c812c8025 https://conda.anaconda.org/conda-forge/noarch/pytz-2023.3.post1-pyhd8ed1ab_0.conda#c93346b446cd08c169d843ae5fc0da97 @@ -161,48 +160,49 @@ https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5 https://conda.anaconda.org/conda-forge/noarch/toolz-0.12.0-pyhd8ed1ab_0.tar.bz2#92facfec94bc02d6ccf42e7173831a36 https://conda.anaconda.org/conda-forge/linux-64/tornado-6.3.3-py38h01eb140_1.conda#660cfc2fc5bd9e3b458ad394976652cf https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.8.0-pyha770c72_0.conda#5b1be40a26d10a06f6d4f1f9e19fa0c7 -https://conda.anaconda.org/conda-forge/noarch/wheel-0.41.2-pyhd8ed1ab_0.conda#1ccd092478b3e0ee10d7a891adbf8a4f +https://conda.anaconda.org/conda-forge/noarch/wheel-0.41.3-pyhd8ed1ab_0.conda#3fc026b9c87d091c4b34a6c997324ae8 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-image-0.4.0-h8ee46fc_1.conda#9d7bcddf49cbf727730af10e71022c73 https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.40-hd590300_0.conda#07c15d846a2e4d673da22cbd85fdb6d2 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.4-h0b41bf4_2.conda#82b6df12252e6f32402b96dacc656fec https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.11-hd590300_0.conda#ed67c36f215b310412b2af935bf3e530 https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a -https://conda.anaconda.org/conda-forge/noarch/babel-2.13.0-pyhd8ed1ab_0.conda#22541af7a9eb59fc6afcadb7ecdf9219 +https://conda.anaconda.org/conda-forge/noarch/babel-2.13.1-pyhd8ed1ab_0.conda#3ccff479c246692468f604df9c85ef26 https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda#f907bb958910dc404647326ca80c263e https://conda.anaconda.org/conda-forge/linux-64/compilers-1.1.1-0.tar.bz2#1ba267e19dbaf3db9dd0404e6fb9cdb9 https://conda.anaconda.org/conda-forge/linux-64/cytoolz-0.12.2-py38h01eb140_1.conda#56222b99bdd044e52c364c4fbee28a7a -https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.0-hfc55251_0.conda#2f55a36b549f51a7e0c2b1e3c3f0ccd4 +https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.1-hfc55251_1.conda#8d7242302bb3d03b9a690b6dda872603 https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-6.8.0-pyha770c72_0.conda#4e9f59a060c3be52bc4ddc46ee9b6946 https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.2-pyhd8ed1ab_1.tar.bz2#c8490ed5c70966d232fdd389d0dbed37 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc https://conda.anaconda.org/conda-forge/linux-64/libblas-3.8.0-20_mkl.tar.bz2#8fbce60932c01d0e193a1a814f2002be https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_h7634d5b_3.conda#0922208521c0463e690bbaebba7eb551 https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.6.0-h5d7e998_0.conda#d8edd0e29db6fb6b6988e1a28d35d994 +https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda#93a8e71256479c62074356ef6ebf501b https://conda.anaconda.org/conda-forge/noarch/memory_profiler-0.61.0-pyhd8ed1ab_0.tar.bz2#8b45f9f2b2f7a98b0ec179c8991a4a9b https://conda.anaconda.org/conda-forge/noarch/partd-1.4.1-pyhd8ed1ab_0.conda#acf4b7c0bcd5fa3b0e05801c4d2accd6 https://conda.anaconda.org/conda-forge/linux-64/pillow-10.1.0-py38ha43c96d_0.conda#67ca17c651f86159a3b8ed1132d97c12 -https://conda.anaconda.org/conda-forge/noarch/pip-23.3-pyhd8ed1ab_0.conda#a06f102f59c8e3bb8b3e46e71c384709 +https://conda.anaconda.org/conda-forge/noarch/pip-23.3.1-pyhd8ed1ab_0.conda#2400c0b86889f43aa52067161e1fb108 +https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.0.0-pyhd8ed1ab_0.conda#6bb4ee32cd435deaeac72776c001e7ac https://conda.anaconda.org/conda-forge/noarch/plotly-5.14.0-pyhd8ed1ab_0.conda#6a7bcc42ef58dd6cf3da9333ea102433 https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-hb77b528_5.conda#ac902ff3c1c6d750dd0dfc93a974ab74 -https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.2-pyhd8ed1ab_0.conda#6dd662ff5ac9a783e5c940ce9f3fe649 +https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.3-pyhd8ed1ab_0.conda#5bdca0aca30b0ee62bb84854e027eae0 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.12-py38h17151c0_0.conda#ae2edf79b63f97071aea203b22a6774a -https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.8.0-hd8ed1ab_0.conda#384462e63262a527bda564fa2d9126c0 -https://conda.anaconda.org/conda-forge/noarch/urllib3-2.0.7-pyhd8ed1ab_0.conda#270e71c14d37074b1d066ee21cf0c4a6 -https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.6-h98fc4e7_2.conda#1c95f7c612f9121353c4ef764678113e -https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.2.1-h3d44ed6_0.conda#98db5f8813f45e2b29766aff0e4a499c +https://conda.anaconda.org/conda-forge/noarch/urllib3-2.1.0-pyhd8ed1ab_0.conda#f8ced8ee63830dec7ecc1be048d1470a +https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.7-h98fc4e7_0.conda#6c919bafe5e03428a8e2ef319d7ef990 +https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.3.0-h3d44ed6_0.conda#5a6f6c00ef982a9bc83558d9ac8f64a0 https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-6.8.0-hd8ed1ab_0.conda#b279b07ce18058034e5b3606ba103a8b https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.8.0-20_mkl.tar.bz2#14b25490fdcc44e879ac6c10fe764f68 https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.8.0-20_mkl.tar.bz2#52c0ae3606eeae7e1d493f37f336f4f5 -https://conda.anaconda.org/conda-forge/noarch/platformdirs-3.11.0-pyhd8ed1ab_0.conda#8f567c0a74aa44cf732f15773b4083b0 https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.12.2-py38h17151c0_5.conda#3d66f5c4a0af2713f60ec11bf1230136 https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda#a30144e4156cdbb236f99ebb49828f8b +https://conda.anaconda.org/conda-forge/noarch/rich-13.7.0-pyhd8ed1ab_0.conda#d7a11d4f3024b2f4a6e0ae7377dd61e9 https://conda.anaconda.org/conda-forge/noarch/dask-core-2023.5.0-pyhd8ed1ab_0.conda#03ed2d040648a5ba1063bf1cb0d87b78 -https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.6-h8e1006c_2.conda#3d8e98279bad55287f2ef9047996f33c +https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.7-h8e1006c_0.conda#065e2c1d49afa3fdc1a01f1dacd6ab09 https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.8.0-20_mkl.tar.bz2#8274dc30518af9df1de47f5d9e73165c https://conda.anaconda.org/conda-forge/linux-64/numpy-1.17.3-py38h95a1406_0.tar.bz2#bc0cbf611fe2f86eab29b98e51404f5e -https://conda.anaconda.org/conda-forge/noarch/pooch-1.7.0-pyhd8ed1ab_4.conda#3cdaf7af08850933662b1e228bc6b5bc +https://conda.anaconda.org/conda-forge/noarch/pooch-1.8.0-pyhd8ed1ab_0.conda#134b2b57b7865d2316a7cce1915a51ed https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e https://conda.anaconda.org/conda-forge/noarch/sphinx-6.0.0-pyhd8ed1ab_2.conda#ac1d3b55da1669ee3a56973054fd7efb https://conda.anaconda.org/conda-forge/linux-64/blas-2.20-mkl.tar.bz2#e7d09a07f5413e53dca5282b8fa50bed diff --git a/build_tools/cirrus/py39_conda_forge_linux-aarch64_conda.lock b/build_tools/cirrus/py39_conda_forge_linux-aarch64_conda.lock index 4b9f2756ffc50..199153001d182 100644 --- a/build_tools/cirrus/py39_conda_forge_linux-aarch64_conda.lock +++ b/build_tools/cirrus/py39_conda_forge_linux-aarch64_conda.lock @@ -2,35 +2,35 @@ # platform: linux-aarch64 # input_hash: de5bfe2a68b349f08233af7b94fc3b2045503b21289e8d3bdb30a1613fd0ddb8 @EXPLICIT -https://conda.anaconda.org/conda-forge/linux-aarch64/ca-certificates-2023.7.22-hcefe29a_0.conda#95d7f998087114466fa91e7c2887fa2f +https://conda.anaconda.org/conda-forge/linux-aarch64/ca-certificates-2023.11.17-hcefe29a_0.conda#695a28440b58e3ba920bcac4ac7c73c6 https://conda.anaconda.org/conda-forge/linux-aarch64/ld_impl_linux-aarch64-2.40-h2d8c526_0.conda#16246d69e945d0b1969a6099e7c5d457 -https://conda.anaconda.org/conda-forge/linux-aarch64/libstdcxx-ng-13.2.0-h9a76618_2.conda#921c652898c8602bf2697d015f3efc77 +https://conda.anaconda.org/conda-forge/linux-aarch64/libstdcxx-ng-13.2.0-h9a76618_3.conda#7ad2164936c4975d94ca883d34809c0f https://conda.anaconda.org/conda-forge/linux-aarch64/python_abi-3.9-4_cp39.conda#c191905a08694e4a5cb1238e90233878 https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda#939e3e74d8be4dac89ce83b20de2492a https://conda.anaconda.org/conda-forge/linux-aarch64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2#98a1185182fec3c434069fa74e6473d6 -https://conda.anaconda.org/conda-forge/linux-aarch64/libgcc-ng-13.2.0-hf8544c7_2.conda#f4dfb3bad7c8b38c3f8ed7f15a91a1ed -https://conda.anaconda.org/conda-forge/linux-aarch64/bzip2-1.0.8-hf897c2e_4.tar.bz2#2d787570a729e273a4e75775ddf3348a +https://conda.anaconda.org/conda-forge/linux-aarch64/libgcc-ng-13.2.0-hf8544c7_3.conda#00f021ee1a24c798ae53c87ee79597f1 +https://conda.anaconda.org/conda-forge/linux-aarch64/bzip2-1.0.8-h31becfc_5.conda#a64e35f01e0b7a2a152eca87d33b9c87 https://conda.anaconda.org/conda-forge/linux-aarch64/lerc-4.0.0-h4de3ea5_0.tar.bz2#1a0ffc65e03ce81559dbcb0695ad1476 https://conda.anaconda.org/conda-forge/linux-aarch64/libbrotlicommon-1.1.0-h31becfc_1.conda#1b219fd801eddb7a94df5bd001053ad9 https://conda.anaconda.org/conda-forge/linux-aarch64/libdeflate-1.19-h31becfc_0.conda#014e57e35f2dc95c9a12f63d4378e093 https://conda.anaconda.org/conda-forge/linux-aarch64/libffi-3.4.2-h3557bc0_5.tar.bz2#dddd85f4d52121fab0a8b099c5e06501 -https://conda.anaconda.org/conda-forge/linux-aarch64/libgfortran5-13.2.0-h582850c_2.conda#1be4fb84d6b6617a844933ca406c6bd5 +https://conda.anaconda.org/conda-forge/linux-aarch64/libgfortran5-13.2.0-h582850c_3.conda#d81dcb787465447542ad9c4cf0bab65e https://conda.anaconda.org/conda-forge/linux-aarch64/libjpeg-turbo-3.0.0-h31becfc_1.conda#ed24e702928be089d9ba3f05618515c6 https://conda.anaconda.org/conda-forge/linux-aarch64/libnsl-2.0.1-h31becfc_0.conda#c14f32510f694e3185704d89967ec422 https://conda.anaconda.org/conda-forge/linux-aarch64/libuuid-2.38.1-hb4cce97_0.conda#000e30b09db0b7c775b21695dff30969 https://conda.anaconda.org/conda-forge/linux-aarch64/libwebp-base-1.3.2-h31becfc_0.conda#1490de434d2a2c06a98af27641a2ffff https://conda.anaconda.org/conda-forge/linux-aarch64/libzlib-1.2.13-h31becfc_5.conda#b213aa87eea9491ef7b129179322e955 -https://conda.anaconda.org/conda-forge/linux-aarch64/ncurses-6.4-h2e1726e_0.conda#40beaf447150c2760affc591c7509595 -https://conda.anaconda.org/conda-forge/linux-aarch64/openssl-3.1.3-h31becfc_0.conda#19469b6fdd72c1ffe022d9845fbb85fa +https://conda.anaconda.org/conda-forge/linux-aarch64/ncurses-6.4-h0425590_2.conda#4ff0a396150dedad4269e16e5810f769 +https://conda.anaconda.org/conda-forge/linux-aarch64/openssl-3.1.4-h31becfc_0.conda#bc0e17d9ee24d18aa8ba435d86a2a460 https://conda.anaconda.org/conda-forge/linux-aarch64/pthread-stubs-0.4-hb9de7d4_1001.tar.bz2#d0183ec6ce0b5aaa3486df25fa5f0ded https://conda.anaconda.org/conda-forge/linux-aarch64/xorg-libxau-1.0.11-h31becfc_0.conda#13de34f69cb73165dbe08c1e9148bedb https://conda.anaconda.org/conda-forge/linux-aarch64/xorg-libxdmcp-1.1.3-h3557bc0_0.tar.bz2#a6c9016ae1ca5c47a3603ed4cd65fedd https://conda.anaconda.org/conda-forge/linux-aarch64/xz-5.2.6-h9cdd2b7_0.tar.bz2#83baad393a31d59c20b63ba4da6592df https://conda.anaconda.org/conda-forge/linux-aarch64/libbrotlidec-1.1.0-h31becfc_1.conda#8db7cff89510bec0b863a0a8ee6a7bce https://conda.anaconda.org/conda-forge/linux-aarch64/libbrotlienc-1.1.0-h31becfc_1.conda#ad3d3a826b5848d99936e4466ebbaa26 -https://conda.anaconda.org/conda-forge/linux-aarch64/libgfortran-ng-13.2.0-he9431aa_2.conda#720092257480c53e80f32cc819821fea +https://conda.anaconda.org/conda-forge/linux-aarch64/libgfortran-ng-13.2.0-he9431aa_3.conda#6c292066bb9876d7ba35c590868baaeb https://conda.anaconda.org/conda-forge/linux-aarch64/libpng-1.6.39-hf9034f9_0.conda#5ec9052384a6ac85e9111e9ac7c5ec4c -https://conda.anaconda.org/conda-forge/linux-aarch64/libsqlite-3.43.2-h194ca79_0.conda#16417bba0efbcb7539b942a611cc899a +https://conda.anaconda.org/conda-forge/linux-aarch64/libsqlite-3.44.0-h194ca79_0.conda#6d33a45e15846407c1a9a7388dda5436 https://conda.anaconda.org/conda-forge/linux-aarch64/libxcb-1.15-h2a766a3_0.conda#eb3d8c8170e3d03f2564ed2024aa00c8 https://conda.anaconda.org/conda-forge/linux-aarch64/readline-8.2-h8fc344f_1.conda#105eb1e16bf83bfb2eb380a48032b655 https://conda.anaconda.org/conda-forge/linux-aarch64/tk-8.6.13-h194ca79_0.conda#f75105e0585851f818e0009dd1dde4dc @@ -38,24 +38,24 @@ https://conda.anaconda.org/conda-forge/linux-aarch64/zstd-1.5.5-h4c53e97_0.conda https://conda.anaconda.org/conda-forge/linux-aarch64/brotli-bin-1.1.0-h31becfc_1.conda#9e4a13596ab651ea8d77aae023d0ce3f https://conda.anaconda.org/conda-forge/linux-aarch64/freetype-2.12.1-hf0a5ef3_2.conda#a5ab74c5bd158c3d5532b66d8d83d907 https://conda.anaconda.org/conda-forge/linux-aarch64/libhiredis-1.0.2-h05efe27_0.tar.bz2#a87f068744fd20334cd41489eb163bee -https://conda.anaconda.org/conda-forge/linux-aarch64/libopenblas-0.3.24-pthreads_h5a5ec62_0.conda#22555a102c05b77dc45ff22a21255935 +https://conda.anaconda.org/conda-forge/linux-aarch64/libopenblas-0.3.25-pthreads_h5a5ec62_0.conda#60e86bc93e3f213278dc5081115fb63b https://conda.anaconda.org/conda-forge/linux-aarch64/libtiff-4.6.0-h1708d11_2.conda#d5638e110e7f22e2602a8edd20656720 -https://conda.anaconda.org/conda-forge/linux-aarch64/llvm-openmp-17.0.3-h8b0cb96_0.conda#f88e733b73844b4518babf720a5fa938 +https://conda.anaconda.org/conda-forge/linux-aarch64/llvm-openmp-17.0.5-h8b0cb96_0.conda#07056470540d494e46d432e8468d9c24 https://conda.anaconda.org/conda-forge/linux-aarch64/python-3.9.18-h4ac3b42_0_cpython.conda#4d36e157278470ac06508579c6d36555 https://conda.anaconda.org/conda-forge/linux-aarch64/brotli-1.1.0-h31becfc_1.conda#e41f5862ac746428407f3fd44d2ed01f https://conda.anaconda.org/conda-forge/linux-aarch64/ccache-4.8.1-h6552966_0.conda#5b436a19e818f05fe0c9ab4f5ac61233 -https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda#7f3dbc9179b4dde7da98dfb151d0ad22 +https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789 https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/linux-aarch64/cython-3.0.4-py39h387a81e_0.conda#8996fc67a21f5061c5d6fc81ee623c32 -https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.1.3-pyhd8ed1ab_0.conda#e6518222753f519e911e83136d2158d9 +https://conda.anaconda.org/conda-forge/linux-aarch64/cython-3.0.5-py39h387a81e_0.conda#5bac88110a57d287ceecb7ac78140247 +https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363 https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 https://conda.anaconda.org/conda-forge/linux-aarch64/kiwisolver-1.4.5-py39had2cf8c_1.conda#ddb99610f7b950fdd5ff2aff19136363 https://conda.anaconda.org/conda-forge/linux-aarch64/lcms2-2.15-h922389a_3.conda#a3d4d0d61fc9619bc9066263ed014b45 -https://conda.anaconda.org/conda-forge/linux-aarch64/libblas-3.9.0-19_linuxaarch64_openblas.conda#b5e24d17a35602ac07c72e1133a3cc20 +https://conda.anaconda.org/conda-forge/linux-aarch64/libblas-3.9.0-20_linuxaarch64_openblas.conda#11590ed0fb5cebe7bbfa4bab8d8b07f8 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 -https://conda.anaconda.org/conda-forge/linux-aarch64/openblas-0.3.24-pthreads_h339cbfa_0.conda#3818aec554ea9f792fdacd306af8ad6c +https://conda.anaconda.org/conda-forge/linux-aarch64/openblas-0.3.25-pthreads_h339cbfa_0.conda#e0fe9dad1f26708f60e18e9cdd0986a3 https://conda.anaconda.org/conda-forge/linux-aarch64/openjpeg-2.5.0-h0d9d63b_3.conda#123f5df3bc7f0e23c6950fddb97d1f43 https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.3.0-pyhd8ed1ab_0.conda#2390bd10bed1f3fdc7a537fb5a447d8d @@ -67,25 +67,25 @@ https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.2.0-pyha21a80b_0.c https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5844808ffab9ebdb694585b50ba02a96 https://conda.anaconda.org/conda-forge/linux-aarch64/tornado-6.3.3-py39h7cc1d5f_1.conda#c383c279123694d7a586ec47320d1cb1 https://conda.anaconda.org/conda-forge/linux-aarch64/unicodedata2-15.1.0-py39h898b7ef_0.conda#8c072c9329aeea97a46005625267a851 -https://conda.anaconda.org/conda-forge/noarch/wheel-0.41.2-pyhd8ed1ab_0.conda#1ccd092478b3e0ee10d7a891adbf8a4f +https://conda.anaconda.org/conda-forge/noarch/wheel-0.41.3-pyhd8ed1ab_0.conda#3fc026b9c87d091c4b34a6c997324ae8 https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a -https://conda.anaconda.org/conda-forge/linux-aarch64/fonttools-4.43.1-py39h898b7ef_0.conda#6cf3ce57449761f27f19bbe6503dadd0 -https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.0-pyhd8ed1ab_0.conda#48b0d98e0c0ec810d3ccc2a0926c8c0e +https://conda.anaconda.org/conda-forge/linux-aarch64/fonttools-4.45.0-py39h898b7ef_0.conda#ead625927e6b77749066955be2912667 +https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.1-pyhd8ed1ab_0.conda#3d5fa25cf42f3f32a12b2d874ace8574 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc -https://conda.anaconda.org/conda-forge/linux-aarch64/libcblas-3.9.0-19_linuxaarch64_openblas.conda#8d52c7095a072dde1990717b5f0ab267 -https://conda.anaconda.org/conda-forge/linux-aarch64/liblapack-3.9.0-19_linuxaarch64_openblas.conda#c2a01118ea07574a10c19f7e9997f73b +https://conda.anaconda.org/conda-forge/linux-aarch64/libcblas-3.9.0-20_linuxaarch64_openblas.conda#b41e55ae2cb9d3518da2cbe3677b3b3b +https://conda.anaconda.org/conda-forge/linux-aarch64/liblapack-3.9.0-20_linuxaarch64_openblas.conda#e7412a592d9ee7c92026eb1189687271 https://conda.anaconda.org/conda-forge/linux-aarch64/pillow-10.1.0-py39h8ce38d7_0.conda#afedc0abb518dac535cb861f24585160 -https://conda.anaconda.org/conda-forge/noarch/pip-23.3-pyhd8ed1ab_0.conda#a06f102f59c8e3bb8b3e46e71c384709 -https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.2-pyhd8ed1ab_0.conda#6dd662ff5ac9a783e5c940ce9f3fe649 +https://conda.anaconda.org/conda-forge/noarch/pip-23.3.1-pyhd8ed1ab_0.conda#2400c0b86889f43aa52067161e1fb108 +https://conda.anaconda.org/conda-forge/noarch/pytest-7.4.3-pyhd8ed1ab_0.conda#5bdca0aca30b0ee62bb84854e027eae0 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 -https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.0-pyhd8ed1ab_0.conda#6a62c2cc25376a0d050b3d1d221c3ee9 -https://conda.anaconda.org/conda-forge/linux-aarch64/liblapacke-3.9.0-19_linuxaarch64_openblas.conda#61d8dfefa1b44482f1f2ea08f3cb88b2 +https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.1-pyhd8ed1ab_0.conda#d04bd1b5bed9177dd7c3cef15e2b6710 +https://conda.anaconda.org/conda-forge/linux-aarch64/liblapacke-3.9.0-20_linuxaarch64_openblas.conda#1b8192f036a2dc41fec67700bb8bacef https://conda.anaconda.org/conda-forge/linux-aarch64/numpy-1.26.0-py39h91c28bb_0.conda#cb45bbda25d8486609cab8ecf2c957e1 https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 -https://conda.anaconda.org/conda-forge/linux-aarch64/blas-devel-3.9.0-19_linuxaarch64_openblas.conda#abad19e55d06c157286e2127ecfcc941 -https://conda.anaconda.org/conda-forge/linux-aarch64/contourpy-1.1.1-py39hd16970a_1.conda#63e211dca08da78531b520f07cf55974 +https://conda.anaconda.org/conda-forge/linux-aarch64/blas-devel-3.9.0-20_linuxaarch64_openblas.conda#211c74d7600d8d1dec226daf5e28e2dc +https://conda.anaconda.org/conda-forge/linux-aarch64/contourpy-1.2.0-py39hd16970a_0.conda#dc11a4a2e020d1d71350baa7cb4980e4 https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e https://conda.anaconda.org/conda-forge/linux-aarch64/scipy-1.11.3-py39h91c28bb_1.conda#216b118cdb919665ad7d9d2faff412df -https://conda.anaconda.org/conda-forge/linux-aarch64/blas-2.119-openblas.conda#9e65d8f69a9f2280160df817cf7b330b -https://conda.anaconda.org/conda-forge/linux-aarch64/matplotlib-base-3.8.0-py39h8e43113_2.conda#0905c3e47e121956768b0f231d8e78ba -https://conda.anaconda.org/conda-forge/linux-aarch64/matplotlib-3.8.0-py39ha65689a_2.conda#6bef30b9cd148b0e51677458525a9662 +https://conda.anaconda.org/conda-forge/linux-aarch64/blas-2.120-openblas.conda#4354e2978d15f5b29b1557792e5c5c63 +https://conda.anaconda.org/conda-forge/linux-aarch64/matplotlib-base-3.8.2-py39h8e43113_0.conda#0dd681b8d2a93b799954714481761fe0 +https://conda.anaconda.org/conda-forge/linux-aarch64/matplotlib-3.8.2-py39ha65689a_0.conda#cbdd0df9ca705d88630c3eeabcf154e7 diff --git a/build_tools/update_environments_and_lock_files.py b/build_tools/update_environments_and_lock_files.py index 88b4a8ff215ec..c2888ff91f743 100644 --- a/build_tools/update_environments_and_lock_files.py +++ b/build_tools/update_environments_and_lock_files.py @@ -64,6 +64,7 @@ "threadpoolctl", "matplotlib", "pandas", + "rich", "pyamg", "pytest", "pytest-xdist", @@ -208,6 +209,7 @@ def remove_from(alist, to_remove): "numpy", "scipy", "pandas", + "rich", "cython", "joblib", "pillow", @@ -233,7 +235,8 @@ def remove_from(alist, to_remove): "conda_dependencies": ( ["pypy", "python"] + remove_from( - common_dependencies_without_coverage, ["python", "pandas", "pillow"] + common_dependencies_without_coverage, + ["python", "pandas", "rich", "pillow"], ) + ["ccache"] ), @@ -247,7 +250,9 @@ def remove_from(alist, to_remove): "folder": "build_tools/azure", "platform": "win-64", "channel": "conda-forge", - "conda_dependencies": remove_from(common_dependencies, ["pandas", "pyamg"]) + [ + "conda_dependencies": remove_from( + common_dependencies, ["pandas", "rich", "pyamg"] + ) + [ "wheel", "pip", ], @@ -322,7 +327,7 @@ def remove_from(alist, to_remove): "platform": "linux-aarch64", "channel": "conda-forge", "conda_dependencies": remove_from( - common_dependencies_without_coverage, ["pandas", "pyamg"] + common_dependencies_without_coverage, ["pandas", "rich", "pyamg"] ) + ["pip", "ccache"], "package_constraints": { "python": "3.9", diff --git a/sklearn/_min_dependencies.py b/sklearn/_min_dependencies.py index 7c0b8be6c1295..f1940fd8b537c 100644 --- a/sklearn/_min_dependencies.py +++ b/sklearn/_min_dependencies.py @@ -32,6 +32,7 @@ "matplotlib": ("3.3.4", "benchmark, docs, examples, tests"), "scikit-image": ("0.16.2", "docs, examples, tests"), "pandas": ("1.0.5", "benchmark, docs, examples, tests"), + "rich": ("13.6.0", "docs, examples, tests"), "seaborn": ("0.9.0", "docs, examples"), "memory_profiler": ("0.57.0", "benchmark, docs"), "pytest": (PYTEST_MIN_VERSION, "tests"), diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index 48a1d9882456d..112a176a92bd8 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -1,10 +1,10 @@ # License: BSD 3 clause # Authors: the scikit-learn developers -import importlib from multiprocessing import Manager from threading import Thread +from ..utils import check_rich_support from . import BaseCallback @@ -14,10 +14,7 @@ class ProgressBar(BaseCallback): auto_propagate = True def __init__(self): - try: - importlib.import_module("rich") # noqa - except ImportError as e: - raise ImportError("ProgressBar requires rich installed.") from e + check_rich_support() def on_fit_begin(self, estimator, X=None, y=None): self._queue = Manager().Queue() diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 8b97fcf8ebfcb..121ed3bf5c95f 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1213,3 +1213,19 @@ def check_pandas_support(caller_name): return pandas except ImportError as e: raise ImportError("{} requires pandas.".format(caller_name)) from e + + +def check_rich_support(caller_name): + """Raise ImportError with detailed error message if rich is not installed. + + caller should lazily import rich and call this helper before any computation. + + Parameters + ---------- + caller_name : str + The name of the caller that requires rich. + """ + try: + import rich # noqa + except ImportError as e: + raise ImportError("{} requires rich.".format(caller_name)) from e From df50ab37e74de846da9ee00d749c6cf9058d5929 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 22 Nov 2023 18:39:04 +0100 Subject: [PATCH 31/55] missing arg --- sklearn/callback/_progressbar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index 112a176a92bd8..9360f454e3abe 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -14,7 +14,7 @@ class ProgressBar(BaseCallback): auto_propagate = True def __init__(self): - check_rich_support() + check_rich_support("Progressbar") def on_fit_begin(self, estimator, X=None, y=None): self._queue = Manager().Queue() From aaa2dec1c48418119b2ce7d26a59068b10491b54 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 23 Nov 2023 18:04:57 +0100 Subject: [PATCH 32/55] improve coverage --- sklearn/callback/tests/_utils.py | 6 +++--- .../tests/test_base_estimator_callback_methods.py | 10 ++++++++++ sklearn/callback/tests/test_progressbar.py | 14 ++++++++++++++ sklearn/utils/__init__.py | 2 +- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 6a7662f1434cd..5befdc48bda38 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -26,13 +26,13 @@ class NotValidCallback: """Unvalid callback since it does not inherit from `BaseCallback`.""" def on_fit_begin(self, estimator, *, X=None, y=None): - pass + pass # pragma: no cover def on_fit_end(self): - pass + pass # pragma: no cover def on_fit_iter_end(self, estimator, node, **kwargs): - pass + pass # pragma: no cover class Estimator(BaseEstimator): diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index 1d5e5d1678083..1de957ee3e83e 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -95,3 +95,13 @@ def test_eval_callbacks_on_fit_begin(): ] estimator._eval_callbacks_on_fit_begin(levels=levels) assert hasattr(estimator, "_computation_tree") + + +def test_no_callback_early_stop(): + """Check that `eval_callbacks_on_fit_iter_end` doesn't trigger early stopping + when there's no callback. + """ + estimator = Estimator() + estimator.fit(X=None, y=None) + + assert estimator.n_iter_ == estimator.max_iter diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py index 33f92eebb2013..4c4db3fd8f7ac 100644 --- a/sklearn/callback/tests/test_progressbar.py +++ b/sklearn/callback/tests/test_progressbar.py @@ -6,6 +6,8 @@ import pytest from sklearn.callback import ProgressBar +from sklearn.utils import check_rich_support +from sklearn.utils._testing import SkipTest from ._utils import Estimator, MetaEstimator @@ -32,3 +34,15 @@ def test_progressbar(n_jobs, prefer, capsys): # Check that all bars are 100% complete assert re.search(r"100%", captured.out) assert not re.search(r"[1-9]%", captured.out) + + +def test_progressbar_requires_rich_error(): + """Check that we raise an informative error when rich is not installed.""" + try: + check_rich_support("test_fetch_openml_requires_pandas") + except ImportError: + err_msg = "Progressbar requires rich" + with pytest.raises(ImportError, match=err_msg): + ProgressBar() + else: + raise SkipTest("This test requires rich to not be installed.") diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 121ed3bf5c95f..6594c3c3ea830 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1228,4 +1228,4 @@ def check_rich_support(caller_name): try: import rich # noqa except ImportError as e: - raise ImportError("{} requires rich.".format(caller_name)) from e + raise ImportError(f"{caller_name} requires rich.") from e From a0667c4021d4a8c1653e296c75ccaba69e4d1448 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 9 Feb 2024 16:46:47 +0100 Subject: [PATCH 33/55] mixin for callback propagation --- sklearn/base.py | 46 ------------------ sklearn/callback/__init__.py | 3 +- sklearn/callback/_base.py | 48 +++++++++++++++++++ sklearn/callback/tests/_utils.py | 4 +- .../test_base_estimator_callback_methods.py | 18 +++---- 5 files changed, 61 insertions(+), 58 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 6944f8d6a0e89..899af566e67dc 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -697,52 +697,6 @@ def _set_callbacks(self, callbacks): return self - # XXX should be a method of MetaEstimatorMixin but this mixin can't handle all - # meta-estimators. - def _propagate_callbacks(self, sub_estimator, *, parent_node): - """Propagate the auto-propagated callbacks to a sub-estimator. - - Parameters - ---------- - sub_estimator : estimator instance - The sub-estimator to propagate the callbacks to. - - parent_node : ComputationNode instance - The computation node in this estimator to set as `parent_node` to the - computation tree of the sub-estimator. It must be the node where the fit - method of the sub-estimator is called. - """ - if hasattr(sub_estimator, "_callbacks") and any( - callback.auto_propagate for callback in sub_estimator._callbacks - ): - bad_callbacks = [ - callback.__class__.__name__ - for callback in sub_estimator._callbacks - if callback.auto_propagate - ] - raise TypeError( - f"The sub-estimators ({sub_estimator.__class__.__name__}) of a" - f" meta-estimator ({self.__class__.__name__}) can't have" - f" auto-propagated callbacks ({bad_callbacks})." - " Set them directly on the meta-estimator." - ) - - if not hasattr(self, "_callbacks"): - return - - propagated_callbacks = [ - callback for callback in self._callbacks if callback.auto_propagate - ] - - if not propagated_callbacks: - return - - sub_estimator._parent_node = parent_node - - sub_estimator._set_callbacks( - getattr(sub_estimator, "_callbacks", []) + propagated_callbacks - ) - def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): """Evaluate the `on_fit_begin` method of the callbacks. diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py index c614c38c58b43..42fd228069c3e 100644 --- a/sklearn/callback/__init__.py +++ b/sklearn/callback/__init__.py @@ -6,12 +6,13 @@ # License: BSD 3 clause # Authors: the scikit-learn developers -from ._base import BaseCallback +from ._base import BaseCallback, CallbackPropagatorMixin from ._computation_tree import ComputationNode, build_computation_tree from ._progressbar import ProgressBar __all__ = [ "BaseCallback", + "CallbackPropagatorMixin", "build_computation_tree", "ComputationNode", "ProgressBar", diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 43ec4390de92a..03f38c4c7d008 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -136,3 +136,51 @@ def _is_propagated(self, estimator): # @property # def request_from_reconstruction_attributes(self): # return False + + +class CallbackPropagatorMixin: + """Mixin class for meta-estimators expected to propagate callbacks.""" + + def _propagate_callbacks(self, sub_estimator, *, parent_node): + """Propagate the auto-propagated callbacks to a sub-estimator. + + Parameters + ---------- + sub_estimator : estimator instance + The sub-estimator to propagate the callbacks to. + + parent_node : ComputationNode instance + The computation node in this estimator to set as `parent_node` to the + computation tree of the sub-estimator. It must be the node where the fit + method of the sub-estimator is called. + """ + if hasattr(sub_estimator, "_callbacks") and any( + callback.auto_propagate for callback in sub_estimator._callbacks + ): + bad_callbacks = [ + callback.__class__.__name__ + for callback in sub_estimator._callbacks + if callback.auto_propagate + ] + raise TypeError( + f"The sub-estimators ({sub_estimator.__class__.__name__}) of a" + f" meta-estimator ({self.__class__.__name__}) can't have" + f" auto-propagated callbacks ({bad_callbacks})." + " Set them directly on the meta-estimator." + ) + + if not hasattr(self, "_callbacks"): + return + + propagated_callbacks = [ + callback for callback in self._callbacks if callback.auto_propagate + ] + + if not propagated_callbacks: + return + + sub_estimator._parent_node = parent_node + + sub_estimator._set_callbacks( + getattr(sub_estimator, "_callbacks", []) + propagated_callbacks + ) diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 5befdc48bda38..c1b07ebc1d61b 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -2,7 +2,7 @@ # Authors: the scikit-learn developers from sklearn.base import BaseEstimator, _fit_context, clone -from sklearn.callback import BaseCallback +from sklearn.callback import BaseCallback, CallbackPropagatorMixin from sklearn.callback._base import _eval_callbacks_on_fit_iter_end from sklearn.utils.parallel import Parallel, delayed @@ -64,7 +64,7 @@ def fit(self, X, y): return self -class MetaEstimator(BaseEstimator): +class MetaEstimator(BaseEstimator, CallbackPropagatorMixin): _parameter_constraints: dict = {} def __init__( diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index 1de957ee3e83e..c377793d168a7 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -48,24 +48,24 @@ def test_propagate_callbacks(): propagated_callback = TestingAutoPropagatedCallback() estimator = Estimator() - estimator._set_callbacks([not_propagated_callback, propagated_callback]) + metaestimator = MetaEstimator(estimator) + metaestimator._set_callbacks([not_propagated_callback, propagated_callback]) - sub_estimator = Estimator() - estimator._propagate_callbacks(sub_estimator, parent_node=None) + metaestimator._propagate_callbacks(estimator, parent_node=None) - assert hasattr(sub_estimator, "_parent_node") - assert not_propagated_callback not in sub_estimator._callbacks - assert propagated_callback in sub_estimator._callbacks + assert hasattr(estimator, "_parent_node") + assert not_propagated_callback not in estimator._callbacks + assert propagated_callback in estimator._callbacks def test_propagate_callback_no_callback(): """Check that no callback is propagated if there's no callback.""" estimator = Estimator() - sub_estimator = Estimator() - estimator._propagate_callbacks(sub_estimator, parent_node=None) + metaestimator = MetaEstimator(estimator) + metaestimator._propagate_callbacks(estimator, parent_node=None) + assert not hasattr(metaestimator, "_callbacks") assert not hasattr(estimator, "_callbacks") - assert not hasattr(sub_estimator, "_callbacks") def test_auto_propagated_callbacks(): From 2fdbda3d462dbfd0e070fd6ce08f2c57d79f092f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Sat, 17 Feb 2024 17:49:23 +0100 Subject: [PATCH 34/55] rename _skl_callbacks --- sklearn/base.py | 16 ++++++------ sklearn/callback/_base.py | 25 +++++++++++-------- .../test_base_estimator_callback_methods.py | 12 ++++----- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 899af566e67dc..dc9d6029ff253 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -132,8 +132,8 @@ def _clone_parametrized(estimator, *, safe=True): params_set = new_object.get_params(deep=False) # attach callbacks to the new estimator - if hasattr(estimator, "_callbacks"): - new_object._callbacks = estimator._callbacks + if hasattr(estimator, "_skl_callbacks"): + new_object._skl_callbacks = estimator._skl_callbacks # quick sanity check of the parameters of the clone for name in new_object_params: @@ -693,7 +693,7 @@ def _set_callbacks(self, callbacks): if not all(isinstance(callback, BaseCallback) for callback in callbacks): raise TypeError("callbacks must be subclasses of BaseCallback.") - self._callbacks = callbacks + self._skl_callbacks = callbacks return self @@ -727,12 +727,12 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): parent=getattr(self, "_parent_node", None), ) - if not hasattr(self, "_callbacks"): + if not hasattr(self, "_skl_callbacks"): return self._computation_tree # Only call the on_fit_begin method of callbacks that are not # propagated from a meta-estimator. - for callback in self._callbacks: + for callback in self._skl_callbacks: if not callback._is_propagated(estimator=self): callback.on_fit_begin(estimator=self, X=X, y=y) @@ -740,12 +740,14 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): def _eval_callbacks_on_fit_end(self): """Evaluate the `on_fit_end` method of the callbacks.""" - if not hasattr(self, "_callbacks") or not hasattr(self, "_computation_tree"): + if not hasattr(self, "_skl_callbacks") or not hasattr( + self, "_computation_tree" + ): return # Only call the on_fit_end method of callbacks that are not # propagated from a meta-estimator. - for callback in self._callbacks: + for callback in self._skl_callbacks: if not callback._is_propagated(estimator=self): callback.on_fit_end() diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 03f38c4c7d008..ab6a467e1e3fd 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -24,7 +24,7 @@ def _eval_callbacks_on_fit_iter_end(**kwargs): estimator = kwargs.get("estimator") node = kwargs.get("node") - if not hasattr(estimator, "_callbacks") or node is None: + if not hasattr(estimator, "_skl_callbacks") or node is None: return False # stopping_criterion and reconstruction_attributes can be costly to compute. @@ -32,15 +32,20 @@ def _eval_callbacks_on_fit_iter_end(**kwargs): # compute them if a callback requests it. # TODO: This is not used yet but will be necessary for next callbacks # Uncomment when needed - # if any(cb.request_stopping_criterion for cb in estimator._callbacks): + # if any(cb.request_stopping_criterion for cb in estimator._skl_callbacks): # kwarg = kwargs.pop("stopping_criterion", lambda: None)() # kwargs["stopping_criterion"] = kwarg - # if any(cb.request_from_reconstruction_attributes for cb in estimator._callbacks): + # if any( + # cb.request_from_reconstruction_attributes + # for cb in estimator._skl_callbacks + # ): # kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() # kwargs["from_reconstruction_attributes"] = kwarg - return any(callback.on_fit_iter_end(**kwargs) for callback in estimator._callbacks) + return any( + callback.on_fit_iter_end(**kwargs) for callback in estimator._skl_callbacks + ) class BaseCallback(ABC): @@ -154,12 +159,12 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): computation tree of the sub-estimator. It must be the node where the fit method of the sub-estimator is called. """ - if hasattr(sub_estimator, "_callbacks") and any( - callback.auto_propagate for callback in sub_estimator._callbacks + if hasattr(sub_estimator, "_skl_callbacks") and any( + callback.auto_propagate for callback in sub_estimator._skl_callbacks ): bad_callbacks = [ callback.__class__.__name__ - for callback in sub_estimator._callbacks + for callback in sub_estimator._skl_callbacks if callback.auto_propagate ] raise TypeError( @@ -169,11 +174,11 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): " Set them directly on the meta-estimator." ) - if not hasattr(self, "_callbacks"): + if not hasattr(self, "_skl_callbacks"): return propagated_callbacks = [ - callback for callback in self._callbacks if callback.auto_propagate + callback for callback in self._skl_callbacks if callback.auto_propagate ] if not propagated_callbacks: @@ -182,5 +187,5 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): sub_estimator._parent_node = parent_node sub_estimator._set_callbacks( - getattr(sub_estimator, "_callbacks", []) + propagated_callbacks + getattr(sub_estimator, "_skl_callbacks", []) + propagated_callbacks ) diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index c377793d168a7..7ea2f79998fd7 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -25,10 +25,10 @@ def test_set_callbacks(callbacks): estimator = Estimator() set_callbacks_return = estimator._set_callbacks(callbacks) - assert hasattr(estimator, "_callbacks") + assert hasattr(estimator, "_skl_callbacks") expected_callbacks = [callbacks] if not isinstance(callbacks, list) else callbacks - assert estimator._callbacks == expected_callbacks + assert estimator._skl_callbacks == expected_callbacks assert set_callbacks_return is estimator @@ -54,8 +54,8 @@ def test_propagate_callbacks(): metaestimator._propagate_callbacks(estimator, parent_node=None) assert hasattr(estimator, "_parent_node") - assert not_propagated_callback not in estimator._callbacks - assert propagated_callback in estimator._callbacks + assert not_propagated_callback not in estimator._skl_callbacks + assert propagated_callback in estimator._skl_callbacks def test_propagate_callback_no_callback(): @@ -64,8 +64,8 @@ def test_propagate_callback_no_callback(): metaestimator = MetaEstimator(estimator) metaestimator._propagate_callbacks(estimator, parent_node=None) - assert not hasattr(metaestimator, "_callbacks") - assert not hasattr(estimator, "_callbacks") + assert not hasattr(metaestimator, "_skl_callbacks") + assert not hasattr(estimator, "_skl_callbacks") def test_auto_propagated_callbacks(): From aea9af700871e34926fe7942cd783ec28830ea13 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 19 Feb 2024 11:52:59 +0100 Subject: [PATCH 35/55] clone callbacks --- sklearn/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/base.py b/sklearn/base.py index dc9d6029ff253..6f9aac19dad6c 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -133,7 +133,7 @@ def _clone_parametrized(estimator, *, safe=True): # attach callbacks to the new estimator if hasattr(estimator, "_skl_callbacks"): - new_object._skl_callbacks = estimator._skl_callbacks + new_object._skl_callbacks = clone(estimator._skl_callbacks, safe=False) # quick sanity check of the parameters of the clone for name in new_object_params: From 44b615af0c797b9094b5d2afff0203ba29f2ba4e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 20 Feb 2024 10:51:24 +0100 Subject: [PATCH 36/55] some renaming and cleanup --- sklearn/base.py | 24 ++-- sklearn/callback/_base.py | 122 ++++++++++-------- sklearn/callback/_computation_tree.py | 45 ++++--- sklearn/callback/_progressbar.py | 10 +- sklearn/callback/tests/_utils.py | 24 ++-- .../test_base_estimator_callback_methods.py | 8 +- .../callback/tests/test_computation_tree.py | 28 ++-- 7 files changed, 142 insertions(+), 119 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 6f9aac19dad6c..86b01b4e87439 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -16,6 +16,7 @@ from . import __version__ from ._config import config_context, get_config from .callback import BaseCallback, build_computation_tree +from .callback._base import default_data from .exceptions import InconsistentVersionWarning from .utils import _IS_32BIT from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr @@ -697,7 +698,7 @@ def _set_callbacks(self, callbacks): return self - def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): + def _eval_callbacks_on_fit_begin(self, *, tree_structure, data): """Evaluate the `on_fit_begin` method of the callbacks. The computation tree is also built at this point. @@ -706,15 +707,16 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): Parameters ---------- - X : ndarray or sparse matrix, default=None - The training data. + tree_structure : list of dict + A description of the nested steps of computation of the estimator to build + the computation tree. It's a list of dict with keys "stage" and + "n_children". - y : ndarray, default=None - The target. - - levels : list of dict - A description of the nested levels of computation of the estimator to build - the computation tree. It's a list of dict with "descr" and "max_iter" keys. + data : dict + Dictionary containing the training and validation data. The keys are + "X_train", "y_train", "sample_weight_train", "X_val", "y_val", + "sample_weight_val". The values are the corresponding data. If a key is + missing, the corresponding value is None. Returns ------- @@ -723,7 +725,7 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): """ self._computation_tree = build_computation_tree( estimator_name=self.__class__.__name__, - levels=levels, + tree_structure=tree_structure, parent=getattr(self, "_parent_node", None), ) @@ -734,7 +736,7 @@ def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): # propagated from a meta-estimator. for callback in self._skl_callbacks: if not callback._is_propagated(estimator=self): - callback.on_fit_begin(estimator=self, X=X, y=y) + callback.on_fit_begin(estimator=self, data={**default_data, **data}) return self._computation_tree diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index ab6a467e1e3fd..d9be734ba251a 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -3,79 +3,43 @@ from abc import ABC, abstractmethod - -# Not a method of BaseEstimator because it might not be directly called from fit but -# by a non-method function called by fit -def _eval_callbacks_on_fit_iter_end(**kwargs): - """Evaluate the `on_fit_iter_end` method of the callbacks. - - This function must be called at the end of each computation node. - - Parameters - ---------- - kwargs : dict - Arguments passed to the callback. - - Returns - ------- - stop : bool - Whether or not to stop the fit at this node. - """ - estimator = kwargs.get("estimator") - node = kwargs.get("node") - - if not hasattr(estimator, "_skl_callbacks") or node is None: - return False - - # stopping_criterion and reconstruction_attributes can be costly to compute. - # They are passed as lambdas for lazy evaluation. We only actually - # compute them if a callback requests it. - # TODO: This is not used yet but will be necessary for next callbacks - # Uncomment when needed - # if any(cb.request_stopping_criterion for cb in estimator._skl_callbacks): - # kwarg = kwargs.pop("stopping_criterion", lambda: None)() - # kwargs["stopping_criterion"] = kwarg - - # if any( - # cb.request_from_reconstruction_attributes - # for cb in estimator._skl_callbacks - # ): - # kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() - # kwargs["from_reconstruction_attributes"] = kwarg - - return any( - callback.on_fit_iter_end(**kwargs) for callback in estimator._skl_callbacks - ) +# default values for the data dict passed to the callbacks +default_data = { + "X_train": None, + "y_train": None, + "sample_weight_train": None, + "X_val": None, + "y_val": None, + "sample_weight_val": None, +} class BaseCallback(ABC): """Abstract class for the callbacks""" @abstractmethod - def on_fit_begin(self, estimator, *, X=None, y=None): - """Method called at the beginning of the fit method of the estimator - - Only called + def on_fit_begin(self, estimator, *, data): + """Method called at the beginning of the fit method of the estimator. Parameters ---------- estimator : estimator instance The estimator the callback is set on. - X : ndarray or sparse matrix, default=None - The training data. - - y : ndarray or sparse matrix, default=None - The target. + data : dict + Dictionary containing the training and validation data. The keys are + "X_train", "y_train", "sample_weight_train", "X_val", "y_val", + "sample_weight_val". The values are the corresponding data. If a key is + missing, the corresponding value is None. """ @abstractmethod def on_fit_end(self): - """Method called at the end of the fit method of the estimator""" + """Method called at the end of the fit method of the estimator.""" @abstractmethod def on_fit_iter_end(self, estimator, node, **kwargs): - """Method called at the end of each computation node of the estimator + """Method called at the end of each computation node of the estimator. Parameters ---------- @@ -89,6 +53,12 @@ def on_fit_iter_end(self, estimator, node, **kwargs): **kwargs : dict arguments passed to the callback. Possible keys are + - data: dict + Dictionary containing the training and validation data. The keys are + "X_train", "y_train", "sample_weight_train", "X_val", "y_val", + "sample_weight_val". The values are the corresponding data. If a key is + missing, the corresponding value is None. + - stopping_criterion: float Usually iterations stop when `stopping_criterion <= tol`. This is only provided at the innermost level of iterations. @@ -189,3 +159,47 @@ def _propagate_callbacks(self, sub_estimator, *, parent_node): sub_estimator._set_callbacks( getattr(sub_estimator, "_skl_callbacks", []) + propagated_callbacks ) + + +# Not a method of BaseEstimator because it might not be directly called from fit but +# by a non-method function called by fit +def _eval_callbacks_on_fit_iter_end(**kwargs): + """Evaluate the `on_fit_iter_end` method of the callbacks. + + This function must be called at the end of each computation node. + + Parameters + ---------- + kwargs : dict + Arguments passed to the callback. + + Returns + ------- + stop : bool + Whether or not to stop the fit at this node. + """ + estimator = kwargs.get("estimator") + node = kwargs.get("node") + + if not hasattr(estimator, "_skl_callbacks") or node is None: + return False + + # stopping_criterion and reconstruction_attributes can be costly to compute. + # They are passed as lambdas for lazy evaluation. We only actually + # compute them if a callback requests it. + # TODO: This is not used yet but will be necessary for next callbacks + # Uncomment when needed + # if any(cb.request_stopping_criterion for cb in estimator._skl_callbacks): + # kwarg = kwargs.pop("stopping_criterion", lambda: None)() + # kwargs["stopping_criterion"] = kwarg + + # if any( + # cb.request_from_reconstruction_attributes + # for cb in estimator._skl_callbacks + # ): + # kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() + # kwargs["from_reconstruction_attributes"] = kwarg + + return any( + callback.on_fit_iter_end(**kwargs) for callback in estimator._skl_callbacks + ) diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py index 2a7bbece6b90e..ba32e04400810 100644 --- a/sklearn/callback/_computation_tree.py +++ b/sklearn/callback/_computation_tree.py @@ -10,10 +10,11 @@ class ComputationNode: estimator_name : str The name of the estimator this computation node belongs to. - description : str, default=None - A description of this computation node. None means it's a leaf. + stage : str, default=None + A description of the stage this computation node belongs to. + None means it's a leaf. - max_iter : int, default=None + n_children : int, default=None The number of its children. None means it's a leaf. idx : int, default=0 @@ -31,8 +32,8 @@ class ComputationNode: def __init__( self, estimator_name, - description=None, - max_iter=None, + stage=None, + n_children=None, idx=0, parent=None, ): @@ -42,10 +43,10 @@ def __init__( # meta-estimator correspond to the same computation step. Therefore, both # nodes are merged into a single node, retaining the information of both. self.estimator_name = (estimator_name,) - self.description = (description,) + self.stage = (stage,) self.parent = parent - self.max_iter = max_iter + self.n_children = n_children self.idx = idx self.children = [] @@ -67,7 +68,7 @@ def __iter__(self): yield from node -def build_computation_tree(estimator_name, levels, parent=None, idx=0): +def build_computation_tree(estimator_name, tree_structure, parent=None, idx=0): """Build the computation tree from the description of the levels. Parameters @@ -75,12 +76,12 @@ def build_computation_tree(estimator_name, levels, parent=None, idx=0): estimator_name : str The name of the estimator this computation tree belongs to. - levels : list of dict - The description of the levels of the computation tree. Each dict must have + tree_structure : list of dict + The description of the stages of the computation tree. Each dict must have the following keys: - - descr: str - A description of the level - - max_iter: int or None + - stage: str + A human readable description of the stage. + - n_children: int or None The number of its children. None means it's a leaf. parent : ComputationNode instance, default=None @@ -94,30 +95,32 @@ def build_computation_tree(estimator_name, levels, parent=None, idx=0): computation_tree : ComputationNode instance The root of the computation tree. """ - this_level = levels[0] + this_stage = tree_structure[0] node = ComputationNode( estimator_name=estimator_name, parent=parent, - max_iter=this_level["max_iter"], - description=this_level["descr"], + n_children=this_stage["n_children"], + stage=this_stage["stage"], idx=idx, ) - if parent is not None and parent.max_iter is None: + if parent is not None and parent.n_children is None: # parent node is a leaf of the computation tree of an outer estimator. It means # that this node is the root of the computation tree of this estimator. They # both correspond the same computation step, so we merge both nodes. - node.description = parent.description + node.description + node.stage = parent.stage + node.stage node.estimator_name = parent.estimator_name + node.estimator_name node.parent = parent.parent node.idx = parent.idx parent.parent.children[node.idx] = node - if node.max_iter is not None: - for i in range(node.max_iter): + if node.n_children is not None: + for i in range(node.n_children): node.children.append( - build_computation_tree(estimator_name, levels[1:], parent=node, idx=i) + build_computation_tree( + estimator_name, tree_structure[1:], parent=node, idx=i + ) ) return node diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index 9360f454e3abe..00f55958ae42c 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -16,7 +16,7 @@ class ProgressBar(BaseCallback): def __init__(self): check_rich_support("Progressbar") - def on_fit_begin(self, estimator, X=None, y=None): + def on_fit_begin(self, estimator, data): self._queue = Manager().Queue() self.progress_monitor = _RichProgressMonitor(queue=self._queue) self.progress_monitor.start() @@ -156,9 +156,9 @@ def __init__(self, node, progress_ctx, parent=None): self.children = {} self.finished = False - if node.max_iter is not None: + if node.n_children is not None: description = self._format_task_description(node) - self.task_id = progress_ctx.add_task(description, total=node.max_iter) + self.task_id = progress_ctx.add_task(description, total=node.n_children) def _format_task_description(self, node): """Return a formatted description for the task of the node.""" @@ -167,11 +167,11 @@ def _format_task_description(self, node): indent = f"{' ' * (node.depth)}" style = f"[{colors[(node.depth)%len(colors)]}]" - description = f"{node.estimator_name[0]} - {node.description[0]}" + description = f"{node.estimator_name[0]} - {node.stage[0]}" if node.parent is not None: description += f" #{node.idx}" if len(node.estimator_name) == 2: - description += f" | {node.estimator_name[1]} - {node.description[1]}" + description += f" | {node.estimator_name[1]} - {node.stage[1]}" return f"{style}{indent}{description}" diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index c1b07ebc1d61b..0e51d1c165b2e 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -8,7 +8,7 @@ class TestingCallback(BaseCallback): - def on_fit_begin(self, estimator, *, X=None, y=None): + def on_fit_begin(self, estimator, *, data): pass def on_fit_end(self): @@ -25,7 +25,7 @@ class TestingAutoPropagatedCallback(TestingCallback): class NotValidCallback: """Unvalid callback since it does not inherit from `BaseCallback`.""" - def on_fit_begin(self, estimator, *, X=None, y=None): + def on_fit_begin(self, estimator, *, data): pass # pragma: no cover def on_fit_end(self): @@ -44,12 +44,11 @@ def __init__(self, max_iter=20): @_fit_context(prefer_skip_nested_validation=False) def fit(self, X, y): root = self._eval_callbacks_on_fit_begin( - levels=[ - {"descr": "fit", "max_iter": self.max_iter}, - {"descr": "iter", "max_iter": None}, + tree_structure=[ + {"stage": "fit", "n_children": self.max_iter}, + {"stage": "iter", "n_children": None}, ], - X=X, - y=y, + data={"X_train": X, "y_train": y}, ) for i in range(self.max_iter): @@ -79,13 +78,12 @@ def __init__( @_fit_context(prefer_skip_nested_validation=False) def fit(self, X, y): root = self._eval_callbacks_on_fit_begin( - levels=[ - {"descr": "fit", "max_iter": self.n_outer}, - {"descr": "outer", "max_iter": self.n_inner}, - {"descr": "inner", "max_iter": None}, + tree_structure=[ + {"stage": "fit", "n_children": self.n_outer}, + {"stage": "outer", "n_children": self.n_inner}, + {"stage": "inner", "n_children": None}, ], - X=X, - y=y, + data={"X_train": X, "y_train": y}, ) Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py index 7ea2f79998fd7..838f117e65812 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -89,11 +89,11 @@ def test_eval_callbacks_on_fit_begin(): estimator = Estimator()._set_callbacks(TestingCallback()) assert not hasattr(estimator, "_computation_tree") - levels = [ - {"descr": "fit", "max_iter": 10}, - {"descr": "iter", "max_iter": None}, + tree_structure = [ + {"stage": "fit", "n_children": 10}, + {"stage": "iter", "n_children": None}, ] - estimator._eval_callbacks_on_fit_begin(levels=levels) + estimator._eval_callbacks_on_fit_begin(tree_structure=tree_structure, data={}) assert hasattr(estimator, "_computation_tree") diff --git a/sklearn/callback/tests/test_computation_tree.py b/sklearn/callback/tests/test_computation_tree.py index ee2e3fa971fdb..de2442d4851a3 100644 --- a/sklearn/callback/tests/test_computation_tree.py +++ b/sklearn/callback/tests/test_computation_tree.py @@ -5,22 +5,24 @@ from sklearn.callback import build_computation_tree -LEVELS = [ - {"descr": "level0", "max_iter": 3}, - {"descr": "level1", "max_iter": 5}, - {"descr": "level2", "max_iter": 7}, - {"descr": "level3", "max_iter": None}, +TREE_STRUCTURE = [ + {"stage": "stage0", "n_children": 3}, + {"stage": "stage1", "n_children": 5}, + {"stage": "stage2", "n_children": 7}, + {"stage": "stage3", "n_children": None}, ] def test_computation_tree(): """Check the construction of the computation tree.""" - computation_tree = build_computation_tree(estimator_name="estimator", levels=LEVELS) + computation_tree = build_computation_tree( + estimator_name="estimator", tree_structure=TREE_STRUCTURE + ) assert computation_tree.estimator_name == ("estimator",) assert computation_tree.parent is None assert computation_tree.idx == 0 - assert len(computation_tree.children) == computation_tree.max_iter == 3 + assert len(computation_tree.children) == computation_tree.n_children == 3 assert [node.idx for node in computation_tree.children] == list(range(3)) for node1 in computation_tree.children: @@ -39,10 +41,12 @@ def test_n_nodes(): """Check that the number of node in a computation tree corresponds to what we expect from the level descriptions. """ - computation_tree = build_computation_tree(estimator_name="", levels=LEVELS) + computation_tree = build_computation_tree( + estimator_name="", tree_structure=TREE_STRUCTURE + ) - max_iter_per_level = [level["max_iter"] for level in LEVELS[:-1]] - expected_n_nodes = 1 + np.sum(np.cumprod(max_iter_per_level)) + n_children_per_level = [stage["n_children"] for stage in TREE_STRUCTURE[:-1]] + expected_n_nodes = 1 + np.sum(np.cumprod(n_children_per_level)) actual_n_nodes = sum(1 for _ in computation_tree) @@ -51,7 +55,9 @@ def test_n_nodes(): def test_path(): """Check that the path from the root to a node is correct.""" - computation_tree = build_computation_tree(estimator_name="", levels=LEVELS) + computation_tree = build_computation_tree( + estimator_name="", tree_structure=TREE_STRUCTURE + ) assert computation_tree.path == [computation_tree] From 07a6875e86a98ffb2911c1d634104f2d43b5260b Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 20 Feb 2024 18:33:48 +0100 Subject: [PATCH 37/55] Merge branch 'callbacks' into base (continued) --- ...st_pip_openblas_pandas_linux-64_conda.lock | 12 ++--- build_tools/azure/ubuntu_atlas_lock.txt | 2 +- build_tools/circle/doc_linux-64_conda.lock | 54 +++++++++---------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock b/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock index d11f0e7d26e88..7a703828a0256 100644 --- a/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock +++ b/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock @@ -32,7 +32,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3.1-py39h06a4308_0.conda#685 # pip docutils @ https://files.pythonhosted.org/packages/26/87/f238c0670b94533ac0353a4e2a1a771a0cc73277b88bff23d3ae35a256c1/docutils-0.20.1-py3-none-any.whl#sha256=96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6 # pip exceptiongroup @ https://files.pythonhosted.org/packages/b8/9a/5028fd52db10e600f1c4674441b968cf2ea4959085bfb5b99fb1250e5f68/exceptiongroup-1.2.0-py3-none-any.whl#sha256=4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14 # pip execnet @ https://files.pythonhosted.org/packages/e8/9c/a079946da30fac4924d92dbc617e5367d454954494cf1e71567bcc4e00ee/execnet-2.0.2-py3-none-any.whl#sha256=88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41 -# pip fonttools @ https://files.pythonhosted.org/packages/ef/02/1e18cc5249b2e9cdd1d6c231373c4ba7ad18ff3ac9164b1ffcac6ed0aa35/fonttools-4.48.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=c900508c46274d32d308ae8e82335117f11aaee1f7d369ac16502c9a78930b0a +# pip fonttools @ https://files.pythonhosted.org/packages/d8/d7/0f4563ea45c14c84fde44aca3cb0896e49d1d960ba1298e789b75b1d2625/fonttools-4.49.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=1f255ce8ed7556658f6d23f6afd22a6d9bbc3edb9b96c96682124dc487e1bf42 # pip idna @ https://files.pythonhosted.org/packages/c2/e7/a82b05cf63a603df6e68d59ae6a68bf5064484a0718ea5033660af4b54a9/idna-3.6-py3-none-any.whl#sha256=c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f # pip imagesize @ https://files.pythonhosted.org/packages/ff/62/85c4c919272577931d407be5ba5d71c20f0b616d31a0befe0ae45bb79abd/imagesize-1.4.1-py2.py3-none-any.whl#sha256=0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b # pip iniconfig @ https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl#sha256=b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 @@ -58,10 +58,10 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3.1-py39h06a4308_0.conda#685 # pip sphinxcontrib-qthelp @ https://files.pythonhosted.org/packages/80/b3/1beac14a88654d2e5120d0143b49be5ad450b86eb1963523d8dbdcc51eb2/sphinxcontrib_qthelp-1.0.7-py3-none-any.whl#sha256=e2ae3b5c492d58fcbd73281fbd27e34b8393ec34a073c792642cd8e529288182 # pip sphinxcontrib-serializinghtml @ https://files.pythonhosted.org/packages/38/24/228bb903ea87b9e08ab33470e6102402a644127108c7117ac9c00d849f82/sphinxcontrib_serializinghtml-1.1.10-py3-none-any.whl#sha256=326369b8df80a7d2d8d7f99aa5ac577f51ea51556ed974e7716cfd4fca3f6cb7 # pip tabulate @ https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl#sha256=024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f -# pip threadpoolctl @ https://files.pythonhosted.org/packages/81/12/fd4dea011af9d69e1cad05c75f3f7202cdcbeac9b712eea58ca779a72865/threadpoolctl-3.2.0-py3-none-any.whl#sha256=2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032 +# pip threadpoolctl @ https://files.pythonhosted.org/packages/b1/2c/f504e55d98418f2fcf756a56877e6d9a45dd5ed28b3d7c267b300e85ad5b/threadpoolctl-3.3.0-py3-none-any.whl#sha256=6155be1f4a39f31a18ea70f94a77e0ccd57dced08122ea61109e7da89883781e # pip tomli @ https://files.pythonhosted.org/packages/97/75/10a9ebee3fd790d20926a90a2547f0bf78f371b2f13aa822c759680ca7b9/tomli-2.0.1-py3-none-any.whl#sha256=939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc # pip tzdata @ https://files.pythonhosted.org/packages/65/58/f9c9e6be752e9fcb8b6a0ee9fb87e6e7a1f6bcab2cdc73f02bb7ba91ada0/tzdata-2024.1-py2.py3-none-any.whl#sha256=9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252 -# pip urllib3 @ https://files.pythonhosted.org/packages/88/75/311454fd3317aefe18415f04568edc20218453b709c63c58b9292c71be17/urllib3-2.2.0-py3-none-any.whl#sha256=ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 +# pip urllib3 @ https://files.pythonhosted.org/packages/a2/73/a68704750a7679d0b6d3ad7aa8d4da8e14e151ae82e6fee774e6e0d05ec8/urllib3-2.2.1-py3-none-any.whl#sha256=450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d # pip zipp @ https://files.pythonhosted.org/packages/d9/66/48866fc6b158c81cc2bfecc04c480f105c6040e8b077bc54c634b4a67926/zipp-3.17.0-py3-none-any.whl#sha256=0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31 # pip contourpy @ https://files.pythonhosted.org/packages/a9/ba/d8fd1380876f1e9114157606302e3644c85f6d116aeba354c212ee13edc7/contourpy-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=11f8d2554e52f459918f7b8e6aa20ec2a3bce35ce95c1f0ef4ba36fbda306df5 # pip coverage @ https://files.pythonhosted.org/packages/ff/e3/351477165426da841458f2c1b732360dd42da140920e3cd4b70676e5b77f/coverage-7.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=d12c923757de24e4e2110cf8832d83a886a4cf215c6e61ed506006872b43a6d1 @@ -70,13 +70,13 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3.1-py39h06a4308_0.conda#685 # pip importlib-resources @ https://files.pythonhosted.org/packages/93/e8/facde510585869b5ec694e8e0363ffe4eba067cb357a8398a55f6a1f8023/importlib_resources-6.1.1-py3-none-any.whl#sha256=e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6 # pip jinja2 @ https://files.pythonhosted.org/packages/30/6d/6de6be2d02603ab56e72997708809e8a5b0fbfee080735109b40a3564843/Jinja2-3.1.3-py3-none-any.whl#sha256=7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa # pip markdown-it-py @ https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl#sha256=355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 -# pip pytest @ https://files.pythonhosted.org/packages/c7/10/727155d44c5e04bb08e880668e53079547282e4f950535234e5a80690564/pytest-8.0.0-py3-none-any.whl#sha256=50fb9cbe836c3f20f0dfa99c565201fb75dc54c8d76373cd1bde06b06657bdb6 +# pip pytest @ https://files.pythonhosted.org/packages/2e/28/30125a808a2448d72fdba26d01ef2bec76a3c860c8694b636e6104e38713/pytest-8.0.1-py3-none-any.whl#sha256=3e4f16fe1c0a9dc9d9389161c127c3edc5d810c38d6793042fb81d9f48a59fca # pip python-dateutil @ https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl#sha256=961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9 # pip requests @ https://files.pythonhosted.org/packages/70/8e/0e2d847013cb52cd35b38c009bb167a1a26b2ce6cd6965bf26b47bc0bf44/requests-2.31.0-py3-none-any.whl#sha256=58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f # pip scipy @ https://files.pythonhosted.org/packages/a6/9d/f864266894b67cdb5731ab531afba68713da3d6d8252f698ccab775d3f68/scipy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=6546dc2c11a9df6926afcbdd8a3edec28566e4e785b915e849348c6dd9f3f490 -# pip tifffile @ https://files.pythonhosted.org/packages/16/09/b9f5e4f9448fd39b7c0c9cbb592409ab28e90a1913795260b975d8424cde/tifffile-2024.1.30-py3-none-any.whl#sha256=40cb48f661acdfea16cb00dc8941bd642b8eb5c59bca6de6a54091bee9ee2699 +# pip tifffile @ https://files.pythonhosted.org/packages/cd/0b/33610b4d0d1bb83a6bfd20ed838f52e02a44e9b439116cd4f3d424e81a80/tifffile-2024.2.12-py3-none-any.whl#sha256=870998f82fbc94ff7c3528884c1b0ae54863504ff51dbebea431ac3fa8fb7c21 # pip lightgbm @ https://files.pythonhosted.org/packages/ba/11/cb8b67f3cbdca05b59a032bb57963d4fe8c8d18c3870f30bed005b7f174d/lightgbm-4.3.0-py3-none-manylinux_2_28_x86_64.whl#sha256=104496a3404cb2452d3412cbddcfbfadbef9c372ea91e3a9b8794bcc5183bf07 -# pip matplotlib @ https://files.pythonhosted.org/packages/53/1f/653d60d2ec81a6095fa3e571cf2de57742bab8a51a5c01de26730ce3dc53/matplotlib-3.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=5864bdd7da445e4e5e011b199bb67168cdad10b501750367c496420f2ad00843 +# pip matplotlib @ https://files.pythonhosted.org/packages/35/82/ca05c3e3ec4a38eaf49a9bfa1a700658284ddaaa2e2523fa91fbb96d207a/matplotlib-3.8.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=6728dde0a3997396b053602dbd907a9bd64ec7d5cf99e728b404083698d3ca01 # pip pandas @ https://files.pythonhosted.org/packages/df/bc/663c52528d6b2c796d0f788655e5f0fd65842523715a18f4d4beaca8dcb2/pandas-2.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=eb61dc8567b798b969bcc1fc964788f5a68214d333cade8319c7ab33e2b5d88a # pip pyamg @ https://files.pythonhosted.org/packages/35/1c/8b2aa6fbb2bae258ab6cdb35b09635bf50865ac2bcdaf220db3d972cc0d8/pyamg-5.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=1332acec6d5ede9440c8ced0ef20952f5b766387116f254b79880ce29fdecee7 # pip pytest-cov @ https://files.pythonhosted.org/packages/a7/4b/8b78d126e275efa2379b1c2e09dc52cf70df16fc3b90613ef82531499d73/pytest_cov-4.1.0-py3-none-any.whl#sha256=6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a diff --git a/build_tools/azure/ubuntu_atlas_lock.txt b/build_tools/azure/ubuntu_atlas_lock.txt index d2a0e285efa1b..52e2b66062c86 100644 --- a/build_tools/azure/ubuntu_atlas_lock.txt +++ b/build_tools/azure/ubuntu_atlas_lock.txt @@ -18,7 +18,7 @@ packaging==23.2 # via pytest pluggy==1.4.0 # via pytest -pytest==8.0.0 +pytest==8.0.1 # via # -r build_tools/azure/ubuntu_atlas_requirements.txt # pytest-xdist diff --git a/build_tools/circle/doc_linux-64_conda.lock b/build_tools/circle/doc_linux-64_conda.lock index 162c3fbb9024c..0c64a989d724f 100644 --- a/build_tools/circle/doc_linux-64_conda.lock +++ b/build_tools/circle/doc_linux-64_conda.lock @@ -8,7 +8,7 @@ https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2#34893075a5c9e55cdafac56607368fc6 https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2#4d59c254e01d9cde7957100457e2d5fb https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-h77eed37_1.conda#6185f640c43843e5ad6fd1c5372c3f80 -https://conda.anaconda.org/conda-forge/noarch/kernel-headers_linux-64-2.6.32-he073ed8_16.conda#7ca122655873935e02c91279c5b03c8c +https://conda.anaconda.org/conda-forge/noarch/kernel-headers_linux-64-2.6.32-he073ed8_17.conda#d731b543793afc0433c4fd593e693fce https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda#7aca3059a1729aa76c597603f10b0dd3 https://conda.anaconda.org/conda-forge/noarch/libgcc-devel_linux-64-12.3.0-h8bca6fd_105.conda#e12ce6b051085b8f27e239f5e5f5bce5 https://conda.anaconda.org/conda-forge/noarch/libstdcxx-devel_linux-64-12.3.0-h8bca6fd_105.conda#b3c6062c84a8e172555ee104ea6a01ab @@ -17,7 +17,7 @@ https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.9-4_cp39.conda#bfe4 https://conda.anaconda.org/conda-forge/noarch/tzdata-2024a-h0c530f3_0.conda#161081fc7cec0bfda0d86d7cb595f8d8 https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2#f766549260d6815b0c52253f1fb1bb29 https://conda.anaconda.org/conda-forge/linux-64/libgomp-13.2.0-h807b86a_5.conda#d211c42b9ce49aee3734fdc828731689 -https://conda.anaconda.org/conda-forge/noarch/sysroot_linux-64-2.12-he073ed8_16.conda#071ea8dceff4d30ac511f4a2f8437cd1 +https://conda.anaconda.org/conda-forge/noarch/sysroot_linux-64-2.12-he073ed8_17.conda#595db67e32b276298ff3d94d07d47fbf https://conda.anaconda.org/conda-forge/linux-64/binutils_impl_linux-64-2.40-hf600244_0.conda#33084421a8c0af6aef1b439707f7662a https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2#fee5683a3f04bd15cbd8318b096a27ab https://conda.anaconda.org/conda-forge/linux-64/binutils-2.40-hdd6e379_0.conda#ccc940fddbc3fcd3d79cd4c654c4b5c4 @@ -109,14 +109,14 @@ https://conda.anaconda.org/conda-forge/linux-64/gfortran_impl_linux-64-12.3.0-hf https://conda.anaconda.org/conda-forge/linux-64/gxx_impl_linux-64-12.3.0-he2b93b0_5.conda#cddba8fd94e52012abea1caad722b9c2 https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.3-hd590300_0.conda#32d16ad533c59bb0a3c5ffaf16110829 -https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.3-h783c2da_0.conda#9bd06b12bbfa6fd1740fd23af4b0f0c7 +https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.4-h783c2da_0.conda#d86baf8740d1a906b9716f2a0bac2f2d https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-hb3ce162_4.conda#8a35df3cbc0c8b12cc8af9473ae75eef https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.26-pthreads_h413a1c8_0.conda#760ae35415f5ba8b15d09df5afe8b23a https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda#ef1910918dd895516a769ed36b5b3a4e https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda#55ed21669b2015f77c180feb1dd41930 https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.6-h4dfa4b3_0.conda#c1665f9c1c9f6c93d8b4e492a6a39056 https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_6.conda#e87530d1b12dd7f4e0f856dc07358d60 -https://conda.anaconda.org/conda-forge/linux-64/nss-3.97-h1d7d5a4_0.conda#b916d71a3032416e3f9136090d814472 +https://conda.anaconda.org/conda-forge/linux-64/nss-3.98-h1d7d5a4_0.conda#54b56c2fdf973656b748e0378900ec13 https://conda.anaconda.org/conda-forge/linux-64/python-3.9.18-h0755675_1_cpython.conda#255a7002aeec7a067ff19b545aca6328 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-hd590300_1.conda#9bfac7ccd94d54fd21a0501296d60424 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.0-h8ee46fc_1.conda#632413adcd8bc16b515cab87a2932913 @@ -139,7 +139,7 @@ https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#6 https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda#0f69b688f52ff6da70bccb7ff7001d1d https://conda.anaconda.org/conda-forge/linux-64/gfortran-12.3.0-h499e0f7_2.conda#0558a8c44eb7a18e6682bd3a8ae6dcab https://conda.anaconda.org/conda-forge/linux-64/gfortran_linux-64-12.3.0-h7fe76b4_2.conda#3a749210487c0358b6f135a648cbbf60 -https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.3-hfc55251_0.conda#41d2f46e0ac8372eeb959860713d9b21 +https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.4-hfc55251_0.conda#d184ba1bf15a2bbb3be6118c90fd487d https://conda.anaconda.org/conda-forge/linux-64/gxx-12.3.0-h8d2909c_2.conda#673bac341be6b90ef9e8abae7e52ca46 https://conda.anaconda.org/conda-forge/linux-64/gxx_linux-64-12.3.0-h8a814eb_2.conda#f517b1525e9783849bd56a5dc45a9960 https://conda.anaconda.org/conda-forge/noarch/idna-3.6-pyhd8ed1ab_0.conda#1a76f09108576397c41c0b0c5bd84134 @@ -167,18 +167,18 @@ https://conda.anaconda.org/conda-forge/linux-64/psutil-5.9.8-py39hd1e30aa_0.cond https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2#2a7de29fb590ca14b5243c4c812c8025 -https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2023.4-pyhd8ed1ab_0.conda#c79cacf8a06a51552fc651652f170208 +https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2024.1-pyhd8ed1ab_0.conda#98206ea9954216ee7540f0c773f2104d https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda#3eeeeb9e4827ace8c0c1419c85d590ad -https://conda.anaconda.org/conda-forge/noarch/setuptools-69.0.3-pyhd8ed1ab_0.conda#40695fdfd15a92121ed2922900d0308b +https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.0-pyhd8ed1ab_1.conda#d76a248ad1b9d4a79c2ce39ee41d626c https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2#e5f25f8dbc060e9a8d912e432202afc2 https://conda.anaconda.org/conda-forge/noarch/snowballstemmer-2.2.0-pyhd8ed1ab_0.tar.bz2#4d22a9315e78c6827f806065957d566e https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_0.conda#da1d979339e2714c30a8e806a33ec087 https://conda.anaconda.org/conda-forge/noarch/tabulate-0.9.0-pyhd8ed1ab_1.tar.bz2#4759805cce2d914c38472f70bf4d8bcb https://conda.anaconda.org/conda-forge/noarch/tenacity-8.2.3-pyhd8ed1ab_0.conda#1482e77f87c6a702a7e05ef22c9b197b -https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.2.0-pyha21a80b_0.conda#978d03388b62173b8e6f79162cf52b86 +https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.3.0-pyhc1e730c_0.conda#698d2d2b621640bddb9191f132967c9f https://conda.anaconda.org/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2#f832c45a477c78bebd107098db465095 https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5844808ffab9ebdb694585b50ba02a96 -https://conda.anaconda.org/conda-forge/linux-64/tornado-6.3.3-py39hd1e30aa_1.conda#cbe186eefb0bcd91e8f47c3908489874 +https://conda.anaconda.org/conda-forge/linux-64/tornado-6.4-py39hd1e30aa_0.conda#1e865e9188204cdfb1fd2531780add88 https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.9.0-pyha770c72_0.conda#a92a6440c3fe7052d63244f3aba2a4a7 https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-15.1.0-py39hd1e30aa_0.conda#1da984bbb6e765743e13388ba7b7b2c8 https://conda.anaconda.org/conda-forge/noarch/wheel-0.42.0-pyhd8ed1ab_0.conda#1cdea58981c5cbc17b51973bcaddcea7 @@ -191,9 +191,9 @@ https://conda.anaconda.org/conda-forge/noarch/babel-2.14.0-pyhd8ed1ab_0.conda#96 https://conda.anaconda.org/conda-forge/linux-64/brunsli-0.1-h9c3ff4c_0.tar.bz2#c1ac6229d0bfd14f8354ff9ad2a26cad https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda#f907bb958910dc404647326ca80c263e https://conda.anaconda.org/conda-forge/linux-64/cxx-compiler-1.7.0-h00ab1b0_0.conda#b4537c98cb59f8725b0e1e65816b4a28 -https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.48.1-py39hd1e30aa_0.conda#402ef3d9608c7653187a3fd6fd45b445 +https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.49.0-py39hd1e30aa_0.conda#dd1b02484cc8c31d4093111a82b6efb2 https://conda.anaconda.org/conda-forge/linux-64/fortran-compiler-1.7.0-heb67821_0.conda#7ef7c0f111dad1c8006504a0f1ccd820 -https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.3-hfc55251_0.conda#e08e51acc7d1ae8dbe13255e7b4c64ac +https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.4-hfc55251_0.conda#f36a7b2420c3fc3c48a3d609841d8fee https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-7.0.1-pyha770c72_0.conda#746623a787e06191d80a2133e5daff17 https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.1-pyhd8ed1ab_0.conda#3d5fa25cf42f3f32a12b2d874ace8574 https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.3-pyhd8ed1ab_0.conda#e7d8df6509ba635247ff9aea31134262 @@ -206,12 +206,12 @@ https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0. https://conda.anaconda.org/conda-forge/noarch/memory_profiler-0.61.0-pyhd8ed1ab_0.tar.bz2#8b45f9f2b2f7a98b0ec179c8991a4a9b https://conda.anaconda.org/conda-forge/linux-64/pillow-10.2.0-py39had0adad_0.conda#2972754dc054bb079d1d121918b5126f https://conda.anaconda.org/conda-forge/noarch/pip-24.0-pyhd8ed1ab_0.conda#f586ac1e56c8638b64f9c8122a7b8a67 -https://conda.anaconda.org/conda-forge/noarch/plotly-5.18.0-pyhd8ed1ab_0.conda#9f6a8664f1fe752f79473eeb9bf33a60 +https://conda.anaconda.org/conda-forge/noarch/plotly-5.19.0-pyhd8ed1ab_0.conda#669cd7065794633b9e64e6a9612ec700 https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-hb77b528_5.conda#ac902ff3c1c6d750dd0dfc93a974ab74 -https://conda.anaconda.org/conda-forge/noarch/pytest-8.0.0-pyhd8ed1ab_0.conda#5ba1cc5b924226349d4a49fb547b7579 +https://conda.anaconda.org/conda-forge/noarch/pytest-8.0.1-pyhd8ed1ab_1.conda#246e33679291d4e85111d812e5103de7 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.12-py39h3d6467e_0.conda#e667a3ab0df62c54e60e1843d2e6defb -https://conda.anaconda.org/conda-forge/noarch/urllib3-2.2.0-pyhd8ed1ab_0.conda#6a7e0694921f668a030d52f0c47baebd +https://conda.anaconda.org/conda-forge/noarch/urllib3-2.2.1-pyhd8ed1ab_0.conda#08807a87fa7af10754d46f63b368e016 https://conda.anaconda.org/conda-forge/linux-64/compilers-1.7.0-ha770c72_0.conda#81458b3aed8ab8711951ec3c0c04e097 https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.9-h98fc4e7_0.conda#bcc7157b06fce7f5e055402a8135dfd8 https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.3.0-h3d44ed6_0.conda#5a6f6c00ef982a9bc83558d9ac8f64a0 @@ -226,23 +226,23 @@ https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-21_linux64_open https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py39h7633fee_0.conda#ed71ad3e30eb03da363fb797419cce98 https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.9-h8e1006c_0.conda#614b81f8ed66c56b640faee7076ad14a https://conda.anaconda.org/conda-forge/linux-64/imagecodecs-2024.1.1-py39hf9b8f0e_0.conda#9ddd29852457d1152ca235eb87bc74fb -https://conda.anaconda.org/conda-forge/noarch/imageio-2.33.1-pyh8c1a49c_0.conda#1c34d58ac469a34e7e96832861368bce +https://conda.anaconda.org/conda-forge/noarch/imageio-2.34.0-pyh4b66e23_0.conda#b8853659d596f967c661f544dd89ede7 https://conda.anaconda.org/conda-forge/linux-64/pandas-2.2.0-py39hddac248_0.conda#95aaa7baa61432a1ce85dedb7b86d2dd https://conda.anaconda.org/conda-forge/noarch/patsy-0.5.6-pyhd8ed1ab_0.conda#a5b55d1cb110cdcedc748b5c3e16e687 -https://conda.anaconda.org/conda-forge/linux-64/polars-0.20.7-py39h927a070_0.conda#24a2968bb1f6630daa0da4368aeeeb64 +https://conda.anaconda.org/conda-forge/linux-64/polars-0.20.10-py39h927a070_0.conda#2c626921a52a9571bda297ef0fceb15a https://conda.anaconda.org/conda-forge/noarch/pooch-1.8.0-pyhd8ed1ab_0.conda#134b2b57b7865d2316a7cce1915a51ed https://conda.anaconda.org/conda-forge/linux-64/pywavelets-1.4.1-py39h44dd56e_1.conda#d037c20e3da2e85f03ebd20ad480c359 https://conda.anaconda.org/conda-forge/linux-64/scipy-1.12.0-py39h474f0d3_2.conda#6ab241b2023730f6b41712dc1b503afa https://conda.anaconda.org/conda-forge/linux-64/blas-2.121-openblas.conda#4a279792fd8861a15705516a52872eb6 -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.2-py39he9076e7_0.conda#6085411aa2f0b2b801d3b46e1d3b83c5 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.3-py39he9076e7_0.conda#5456bdfe5809ebf5689eda6c808b686e https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.0.1-py39hda80f44_1.conda#6df47699edb4d8d3365de2d189a456bc -https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h450f30e_18.conda#ef0430f8df5dcdedcaaab340b228f30c +https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h5810be5_19.conda#54866f708d43002a514d0b9b0f84bc11 https://conda.anaconda.org/conda-forge/linux-64/statsmodels-0.14.1-py39h44dd56e_0.conda#dc565186b972bd87e49b9c35390ddd8c -https://conda.anaconda.org/conda-forge/noarch/tifffile-2024.1.30-pyhd8ed1ab_0.conda#9ae618ad19f5b39955c9f2e43b8d03c3 +https://conda.anaconda.org/conda-forge/noarch/tifffile-2024.2.12-pyhd8ed1ab_0.conda#d5c8bef52be4e70c48b1400eec3eecc8 https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.9-py39h52134e7_5.conda#e1f148e57d071b09187719df86f513c1 https://conda.anaconda.org/conda-forge/linux-64/scikit-image-0.22.0-py39hddac248_2.conda#8d502a4d2cbe5a45ff35ca8af8cbec0a https://conda.anaconda.org/conda-forge/noarch/seaborn-base-0.13.2-pyhd8ed1ab_0.conda#0918a9201e824211cdf444dbf8d55752 -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.2-py39hf3d152e_0.conda#18d40a5ada9a801cabaf5d47c15c6282 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.3-py39hf3d152e_0.conda#983f5b77540eb5aa00238e72ec9b1dfb https://conda.anaconda.org/conda-forge/noarch/seaborn-0.13.2-hd8ed1ab_0.conda#fd31ebf5867914de597f9961c478e482 https://conda.anaconda.org/conda-forge/noarch/numpydoc-1.6.0-pyhd8ed1ab_0.conda#191b8a622191a403700d16a2008e4e29 https://conda.anaconda.org/conda-forge/noarch/sphinx-copybutton-0.5.2-pyhd8ed1ab_0.conda#ac832cc43adc79118cf6e23f1f9b8995 @@ -260,20 +260,20 @@ https://conda.anaconda.org/conda-forge/noarch/sphinxext-opengraph-0.9.1-pyhd8ed1 # pip defusedxml @ https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl#sha256=a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61 # pip fastjsonschema @ https://files.pythonhosted.org/packages/9c/b9/79691036d4a8f9857e74d1728b23f34f583b81350a27492edda58d5604e1/fastjsonschema-2.19.1-py3-none-any.whl#sha256=3672b47bc94178c9f23dbb654bf47440155d4db9df5f7bc47643315f9c405cd0 # pip fqdn @ https://files.pythonhosted.org/packages/cf/58/8acf1b3e91c58313ce5cb67df61001fc9dcd21be4fadb76c1a2d540e09ed/fqdn-1.5.1-py3-none-any.whl#sha256=3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014 -# pip json5 @ https://files.pythonhosted.org/packages/70/ba/fa37123a86ae8287d6678535a944f9c3377d8165e536310ed6f6cb0f0c0e/json5-0.9.14-py2.py3-none-any.whl#sha256=740c7f1b9e584a468dbb2939d8d458db3427f2c93ae2139d05f47e453eae964f +# pip json5 @ https://files.pythonhosted.org/packages/7c/c3/da3b0c409453ae2d39bcfd04007249fa2f50005d365609a7497a4bbb81f1/json5-0.9.17-py2.py3-none-any.whl#sha256=f8ec1ecf985951d70f780f6f877c4baca6a47b6e61e02c4cd190138d10a7805a # pip jsonpointer @ https://files.pythonhosted.org/packages/12/f6/0232cc0c617e195f06f810534d00b74d2f348fe71b2118009ad8ad31f878/jsonpointer-2.4-py2.py3-none-any.whl#sha256=15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a # pip jupyterlab-pygments @ https://files.pythonhosted.org/packages/b1/dd/ead9d8ea85bf202d90cc513b533f9c363121c7792674f78e0d8a854b63b4/jupyterlab_pygments-0.3.0-py3-none-any.whl#sha256=841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780 # pip mistune @ https://files.pythonhosted.org/packages/f0/74/c95adcdf032956d9ef6c89a9b8a5152bf73915f8c633f3e3d88d06bd699c/mistune-3.0.2-py3-none-any.whl#sha256=71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205 # pip overrides @ https://files.pythonhosted.org/packages/2c/ab/fc8290c6a4c722e5514d80f62b2dc4c4df1a68a41d1364e625c35990fcf3/overrides-7.7.0-py3-none-any.whl#sha256=c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49 # pip pandocfilters @ https://files.pythonhosted.org/packages/ef/af/4fbc8cab944db5d21b7e2a5b8e9211a03a79852b1157e2c102fcc61ac440/pandocfilters-1.5.1-py2.py3-none-any.whl#sha256=93be382804a9cdb0a7267585f157e5d1731bbe5545a85b268d6f5fe6232de2bc # pip pkginfo @ https://files.pythonhosted.org/packages/b3/f2/6e95c86a23a30fa205ea6303a524b20cbae27fbee69216377e3d95266406/pkginfo-1.9.6-py3-none-any.whl#sha256=4b7a555a6d5a22169fcc9cf7bfd78d296b0361adad412a346c1226849af5e546 -# pip prometheus-client @ https://files.pythonhosted.org/packages/bb/9f/ad934418c48d01269fc2af02229ff64bcf793fd5d7f8f82dc5e7ea7ef149/prometheus_client-0.19.0-py3-none-any.whl#sha256=c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92 +# pip prometheus-client @ https://files.pythonhosted.org/packages/c7/98/745b810d822103adca2df8decd4c0bbe839ba7ad3511af3f0d09692fc0f0/prometheus_client-0.20.0-py3-none-any.whl#sha256=cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7 # pip ptyprocess @ https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35 # pip pycparser @ https://files.pythonhosted.org/packages/62/d5/5f610ebe421e85889f2e55e33b7f9a6795bd982198517d912eb1c76e1a53/pycparser-2.21-py2.py3-none-any.whl#sha256=8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9 # pip python-json-logger @ https://files.pythonhosted.org/packages/35/a6/145655273568ee78a581e734cf35beb9e33a370b29c5d3c8fee3744de29f/python_json_logger-2.0.7-py3-none-any.whl#sha256=f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd # pip pyyaml @ https://files.pythonhosted.org/packages/7d/39/472f2554a0f1e825bd7c5afc11c817cd7a2f3657460f7159f691fbb37c51/PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c # pip rfc3986-validator @ https://files.pythonhosted.org/packages/9e/51/17023c0f8f1869d8806b979a2bffa3f861f26a3f1a66b094288323fba52f/rfc3986_validator-0.1.1-py2.py3-none-any.whl#sha256=2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9 -# pip rpds-py @ https://files.pythonhosted.org/packages/c2/e9/190521d63b504c12bdcffb27ea6aaac1dbb2521be983c3a2a0ab4a938b8c/rpds_py-0.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=dfe07308b311a8293a0d5ef4e61411c5c20f682db6b5e73de6c7c8824272c256 +# pip rpds-py @ https://files.pythonhosted.org/packages/fd/ea/92231b62681961812e9fbd8ef9be7137856784406bf6a384976bb7b46472/rpds_py-0.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=ddc2f4dfd396c7bfa18e6ce371cba60e4cf9d2e5cdb71376aa2da264605b60b9 # pip send2trash @ https://files.pythonhosted.org/packages/a9/78/e4df1e080ed790acf3a704edf521006dd96b9841bd2e2a462c0d255e0565/Send2Trash-1.8.2-py3-none-any.whl#sha256=a384719d99c07ce1eefd6905d2decb6f8b7ed054025bb0e618919f945de4f679 # pip sniffio @ https://files.pythonhosted.org/packages/c3/a0/5dba8ed157b0136607c7f2151db695885606968d1fae123dc3391e0cfdbf/sniffio-1.3.0-py3-none-any.whl#sha256=eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384 # pip soupsieve @ https://files.pythonhosted.org/packages/4c/f3/038b302fdfbe3be7da016777069f26ceefe11a681055ea1f7817546508e3/soupsieve-2.5-py3-none-any.whl#sha256=eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7 @@ -283,7 +283,7 @@ https://conda.anaconda.org/conda-forge/noarch/sphinxext-opengraph-0.9.1-pyhd8ed1 # pip webcolors @ https://files.pythonhosted.org/packages/d5/e1/3e9013159b4cbb71df9bd7611cbf90dc2c621c8aeeb677fc41dad72f2261/webcolors-1.13-py3-none-any.whl#sha256=29bc7e8752c0a1bd4a1f03c14d6e6a72e93d82193738fa860cbff59d0fcc11bf # pip webencodings @ https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl#sha256=a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78 # pip websocket-client @ https://files.pythonhosted.org/packages/1e/70/1e88138a9afbed1d37093b85f0bebc3011623c4f47c166431599fe9d6c93/websocket_client-1.7.0-py3-none-any.whl#sha256=f4c3d22fec12a2461427a29957ff07d35098ee2d976d3ba244e688b8b4057588 -# pip anyio @ https://files.pythonhosted.org/packages/bf/cd/d6d9bb1dadf73e7af02d18225cbd2c93f8552e13130484f1c8dcfece292b/anyio-4.2.0-py3-none-any.whl#sha256=745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee +# pip anyio @ https://files.pythonhosted.org/packages/14/fd/2f20c40b45e4fb4324834aea24bd4afdf1143390242c0b33774da0e2e34f/anyio-4.3.0-py3-none-any.whl#sha256=048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8 # pip arrow @ https://files.pythonhosted.org/packages/f8/ed/e97229a566617f2ae958a6b13e7cc0f585470eac730a73e9e82c32a3cdd2/arrow-1.3.0-py3-none-any.whl#sha256=c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80 # pip beautifulsoup4 @ https://files.pythonhosted.org/packages/b1/fe/e8c672695b37eecc5cbf43e1d0638d88d66ba3a44c4d321c796f4e59167f/beautifulsoup4-4.12.3-py3-none-any.whl#sha256=b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed # pip bleach @ https://files.pythonhosted.org/packages/ea/63/da7237f805089ecc28a3f36bca6a21c31fcbc2eb380f3b8f1be3312abd14/bleach-6.1.0-py3-none-any.whl#sha256=3225f354cfc436b9789c66c4ee030194bee0568fbf9cbdad3bc8b5c26c5f12b6 @@ -303,11 +303,11 @@ https://conda.anaconda.org/conda-forge/noarch/sphinxext-opengraph-0.9.1-pyhd8ed1 # pip argon2-cffi @ https://files.pythonhosted.org/packages/a4/6a/e8a041599e78b6b3752da48000b14c8d1e8a04ded09c88c714ba047f34f5/argon2_cffi-23.1.0-py3-none-any.whl#sha256=c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea # pip jsonschema @ https://files.pythonhosted.org/packages/39/9d/b035d024c62c85f2e2d4806a59ca7b8520307f34e0932fbc8cc75fe7b2d9/jsonschema-4.21.1-py3-none-any.whl#sha256=7996507afae316306f9e2290407761157c6f78002dcf7419acb99822143d1c6f # pip jupyter-client @ https://files.pythonhosted.org/packages/43/ae/5f4f72980765e2e5e02b260f9c53bcc706cefa7ac9c8d7240225c55788d4/jupyter_client-8.6.0-py3-none-any.whl#sha256=909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99 -# pip jupyterlite-pyodide-kernel @ https://files.pythonhosted.org/packages/08/19/2ef7099e28a9e411e1eb901edb089e43c0321128651c35c6051baba36577/jupyterlite_pyodide_kernel-0.2.2-py3-none-any.whl#sha256=d452e5a4fc5af1cf84073b339b0033e9d4726c9978fe414036604ddecf39ed10 +# pip jupyterlite-pyodide-kernel @ https://files.pythonhosted.org/packages/41/3f/b1e5d76beaddd94f47cc40b8430cca1e178c3acc53cce8556156991845ac/jupyterlite_pyodide_kernel-0.2.3-py3-none-any.whl#sha256=32b30d0f5ea5b87470cd36f824589e705c9bdaa8f7072d534aea04ec2c7993dc # pip jupyter-events @ https://files.pythonhosted.org/packages/e3/55/0c1aa72f4317e826a471dc4adc3036acd11d496ded68c4bbac2a88551519/jupyter_events-0.9.0-py3-none-any.whl#sha256=d853b3c10273ff9bc8bb8b30076d65e2c9685579db736873de6c2232dde148bf # pip nbformat @ https://files.pythonhosted.org/packages/f4/e7/ef30a90b70eba39e675689b9eaaa92530a71d7435ab8f9cae520814e0caf/nbformat-5.9.2-py3-none-any.whl#sha256=1c5172d786a41b82bcfd0c23f9e6b6f072e8fb49c39250219e4acfff1efe89e9 # pip nbclient @ https://files.pythonhosted.org/packages/6b/3a/607149974149f847125c38a62b9ea2b8267eb74823bbf8d8c54ae0212a00/nbclient-0.9.0-py3-none-any.whl#sha256=a3a1ddfb34d4a9d17fc744d655962714a866639acd30130e9be84191cd97cd15 -# pip nbconvert @ https://files.pythonhosted.org/packages/c9/ec/c120b21e7f884a701e12a241992754e719adaf430d0d6b30c6655776bc35/nbconvert-7.16.0-py3-none-any.whl#sha256=ad3dc865ea6e2768d31b7eb6c7ab3be014927216a5ece3ef276748dd809054c7 +# pip nbconvert @ https://files.pythonhosted.org/packages/dc/6f/2c4e3dafb36dff2c98a170c1d61275f2e2d6bfd0f07d25771c1c18a6a529/nbconvert-7.16.1-py3-none-any.whl#sha256=3188727dffadfdc9c6a1c7250729063d7bc78b355ad7aa023138afa030d1cd07 # pip jupyter-server @ https://files.pythonhosted.org/packages/25/d6/6ee093c967d11144aeb1b0b4952d30e51da8eb2737837ab612084c783a58/jupyter_server-2.12.5-py3-none-any.whl#sha256=184a0f82809a8522777cfb6b760ab6f4b1bb398664c5860a27cec696cb884923 -# pip jupyterlab-server @ https://files.pythonhosted.org/packages/a2/97/abbbe35fc67b6f9423309988f2e411f7cb117b08321866d3d8b720f4c0d4/jupyterlab_server-2.25.2-py3-none-any.whl#sha256=5b1798c9cc6a44f65c757de9f97fc06fc3d42535afbf47d2ace5e964ab447aaf +# pip jupyterlab-server @ https://files.pythonhosted.org/packages/ab/ac/a19c579bb8ab2a2aefcf47cd3787683e6e136378d7ab2602be3b8e628030/jupyterlab_server-2.25.3-py3-none-any.whl#sha256=c48862519fded9b418c71645d85a49b2f0ec50d032ba8316738e9276046088c1 # pip jupyterlite-sphinx @ https://files.pythonhosted.org/packages/9c/bd/1695eebeb376315c9fc5cbd41c54fb84bb69c68e69651bfc6f03aa4fe659/jupyterlite_sphinx-0.11.0-py3-none-any.whl#sha256=2a0762167e89ec6acd267c73bb90b528728fdba5e30390ea4fe37ddcec277191 From 6433ba33a5184639e3bb4e6aba458b741080484a Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 21 Feb 2024 13:57:36 +0100 Subject: [PATCH 38/55] fix imports --- sklearn/callback/_progressbar.py | 2 +- sklearn/callback/tests/test_progressbar.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index 00f55958ae42c..ee5669acf5c60 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -4,7 +4,7 @@ from multiprocessing import Manager from threading import Thread -from ..utils import check_rich_support +from ..utils._optional_dependencies import check_rich_support from . import BaseCallback diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py index 4c4db3fd8f7ac..20aab8f4c4ab7 100644 --- a/sklearn/callback/tests/test_progressbar.py +++ b/sklearn/callback/tests/test_progressbar.py @@ -6,7 +6,7 @@ import pytest from sklearn.callback import ProgressBar -from sklearn.utils import check_rich_support +from sklearn.utils._optional_dependencies import check_rich_support from sklearn.utils._testing import SkipTest from ._utils import Estimator, MetaEstimator From 268d5cfd363e5c70ddac77b27967e3d767347520 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 23 Feb 2024 14:42:18 +0100 Subject: [PATCH 39/55] update lock files --- build_tools/azure/debian_atlas_32bit_lock.txt | 2 +- ...latest_conda_forge_mkl_linux-64_conda.lock | 2 +- ...pylatest_conda_forge_mkl_osx-64_conda.lock | 2 +- ...test_conda_mkl_no_openmp_osx-64_conda.lock | 69 +++++++++---------- ...st_pip_openblas_pandas_linux-64_conda.lock | 2 +- ...pylatest_pip_scipy_dev_linux-64_conda.lock | 2 +- .../pymin_conda_forge_mkl_win-64_conda.lock | 2 +- build_tools/circle/doc_linux-64_conda.lock | 2 +- .../doc_min_dependencies_linux-64_conda.lock | 2 +- 9 files changed, 42 insertions(+), 43 deletions(-) diff --git a/build_tools/azure/debian_atlas_32bit_lock.txt b/build_tools/azure/debian_atlas_32bit_lock.txt index 6f2cac31e4eb9..c171fa12d2080 100644 --- a/build_tools/azure/debian_atlas_32bit_lock.txt +++ b/build_tools/azure/debian_atlas_32bit_lock.txt @@ -6,7 +6,7 @@ # attrs==23.2.0 # via pytest -coverage==7.4.1 +coverage==7.4.2 # via pytest-cov cython==3.0.8 # via -r build_tools/azure/debian_atlas_32bit_requirements.txt diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock index 7db1d5002e929..9d4b87879804f 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock @@ -168,7 +168,7 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.11-hd590300_ https://conda.anaconda.org/conda-forge/linux-64/aws-c-auth-0.7.3-h28f7589_1.conda#97503d3e565004697f1651753aa95b9e https://conda.anaconda.org/conda-forge/linux-64/aws-c-mqtt-0.9.3-hb447be9_1.conda#c520669eb0be9269a5f0d8ef62531882 https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda#f907bb958910dc404647326ca80c263e -https://conda.anaconda.org/conda-forge/linux-64/coverage-7.4.1-py311h459d7ec_0.conda#9caf3270065a2d40fd9a443ba1568e96 +https://conda.anaconda.org/conda-forge/linux-64/coverage-7.4.2-py311h459d7ec_0.conda#4a7131a3590a262cc19df4aaa497d0f7 https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.49.0-py311h459d7ec_0.conda#d66c9e36ab104f94e35b015c86c2fcb4 https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.4-hfc55251_0.conda#f36a7b2420c3fc3c48a3d609841d8fee https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc diff --git a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock index e8c55ea9d3eb9..4e45857e600be 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock @@ -90,7 +90,7 @@ https://conda.anaconda.org/conda-forge/osx-64/tornado-6.4-py312h41838bb_0.conda# https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.9.0-pyha770c72_0.conda#a92a6440c3fe7052d63244f3aba2a4a7 https://conda.anaconda.org/conda-forge/osx-64/cctools-973.0.1-h40f6528_16.conda#b7234c329d4503600b032f168f4b65e7 https://conda.anaconda.org/conda-forge/osx-64/clang-16.0.6-hdae98eb_5.conda#5f020dce5a00342141d87f952c9c0282 -https://conda.anaconda.org/conda-forge/osx-64/coverage-7.4.1-py312h41838bb_0.conda#4a89ca53df4faeca1b88d63f12267433 +https://conda.anaconda.org/conda-forge/osx-64/coverage-7.4.2-py312h41838bb_0.conda#a26d3ca323080d6203ea9395def98a72 https://conda.anaconda.org/conda-forge/osx-64/fonttools-4.49.0-py312h41838bb_0.conda#910043c784378419df3160b7661ee915 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc https://conda.anaconda.org/conda-forge/osx-64/libblas-3.9.0-20_osx64_mkl.conda#160fdc97a51d66d51dc782fb67d35205 diff --git a/build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock b/build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock index 144940c0bf92a..bb3db90badf1e 100644 --- a/build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock +++ b/build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock @@ -17,7 +17,6 @@ https://repo.anaconda.com/pkgs/main/noarch/tzdata-2023d-h04d1e81_0.conda#fdb3195 https://repo.anaconda.com/pkgs/main/osx-64/xz-5.4.5-h6c40b1e_0.conda#351c5d33fe551018a2068e7a2ca8a6c1 https://repo.anaconda.com/pkgs/main/osx-64/zlib-1.2.13-h4dc903c_0.conda#d0202dd912bfb45d3422786531717882 https://repo.anaconda.com/pkgs/main/osx-64/ccache-3.7.9-hf120daa_0.conda#a01515a32e721c51d631283f991bc8ea -https://repo.anaconda.com/pkgs/main/osx-64/expat-2.5.0-hcec6c5f_0.conda#ce90fd42031d3c01944146f089a9130b https://repo.anaconda.com/pkgs/main/osx-64/intel-openmp-2023.1.0-ha357a0b_43548.conda#ba8a89ffe593eb88e4c01334753c40c3 https://repo.anaconda.com/pkgs/main/osx-64/lerc-3.0-he9d5cce_0.conda#aec2c3dbef836849c9260f05be04f3db https://repo.anaconda.com/pkgs/main/osx-64/libbrotlidec-1.0.9-hca72f7f_7.conda#b85983951745cc666d9a1b42894210b2 @@ -37,49 +36,49 @@ https://repo.anaconda.com/pkgs/main/osx-64/sqlite-3.41.2-h6c40b1e_0.conda#6947a5 https://repo.anaconda.com/pkgs/main/osx-64/zstd-1.5.5-hc035e20_0.conda#5e0b7ddb1b7dc6b630e1f9a03499c19c https://repo.anaconda.com/pkgs/main/osx-64/brotli-1.0.9-hca72f7f_7.conda#68e54d12ec67591deb2ffd70348fb00f https://repo.anaconda.com/pkgs/main/osx-64/libtiff-4.5.1-hcec6c5f_0.conda#e127a800ffd9d300ed7d5e1b026944ec -https://repo.anaconda.com/pkgs/main/osx-64/python-3.12.1-hd58486a_0.conda#92c7d155dcfbf64036ccccc3bd95f241 -https://repo.anaconda.com/pkgs/main/osx-64/coverage-7.2.2-py312h6c40b1e_0.conda#b6e4b9fba325047c07f3c9211ae91d1c +https://repo.anaconda.com/pkgs/main/osx-64/python-3.11.7-hf27a42d_0.conda#fe0cfacb8965d0a06f8098464d5a8402 +https://repo.anaconda.com/pkgs/main/osx-64/coverage-7.2.2-py311h6c40b1e_0.conda#e15605553450156cf75c3ae38a920475 https://repo.anaconda.com/pkgs/main/noarch/cycler-0.11.0-pyhd3eb1b0_0.conda#f5e365d2cdb66d547eb8c3ab93843aab https://repo.anaconda.com/pkgs/main/noarch/execnet-1.9.0-pyhd3eb1b0_0.conda#f895937671af67cebb8af617494b3513 https://repo.anaconda.com/pkgs/main/noarch/iniconfig-1.1.1-pyhd3eb1b0_0.tar.bz2#e40edff2c5708f342cef43c7f280c507 -https://repo.anaconda.com/pkgs/main/osx-64/joblib-1.2.0-py312hecd8cb5_0.conda#aeeb33f85c1e6776700b67a4762d2e6d -https://repo.anaconda.com/pkgs/main/osx-64/kiwisolver-1.4.4-py312hcec6c5f_0.conda#2ba6561ddd1d05936fe74f5d118ce7dd +https://repo.anaconda.com/pkgs/main/osx-64/joblib-1.2.0-py311hecd8cb5_0.conda#af8c1fcd4e8e0c6fa2a4f4ecda261dc9 +https://repo.anaconda.com/pkgs/main/osx-64/kiwisolver-1.4.4-py311hcec6c5f_0.conda#f2cf31e2a762f071fd6bc4d74ea2bfc8 https://repo.anaconda.com/pkgs/main/osx-64/lcms2-2.12-hf1fd2bf_0.conda#697aba7a3308226df7a93ccfeae16ffa -https://repo.anaconda.com/pkgs/main/osx-64/mdurl-0.1.0-py312hecd8cb5_0.conda#0d3a6bae224df024c474dfc062324218 -https://repo.anaconda.com/pkgs/main/osx-64/mkl-service-2.4.0-py312h6c40b1e_1.conda#b1ef860be9043b35c5e8d9388b858514 +https://repo.anaconda.com/pkgs/main/osx-64/mdurl-0.1.0-py311hecd8cb5_0.conda#18c799813d3fdd2ea68c807d54038820 +https://repo.anaconda.com/pkgs/main/osx-64/mkl-service-2.4.0-py311h6c40b1e_1.conda#f709b80c57a0fcc577319920d1b7228b https://repo.anaconda.com/pkgs/main/noarch/munkres-1.1.4-py_0.conda#148362ba07f92abab76999a680c80084 https://repo.anaconda.com/pkgs/main/osx-64/openjpeg-2.4.0-h66ea3da_0.conda#882833bd7befc5e60e6fba9c518c1b79 -https://repo.anaconda.com/pkgs/main/osx-64/packaging-23.1-py312hecd8cb5_0.conda#27f59725d093a50f366eaeba0db9ec61 -https://repo.anaconda.com/pkgs/main/osx-64/pluggy-1.0.0-py312hecd8cb5_1.conda#647fada22f1697691fdee90b52c99bcb -https://repo.anaconda.com/pkgs/main/osx-64/pygments-2.15.1-py312hecd8cb5_1.conda#76178b3f791217ae17fcb1a295ffdb84 -https://repo.anaconda.com/pkgs/main/osx-64/pyparsing-3.0.9-py312hecd8cb5_0.conda#d85cf2b81c6d9326a57a6418e14db258 +https://repo.anaconda.com/pkgs/main/osx-64/packaging-23.1-py311hecd8cb5_0.conda#4f5c491cd2de9d61f61c0ea3340ab46a +https://repo.anaconda.com/pkgs/main/osx-64/pluggy-1.0.0-py311hecd8cb5_1.conda#98e4da64cd934965a0caf4136280ff35 +https://repo.anaconda.com/pkgs/main/osx-64/pygments-2.15.1-py311hecd8cb5_1.conda#9e0f1e7667af6f469dfba22aa87dc6e2 +https://repo.anaconda.com/pkgs/main/osx-64/pyparsing-3.0.9-py311hecd8cb5_0.conda#a4262f849ecc82af69f58da0cbcaaf04 https://repo.anaconda.com/pkgs/main/noarch/python-tzdata-2023.3-pyhd3eb1b0_0.conda#479c037de0186d114b9911158427624e -https://repo.anaconda.com/pkgs/main/osx-64/pytz-2023.3.post1-py312hecd8cb5_0.conda#2636382c9a424f69cbc36b1c5dc1f2fc -https://repo.anaconda.com/pkgs/main/osx-64/setuptools-68.2.2-py312hecd8cb5_0.conda#64235f0c451427d86808c70c1c31cb8b +https://repo.anaconda.com/pkgs/main/osx-64/pytz-2023.3.post1-py311hecd8cb5_0.conda#32d107281d133e3935dfb6935153e438 +https://repo.anaconda.com/pkgs/main/osx-64/setuptools-68.2.2-py311hecd8cb5_0.conda#c5f526775f35d920e9d6099fa146d7a1 https://repo.anaconda.com/pkgs/main/noarch/six-1.16.0-pyhd3eb1b0_1.conda#34586824d411d36af2fa40e799c172d0 https://repo.anaconda.com/pkgs/main/noarch/threadpoolctl-2.2.0-pyh0d69192_0.conda#bbfdbae4934150b902f97daaf287efe2 https://repo.anaconda.com/pkgs/main/noarch/toml-0.10.2-pyhd3eb1b0_0.conda#cda05f5f6d8509529d1a2743288d197a -https://repo.anaconda.com/pkgs/main/osx-64/tornado-6.3.3-py312h6c40b1e_0.conda#49173b5a36c9134865221f29d4a73fb6 -https://repo.anaconda.com/pkgs/main/osx-64/wheel-0.41.2-py312hecd8cb5_0.conda#e7aea266d81142e2bb0bbc2280e64526 +https://repo.anaconda.com/pkgs/main/osx-64/tornado-6.3.3-py311h6c40b1e_0.conda#e98809ea222b3da0ebeae40bc73dfdb0 +https://repo.anaconda.com/pkgs/main/osx-64/wheel-0.41.2-py311hecd8cb5_0.conda#38cad06df66feef1e1a02e82587034fc https://repo.anaconda.com/pkgs/main/noarch/fonttools-4.25.0-pyhd3eb1b0_0.conda#bb9c5b5a6d892fca5efe4bf0203b6a48 -https://repo.anaconda.com/pkgs/main/osx-64/markdown-it-py-2.2.0-py312hecd8cb5_1.conda#bc2e2635a5c7fc25b591c4cd5216194b -https://repo.anaconda.com/pkgs/main/osx-64/numpy-base-1.26.3-py312h6f81483_0.conda#58c3bc6c19210583249b16d69f9bdb0a -https://repo.anaconda.com/pkgs/main/osx-64/pillow-10.2.0-py312h6c40b1e_0.conda#5a44bd28cf26fff2d6219e76a86db126 -https://repo.anaconda.com/pkgs/main/osx-64/pip-23.3.1-py312hecd8cb5_0.conda#efc3db40cac09f74bb480d28d3a0b260 -https://repo.anaconda.com/pkgs/main/osx-64/pytest-7.4.0-py312hecd8cb5_0.conda#b816a2439ba9b87524aec74d58e55b0a +https://repo.anaconda.com/pkgs/main/osx-64/markdown-it-py-2.2.0-py311hecd8cb5_1.conda#fcb5ce9cc6f0157d39022a57d4319729 +https://repo.anaconda.com/pkgs/main/osx-64/numpy-base-1.26.3-py311h53bf9ac_0.conda#b17bbb5b3686c48494ee3acffc8f81a9 +https://repo.anaconda.com/pkgs/main/osx-64/pillow-10.2.0-py311h6c40b1e_0.conda#74ec88cab7ca5e71fe5fddf1b421e954 +https://repo.anaconda.com/pkgs/main/osx-64/pip-23.3.1-py311hecd8cb5_0.conda#da6ad85d5bb1add3997502c106791ac3 +https://repo.anaconda.com/pkgs/main/osx-64/pytest-7.4.0-py311hecd8cb5_0.conda#8c5496a4a1f36160ac5556495faa4a24 https://repo.anaconda.com/pkgs/main/noarch/python-dateutil-2.8.2-pyhd3eb1b0_0.conda#211ee00320b08a1ac9fea6677649f6c9 -https://repo.anaconda.com/pkgs/main/osx-64/pytest-cov-4.1.0-py312hecd8cb5_1.conda#a33a24eb20359f464938e75b2f57e23a -https://repo.anaconda.com/pkgs/main/osx-64/pytest-xdist-3.5.0-py312hecd8cb5_0.conda#d1ecfb3691cceecb1f16bcfdf0b67bb5 -https://repo.anaconda.com/pkgs/main/osx-64/rich-13.3.5-py312hecd8cb5_1.conda#05de027713190752da0887054acbf016 -https://repo.anaconda.com/pkgs/main/osx-64/bottleneck-1.3.5-py312h32608ca_0.conda#f002daa9c28b94536e464359cc7d04e7 -https://repo.anaconda.com/pkgs/main/osx-64/contourpy-1.2.0-py312ha357a0b_0.conda#57d384ad07152375b40a6293f79e3f0c -https://repo.anaconda.com/pkgs/main/osx-64/matplotlib-3.8.0-py312hecd8cb5_0.conda#64ffa3462aace0fc2d5fa5bff15f63f6 -https://repo.anaconda.com/pkgs/main/osx-64/matplotlib-base-3.8.0-py312h7f12edd_0.conda#bda389e5a1ff69f763911cf90102893b -https://repo.anaconda.com/pkgs/main/osx-64/mkl_fft-1.3.8-py312h6c40b1e_0.conda#d59d01b940493f2b6a84aac922fd0c76 -https://repo.anaconda.com/pkgs/main/osx-64/mkl_random-1.2.4-py312ha357a0b_0.conda#c1ea9c8eee79a5af3399f3c31be0e9c6 -https://repo.anaconda.com/pkgs/main/osx-64/numpy-1.26.3-py312hac873b0_0.conda#d2310a3607112d4b042330d0140434ef -https://repo.anaconda.com/pkgs/main/osx-64/numexpr-2.8.7-py312hac873b0_0.conda#6303ba071636ef57fddf69eb6f440ec1 -https://repo.anaconda.com/pkgs/main/osx-64/scipy-1.11.4-py312h81688c2_0.conda#7d57b4c21a9261f97fa511e0940c5d93 -https://repo.anaconda.com/pkgs/main/osx-64/pandas-2.1.4-py312he282a81_0.conda#dcbed31bc94e03cc6f53312e0fb4eb49 -https://repo.anaconda.com/pkgs/main/osx-64/pyamg-4.2.3-py312h44cbcf4_0.conda#3bdc7be74087b3a5a83c520a74e1e8eb -# pip cython @ https://files.pythonhosted.org/packages/3d/8e/28f8c6109990eef7317ab7e43644092b49a88a39f9373dcd19318946df09/Cython-3.0.8-cp312-cp312-macosx_10_9_x86_64.whl#sha256=90d3fe31db55685d8cb97d43b0ec39ef614fcf660f83c77ed06aa670cb0e164f +https://repo.anaconda.com/pkgs/main/osx-64/pytest-cov-4.1.0-py311hecd8cb5_1.conda#b1e41a8eda3f119b39b13f3a4d0c5bf5 +https://repo.anaconda.com/pkgs/main/osx-64/pytest-xdist-3.5.0-py311hecd8cb5_0.conda#e892e4359ea4f0987e8268f7e7869680 +https://repo.anaconda.com/pkgs/main/osx-64/rich-13.3.5-py311hecd8cb5_0.conda#6360bb4bc83d7dfe34cd5c04e25690fa +https://repo.anaconda.com/pkgs/main/osx-64/bottleneck-1.3.7-py311hb3a5e46_0.conda#0ed3a6bc594cbac2b69dc4f57d8ae96d +https://repo.anaconda.com/pkgs/main/osx-64/contourpy-1.2.0-py311ha357a0b_0.conda#c9189b40e5b4be360aef22be336a4838 +https://repo.anaconda.com/pkgs/main/osx-64/matplotlib-3.8.0-py311hecd8cb5_0.conda#f720f09a9d1bb976aa92a13180cf7133 +https://repo.anaconda.com/pkgs/main/osx-64/matplotlib-base-3.8.0-py311h41a4f6b_0.conda#da5175158820055096f25520004fb9b3 +https://repo.anaconda.com/pkgs/main/osx-64/mkl_fft-1.3.8-py311h6c40b1e_0.conda#7e70133e3cf6151d2826da7ae3af609f +https://repo.anaconda.com/pkgs/main/osx-64/mkl_random-1.2.4-py311ha357a0b_0.conda#b363dccbb0219bb2f810a05b9bde92fb +https://repo.anaconda.com/pkgs/main/osx-64/numpy-1.26.3-py311h728a8a3_0.conda#19c75a5ddc6d2b5e91515cfca46f0ba0 +https://repo.anaconda.com/pkgs/main/osx-64/numexpr-2.8.7-py311h728a8a3_0.conda#21a483a6825576049b1abda53076ef3e +https://repo.anaconda.com/pkgs/main/osx-64/scipy-1.11.4-py311h224febf_0.conda#c1db23a0c898869d0f4f02831f9e31e3 +https://repo.anaconda.com/pkgs/main/osx-64/pandas-2.1.4-py311hdb55bb0_0.conda#b118594fae66a7cd93c088f75de7faca +https://repo.anaconda.com/pkgs/main/osx-64/pyamg-4.2.3-py311h37a6a59_0.conda#5fca7d043dc68c1d7acc22aa03a24918 +# pip cython @ https://files.pythonhosted.org/packages/db/a7/f4a0bc9a80e23b380daa2ebb4879bf434aaa0b3b91f7ad8a7f9762b4bd1b/Cython-3.0.8-cp311-cp311-macosx_10_9_x86_64.whl#sha256=aae26f9663e50caf9657148403d9874eea41770ecdd6caf381d177c2b1bb82ba diff --git a/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock b/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock index 7a703828a0256..ac71d5ebe2d21 100644 --- a/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock +++ b/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock @@ -64,7 +64,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3.1-py39h06a4308_0.conda#685 # pip urllib3 @ https://files.pythonhosted.org/packages/a2/73/a68704750a7679d0b6d3ad7aa8d4da8e14e151ae82e6fee774e6e0d05ec8/urllib3-2.2.1-py3-none-any.whl#sha256=450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d # pip zipp @ https://files.pythonhosted.org/packages/d9/66/48866fc6b158c81cc2bfecc04c480f105c6040e8b077bc54c634b4a67926/zipp-3.17.0-py3-none-any.whl#sha256=0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31 # pip contourpy @ https://files.pythonhosted.org/packages/a9/ba/d8fd1380876f1e9114157606302e3644c85f6d116aeba354c212ee13edc7/contourpy-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=11f8d2554e52f459918f7b8e6aa20ec2a3bce35ce95c1f0ef4ba36fbda306df5 -# pip coverage @ https://files.pythonhosted.org/packages/ff/e3/351477165426da841458f2c1b732360dd42da140920e3cd4b70676e5b77f/coverage-7.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=d12c923757de24e4e2110cf8832d83a886a4cf215c6e61ed506006872b43a6d1 +# pip coverage @ https://files.pythonhosted.org/packages/f3/11/64ccdc31b1e668bad17a0174a840bdea4d22084bc2a82f745b5cab6d1212/coverage-7.4.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=c00e54f0bd258ab25e7f731ca1d5144b0bf7bec0051abccd2bdcff65fa3262c9 # pip imageio @ https://files.pythonhosted.org/packages/02/25/66533a8390e3763cf8254dee143dbf8a830391ea60d2762512ba7f9ddfbe/imageio-2.34.0-py3-none-any.whl#sha256=08082bf47ccb54843d9c73fe9fc8f3a88c72452ab676b58aca74f36167e8ccba # pip importlib-metadata @ https://files.pythonhosted.org/packages/c0/8b/d8427f023c081a8303e6ac7209c16e6878f2765d5b59667f3903fbcfd365/importlib_metadata-7.0.1-py3-none-any.whl#sha256=4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e # pip importlib-resources @ https://files.pythonhosted.org/packages/93/e8/facde510585869b5ec694e8e0363ffe4eba067cb357a8398a55f6a1f8023/importlib_resources-6.1.1-py3-none-any.whl#sha256=e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6 diff --git a/build_tools/azure/pylatest_pip_scipy_dev_linux-64_conda.lock b/build_tools/azure/pylatest_pip_scipy_dev_linux-64_conda.lock index dd1d8657606a1..fff96a2c78ff0 100644 --- a/build_tools/azure/pylatest_pip_scipy_dev_linux-64_conda.lock +++ b/build_tools/azure/pylatest_pip_scipy_dev_linux-64_conda.lock @@ -30,7 +30,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3.1-py312h06a4308_0.conda#e1 # pip babel @ https://files.pythonhosted.org/packages/0d/35/4196b21041e29a42dc4f05866d0c94fa26c9da88ce12c38c2265e42c82fb/Babel-2.14.0-py3-none-any.whl#sha256=efb1a25b7118e67ce3a259bed20545c29cb68be8ad2c784c83689981b7a57287 # pip certifi @ https://files.pythonhosted.org/packages/ba/06/a07f096c664aeb9f01624f858c3add0a4e913d6c96257acb4fce61e7de14/certifi-2024.2.2-py3-none-any.whl#sha256=dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 # pip charset-normalizer @ https://files.pythonhosted.org/packages/ee/fb/14d30eb4956408ee3ae09ad34299131fb383c47df355ddb428a7331cfa1e/charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b -# pip coverage @ https://files.pythonhosted.org/packages/c3/92/f2d89715c3397e76fe365b1ecbb861d1279ff8d47d23635040a358bc75dc/coverage-7.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=fe558371c1bdf3b8fa03e097c523fb9645b8730399c14fe7721ee9c9e2a545d3 +# pip coverage @ https://files.pythonhosted.org/packages/b5/4f/02e6fa3518d0d32ce92d1dba11f48e17ac1522843dc4f6276eb9e6c03ef0/coverage-7.4.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=42a9e754aa250fe61f0f99986399cec086d7e7a01dd82fd863a20af34cbce962 # pip docutils @ https://files.pythonhosted.org/packages/26/87/f238c0670b94533ac0353a4e2a1a771a0cc73277b88bff23d3ae35a256c1/docutils-0.20.1-py3-none-any.whl#sha256=96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6 # pip execnet @ https://files.pythonhosted.org/packages/e8/9c/a079946da30fac4924d92dbc617e5367d454954494cf1e71567bcc4e00ee/execnet-2.0.2-py3-none-any.whl#sha256=88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41 # pip idna @ https://files.pythonhosted.org/packages/c2/e7/a82b05cf63a603df6e68d59ae6a68bf5064484a0718ea5033660af4b54a9/idna-3.6-py3-none-any.whl#sha256=c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f diff --git a/build_tools/azure/pymin_conda_forge_mkl_win-64_conda.lock b/build_tools/azure/pymin_conda_forge_mkl_win-64_conda.lock index e84c453933286..db722a64abcd0 100644 --- a/build_tools/azure/pymin_conda_forge_mkl_win-64_conda.lock +++ b/build_tools/azure/pymin_conda_forge_mkl_win-64_conda.lock @@ -77,7 +77,7 @@ https://conda.anaconda.org/conda-forge/win-64/tornado-6.4-py39h7a188e9_0.conda#1 https://conda.anaconda.org/conda-forge/win-64/unicodedata2-15.1.0-py39h7a188e9_0.conda#895188b9d9aa5962246b345b26091111 https://conda.anaconda.org/conda-forge/noarch/wheel-0.42.0-pyhd8ed1ab_0.conda#1cdea58981c5cbc17b51973bcaddcea7 https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a -https://conda.anaconda.org/conda-forge/win-64/coverage-7.4.1-py39h7a188e9_0.conda#dde632fc22f9e92101bd78cca27cd4ca +https://conda.anaconda.org/conda-forge/win-64/coverage-7.4.2-py39h7a188e9_0.conda#3da6df64c4d54dc4f9ae9513eccfc0d4 https://conda.anaconda.org/conda-forge/win-64/fonttools-4.49.0-py39h7a188e9_0.conda#842e830fa2ec96d47fa1e3408f58963a https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.1-pyhd8ed1ab_0.conda#3d5fa25cf42f3f32a12b2d874ace8574 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc diff --git a/build_tools/circle/doc_linux-64_conda.lock b/build_tools/circle/doc_linux-64_conda.lock index 0c64a989d724f..667f4c8ec46e9 100644 --- a/build_tools/circle/doc_linux-64_conda.lock +++ b/build_tools/circle/doc_linux-64_conda.lock @@ -230,7 +230,7 @@ https://conda.anaconda.org/conda-forge/noarch/imageio-2.34.0-pyh4b66e23_0.conda# https://conda.anaconda.org/conda-forge/linux-64/pandas-2.2.0-py39hddac248_0.conda#95aaa7baa61432a1ce85dedb7b86d2dd https://conda.anaconda.org/conda-forge/noarch/patsy-0.5.6-pyhd8ed1ab_0.conda#a5b55d1cb110cdcedc748b5c3e16e687 https://conda.anaconda.org/conda-forge/linux-64/polars-0.20.10-py39h927a070_0.conda#2c626921a52a9571bda297ef0fceb15a -https://conda.anaconda.org/conda-forge/noarch/pooch-1.8.0-pyhd8ed1ab_0.conda#134b2b57b7865d2316a7cce1915a51ed +https://conda.anaconda.org/conda-forge/noarch/pooch-1.8.1-pyhd8ed1ab_0.conda#d15917f33140f8d2ac9ca44db7ec8a25 https://conda.anaconda.org/conda-forge/linux-64/pywavelets-1.4.1-py39h44dd56e_1.conda#d037c20e3da2e85f03ebd20ad480c359 https://conda.anaconda.org/conda-forge/linux-64/scipy-1.12.0-py39h474f0d3_2.conda#6ab241b2023730f6b41712dc1b503afa https://conda.anaconda.org/conda-forge/linux-64/blas-2.121-openblas.conda#4a279792fd8861a15705516a52872eb6 diff --git a/build_tools/circle/doc_min_dependencies_linux-64_conda.lock b/build_tools/circle/doc_min_dependencies_linux-64_conda.lock index 9fea075856154..d7fdd03c17d22 100644 --- a/build_tools/circle/doc_min_dependencies_linux-64_conda.lock +++ b/build_tools/circle/doc_min_dependencies_linux-64_conda.lock @@ -206,7 +206,7 @@ https://conda.anaconda.org/conda-forge/noarch/dask-core-2024.2.0-pyhd8ed1ab_0.co https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.9-h8e1006c_0.conda#614b81f8ed66c56b640faee7076ad14a https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-21_linux64_mkl.conda#0553cad80ef02be86c8e178eeecb6a34 https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-21_linux64_mkl.conda#52837ab7fd5b43d3960c62e5c91958d6 -https://conda.anaconda.org/conda-forge/noarch/pooch-1.8.0-pyhd8ed1ab_0.conda#134b2b57b7865d2316a7cce1915a51ed +https://conda.anaconda.org/conda-forge/noarch/pooch-1.8.1-pyhd8ed1ab_0.conda#d15917f33140f8d2ac9ca44db7ec8a25 https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-21_linux64_mkl.conda#0d45f03de7143f324b37454af46feb26 https://conda.anaconda.org/conda-forge/linux-64/numpy-1.19.5-py39hd249d9e_3.tar.bz2#0cf333996ebdeeba8d1c8c1c0ee9eff9 https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h5810be5_19.conda#54866f708d43002a514d0b9b0f84bc11 From d392b63df2a022572ae93b6d3cbbbbb8e9c64928 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 6 Mar 2024 15:01:53 +0100 Subject: [PATCH 40/55] debug ci --- build_tools/azure/test_script.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 5a4341cac9e62..2adbb5a28cdb9 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -61,5 +61,5 @@ if [[ -n "$SELECTED_TESTS" ]]; then fi set -x -eval "$TEST_CMD --maxfail=10 --pyargs sklearn" +eval "$TEST_CMD --maxfail=7 --pyargs sklearn" set +x From 91777578db84f6fe9253d95d6896454741234809 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 6 Mar 2024 15:31:47 +0100 Subject: [PATCH 41/55] iter --- build_tools/azure/test_script.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 2adbb5a28cdb9..56c105471297f 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -61,5 +61,5 @@ if [[ -n "$SELECTED_TESTS" ]]; then fi set -x -eval "$TEST_CMD --maxfail=7 --pyargs sklearn" +eval "$TEST_CMD -x --pyargs sklearn" set +x From 436bcadfbc615f588c48b76b73db659fd1b69680 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 6 Mar 2024 15:51:17 +0100 Subject: [PATCH 42/55] iter --- build_tools/azure/test_script.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 56c105471297f..f7d06a5b5365b 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -50,7 +50,7 @@ fi if [[ "$PYTEST_XDIST_VERSION" != "none" ]]; then XDIST_WORKERS=$(python -c "import joblib; print(joblib.cpu_count(only_physical_cores=True))") - TEST_CMD="$TEST_CMD -n$XDIST_WORKERS" + TEST_CMD="$TEST_CMD" fi if [[ -n "$SELECTED_TESTS" ]]; then From 5bf660855663c5f843c8a3446d7fd450e2056f20 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 6 Mar 2024 17:09:54 +0100 Subject: [PATCH 43/55] iter --- ...latest_conda_forge_mkl_linux-64_conda.lock | 8 +- ...pylatest_conda_forge_mkl_osx-64_conda.lock | 16 ++- .../pymin_conda_forge_mkl_win-64_conda.lock | 85 ++++++----- ...e_openblas_ubuntu_2204_linux-64_conda.lock | 134 +++++++++++++----- build_tools/azure/pypy3_linux-64_conda.lock | 2 +- build_tools/azure/test_script.sh | 4 +- build_tools/circle/doc_linux-64_conda.lock | 4 +- .../doc_min_dependencies_linux-64_conda.lock | 2 +- 8 files changed, 172 insertions(+), 83 deletions(-) diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock index 078fa6567594d..a3ccfa1f4c9c6 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock @@ -19,7 +19,6 @@ https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_5.cond https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.11-hd590300_1.conda#0bb492cca54017ea314b809b1ee3a176 https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2#d9c69a24ad678ffce24c6543a0176b00 https://conda.anaconda.org/conda-forge/linux-64/aws-c-common-0.9.0-hd590300_0.conda#71b89db63b5b504e7afc8ad901172e1e -https://conda.anaconda.org/conda-forge/linux-64/brotli-1.0.9-h9c3ff4c_4.tar.bz2#f4f75dc7038aaeb6eaae16a5ef5350b3 https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda#69b8b6202a07720f448be700e300ccf4 https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.27.0-hd590300_0.conda#f6afff0e9ee08d2f1b897881a4f38cdb https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2#14947d8770185e5153fdd04d4673ed37 @@ -56,6 +55,7 @@ https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda#da0ec https://conda.anaconda.org/conda-forge/linux-64/openssl-3.2.1-hd590300_0.conda#51a753e64a3027bd7e23a189b1f6e91e https://conda.anaconda.org/conda-forge/linux-64/pixman-0.43.2-h59595ed_0.conda#71004cbf7924e19c02746ccde9fd7123 https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2#22dad4df6e8630e8dff2428f6f6a7036 +https://conda.anaconda.org/conda-forge/linux-64/rdma-core-28.9-h59595ed_1.conda#aeffb7c06b5f65e55e6c637408dc4100 https://conda.anaconda.org/conda-forge/linux-64/re2-2023.03.02-h8c504da_0.conda#206f8fa808748f6e90599c3368a1114e https://conda.anaconda.org/conda-forge/linux-64/sleef-3.5.1-h9b69904_2.tar.bz2#6e016cf4c525d04a7bd038cee53ad3fd https://conda.anaconda.org/conda-forge/linux-64/snappy-1.1.10-h9fff704_0.conda#e6d228cd0bb74a51dd18f5bfce0b4115 @@ -95,11 +95,12 @@ https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.42-hcad00b1_0.conda#679 https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4 https://conda.anaconda.org/conda-forge/linux-64/s2n-1.3.49-h06160fa_0.conda#1d78349eb26366ecc034a4afe70a8534 https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda#d453b98d9c83e71da0741bb0ff4d76bc -https://conda.anaconda.org/conda-forge/linux-64/ucx-1.14.1-h0aa22dc_2.conda#221632ba76f3c2bb58f1db59911548e7 +https://conda.anaconda.org/conda-forge/linux-64/ucx-1.14.1-h64cca9d_5.conda#39aa3b356d10d7e5add0c540945a0944 https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda#93ee23f12bc2e684548181256edd2cf6 https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda#68c34ec6149623be41a1933ab996a209 https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda#04b88013080254850d6c01ed54810589 https://conda.anaconda.org/conda-forge/linux-64/aws-c-io-0.13.32-he9a53bd_1.conda#8a24e5820f4a0ffd2ed9c4722cd5d7ca +https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.0.9-h166bdaf_9.conda#d47dee1856d9cb955b8076eeff304a5b https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda#9ae35c3d96db2c94ce0cef86efdfa2cb https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.3-hd590300_0.conda#32d16ad533c59bb0a3c5ffaf16110829 @@ -124,6 +125,7 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.con https://conda.anaconda.org/conda-forge/noarch/array-api-compat-1.4.1-pyhd8ed1ab_0.conda#adfa5bcaa427c1aae88572ac0d08c4c2 https://conda.anaconda.org/conda-forge/linux-64/aws-c-event-stream-0.3.1-h2e3709c_4.conda#2cf21b1cbc1c096a28ffa2892257a2c1 https://conda.anaconda.org/conda-forge/linux-64/aws-c-http-0.7.11-h00aa349_4.conda#cb932dff7328ff620ce8059c9968b095 +https://conda.anaconda.org/conda-forge/linux-64/brotli-1.0.9-h166bdaf_9.conda#4601544b4982ba1861fa9b9c607b2c06 https://conda.anaconda.org/conda-forge/linux-64/ccache-4.9.1-h1fcd64f_0.conda#3620f564bcf28c3524951b6f64f5c5ac https://conda.anaconda.org/conda-forge/noarch/certifi-2024.2.2-pyhd8ed1ab_0.conda#0876280e409658fc6f9e75d035960333 https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 @@ -149,7 +151,7 @@ https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda# https://conda.anaconda.org/conda-forge/noarch/pluggy-1.4.0-pyhd8ed1ab_0.conda#139e9feb65187e916162917bb2484976 https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2#7205635cd71531943440fbfe3b6b5727 https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda#b9a4dacf97241704529131a0dfc0494f https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2024.1-pyhd8ed1ab_0.conda#98206ea9954216ee7540f0c773f2104d https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda#3eeeeb9e4827ace8c0c1419c85d590ad https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.1-pyhd8ed1ab_0.conda#576de899521b7d43674ba3ef6eae9142 diff --git a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock index 157cf0b17fda3..ec6d71a4cb49d 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock @@ -24,7 +24,7 @@ https://conda.anaconda.org/conda-forge/osx-64/xorg-libxau-1.0.11-h0dc2134_0.cond https://conda.anaconda.org/conda-forge/osx-64/xorg-libxdmcp-1.1.3-h35c211d_0.tar.bz2#86ac76d6bf1cbb9621943eb3bd9ae36e https://conda.anaconda.org/conda-forge/osx-64/xz-5.2.6-h775f41a_0.tar.bz2#a72f9d4ea13d55d745ff1ed594747f10 https://conda.anaconda.org/conda-forge/osx-64/gmp-6.3.0-h93d8f39_0.conda#a4ffd4bfd88659cbecbd36b61594bf0d -https://conda.anaconda.org/conda-forge/osx-64/isl-0.25-hb486fe8_0.tar.bz2#45a9a46c78c0ea5c275b535f7923bde3 +https://conda.anaconda.org/conda-forge/osx-64/isl-0.26-imath32_h2e86a7b_101.conda#d06222822a9144918333346f145b68c6 https://conda.anaconda.org/conda-forge/osx-64/lerc-4.0.0-hb486fe8_0.tar.bz2#f9d6a4c82889d5ecedec1d90eb673c55 https://conda.anaconda.org/conda-forge/osx-64/libbrotlidec-1.1.0-h0dc2134_1.conda#9ee0bab91b2ca579e10353738be36063 https://conda.anaconda.org/conda-forge/osx-64/libbrotlienc-1.1.0-h0dc2134_1.conda#8a421fe09c6187f0eb5e2338a8a8be6d @@ -67,7 +67,7 @@ https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5 https://conda.anaconda.org/conda-forge/osx-64/cython-3.0.9-py312hede676d_0.conda#e7cfe4322252a2d0786a064c214436ae https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda#8d652ea2ee8eaee02ed8dc820bc794aa https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 -https://conda.anaconda.org/conda-forge/osx-64/gfortran_impl_osx-64-12.3.0-h54fd467_1.conda#5f4d40236e204c6e62cd0a316244f316 +https://conda.anaconda.org/conda-forge/osx-64/gfortran_impl_osx-64-12.3.0-hc328e78_3.conda#b3d751dc7073bbfdfa9d863e39b9685d https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 https://conda.anaconda.org/conda-forge/osx-64/kiwisolver-1.4.5-py312h49ebfd2_1.conda#21f174a5cfb5964069c374171a979157 https://conda.anaconda.org/conda-forge/osx-64/ld64-609-ha02d983_16.conda#6dfb00e6cab263fe598d48df153d3288 @@ -78,7 +78,7 @@ https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda# https://conda.anaconda.org/conda-forge/osx-64/pillow-10.2.0-py312h0c70c2f_0.conda#0cc3674239ad12c6836cb4174f106c92 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.4.0-pyhd8ed1ab_0.conda#139e9feb65187e916162917bb2484976 https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda#b9a4dacf97241704529131a0dfc0494f https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2024.1-pyhd8ed1ab_0.conda#98206ea9954216ee7540f0c773f2104d https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda#3eeeeb9e4827ace8c0c1419c85d590ad https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.1-pyhd8ed1ab_0.conda#576de899521b7d43674ba3ef6eae9142 @@ -113,14 +113,16 @@ https://conda.anaconda.org/conda-forge/osx-64/contourpy-1.2.0-py312hbf0bb39_0.co https://conda.anaconda.org/conda-forge/osx-64/pandas-2.2.1-py312h83c8a23_0.conda#c562e07382cdc3194c21b8eca06460ff https://conda.anaconda.org/conda-forge/osx-64/scipy-1.12.0-py312h8adb940_2.conda#b16a9767f5f4b0a0ec8fb566e2c586f7 https://conda.anaconda.org/conda-forge/osx-64/blas-2.120-mkl.conda#b041a7677a412f3d925d8208936cb1e2 -https://conda.anaconda.org/conda-forge/osx-64/clang_osx-64-16.0.6-h8787910_2.conda#efd22736de32ae376efc729b0b5af8a2 +https://conda.anaconda.org/conda-forge/osx-64/clang_impl_osx-64-16.0.6-h8787910_10.conda#cf49a37b020f4016c52d2ed419120b67 https://conda.anaconda.org/conda-forge/osx-64/matplotlib-base-3.8.3-py312h1fe5000_0.conda#5f65fc4ce880d4c795e217d563a114ec https://conda.anaconda.org/conda-forge/osx-64/pyamg-5.0.1-py312h674694f_1.conda#e5b9c0f8b5c367467425ff34353ef761 +https://conda.anaconda.org/conda-forge/osx-64/clang_osx-64-16.0.6-hb91bd55_10.conda#397823847b6952fbd56bec0a766c7d22 +https://conda.anaconda.org/conda-forge/osx-64/matplotlib-3.8.3-py312hb401068_0.conda#7015bf84c9d39284c4746d814da2a0f1 https://conda.anaconda.org/conda-forge/osx-64/c-compiler-1.7.0-h282daa2_0.conda#4652f33fe8d895f61177e2783b289377 -https://conda.anaconda.org/conda-forge/osx-64/clangxx_osx-64-16.0.6-h1b7723c_2.conda#9b40be235bc26c0949debeca13a35305 +https://conda.anaconda.org/conda-forge/osx-64/clangxx_impl_osx-64-16.0.6-h6d92fbe_10.conda#f0746a628875a935eb1beba8f2e143a3 https://conda.anaconda.org/conda-forge/osx-64/gfortran_osx-64-12.3.0-h18f7dce_1.conda#436af2384c47aedb94af78a128e174f1 -https://conda.anaconda.org/conda-forge/osx-64/matplotlib-3.8.3-py312hb401068_0.conda#7015bf84c9d39284c4746d814da2a0f1 -https://conda.anaconda.org/conda-forge/osx-64/cxx-compiler-1.7.0-h7728843_0.conda#8abaa2694c1fba2b6bd3753d00a60415 +https://conda.anaconda.org/conda-forge/osx-64/clangxx_osx-64-16.0.6-hb91bd55_10.conda#b6e4f35ca158de82ad2ea0ec68c7f8f8 https://conda.anaconda.org/conda-forge/osx-64/gfortran-12.3.0-h2c809b3_1.conda#c48adbaa8944234b80ef287c37e329b0 +https://conda.anaconda.org/conda-forge/osx-64/cxx-compiler-1.7.0-h7728843_0.conda#8abaa2694c1fba2b6bd3753d00a60415 https://conda.anaconda.org/conda-forge/osx-64/fortran-compiler-1.7.0-h6c2ab21_0.conda#2c11db8b46df0a547997116f0fd54b8e https://conda.anaconda.org/conda-forge/osx-64/compilers-1.7.0-h694c41f_0.conda#3576aa54986a3e2a5370e4232b35c036 diff --git a/build_tools/azure/pymin_conda_forge_mkl_win-64_conda.lock b/build_tools/azure/pymin_conda_forge_mkl_win-64_conda.lock index da0d82faf9bca..6406945c32b3e 100644 --- a/build_tools/azure/pymin_conda_forge_mkl_win-64_conda.lock +++ b/build_tools/azure/pymin_conda_forge_mkl_win-64_conda.lock @@ -4,13 +4,11 @@ @EXPLICIT https://conda.anaconda.org/conda-forge/win-64/ca-certificates-2024.2.2-h56e8100_0.conda#63da060240ab8087b60d1357051ea7d6 https://conda.anaconda.org/conda-forge/win-64/intel-openmp-2024.0.0-h57928b3_49841.conda#e3255c8cdaf1d52f15816d1970f9c77a -https://conda.anaconda.org/conda-forge/win-64/libexpat-2.5.0-h63175ca_1.conda#636cc3cbbd2e28bcfd2f73b2044aac2c https://conda.anaconda.org/conda-forge/win-64/mkl-include-2024.0.0-h66d3029_49657.conda#4477b53b9f7edc041962c92a5d5e9caa https://conda.anaconda.org/conda-forge/win-64/msys2-conda-epoch-20160418-1.tar.bz2#b0309b72560df66f71a9d5e34a5efdfa -https://conda.anaconda.org/conda-forge/win-64/python_abi-3.9-4_pypy39_pp73.conda#0e5639cb7ef9d5774a5ab2d48a011c77 +https://conda.anaconda.org/conda-forge/win-64/python_abi-3.9-4_cp39.conda#948b0d93d4ab1372d8fd45e1560afd47 https://conda.anaconda.org/conda-forge/noarch/tzdata-2024a-h0c530f3_0.conda#161081fc7cec0bfda0d86d7cb595f8d8 https://conda.anaconda.org/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_0.tar.bz2#72608f6cd3e5898229c3ea16deb1ac43 -https://conda.anaconda.org/conda-forge/win-64/expat-2.5.0-h63175ca_1.conda#87c77fe1b445aedb5c6d207dd236fa3e https://conda.anaconda.org/conda-forge/win-64/m2w64-gmp-6.1.0-2.tar.bz2#53a1c73e1e3d185516d7e3af177596d9 https://conda.anaconda.org/conda-forge/win-64/m2w64-libwinpthread-git-5.0.0.4634.697f757-2.tar.bz2#774130a326dee16f1ceb05cc687ee4f0 https://conda.anaconda.org/conda-forge/win-64/vc14_runtime-14.38.33130-h82b7239_18.conda#8be79fdd2725ddf7bbf8a27a4c1f79ba @@ -18,12 +16,14 @@ https://conda.anaconda.org/conda-forge/win-64/m2w64-gcc-libs-core-5.3.0-7.tar.bz https://conda.anaconda.org/conda-forge/win-64/vc-14.3-hcf57466_18.conda#20e1e652a4c740fa719002a8449994a2 https://conda.anaconda.org/conda-forge/win-64/vs2015_runtime-14.38.33130-hcb4865c_18.conda#10d42885e3ed84e575b454db30f1aa93 https://conda.anaconda.org/conda-forge/win-64/bzip2-1.0.8-hcfcfb64_5.conda#26eb8ca6ea332b675e11704cce84a3be +https://conda.anaconda.org/conda-forge/win-64/icu-73.2-h63175ca_0.conda#0f47d9e3192d9e09ae300da0d28e0f56 https://conda.anaconda.org/conda-forge/win-64/lerc-4.0.0-h63175ca_0.tar.bz2#1900cb3cab5055833cfddb0ba233b074 https://conda.anaconda.org/conda-forge/win-64/libbrotlicommon-1.1.0-hcfcfb64_1.conda#f77f319fb82980166569e1280d5b2864 https://conda.anaconda.org/conda-forge/win-64/libdeflate-1.19-hcfcfb64_0.conda#002b1b723b44dbd286b9e3708762433c https://conda.anaconda.org/conda-forge/win-64/libffi-3.4.2-h8ffe710_5.tar.bz2#2c96d1b6915b408893f9472569dee135 https://conda.anaconda.org/conda-forge/win-64/libiconv-1.17-hcfcfb64_2.conda#e1eb10b1cca179f2baa3601e4efc8712 https://conda.anaconda.org/conda-forge/win-64/libjpeg-turbo-3.0.0-hcfcfb64_1.conda#3f1b948619c45b1ca714d60c7389092c +https://conda.anaconda.org/conda-forge/win-64/libogg-1.3.4-h8ffe710_1.tar.bz2#04286d905a0dcb7f7d4a12bdfe02516d https://conda.anaconda.org/conda-forge/win-64/libsqlite-3.45.1-hcfcfb64_0.conda#c583c1d6999b7aa148eff3089e13c44b https://conda.anaconda.org/conda-forge/win-64/libwebp-base-1.3.2-hcfcfb64_0.conda#dcde8820959e64378d4e06147ffecfdd https://conda.anaconda.org/conda-forge/win-64/libzlib-1.2.13-hcfcfb64_5.conda#5fdb9c6a113b6b6cb5e517fd972d5f41 @@ -32,70 +32,83 @@ https://conda.anaconda.org/conda-forge/win-64/openssl-3.2.1-hcfcfb64_0.conda#158 https://conda.anaconda.org/conda-forge/win-64/pthreads-win32-2.9.1-hfa6e2cd_3.tar.bz2#e2da8758d7d51ff6aa78a14dfb9dbed4 https://conda.anaconda.org/conda-forge/win-64/tk-8.6.13-h5226925_1.conda#fc048363eb8f03cd1737600a5d08aafe https://conda.anaconda.org/conda-forge/win-64/xz-5.2.6-h8d14728_0.tar.bz2#515d77642eaa3639413c6b1bc3f94219 +https://conda.anaconda.org/conda-forge/win-64/gettext-0.21.1-h5728263_0.tar.bz2#299d4fd6798a45337042ff5a48219e5f +https://conda.anaconda.org/conda-forge/win-64/krb5-1.21.2-heb0366b_0.conda#6e8b0f22b4eef3b3cb3849bb4c3d47f9 https://conda.anaconda.org/conda-forge/win-64/libbrotlidec-1.1.0-hcfcfb64_1.conda#19ce3e1dacc7912b3d6ff40690ba9ae0 https://conda.anaconda.org/conda-forge/win-64/libbrotlienc-1.1.0-hcfcfb64_1.conda#71e890a0b361fd58743a13f77e1506b7 https://conda.anaconda.org/conda-forge/win-64/libpng-1.6.43-h19919ed_0.conda#77e398acc32617a0384553aea29e866b +https://conda.anaconda.org/conda-forge/win-64/libvorbis-1.3.7-h0e60522_0.tar.bz2#e1a22282de0169c93e4ffe6ce6acc212 https://conda.anaconda.org/conda-forge/win-64/libxml2-2.12.5-hc3477c8_0.conda#d8c3c1c8242db352f38cd1dc0bf44f77 https://conda.anaconda.org/conda-forge/win-64/m2w64-gcc-libs-5.3.0-7.tar.bz2#fe759119b8b3bfa720b8762c6fdc35de -https://conda.anaconda.org/conda-forge/win-64/sqlite-3.45.1-hcfcfb64_0.conda#3c6f2dc59bcde87ee1de006f22ecc40a -https://conda.anaconda.org/conda-forge/win-64/zlib-1.2.13-hcfcfb64_5.conda#a318e8622e11663f645cc7fa3260f462 +https://conda.anaconda.org/conda-forge/win-64/pcre2-10.42-h17e33f8_0.conda#59610c61da3af020289a806ec9c6a7fd +https://conda.anaconda.org/conda-forge/win-64/python-3.9.18-h4de0772_1_cpython.conda#c0bc0080c5ec044edae6dbfa97ab337f https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.5-h12be248_0.conda#792bb5da68bf0a6cac6a6072ecb8dbeb https://conda.anaconda.org/conda-forge/win-64/brotli-bin-1.1.0-hcfcfb64_1.conda#0105229d7c5fabaa840043a86c10ec64 -https://conda.anaconda.org/conda-forge/win-64/freetype-2.12.1-hdaf720e_2.conda#3761b23693f768dc75a8fd0a73ca053f -https://conda.anaconda.org/conda-forge/win-64/libhwloc-2.9.3-default_haede6df_1009.conda#87da045f6d26ce9fe20ad76a18f6a18a -https://conda.anaconda.org/conda-forge/win-64/libtiff-4.6.0-h6e2ebb7_2.conda#08d653b74ee2dec0131ad4259ffbb126 -https://conda.anaconda.org/conda-forge/win-64/pthread-stubs-0.4-hcd874cb_1001.tar.bz2#a1f820480193ea83582b13249a7e7bd9 -https://conda.anaconda.org/conda-forge/win-64/pypy3.9-7.3.15-h994e1e7_0.conda#f86b3b2eb4da9bd98dccc645ae4224ba -https://conda.anaconda.org/conda-forge/win-64/xorg-libxau-1.0.11-hcd874cb_0.conda#c46ba8712093cb0114404ae8a7582e1a -https://conda.anaconda.org/conda-forge/win-64/xorg-libxdmcp-1.1.3-hcd874cb_0.tar.bz2#46878ebb6b9cbd8afcf8088d7ef00ece -https://conda.anaconda.org/conda-forge/win-64/brotli-1.1.0-hcfcfb64_1.conda#f47f6db2528e38321fb00ae31674c133 -https://conda.anaconda.org/conda-forge/win-64/lcms2-2.16-h67d730c_0.conda#d3592435917b62a8becff3a60db674f6 -https://conda.anaconda.org/conda-forge/win-64/libxcb-1.15-hcd874cb_0.conda#090d91b69396f14afef450c285f9758c -https://conda.anaconda.org/conda-forge/win-64/openjpeg-2.5.2-h3d672ee_0.conda#7e7099ad94ac3b599808950cec30ad4e -https://conda.anaconda.org/conda-forge/win-64/python-3.9.18-1_73_pypy.conda#5cf0736ad57ababd38c03c9f15be1cf3 -https://conda.anaconda.org/conda-forge/win-64/tbb-2021.11.0-h91493d7_1.conda#21069f3ed16812f9f4f2700667b6ec86 https://conda.anaconda.org/conda-forge/noarch/certifi-2024.2.2-pyhd8ed1ab_0.conda#0876280e409658fc6f9e75d035960333 https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/win-64/cython-3.0.9-py39h3665ca7_0.conda#7d1341cbd501712db99ed8df851ccfbc +https://conda.anaconda.org/conda-forge/win-64/cython-3.0.9-py39h99910a6_0.conda#37456058d8e7e7e3e3f68771aff7e543 https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda#8d652ea2ee8eaee02ed8dc820bc794aa https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 +https://conda.anaconda.org/conda-forge/win-64/freetype-2.12.1-hdaf720e_2.conda#3761b23693f768dc75a8fd0a73ca053f https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 -https://conda.anaconda.org/conda-forge/win-64/kiwisolver-1.4.5-py39h314d263_1.conda#1507eb1add8775b41ab030de2289c071 -https://conda.anaconda.org/conda-forge/win-64/mkl-2024.0.0-h66d3029_49657.conda#006b65d9cd436247dfe053df772e041d +https://conda.anaconda.org/conda-forge/win-64/kiwisolver-1.4.5-py39h1f6ef14_1.conda#4fc5bd0a7b535252028c647cc27d6c87 +https://conda.anaconda.org/conda-forge/win-64/libclang13-15.0.7-default_h85b4d89_4.conda#c6b0181860717a08469a324c4180ff2d +https://conda.anaconda.org/conda-forge/win-64/libglib-2.78.4-h16e383f_0.conda#72dc4e1cdde0894015567c90f9c4e261 +https://conda.anaconda.org/conda-forge/win-64/libhwloc-2.9.3-default_haede6df_1009.conda#87da045f6d26ce9fe20ad76a18f6a18a +https://conda.anaconda.org/conda-forge/win-64/libtiff-4.6.0-h6e2ebb7_2.conda#08d653b74ee2dec0131ad4259ffbb126 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 -https://conda.anaconda.org/conda-forge/win-64/pillow-10.2.0-py39hf7d859a_0.conda#28921ab19a9dbf7d6f5b63d5ec3163c7 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.4.0-pyhd8ed1ab_0.conda#139e9feb65187e916162917bb2484976 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb +https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2#7205635cd71531943440fbfe3b6b5727 +https://conda.anaconda.org/conda-forge/win-64/pthread-stubs-0.4-hcd874cb_1001.tar.bz2#a1f820480193ea83582b13249a7e7bd9 +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda#b9a4dacf97241704529131a0dfc0494f https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.1-pyhd8ed1ab_0.conda#576de899521b7d43674ba3ef6eae9142 https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2#e5f25f8dbc060e9a8d912e432202afc2 https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.3.0-pyhc1e730c_0.conda#698d2d2b621640bddb9191f132967c9f https://conda.anaconda.org/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2#f832c45a477c78bebd107098db465095 https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5844808ffab9ebdb694585b50ba02a96 -https://conda.anaconda.org/conda-forge/win-64/tornado-6.4-py39h7a188e9_0.conda#1e8c1eb1bc2e7d8fd8296f24d95858c2 -https://conda.anaconda.org/conda-forge/win-64/unicodedata2-15.1.0-py39h7a188e9_0.conda#895188b9d9aa5962246b345b26091111 +https://conda.anaconda.org/conda-forge/win-64/tornado-6.4-py39ha55989b_0.conda#d8f52e8e1d02f9a5901f9224e2ddf98f +https://conda.anaconda.org/conda-forge/win-64/unicodedata2-15.1.0-py39ha55989b_0.conda#20ec896e8d97f2ff8be1124e624dc8f2 https://conda.anaconda.org/conda-forge/noarch/wheel-0.42.0-pyhd8ed1ab_0.conda#1cdea58981c5cbc17b51973bcaddcea7 +https://conda.anaconda.org/conda-forge/win-64/xorg-libxau-1.0.11-hcd874cb_0.conda#c46ba8712093cb0114404ae8a7582e1a +https://conda.anaconda.org/conda-forge/win-64/xorg-libxdmcp-1.1.3-hcd874cb_0.tar.bz2#46878ebb6b9cbd8afcf8088d7ef00ece https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a -https://conda.anaconda.org/conda-forge/win-64/coverage-7.4.3-py39h7a188e9_1.conda#c4b84afe8288f5a630559a8fc0aedfbb -https://conda.anaconda.org/conda-forge/win-64/fonttools-4.49.0-py39h7a188e9_0.conda#842e830fa2ec96d47fa1e3408f58963a +https://conda.anaconda.org/conda-forge/win-64/brotli-1.1.0-hcfcfb64_1.conda#f47f6db2528e38321fb00ae31674c133 +https://conda.anaconda.org/conda-forge/win-64/coverage-7.4.3-py39ha55989b_1.conda#c68e9c43ed91b369a592d5268c9dac71 +https://conda.anaconda.org/conda-forge/win-64/glib-tools-2.78.4-h12be248_0.conda#9e2a4c1cace3fbdeb11f20578484ddaf https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.2-pyhd8ed1ab_0.conda#6f4399795892835bd192ea210ca69447 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc -https://conda.anaconda.org/conda-forge/win-64/libblas-3.9.0-21_win64_mkl.conda#ebba3846d11201fe54277e4965ba5250 -https://conda.anaconda.org/conda-forge/win-64/mkl-devel-2024.0.0-h57928b3_49657.conda#385b9dcf11c859acc506dcb40451f904 +https://conda.anaconda.org/conda-forge/win-64/lcms2-2.16-h67d730c_0.conda#d3592435917b62a8becff3a60db674f6 +https://conda.anaconda.org/conda-forge/win-64/libclang-15.0.7-default_hde6756a_4.conda#a621ea4ac3f826d02441369e73e53800 +https://conda.anaconda.org/conda-forge/win-64/libxcb-1.15-hcd874cb_0.conda#090d91b69396f14afef450c285f9758c +https://conda.anaconda.org/conda-forge/win-64/openjpeg-2.5.2-h3d672ee_0.conda#7e7099ad94ac3b599808950cec30ad4e https://conda.anaconda.org/conda-forge/noarch/pip-24.0-pyhd8ed1ab_0.conda#f586ac1e56c8638b64f9c8122a7b8a67 https://conda.anaconda.org/conda-forge/noarch/pytest-8.0.2-pyhd8ed1ab_0.conda#40bd3ef942b9642a3eb20b0bbf92469b https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0-pyhd8ed1ab_0.conda#2cf4264fffb9e6eff6031c5b6884d61c +https://conda.anaconda.org/conda-forge/win-64/sip-6.7.12-py39h99910a6_0.conda#0cc5774390ada632ed7975203057c91c +https://conda.anaconda.org/conda-forge/win-64/tbb-2021.11.0-h91493d7_1.conda#21069f3ed16812f9f4f2700667b6ec86 +https://conda.anaconda.org/conda-forge/win-64/fonttools-4.49.0-py39ha55989b_0.conda#3db31ee7eada607a636bd6d6105f7919 +https://conda.anaconda.org/conda-forge/win-64/glib-2.78.4-h12be248_0.conda#0080f150ed83685497f841f4b70fca1f https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.2-pyhd8ed1ab_0.conda#c69b222e1a89945b80feb249d57d8949 -https://conda.anaconda.org/conda-forge/win-64/libcblas-3.9.0-21_win64_mkl.conda#38e5ec23bc2b62f9dd971143aa9dddb7 -https://conda.anaconda.org/conda-forge/win-64/liblapack-3.9.0-21_win64_mkl.conda#c4740f091cb75987390087934354a621 +https://conda.anaconda.org/conda-forge/win-64/mkl-2024.0.0-h66d3029_49657.conda#006b65d9cd436247dfe053df772e041d +https://conda.anaconda.org/conda-forge/win-64/pillow-10.2.0-py39h368b509_0.conda#706d6e5bbc4b5d2ac7b8a6077319294d +https://conda.anaconda.org/conda-forge/win-64/pyqt5-sip-12.12.2-py39h99910a6_5.conda#dffbcea794c524c471772a5f697c2aea https://conda.anaconda.org/conda-forge/noarch/pytest-cov-4.1.0-pyhd8ed1ab_0.conda#06eb685a3a0b146347a58dda979485da https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-3.5.0-pyhd8ed1ab_0.conda#d5f595da2daead898ca958ac62f0307b +https://conda.anaconda.org/conda-forge/win-64/gstreamer-1.22.9-hb4038d2_0.conda#0480eecdb44a71929d5e78bf1a8644fb +https://conda.anaconda.org/conda-forge/win-64/libblas-3.9.0-21_win64_mkl.conda#ebba3846d11201fe54277e4965ba5250 +https://conda.anaconda.org/conda-forge/win-64/mkl-devel-2024.0.0-h57928b3_49657.conda#385b9dcf11c859acc506dcb40451f904 +https://conda.anaconda.org/conda-forge/win-64/gst-plugins-base-1.22.9-h001b923_0.conda#304b9124de13767ea8c933f72f50b348 +https://conda.anaconda.org/conda-forge/win-64/libcblas-3.9.0-21_win64_mkl.conda#38e5ec23bc2b62f9dd971143aa9dddb7 +https://conda.anaconda.org/conda-forge/win-64/liblapack-3.9.0-21_win64_mkl.conda#c4740f091cb75987390087934354a621 https://conda.anaconda.org/conda-forge/win-64/liblapacke-3.9.0-21_win64_mkl.conda#a4844669ed07bb5b7f182e9ca4de2a70 -https://conda.anaconda.org/conda-forge/win-64/numpy-1.26.4-py39hd5b4b07_0.conda#b031bd8486c14d50e0007184d582dd7f +https://conda.anaconda.org/conda-forge/win-64/numpy-1.26.4-py39hddb5d58_0.conda#6e30ff8f2d3f59f45347dfba8bc22a04 +https://conda.anaconda.org/conda-forge/win-64/qt-main-5.15.8-h9e85ed6_19.conda#1e5fa5b05768a8eed9d8bb0bf5585b1f https://conda.anaconda.org/conda-forge/win-64/blas-devel-3.9.0-21_win64_mkl.conda#dfb57411138b9548b9d6c65f7fe6af32 -https://conda.anaconda.org/conda-forge/win-64/contourpy-1.2.0-py39h314d263_0.conda#757c2eb182c113ea1acccd5a4a1e5dcd -https://conda.anaconda.org/conda-forge/win-64/scipy-1.12.0-py39hd5b4b07_2.conda#c21ed783c77da3068d28c53fe9fc81d6 +https://conda.anaconda.org/conda-forge/win-64/contourpy-1.2.0-py39h1f6ef14_0.conda#9eeea323eacb6549cbb3df3d81181cb2 +https://conda.anaconda.org/conda-forge/win-64/pyqt-5.15.9-py39hb77abff_5.conda#5ed899124a51958336371ff01482b8fd +https://conda.anaconda.org/conda-forge/win-64/scipy-1.12.0-py39hddb5d58_2.conda#e421d27a09f9131514436f8233125766 https://conda.anaconda.org/conda-forge/win-64/blas-2.121-mkl.conda#87ae78b8197b890e5a429f91769dfac7 -https://conda.anaconda.org/conda-forge/win-64/matplotlib-base-3.8.3-py39h3a59091_0.conda#a1ee6167417a4b6324559028e33beec8 -https://conda.anaconda.org/conda-forge/win-64/matplotlib-3.8.3-py39h0d475fb_0.conda#67654e1e89909ac4ca2514fd3e68d291 +https://conda.anaconda.org/conda-forge/win-64/matplotlib-base-3.8.3-py39hf19769e_0.conda#e7a42adb568586ff4035d7ef2d06c4b1 +https://conda.anaconda.org/conda-forge/win-64/matplotlib-3.8.3-py39hcbf5309_0.conda#a4b5946f68ecaed034fa849b8d639e63 diff --git a/build_tools/azure/pymin_conda_forge_openblas_ubuntu_2204_linux-64_conda.lock b/build_tools/azure/pymin_conda_forge_openblas_ubuntu_2204_linux-64_conda.lock index 6fcf87cd1b8e2..5442920c586ff 100644 --- a/build_tools/azure/pymin_conda_forge_openblas_ubuntu_2204_linux-64_conda.lock +++ b/build_tools/azure/pymin_conda_forge_openblas_ubuntu_2204_linux-64_conda.lock @@ -4,84 +4,132 @@ @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2024.2.2-hbcca054_0.conda#2f4327a1cbe7f022401b236e915a5fef +https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2#0c96522c6bdaed4b1566d11387caaf45 +https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2#34893075a5c9e55cdafac56607368fc6 +https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2#4d59c254e01d9cde7957100457e2d5fb +https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-h77eed37_1.conda#6185f640c43843e5ad6fd1c5372c3f80 +https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda#7aca3059a1729aa76c597603f10b0dd3 https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_5.conda#f6f6600d18a4047b54f803cf708b868a -https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.9-4_pypy39_pp73.conda#c1b2f29111681a4036ed21eaa3f44620 +https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.9-4_cp39.conda#bfe4b3259a8ac6cdf0037752904da6a7 https://conda.anaconda.org/conda-forge/noarch/tzdata-2024a-h0c530f3_0.conda#161081fc7cec0bfda0d86d7cb595f8d8 +https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2#f766549260d6815b0c52253f1fb1bb29 +https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2#fee5683a3f04bd15cbd8318b096a27ab https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2#562b26ba2e19059551a811e72ab7f793 https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_5.conda#d4ff227c46917d3b4565302a2bbb276b +https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.11-hd590300_1.conda#0bb492cca54017ea314b809b1ee3a176 +https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2#d9c69a24ad678ffce24c6543a0176b00 https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda#69b8b6202a07720f448be700e300ccf4 +https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2#14947d8770185e5153fdd04d4673ed37 +https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2#8c54672728e8ec6aa6db90cf2806d220 +https://conda.anaconda.org/conda-forge/linux-64/icu-73.2-h59595ed_0.conda#cc47e1facc155f91abd89b11e48e72ff +https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.1-h166bdaf_0.tar.bz2#30186d27e2c9fa62b45fb1476b7200e3 +https://conda.anaconda.org/conda-forge/linux-64/lame-3.100-h166bdaf_1003.tar.bz2#a8832b479f93521a9e7b5b743803be51 https://conda.anaconda.org/conda-forge/linux-64/lerc-4.0.0-h27087fc_0.tar.bz2#76bbff344f0134279f225174e9064c8f https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.1.0-hd590300_1.conda#aec6c91c7371c26392a06708a73c70e5 https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.19-hd590300_0.conda#1635570038840ee3f9c71d22aa5b8b6d https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda#6305a3dd2752c76335295da4e581f2fd https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2#d645c6d2ac96843a2bfaccd2d62b3ac3 https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_5.conda#7a6bd7a12a4bd359e2afe6c0fa1acace +https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-hd590300_2.conda#d66573916ffcf376178462f1b61c941e https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-3.0.0-hd590300_1.conda#ea25936bb4080d843790b586850f82b8 +https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda#30fd6e37fe21f86f4bd26d6ee73eeec7 +https://conda.anaconda.org/conda-forge/linux-64/libogg-1.3.4-h7f98852_1.tar.bz2#6e8cc2173440d77708196c5b93771680 +https://conda.anaconda.org/conda-forge/linux-64/libopus-1.3.1-h7f98852_1.tar.bz2#15345e56d527b330e1cacbdf58676e8f +https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda#40b61aab5c7ba9ff276c41cfffe6b80b https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.conda#30de3fd9b3b602f7473f30e684eeea8c +https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda#5aa797f8787fe7a17d1b0821485b5adc https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda#f36c115f1ee199da648e0597ec2047ad +https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda#318b08df404f9c9be5712aaa5a6f0bb0 +https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.4-h59595ed_0.conda#3f1017b4141e943d9bc8739237f749e8 https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda#7dbaa197d7ba6032caf7ae7f32c1efa0 +https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda#da0ec11a6454ae19bff5b02ed881a2b1 https://conda.anaconda.org/conda-forge/linux-64/openssl-3.2.1-hd590300_0.conda#51a753e64a3027bd7e23a189b1f6e91e +https://conda.anaconda.org/conda-forge/linux-64/pixman-0.43.2-h59595ed_0.conda#71004cbf7924e19c02746ccde9fd7123 https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2#22dad4df6e8630e8dff2428f6f6a7036 https://conda.anaconda.org/conda-forge/linux-64/xorg-kbproto-1.0.7-h7f98852_1002.tar.bz2#4b230e8381279d76131116660f5a241a +https://conda.anaconda.org/conda-forge/linux-64/xorg-libice-1.1.1-hd590300_0.conda#b462a33c0be1421532f28bfe8f4a7514 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.11-hd590300_0.conda#2c80dc38fface310c9bd81b17037fee5 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxdmcp-1.1.3-h7f98852_0.tar.bz2#be93aabceefa2fac576e971aef407908 +https://conda.anaconda.org/conda-forge/linux-64/xorg-renderproto-0.11.1-h7f98852_1002.tar.bz2#06feff3d2634e3097ce2fe681474b534 https://conda.anaconda.org/conda-forge/linux-64/xorg-xextproto-7.3.0-h0b41bf4_1003.conda#bce9f945da8ad2ae9b1d7165a64d0f87 +https://conda.anaconda.org/conda-forge/linux-64/xorg-xf86vidmodeproto-2.3.1-h7f98852_1002.tar.bz2#3ceea9668625c18f19530de98b15d5b0 https://conda.anaconda.org/conda-forge/linux-64/xorg-xproto-7.0.31-h7f98852_1007.tar.bz2#b4a4381d54784606820704f7b5f05a15 https://conda.anaconda.org/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2#2161070d867d1b1204ea749c8eec4ef0 https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda#8b9b5aca60558d02ddaa09d599e55920 https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.1.0-hd590300_1.conda#f07002e225d7a60a694d42a7bf5ff53f https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.1.0-hd590300_1.conda#5fc11c6020d421960607d821310fcd4d +https://conda.anaconda.org/conda-forge/linux-64/libcap-2.69-h0f662aa_0.conda#25cb5999faa414e5ccb2c1388f62d3d5 +https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20191231-he28a2e2_2.tar.bz2#4d331e44109e3f0e19b4cb8f9b82f3e1 +https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda#a1cfcc585f0c42bf8d5546bb1dfb668d +https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.3-h59595ed_0.conda#ee48bf17cc83a00f59ca1494d5646869 https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_5.conda#e73e9cfd1191783392131e6238bdb3e9 +https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.48-h71f35ed_0.conda#4d18d86916705d352d5f4adfb7f0edd3 https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.43-h2797004_0.conda#009981dd9cfcaa4dbfa25ffaed86bcae https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.45.1-h2797004_0.conda#fc4ccadfbf6d4784de88c41704792562 +https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2#309dec04b70a3cc0f1e84a4013683bc0 https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda#33277193f5b92bad9fdd230eb700929c +https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.12.5-h232c23b_0.conda#c442ebfda7a475f5e78f1c8e45f1e919 +https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_6.conda#80bf3b277c120dd294b51d404b931a75 +https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.42-hcad00b1_0.conda#679c8961826aa4b50653bce17ee52abe https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4 https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda#d453b98d9c83e71da0741bb0ff4d76bc +https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda#93ee23f12bc2e684548181256edd2cf6 https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda#68c34ec6149623be41a1933ab996a209 https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda#04b88013080254850d6c01ed54810589 https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.1.0-hd590300_1.conda#39f910d205726805a958da408ca194ba https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda#9ae35c3d96db2c94ce0cef86efdfa2cb -https://conda.anaconda.org/conda-forge/linux-64/gdbm-1.18-h0a1914f_2.tar.bz2#b77bc399b07a19c00fe12fdc95ee0297 +https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 +https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.3-hd590300_0.conda#32d16ad533c59bb0a3c5ffaf16110829 +https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.4-h783c2da_0.conda#d86baf8740d1a906b9716f2a0bac2f2d https://conda.anaconda.org/conda-forge/linux-64/libhiredis-1.0.2-h2cc385e_0.tar.bz2#b34907d3a81a3cd8095ee83d174c074a +https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-hb3ce162_4.conda#8a35df3cbc0c8b12cc8af9473ae75eef https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.26-pthreads_h413a1c8_0.conda#760ae35415f5ba8b15d09df5afe8b23a +https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda#ef1910918dd895516a769ed36b5b3a4e https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda#55ed21669b2015f77c180feb1dd41930 https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-17.0.6-h4dfa4b3_0.conda#c1665f9c1c9f6c93d8b4e492a6a39056 -https://conda.anaconda.org/conda-forge/linux-64/sqlite-3.45.1-h2c6b66d_0.conda#93acf31b379acebada263b9bce3dc6ed +https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_6.conda#e87530d1b12dd7f4e0f856dc07358d60 +https://conda.anaconda.org/conda-forge/linux-64/nss-3.98-h1d7d5a4_0.conda#54b56c2fdf973656b748e0378900ec13 +https://conda.anaconda.org/conda-forge/linux-64/python-3.9.18-h0755675_1_cpython.conda#255a7002aeec7a067ff19b545aca6328 +https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-hd590300_1.conda#9bfac7ccd94d54fd21a0501296d60424 +https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.0-h8ee46fc_1.conda#632413adcd8bc16b515cab87a2932913 +https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.9-hd590300_1.conda#e995b155d938b6779da6ace6c6b13816 +https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.1-h8ee46fc_1.conda#90108a432fb5c6150ccfee3f03388656 https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.conda#49e482d882669206653b095f5206c05b +https://conda.anaconda.org/conda-forge/noarch/alabaster-0.7.16-pyhd8ed1ab_0.conda#def531a3ac77b7fb8c21d17bb5d0badb https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hd590300_1.conda#f27a24d46e3ea7b70a1f98e50c62508f +https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.1.0-py39h3d6467e_1.conda#c48418c8b35f1d59ae9ae1174812b40a https://conda.anaconda.org/conda-forge/linux-64/ccache-4.9.1-h1fcd64f_0.conda#3620f564bcf28c3524951b6f64f5c5ac -https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.16-hb7c19ff_0.conda#51bb7010fc86f70eee639b4bb7a894f5 -https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-21_linux64_openblas.conda#0ac9f44fc096772b0aa092119b00c3ca -https://conda.anaconda.org/conda-forge/linux-64/openblas-0.3.26-pthreads_h7a3da1a_0.conda#bda28edbedb0ae5f0a9d3ebcb4290c1d -https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.2-h488ebb8_0.conda#7f2e286780f072ed750df46dc2631138 -https://conda.anaconda.org/conda-forge/linux-64/pypy3.9-7.3.15-h9557127_0.conda#0a12c57c7fefeb6407c1ff47aa0b35df -https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-21_linux64_openblas.conda#4a3816d06451c4946e2db26b86472cb6 -https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-21_linux64_openblas.conda#1a42f305615c3867684e049e85927531 -https://conda.anaconda.org/conda-forge/linux-64/python-3.9.18-1_73_pypy.conda#6e0143cd3dd940d3004cd857e37ccd81 -https://conda.anaconda.org/conda-forge/noarch/alabaster-0.7.16-pyhd8ed1ab_0.conda#def531a3ac77b7fb8c21d17bb5d0badb -https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.1.0-py39hc10206b_1.conda#6ce3c8b3e1006bb5b927426ab0fc4196 https://conda.anaconda.org/conda-forge/noarch/certifi-2024.2.2-pyhd8ed1ab_0.conda#0876280e409658fc6f9e75d035960333 https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.2-pyhd8ed1ab_0.conda#7f4a9e3fcff3f6356ae99244a014da6a https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5cd86562580f274031ede6aa6aa24441 -https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.9-py39hc10206b_0.conda#6ec46cd4c47257085675df035236f31e -https://conda.anaconda.org/conda-forge/linux-64/docutils-0.20.1-py39h4162558_3.conda#50537992e6429628d78f7ff883f46f34 +https://conda.anaconda.org/conda-forge/linux-64/cython-3.0.9-py39h3d6467e_0.conda#82f4e576cbe74921703f91d3b43c8a73 +https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2#ecfff944ba3960ecb334b9a2663d708d +https://conda.anaconda.org/conda-forge/linux-64/docutils-0.20.1-py39hf3d152e_3.conda#09a48956e1c155907fd0d626f3e80f2e https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda#8d652ea2ee8eaee02ed8dc820bc794aa https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 +https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda#0f69b688f52ff6da70bccb7ff7001d1d +https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.4-hfc55251_0.conda#d184ba1bf15a2bbb3be6118c90fd487d https://conda.anaconda.org/conda-forge/noarch/idna-3.6-pyhd8ed1ab_0.conda#1a76f09108576397c41c0b0c5bd84134 https://conda.anaconda.org/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2#7de5386c8fea29e76b303f37dde4c352 https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 -https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py39ha90811c_1.conda#25edffabcb0760fc1821597c4ce920db -https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-21_linux64_openblas.conda#854c3c762b25ccfbaa1aa24ee34288e3 -https://conda.anaconda.org/conda-forge/linux-64/markupsafe-2.1.5-py39hf860d4a_0.conda#184bd7833ec48c4b2961db1419b405c2 +https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py39h7633fee_1.conda#c9f74d717e5a2847a9f8b779c54130f2 +https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.16-hb7c19ff_0.conda#51bb7010fc86f70eee639b4bb7a894f5 +https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-21_linux64_openblas.conda#0ac9f44fc096772b0aa092119b00c3ca +https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_ha2b6cf4_4.conda#898e0dd993afbed0d871b60c2eb33b83 +https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h4637d8d_4.conda#d4529f4dff3057982a7617c7ac58fde3 +https://conda.anaconda.org/conda-forge/linux-64/libpq-16.2-h33b98f1_0.conda#fe0e297faf462ee579c95071a5211665 +https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-255-h3516f8a_1.conda#3366af27f0b593544a6cd453c7932ac5 +https://conda.anaconda.org/conda-forge/linux-64/markupsafe-2.1.5-py39hd1e30aa_0.conda#9a9a22eb1f83c44953319ee3b027769f https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_0.conda#776a8dd9e824f77abac30e6ef43a8f7a https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2#2ba8498c1018c1e9c61eb99b973dfe19 -https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py39h6dedee3_0.conda#557d64563e84ff21b14f586c7f662b7f +https://conda.anaconda.org/conda-forge/linux-64/openblas-0.3.26-pthreads_h7a3da1a_0.conda#bda28edbedb0ae5f0a9d3ebcb4290c1d +https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.2-h488ebb8_0.conda#7f2e286780f072ed750df46dc2631138 https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 -https://conda.anaconda.org/conda-forge/linux-64/pillow-10.2.0-py39hcf8a34e_0.conda#8a406ee5a979c2591f4c734d6fe4a958 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.4.0-pyhd8ed1ab_0.conda#139e9feb65187e916162917bb2484976 +https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2#7205635cd71531943440fbfe3b6b5727 https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda#b9a4dacf97241704529131a0dfc0494f https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2#2a7de29fb590ca14b5243c4c812c8025 https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2024.1-pyhd8ed1ab_0.conda#98206ea9954216ee7540f0c773f2104d https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda#3eeeeb9e4827ace8c0c1419c85d590ad @@ -91,33 +139,55 @@ https://conda.anaconda.org/conda-forge/noarch/snowballstemmer-2.2.0-pyhd8ed1ab_0 https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_0.conda#da1d979339e2714c30a8e806a33ec087 https://conda.anaconda.org/conda-forge/noarch/tabulate-0.9.0-pyhd8ed1ab_1.tar.bz2#4759805cce2d914c38472f70bf4d8bcb https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.3.0-pyhc1e730c_0.conda#698d2d2b621640bddb9191f132967c9f +https://conda.anaconda.org/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2#f832c45a477c78bebd107098db465095 https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5844808ffab9ebdb694585b50ba02a96 -https://conda.anaconda.org/conda-forge/linux-64/tornado-6.4-py39hf860d4a_0.conda#e7fded713fb466e1e0670afce1761b47 +https://conda.anaconda.org/conda-forge/linux-64/tornado-6.4-py39hd1e30aa_0.conda#1e865e9188204cdfb1fd2531780add88 https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.10.0-pyha770c72_0.conda#16ae769069b380646c47142d719ef466 -https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-15.1.0-py39hf860d4a_0.conda#f699157518d28d00c87542b4ec1273be +https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-15.1.0-py39hd1e30aa_0.conda#1da984bbb6e765743e13388ba7b7b2c8 +https://conda.anaconda.org/conda-forge/linux-64/xcb-util-image-0.4.0-h8ee46fc_1.conda#9d7bcddf49cbf727730af10e71022c73 +https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.41-hd590300_0.conda#81f740407b45e3f9047b3174fa94eb9e +https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.4-h0b41bf4_2.conda#82b6df12252e6f32402b96dacc656fec +https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.11-hd590300_0.conda#ed67c36f215b310412b2af935bf3e530 https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a https://conda.anaconda.org/conda-forge/noarch/babel-2.14.0-pyhd8ed1ab_0.conda#9669586875baeced8fc30c0826c3270e -https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-21_linux64_openblas.conda#77cefbfb4d47ba8cafef8e3f768a4538 -https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py39ha90811c_0.conda#f3b2afc64bf0cbe901a9b00d44611c61 -https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.49.0-py39hf860d4a_0.conda#fa0d38d44f69d5c8ca476beb24fb456e +https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda#f907bb958910dc404647326ca80c263e +https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.49.0-py39hd1e30aa_0.conda#dd1b02484cc8c31d4093111a82b6efb2 +https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.4-hfc55251_0.conda#f36a7b2420c3fc3c48a3d609841d8fee https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-7.0.1-pyha770c72_0.conda#746623a787e06191d80a2133e5daff17 https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.2-pyhd8ed1ab_0.conda#6f4399795892835bd192ea210ca69447 https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.3-pyhd8ed1ab_0.conda#e7d8df6509ba635247ff9aea31134262 https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc +https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-21_linux64_openblas.conda#4a3816d06451c4946e2db26b86472cb6 +https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_hb11cfb5_4.conda#c90f4cbb57839c98fef8f830e4b9972f +https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-21_linux64_openblas.conda#1a42f305615c3867684e049e85927531 +https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.6.0-hd429924_1.conda#1dbcc04604fdf1e526e6d1b0b6938396 https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda#93a8e71256479c62074356ef6ebf501b +https://conda.anaconda.org/conda-forge/linux-64/pillow-10.2.0-py39had0adad_0.conda#2972754dc054bb079d1d121918b5126f +https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-hb77b528_5.conda#ac902ff3c1c6d750dd0dfc93a974ab74 https://conda.anaconda.org/conda-forge/noarch/pytest-8.0.2-pyhd8ed1ab_0.conda#40bd3ef942b9642a3eb20b0bbf92469b https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0-pyhd8ed1ab_0.conda#2cf4264fffb9e6eff6031c5b6884d61c -https://conda.anaconda.org/conda-forge/linux-64/scipy-1.12.0-py39h6dedee3_2.conda#6c5d74bac41838f4377dfd45085e1fec +https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.12-py39h3d6467e_0.conda#e667a3ab0df62c54e60e1843d2e6defb https://conda.anaconda.org/conda-forge/noarch/urllib3-2.2.1-pyhd8ed1ab_0.conda#08807a87fa7af10754d46f63b368e016 -https://conda.anaconda.org/conda-forge/linux-64/blas-2.121-openblas.conda#4a279792fd8861a15705516a52872eb6 +https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.9-h98fc4e7_0.conda#bcc7157b06fce7f5e055402a8135dfd8 +https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.3.0-h3d44ed6_0.conda#5a6f6c00ef982a9bc83558d9ac8f64a0 https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.2-pyhd8ed1ab_0.conda#c69b222e1a89945b80feb249d57d8949 -https://conda.anaconda.org/conda-forge/linux-64/pandas-2.2.1-py39h3417b97_0.conda#30cc61452d2282645600558e119dbc5e -https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.0.1-py39h5fd064f_1.conda#e364cfb3ffb590ccef24b5a92389e751 +https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-21_linux64_openblas.conda#854c3c762b25ccfbaa1aa24ee34288e3 +https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py39h474f0d3_0.conda#aa265f5697237aa13cc10f53fa8acc4f +https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.12.2-py39h3d6467e_5.conda#93aff412f3e49fdb43361c0215cbd72d https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-3.5.0-pyhd8ed1ab_0.conda#d5f595da2daead898ca958ac62f0307b https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda#a30144e4156cdbb236f99ebb49828f8b https://conda.anaconda.org/conda-forge/noarch/rich-13.7.1-pyhd8ed1ab_0.conda#ba445bf767ae6f0d959ff2b40c20912b -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.3-py39h4e7d633_0.conda#0b15e2f7764b1f64a5f4156ba20b090e -https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.3-py39h4162558_0.conda#ccb335b71aedcf24c36b2546741fb5f8 +https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-21_linux64_openblas.conda#77cefbfb4d47ba8cafef8e3f768a4538 +https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py39h7633fee_0.conda#ed71ad3e30eb03da363fb797419cce98 +https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.9-h8e1006c_0.conda#614b81f8ed66c56b640faee7076ad14a +https://conda.anaconda.org/conda-forge/linux-64/pandas-2.2.1-py39hddac248_0.conda#85293a042c24a08e71b7608ee66b6134 +https://conda.anaconda.org/conda-forge/linux-64/scipy-1.12.0-py39h474f0d3_2.conda#6ab241b2023730f6b41712dc1b503afa +https://conda.anaconda.org/conda-forge/linux-64/blas-2.121-openblas.conda#4a279792fd8861a15705516a52872eb6 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.3-py39he9076e7_0.conda#5456bdfe5809ebf5689eda6c808b686e +https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.0.1-py39hda80f44_1.conda#6df47699edb4d8d3365de2d189a456bc +https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h5810be5_19.conda#54866f708d43002a514d0b9b0f84bc11 +https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.9-py39h52134e7_5.conda#e1f148e57d071b09187719df86f513c1 +https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.3-py39hf3d152e_0.conda#983f5b77540eb5aa00238e72ec9b1dfb https://conda.anaconda.org/conda-forge/noarch/numpydoc-1.6.0-pyhd8ed1ab_0.conda#191b8a622191a403700d16a2008e4e29 https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-applehelp-1.0.8-pyhd8ed1ab_0.conda#611a35a27914fac3aa37611a6fe40bb5 https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-devhelp-1.0.6-pyhd8ed1ab_0.conda#d7e4954df0d3aea2eacc7835ad12671d diff --git a/build_tools/azure/pypy3_linux-64_conda.lock b/build_tools/azure/pypy3_linux-64_conda.lock index d606f8f752a04..e81dab95c0bb9 100644 --- a/build_tools/azure/pypy3_linux-64_conda.lock +++ b/build_tools/azure/pypy3_linux-64_conda.lock @@ -72,7 +72,7 @@ https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py39h6dedee3_0.cond https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6 https://conda.anaconda.org/conda-forge/linux-64/pillow-10.2.0-py39hcf8a34e_0.conda#8a406ee5a979c2591f4c734d6fe4a958 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.4.0-pyhd8ed1ab_0.conda#139e9feb65187e916162917bb2484976 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda#b9a4dacf97241704529131a0dfc0494f https://conda.anaconda.org/conda-forge/noarch/pypy-7.3.15-1_pypy39.conda#a418a6c16bd6f7ed56b92194214791a0 https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.1-pyhd8ed1ab_0.conda#576de899521b7d43674ba3ef6eae9142 https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2#e5f25f8dbc060e9a8d912e432202afc2 diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index f7d06a5b5365b..5a4341cac9e62 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -50,7 +50,7 @@ fi if [[ "$PYTEST_XDIST_VERSION" != "none" ]]; then XDIST_WORKERS=$(python -c "import joblib; print(joblib.cpu_count(only_physical_cores=True))") - TEST_CMD="$TEST_CMD" + TEST_CMD="$TEST_CMD -n$XDIST_WORKERS" fi if [[ -n "$SELECTED_TESTS" ]]; then @@ -61,5 +61,5 @@ if [[ -n "$SELECTED_TESTS" ]]; then fi set -x -eval "$TEST_CMD -x --pyargs sklearn" +eval "$TEST_CMD --maxfail=10 --pyargs sklearn" set +x diff --git a/build_tools/circle/doc_linux-64_conda.lock b/build_tools/circle/doc_linux-64_conda.lock index 9f8690e3050c4..f3df0cfdcf32b 100644 --- a/build_tools/circle/doc_linux-64_conda.lock +++ b/build_tools/circle/doc_linux-64_conda.lock @@ -44,6 +44,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.19-hd590300_0.conda https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda#6305a3dd2752c76335295da4e581f2fd https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2#d645c6d2ac96843a2bfaccd2d62b3ac3 https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_5.conda#7a6bd7a12a4bd359e2afe6c0fa1acace +https://conda.anaconda.org/conda-forge/linux-64/libhwy-1.0.7-h00ab1b0_0.conda#271c74eadb196f7ae588d95a11e9acd3 https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-hd590300_2.conda#d66573916ffcf376178462f1b61c941e https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-3.0.0-hd590300_1.conda#ea25936bb4080d843790b586850f82b8 https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda#30fd6e37fe21f86f4bd26d6ee73eeec7 @@ -110,6 +111,7 @@ https://conda.anaconda.org/conda-forge/linux-64/gxx_impl_linux-64-12.3.0-he2b93b https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.3-hd590300_0.conda#32d16ad533c59bb0a3c5ffaf16110829 https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.4-h783c2da_0.conda#d86baf8740d1a906b9716f2a0bac2f2d +https://conda.anaconda.org/conda-forge/linux-64/libjxl-0.10.1-h5b01ea3_0.conda#6a6a96a3cd66ff9514a22f1eea91e303 https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-hb3ce162_4.conda#8a35df3cbc0c8b12cc8af9473ae75eef https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.26-pthreads_h413a1c8_0.conda#760ae35415f5ba8b15d09df5afe8b23a https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda#ef1910918dd895516a769ed36b5b3a4e @@ -225,7 +227,7 @@ https://conda.anaconda.org/conda-forge/noarch/rich-13.7.1-pyhd8ed1ab_0.conda#ba4 https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-21_linux64_openblas.conda#77cefbfb4d47ba8cafef8e3f768a4538 https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py39h7633fee_0.conda#ed71ad3e30eb03da363fb797419cce98 https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.9-h8e1006c_0.conda#614b81f8ed66c56b640faee7076ad14a -https://conda.anaconda.org/conda-forge/linux-64/imagecodecs-2024.1.1-py39hf9b8f0e_0.conda#9ddd29852457d1152ca235eb87bc74fb +https://conda.anaconda.org/conda-forge/linux-64/imagecodecs-2024.1.1-py39hf9b8f0e_1.conda#14a906ccb9a6f385fc4fe3c393246b49 https://conda.anaconda.org/conda-forge/noarch/imageio-2.34.0-pyh4b66e23_0.conda#b8853659d596f967c661f544dd89ede7 https://conda.anaconda.org/conda-forge/linux-64/pandas-2.2.1-py39hddac248_0.conda#85293a042c24a08e71b7608ee66b6134 https://conda.anaconda.org/conda-forge/noarch/patsy-0.5.6-pyhd8ed1ab_0.conda#a5b55d1cb110cdcedc748b5c3e16e687 diff --git a/build_tools/circle/doc_min_dependencies_linux-64_conda.lock b/build_tools/circle/doc_min_dependencies_linux-64_conda.lock index da3cfe1c0d6b7..b0c693c119be7 100644 --- a/build_tools/circle/doc_min_dependencies_linux-64_conda.lock +++ b/build_tools/circle/doc_min_dependencies_linux-64_conda.lock @@ -147,7 +147,7 @@ https://conda.anaconda.org/conda-forge/noarch/pluggy-1.4.0-pyhd8ed1ab_0.conda#13 https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2#7205635cd71531943440fbfe3b6b5727 https://conda.anaconda.org/conda-forge/linux-64/psutil-5.9.8-py39hd1e30aa_0.conda#ec86403fde8793ac1c36f8afa3d15902 https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda#140a7f159396547e9799aa98f9f0742e -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda#b9a4dacf97241704529131a0dfc0494f https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2#2a7de29fb590ca14b5243c4c812c8025 https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda#3eeeeb9e4827ace8c0c1419c85d590ad https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.1-py39hd1e30aa_1.conda#37218233bcdc310e4fde6453bc1b40d8 From c52e828adaaf2257da08d5c75155eb75cac9d0fd Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 3 Apr 2024 16:11:26 +0200 Subject: [PATCH 44/55] refactor into callback context --- sklearn/base.py | 61 ++----------- sklearn/callback/__init__.py | 10 +- sklearn/callback/_base.py | 108 ---------------------- sklearn/callback/_callback_context.py | 83 +++++++++++++++++ sklearn/callback/_computation_tree.py | 126 -------------------------- sklearn/callback/_progressbar.py | 121 +++++++++++++++---------- sklearn/callback/_task_tree.py | 95 +++++++++++++++++++ sklearn/callback/tests/_utils.py | 86 +++++++++++++----- 8 files changed, 326 insertions(+), 364 deletions(-) create mode 100644 sklearn/callback/_callback_context.py delete mode 100644 sklearn/callback/_computation_tree.py create mode 100644 sklearn/callback/_task_tree.py diff --git a/sklearn/base.py b/sklearn/base.py index 86b01b4e87439..3b2cb91c81693 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -15,8 +15,7 @@ from . import __version__ from ._config import config_context, get_config -from .callback import BaseCallback, build_computation_tree -from .callback._base import default_data +from .callback import BaseCallback, CallbackContext from .exceptions import InconsistentVersionWarning from .utils import _IS_32BIT from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr @@ -698,60 +697,20 @@ def _set_callbacks(self, callbacks): return self - def _eval_callbacks_on_fit_begin(self, *, tree_structure, data): - """Evaluate the `on_fit_begin` method of the callbacks. - - The computation tree is also built at this point. - - This method should be called after all data and parameters validation. - - Parameters - ---------- - tree_structure : list of dict - A description of the nested steps of computation of the estimator to build - the computation tree. It's a list of dict with keys "stage" and - "n_children". - - data : dict - Dictionary containing the training and validation data. The keys are - "X_train", "y_train", "sample_weight_train", "X_val", "y_val", - "sample_weight_val". The values are the corresponding data. If a key is - missing, the corresponding value is None. + def _init_callback_context(self): + """Initialize the callback context for the estimator. Returns ------- - root : ComputationNode instance - The root of the computation tree. + callback_fit_ctx : CallbackContext + The callback context for the estimator. """ - self._computation_tree = build_computation_tree( - estimator_name=self.__class__.__name__, - tree_structure=tree_structure, - parent=getattr(self, "_parent_node", None), - ) - if not hasattr(self, "_skl_callbacks"): - return self._computation_tree - - # Only call the on_fit_begin method of callbacks that are not - # propagated from a meta-estimator. - for callback in self._skl_callbacks: - if not callback._is_propagated(estimator=self): - callback.on_fit_begin(estimator=self, data={**default_data, **data}) - - return self._computation_tree - - def _eval_callbacks_on_fit_end(self): - """Evaluate the `on_fit_end` method of the callbacks.""" - if not hasattr(self, "_skl_callbacks") or not hasattr( - self, "_computation_tree" - ): - return + self._callback_fit_ctx = CallbackContext( + callbacks=getattr(self, "_skl_callbacks", []), + ) - # Only call the on_fit_end method of callbacks that are not - # propagated from a meta-estimator. - for callback in self._skl_callbacks: - if not callback._is_propagated(estimator=self): - callback.on_fit_end() + return self._callback_fit_ctx @property def _repr_html_(self): @@ -1558,7 +1517,7 @@ def wrapper(estimator, *args, **kwargs): try: return fit_method(estimator, *args, **kwargs) finally: - estimator._eval_callbacks_on_fit_end() + estimator._callback_fit_ctx.eval_on_fit_end() return wrapper diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py index 42fd228069c3e..0e5b7ddc0999b 100644 --- a/sklearn/callback/__init__.py +++ b/sklearn/callback/__init__.py @@ -6,14 +6,14 @@ # License: BSD 3 clause # Authors: the scikit-learn developers -from ._base import BaseCallback, CallbackPropagatorMixin -from ._computation_tree import ComputationNode, build_computation_tree +from ._base import BaseCallback +from ._callback_context import CallbackContext from ._progressbar import ProgressBar +from ._task_tree import TaskNode __all__ = [ "BaseCallback", - "CallbackPropagatorMixin", - "build_computation_tree", - "ComputationNode", + "CallbackContext", + "TaskNode", "ProgressBar", ] diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index d9be734ba251a..aae5059ce9062 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -3,16 +3,6 @@ from abc import ABC, abstractmethod -# default values for the data dict passed to the callbacks -default_data = { - "X_train": None, - "y_train": None, - "sample_weight_train": None, - "X_val": None, - "y_val": None, - "sample_weight_val": None, -} - class BaseCallback(ABC): """Abstract class for the callbacks""" @@ -96,12 +86,6 @@ def auto_propagate(self): """ return False - def _is_propagated(self, estimator): - """Check if this callback attached to estimator has been propagated from a - meta-estimator. - """ - return self.auto_propagate and hasattr(estimator, "_parent_node") - # TODO: This is not used yet but will be necessary for next callbacks # Uncomment when needed # @property @@ -111,95 +95,3 @@ def _is_propagated(self, estimator): # @property # def request_from_reconstruction_attributes(self): # return False - - -class CallbackPropagatorMixin: - """Mixin class for meta-estimators expected to propagate callbacks.""" - - def _propagate_callbacks(self, sub_estimator, *, parent_node): - """Propagate the auto-propagated callbacks to a sub-estimator. - - Parameters - ---------- - sub_estimator : estimator instance - The sub-estimator to propagate the callbacks to. - - parent_node : ComputationNode instance - The computation node in this estimator to set as `parent_node` to the - computation tree of the sub-estimator. It must be the node where the fit - method of the sub-estimator is called. - """ - if hasattr(sub_estimator, "_skl_callbacks") and any( - callback.auto_propagate for callback in sub_estimator._skl_callbacks - ): - bad_callbacks = [ - callback.__class__.__name__ - for callback in sub_estimator._skl_callbacks - if callback.auto_propagate - ] - raise TypeError( - f"The sub-estimators ({sub_estimator.__class__.__name__}) of a" - f" meta-estimator ({self.__class__.__name__}) can't have" - f" auto-propagated callbacks ({bad_callbacks})." - " Set them directly on the meta-estimator." - ) - - if not hasattr(self, "_skl_callbacks"): - return - - propagated_callbacks = [ - callback for callback in self._skl_callbacks if callback.auto_propagate - ] - - if not propagated_callbacks: - return - - sub_estimator._parent_node = parent_node - - sub_estimator._set_callbacks( - getattr(sub_estimator, "_skl_callbacks", []) + propagated_callbacks - ) - - -# Not a method of BaseEstimator because it might not be directly called from fit but -# by a non-method function called by fit -def _eval_callbacks_on_fit_iter_end(**kwargs): - """Evaluate the `on_fit_iter_end` method of the callbacks. - - This function must be called at the end of each computation node. - - Parameters - ---------- - kwargs : dict - Arguments passed to the callback. - - Returns - ------- - stop : bool - Whether or not to stop the fit at this node. - """ - estimator = kwargs.get("estimator") - node = kwargs.get("node") - - if not hasattr(estimator, "_skl_callbacks") or node is None: - return False - - # stopping_criterion and reconstruction_attributes can be costly to compute. - # They are passed as lambdas for lazy evaluation. We only actually - # compute them if a callback requests it. - # TODO: This is not used yet but will be necessary for next callbacks - # Uncomment when needed - # if any(cb.request_stopping_criterion for cb in estimator._skl_callbacks): - # kwarg = kwargs.pop("stopping_criterion", lambda: None)() - # kwargs["stopping_criterion"] = kwarg - - # if any( - # cb.request_from_reconstruction_attributes - # for cb in estimator._skl_callbacks - # ): - # kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() - # kwargs["from_reconstruction_attributes"] = kwarg - - return any( - callback.on_fit_iter_end(**kwargs) for callback in estimator._skl_callbacks - ) diff --git a/sklearn/callback/_callback_context.py b/sklearn/callback/_callback_context.py new file mode 100644 index 0000000000000..6d96f2d7fd6b5 --- /dev/null +++ b/sklearn/callback/_callback_context.py @@ -0,0 +1,83 @@ +# License: BSD 3 clause +# Authors: the scikit-learn developers + +from ._task_tree import TaskNode + + +class CallbackContext: + def __init__(self, callbacks): + self.callbacks = callbacks + + def eval_on_fit_begin(self, *, estimator, max_subtasks=0, data): + self.task_node = TaskNode( + estimator_name=estimator.__class__.__name__, + name="fit", + max_subtasks=max_subtasks, + ) + + parent_task_node = getattr(estimator, "_parent_task_node", None) + if parent_task_node is not None: + self.task_node._merge_with(parent_task_node) + + for callback in self.callbacks: + # Only call the on_fit_begin method of callbacks that are not + # propagated from a meta-estimator. + if not (callback.auto_propagate and parent_task_node is not None): + callback.on_fit_begin(estimator, data=data) + + return self + + def eval_on_fit_iter_end(self, **kwargs): + return any( + callback.on_fit_iter_end(task_node=self.task_node, **kwargs) + for callback in self.callbacks + ) + + def eval_on_fit_end(self): + for callback in self.callbacks: + # Only call the on_fit_end method of callbacks that are not + # propagated from a meta-estimator. + if not (callback.auto_propagate and self.task_node.parent is not None): + callback.on_fit_end(task_node=self.task_node) + + def subcontext(self, task="", max_subtasks=0, idx=0, sub_estimator=None): + sub_ctx = CallbackContext(callbacks=self.callbacks) + + sub_ctx.task_node = self.task_node._add_child( + name=task, + max_subtasks=max_subtasks, + idx=idx, + ) + + if sub_estimator is not None: + sub_ctx._propagate_callbacks(sub_estimator=sub_estimator) + + return sub_ctx + + def _propagate_callbacks(self, sub_estimator): + bad_callbacks = [ + callback.__class__.__name__ + for callback in getattr(sub_estimator, "_skl_callbacks", []) + if callback.auto_propagate + ] + + if bad_callbacks: + raise TypeError( + f"The sub-estimator ({sub_estimator.__class__.__name__}) of a" + f" meta-estimator ({self.task_node.estimator_name}) can't have" + f" auto-propagated callbacks ({bad_callbacks})." + " Register them directly on the meta-estimator." + ) + + callbacks_to_propagate = [ + callback for callback in self.callbacks if callback.auto_propagate + ] + + if not callbacks_to_propagate: + return + + sub_estimator._parent_task_node = self.task_node + + sub_estimator._set_callbacks( + getattr(sub_estimator, "_skl_callbacks", []) + callbacks_to_propagate + ) diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py deleted file mode 100644 index ba32e04400810..0000000000000 --- a/sklearn/callback/_computation_tree.py +++ /dev/null @@ -1,126 +0,0 @@ -# License: BSD 3 clause -# Authors: the scikit-learn developers - - -class ComputationNode: - """A node in a computation tree. - - Parameters - ---------- - estimator_name : str - The name of the estimator this computation node belongs to. - - stage : str, default=None - A description of the stage this computation node belongs to. - None means it's a leaf. - - n_children : int, default=None - The number of its children. None means it's a leaf. - - idx : int, default=0 - The index of this node among its siblings. - - parent : ComputationNode instance, default=None - The parent node. None means this is the root. - - Attributes - ---------- - children : list - The list of its children nodes. For a leaf, it's an empty list - """ - - def __init__( - self, - estimator_name, - stage=None, - n_children=None, - idx=0, - parent=None, - ): - # estimator_name and description are tuples because an estimator can be - # a sub-estimator of a meta-estimator. In that case, the root of the computation - # tree of the sub-estimator and a leaf of the computation tree of the - # meta-estimator correspond to the same computation step. Therefore, both - # nodes are merged into a single node, retaining the information of both. - self.estimator_name = (estimator_name,) - self.stage = (stage,) - - self.parent = parent - self.n_children = n_children - self.idx = idx - - self.children = [] - - @property - def depth(self): - """The depth of this node in the computation tree.""" - return 0 if self.parent is None else self.parent.depth + 1 - - @property - def path(self): - """List of all the nodes in the path from the root to this node.""" - return [self] if self.parent is None else self.parent.path + [self] - - def __iter__(self): - """Pre-order depth-first traversal""" - yield self - for node in self.children: - yield from node - - -def build_computation_tree(estimator_name, tree_structure, parent=None, idx=0): - """Build the computation tree from the description of the levels. - - Parameters - ---------- - estimator_name : str - The name of the estimator this computation tree belongs to. - - tree_structure : list of dict - The description of the stages of the computation tree. Each dict must have - the following keys: - - stage: str - A human readable description of the stage. - - n_children: int or None - The number of its children. None means it's a leaf. - - parent : ComputationNode instance, default=None - The parent node. None means this is the root. - - idx : int, default=0 - The index of this node among its siblings. - - Returns - ------- - computation_tree : ComputationNode instance - The root of the computation tree. - """ - this_stage = tree_structure[0] - - node = ComputationNode( - estimator_name=estimator_name, - parent=parent, - n_children=this_stage["n_children"], - stage=this_stage["stage"], - idx=idx, - ) - - if parent is not None and parent.n_children is None: - # parent node is a leaf of the computation tree of an outer estimator. It means - # that this node is the root of the computation tree of this estimator. They - # both correspond the same computation step, so we merge both nodes. - node.stage = parent.stage + node.stage - node.estimator_name = parent.estimator_name + node.estimator_name - node.parent = parent.parent - node.idx = parent.idx - parent.parent.children[node.idx] = node - - if node.n_children is not None: - for i in range(node.n_children): - node.children.append( - build_computation_tree( - estimator_name, tree_structure[1:], parent=node, idx=i - ) - ) - - return node diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index ee5669acf5c60..bba2c7ec9e3e8 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -21,10 +21,11 @@ def on_fit_begin(self, estimator, data): self.progress_monitor = _RichProgressMonitor(queue=self._queue) self.progress_monitor.start() - def on_fit_iter_end(self, *, estimator, node, **kwargs): - self._queue.put(node) + def on_fit_iter_end(self, *, task_node, **kwargs): + self._queue.put(task_node) - def on_fit_end(self): + def on_fit_end(self, *, task_node): + self._queue.put(task_node) self._queue.put(None) self.progress_monitor.join() @@ -55,7 +56,7 @@ class _RichProgressMonitor(Thread): """Thread monitoring the progress of an estimator with rich based display. The display is a list of nested rich tasks using `rich.Progress`. There is one for - each non-leaf node in the computation tree of the estimator. + each non-leaf node in the task tree of the estimator. Parameters ---------- @@ -75,103 +76,123 @@ def run(self): finished_style=Style(color="cyan"), ), TextColumn("[bright_magenta]{task.percentage:>3.0f}%"), - TimeRemainingColumn(), + TimeRemainingColumn(elapsed_when_finished=True), auto_refresh=False, ) # Holds the root of the tree of rich tasks (i.e. progress bars) that will be # created dynamically as the computation tree of the estimator is traversed. - self.root_task = None + self.root_rich_task = None with self.progress_ctx: - while node := self.queue.get(): - self._update_task_tree(node) + while task_node := self.queue.get(): + self._update_task_tree(task_node) self._update_tasks() self.progress_ctx.refresh() - def _update_task_tree(self, node): + def _update_task_tree(self, task_node): """Update the tree of tasks from a new node.""" - curr_task, parent_task = None, None + curr_rich_task, parent_task = None, None - for curr_node in node.path: + for curr_node in task_node.path: if curr_node.parent is None: # root node - if self.root_task is None: - self.root_task = TaskNode(curr_node, progress_ctx=self.progress_ctx) - curr_task = self.root_task + if self.root_rich_task is None: + self.root_rich_task = RichTaskNode( + curr_node, progress_ctx=self.progress_ctx + ) + curr_rich_task = self.root_rich_task elif curr_node.idx not in parent_task.children: - curr_task = TaskNode( + curr_rich_task = RichTaskNode( curr_node, progress_ctx=self.progress_ctx, parent=parent_task ) - parent_task.children[curr_node.idx] = curr_task + parent_task.children[curr_node.idx] = curr_rich_task else: # task already exists - curr_task = parent_task.children[curr_node.idx] - parent_task = curr_task + curr_rich_task = parent_task.children[curr_node.idx] + parent_task = curr_rich_task - # Mark the deepest task as finished (this is the one corresponding the + # Mark the deepest task as finished (this is the one corresponding to the # computation node that we just get from the queue). - curr_task.finished = True + curr_rich_task.finished = True def _update_tasks(self): """Loop through the tasks in their display order and update their progress.""" self.progress_ctx._ordered_tasks = [] - for task_node in self.root_task: - task = self.progress_ctx.tasks[task_node.task_id] + for rich_task_node in self.root_rich_task: + task = self.progress_ctx.tasks[rich_task_node.task_id] - if task_node.parent is not None and task_node.parent.finished: - # If the parent task is finished, then mark the current task as - # finished. It can happen if an estimator doesn't reach its max number - # of iterations (e.g. early stopping). - completed = task.total - else: - completed = sum(t.finished for t in task_node.children.values()) + total = task.total + + if rich_task_node.finished: + # It's possible that a task finishes without reaching its total + # (e.g. early stopping). We mark it as 100% completed. - if completed == task.total: - task_node.finished = True + if task.total is None: + # Indeterminate task is finished. Set total to an arbitrary + # value to render its completion as 100%. + completed = total = 1 + else: + completed = total + else: + completed = sum(t.finished for t in rich_task_node.children.values()) self.progress_ctx.update( - task_node.task_id, completed=completed, refresh=False + rich_task_node.task_id, completed=completed, total=total, refresh=False ) self.progress_ctx._ordered_tasks.append(task) -class TaskNode: +class RichTaskNode: """A node in the tree of rich tasks. Parameters ---------- - node : `ComputationNode` instance - The computation node this task corresponds to. + task_node : `TaskNode` instance + The task node of an estimator this task corresponds to. progress_ctx : `rich.Progress` instance The progress context to which this task belongs. - parent : `TaskNode` instance + parent : `RichTaskNode` instance The parent of this task. + + Attributes + ---------- + finished : bool + Whether the task is finished. + + task_id : int + The ID of the task in the Progress context. + + children : dict + A mapping from the index of a child to the child node `{idx: RichTaskNode}`. + For a leaf, it's an empty dictionary. """ - def __init__(self, node, progress_ctx, parent=None): - self.node_idx = node.idx + def __init__(self, task_node, progress_ctx, parent=None): + self.node_idx = task_node.idx self.parent = parent self.children = {} self.finished = False - if node.n_children is not None: - description = self._format_task_description(node) - self.task_id = progress_ctx.add_task(description, total=node.n_children) + if task_node.max_subtasks != 0: + description = self._format_task_description(task_node) + self.task_id = progress_ctx.add_task( + description, total=task_node.max_subtasks + ) - def _format_task_description(self, node): - """Return a formatted description for the task of the node.""" + def _format_task_description(self, task_node): + """Return a formatted description for the task.""" colors = ["bright_magenta", "cyan", "dark_orange"] - indent = f"{' ' * (node.depth)}" - style = f"[{colors[(node.depth)%len(colors)]}]" + indent = f"{' ' * (task_node.depth)}" + style = f"[{colors[(task_node.depth)%len(colors)]}]" - description = f"{node.estimator_name[0]} - {node.stage[0]}" - if node.parent is not None: - description += f" #{node.idx}" - if len(node.estimator_name) == 2: - description += f" | {node.estimator_name[1]} - {node.stage[1]}" + description = f"{task_node.estimator_name[0]} - {task_node.name[0]}" + if task_node.parent is not None: + description += f" #{task_node.idx}" + if len(task_node.estimator_name) == 2: + description += f" | {task_node.estimator_name[1]} - {task_node.name[1]}" return f"{style}{indent}{description}" diff --git a/sklearn/callback/_task_tree.py b/sklearn/callback/_task_tree.py new file mode 100644 index 0000000000000..62ee396a0414b --- /dev/null +++ b/sklearn/callback/_task_tree.py @@ -0,0 +1,95 @@ +# License: BSD 3 clause +# Authors: the scikit-learn developers + + +class TaskNode: + """A node in a task tree. + + Parameters + ---------- + estimator_name : str + The name of the estimator this task node belongs to. + + name : str, default=None + The name of the task this node represents. + + max_subtasks : int or None, default=0 + The maximum number of its children. 0 means it's a leaf. + None means the number of children is not known in advance. + + idx : int, default=0 + The index of this node among its siblings. + + parent : TaskNode instance, default=None + The parent node. None means this is the root. + + Note that the root task of an estimator can become an intermediate node + of a meta-estimator. + + Attributes + ---------- + children : dict + A mapping from the index of a child to the child node `{idx: TaskNode}`. + For a leaf, it's an empty dictionary. + """ + + def __init__( + self, + estimator_name, + name="fit", + max_subtasks=0, + idx=0, + parent=None, + ): + # estimator_name and name are tuples because an estimator can be + # a sub-estimator of a meta-estimator. In that case, the root of the task + # tree of the sub-estimator and a leaf of the task tree of the + # meta-estimator correspond to the same computation step. Therefore, both + # nodes are merged into a single node, retaining the information of both. + self.estimator_name = (estimator_name,) + self.name = (name,) + + self.max_subtasks = max_subtasks + self.idx = idx + self.parent = parent + + # Children stored in a dict indexed by their idx for easy access because the + # order in which self.children is populated is not guaranteed to follow the + # order of the idx du to parallelism. + self.children = {} + + def _add_child(self, name, max_subtasks, idx): + child = TaskNode( + estimator_name=self.estimator_name[-1], + name=name, + max_subtasks=max_subtasks, + idx=idx, + parent=self, + ) + self.children[idx] = child + + return child + + def _merge_with(self, task_node): + self.parent = task_node.parent + self.idx = task_node.idx + task_node.parent.children[self.idx] = self + + self.name = task_node.name + self.name + self.estimator_name = task_node.estimator_name + self.estimator_name + + @property + def depth(self): + """The depth of this node in the computation tree.""" + return 0 if self.parent is None else self.parent.depth + 1 + + @property + def path(self): + """List of all the nodes in the path from the root to this node.""" + return [self] if self.parent is None else self.parent.path + [self] + + def __iter__(self): + """Pre-order depth-first traversal""" + yield self + for node in self.children.values(): + yield from node diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 0e51d1c165b2e..460f2666f3b87 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -1,9 +1,10 @@ # License: BSD 3 clause # Authors: the scikit-learn developers +import time + from sklearn.base import BaseEstimator, _fit_context, clone -from sklearn.callback import BaseCallback, CallbackPropagatorMixin -from sklearn.callback._base import _eval_callbacks_on_fit_iter_end +from sklearn.callback import BaseCallback from sklearn.utils.parallel import Parallel, delayed @@ -43,18 +44,20 @@ def __init__(self, max_iter=20): @_fit_context(prefer_skip_nested_validation=False) def fit(self, X, y): - root = self._eval_callbacks_on_fit_begin( - tree_structure=[ - {"stage": "fit", "n_children": self.max_iter}, - {"stage": "iter", "n_children": None}, - ], + callback_ctx = self._init_callback_context().eval_on_fit_begin( + estimator=self, + max_subtasks=self.max_iter, data={"X_train": X, "y_train": y}, ) for i in range(self.max_iter): - if _eval_callbacks_on_fit_iter_end( + subcontext = callback_ctx.subcontext(idx=i) + + time.sleep(0.05) # Computation intensive task + + if subcontext.eval_on_fit_iter_end( estimator=self, - node=root.children[i], + data={"X_train": X, "y_train": y}, ): break @@ -63,7 +66,37 @@ def fit(self, X, y): return self -class MetaEstimator(BaseEstimator, CallbackPropagatorMixin): +class WhileEstimator(BaseEstimator): + _parameter_constraints: dict = {} + + @_fit_context(prefer_skip_nested_validation=False) + def fit(self, X, y): + callback_ctx = self._init_callback_context().eval_on_fit_begin( + estimator=self, max_subtasks=None, data={"X_train": X, "y_train": y} + ) + + i = 0 + + while True: + subcontext = callback_ctx.subcontext(idx=i) + + time.sleep(0.05) # Computation intensive task + + if subcontext.eval_on_fit_iter_end( + estimator=self, + data={"X_train": X, "y_train": y}, + ): + break + + if i == 20: + break + + i += 1 + + return self + + +class MetaEstimator(BaseEstimator): _parameter_constraints: dict = {} def __init__( @@ -77,29 +110,34 @@ def __init__( @_fit_context(prefer_skip_nested_validation=False) def fit(self, X, y): - root = self._eval_callbacks_on_fit_begin( - tree_structure=[ - {"stage": "fit", "n_children": self.n_outer}, - {"stage": "outer", "n_children": self.n_inner}, - {"stage": "inner", "n_children": None}, - ], - data={"X_train": X, "y_train": y}, + callback_ctx = self._init_callback_context().eval_on_fit_begin( + estimator=self, max_subtasks=self.n_outer, data={"X_train": X, "y_train": y} ) Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( - delayed(_func)(self, self.estimator, X, y, node) - for _, node in enumerate(root.children) + delayed(_func)(i, self, self.estimator, X, y, callback_ctx=callback_ctx) + for i in range(self.n_outer) ) return self -def _func(meta_estimator, inner_estimator, X, y, parent_node): - for _, node in enumerate(parent_node.children): +def _func(outer_idx, meta_estimator, inner_estimator, X, y, *, callback_ctx): + outer_ctx = callback_ctx.subcontext( + task="outer", max_subtasks=meta_estimator.n_inner, idx=outer_idx + ) + + for i in range(meta_estimator.n_inner): est = clone(inner_estimator) - meta_estimator._propagate_callbacks(est, parent_node=node) + inner_ctx = outer_ctx.subcontext(task="inner", idx=i, sub_estimator=est) est.fit(X, y) - _eval_callbacks_on_fit_iter_end(estimator=meta_estimator, node=node) + inner_ctx.eval_on_fit_iter_end( + estimator=meta_estimator, + data={"X_train": X, "y_train": y}, + ) - _eval_callbacks_on_fit_iter_end(estimator=meta_estimator, node=parent_node) + outer_ctx.eval_on_fit_iter_end( + estimator=meta_estimator, + data={"X_train": X, "y_train": y}, + ) From 619bf9fe5f6eddc34ec588a126fa29181e3e5722 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 3 Apr 2024 16:19:14 +0200 Subject: [PATCH 45/55] cln base --- sklearn/callback/_base.py | 41 ++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index aae5059ce9062..93f9a7422f099 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -17,19 +17,14 @@ def on_fit_begin(self, estimator, *, data): The estimator the callback is set on. data : dict - Dictionary containing the training and validation data. The keys are - "X_train", "y_train", "sample_weight_train", "X_val", "y_val", - "sample_weight_val". The values are the corresponding data. If a key is - missing, the corresponding value is None. + Dictionary containing the training and validation data. The possible + keys are "X_train", "y_train", "sample_weight_train", "X_val", "y_val" + and "sample_weight_val". """ @abstractmethod - def on_fit_end(self): - """Method called at the end of the fit method of the estimator.""" - - @abstractmethod - def on_fit_iter_end(self, estimator, node, **kwargs): - """Method called at the end of each computation node of the estimator. + def on_fit_iter_end(self, estimator, task_node, **kwargs): + """Method called at the end of each task of the estimator. Parameters ---------- @@ -37,8 +32,8 @@ def on_fit_iter_end(self, estimator, node, **kwargs): The caller estimator. It might differ from the estimator passed to the `on_fit_begin` method for auto-propagated callbacks. - node : ComputationNode instance - The caller computation node. + task_node : TaskNode instance + The caller task node. **kwargs : dict arguments passed to the callback. Possible keys are @@ -74,6 +69,18 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) Whether or not to stop the current level of iterations at this node. """ + @abstractmethod + def on_fit_end(self, task_node): + """Method called at the end of the fit method of the estimator. + + Parameters + ---------- + task_node : TaskNode instance + The task node corresponding to the whole `fit` task. This is usually the + root of the task tree of the estimator but it can be an intermediate node + if the estimator is a sub-estimator of a meta-estimator. + """ + @property def auto_propagate(self): """Whether or not this callback should be propagated to sub-estimators. @@ -85,13 +92,3 @@ def auto_propagate(self): the meta-estimator and its sub-estimators. """ return False - - # TODO: This is not used yet but will be necessary for next callbacks - # Uncomment when needed - # @property - # def request_stopping_criterion(self): - # return False - - # @property - # def request_from_reconstruction_attributes(self): - # return False From 390065c11bdb5472ba58cf4485c0303e7d707bde Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 3 Apr 2024 16:19:33 +0200 Subject: [PATCH 46/55] lint --- sklearn/callback/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 93f9a7422f099..d8a2e9471cec1 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -72,7 +72,7 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) @abstractmethod def on_fit_end(self, task_node): """Method called at the end of the fit method of the estimator. - + Parameters ---------- task_node : TaskNode instance From 183b74ed0587ac0ca23eab106716877a207b81f9 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 3 Apr 2024 16:33:43 +0200 Subject: [PATCH 47/55] fix test progressbar --- sklearn/callback/tests/_utils.py | 4 ++-- sklearn/callback/tests/test_progressbar.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 460f2666f3b87..8fae15ec1e63e 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -53,7 +53,7 @@ def fit(self, X, y): for i in range(self.max_iter): subcontext = callback_ctx.subcontext(idx=i) - time.sleep(0.05) # Computation intensive task + time.sleep(0.001) # Computation intensive task if subcontext.eval_on_fit_iter_end( estimator=self, @@ -80,7 +80,7 @@ def fit(self, X, y): while True: subcontext = callback_ctx.subcontext(idx=i) - time.sleep(0.05) # Computation intensive task + time.sleep(0.001) # Computation intensive task if subcontext.eval_on_fit_iter_end( estimator=self, diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py index 20aab8f4c4ab7..a084ec1b32b2b 100644 --- a/sklearn/callback/tests/test_progressbar.py +++ b/sklearn/callback/tests/test_progressbar.py @@ -9,16 +9,17 @@ from sklearn.utils._optional_dependencies import check_rich_support from sklearn.utils._testing import SkipTest -from ._utils import Estimator, MetaEstimator +from ._utils import Estimator, WhileEstimator, MetaEstimator @pytest.mark.parametrize("n_jobs", [1, 2]) @pytest.mark.parametrize("prefer", ["threads", "processes"]) -def test_progressbar(n_jobs, prefer, capsys): +@pytest.mark.parametrize("InnerEstimator", [Estimator, WhileEstimator]) +def test_progressbar(n_jobs, prefer, InnerEstimator, capsys): """Check the output of the progress bars and their completion.""" pytest.importorskip("rich") - est = Estimator() + est = InnerEstimator() meta_est = MetaEstimator(est, n_jobs=n_jobs, prefer=prefer) meta_est._set_callbacks(ProgressBar()) meta_est.fit(None, None) @@ -39,7 +40,7 @@ def test_progressbar(n_jobs, prefer, capsys): def test_progressbar_requires_rich_error(): """Check that we raise an informative error when rich is not installed.""" try: - check_rich_support("test_fetch_openml_requires_pandas") + check_rich_support("test_progressbar_requires_rich_error") except ImportError: err_msg = "Progressbar requires rich" with pytest.raises(ImportError, match=err_msg): From 34ca96dc696f6f89424fb6e7bf93c8c9ea168c37 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 3 Apr 2024 18:10:15 +0200 Subject: [PATCH 48/55] begin migrating tests --- sklearn/callback/_task_tree.py | 21 +++++- ...ck_methods.py => test_callback_context.py} | 9 +++ .../callback/tests/test_computation_tree.py | 73 ------------------- sklearn/callback/tests/test_progressbar.py | 2 +- sklearn/callback/tests/test_task_tree.py | 60 +++++++++++++++ 5 files changed, 87 insertions(+), 78 deletions(-) rename sklearn/callback/tests/{test_base_estimator_callback_methods.py => test_callback_context.py} (92%) delete mode 100644 sklearn/callback/tests/test_computation_tree.py create mode 100644 sklearn/callback/tests/test_task_tree.py diff --git a/sklearn/callback/_task_tree.py b/sklearn/callback/_task_tree.py index 62ee396a0414b..72e1763ebeb4e 100644 --- a/sklearn/callback/_task_tree.py +++ b/sklearn/callback/_task_tree.py @@ -17,8 +17,8 @@ class TaskNode: The maximum number of its children. 0 means it's a leaf. None means the number of children is not known in advance. - idx : int, default=0 - The index of this node among its siblings. + idx : int, default=None + The index of this node among its siblings. None means this is the root. parent : TaskNode instance, default=None The parent node. None means this is the root. @@ -38,7 +38,7 @@ def __init__( estimator_name, name="fit", max_subtasks=0, - idx=0, + idx=None, parent=None, ): # estimator_name and name are tuples because an estimator can be @@ -58,7 +58,20 @@ def __init__( # order of the idx du to parallelism. self.children = {} - def _add_child(self, name, max_subtasks, idx): + def _add_child(self, *, name, max_subtasks, idx): + if idx in self.children: + raise ValueError( + f"Child of task node {self.name} of estimator {self.estimator_name} " + f"with index {idx} already exists." + ) + + if len(self.children) == self.max_subtasks: + raise ValueError( + f"Cannot add child to task node {self.name} of estimator " + f"{self.estimator_name} because it already has its maximum " + f"number of children ({self.max_subtasks})." + ) + child = TaskNode( estimator_name=self.estimator_name[-1], name=name, diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_callback_context.py similarity index 92% rename from sklearn/callback/tests/test_base_estimator_callback_methods.py rename to sklearn/callback/tests/test_callback_context.py index 838f117e65812..3dd4d2f99d919 100644 --- a/sklearn/callback/tests/test_base_estimator_callback_methods.py +++ b/sklearn/callback/tests/test_callback_context.py @@ -42,6 +42,15 @@ def test_set_callbacks_error(callbacks): estimator._set_callbacks(callbacks) +def test_init_callback_context(): + """Sanity check for the `_init_callback_context` method.""" + estimator = Estimator() + callback_ctx = estimator._init_callback_context() + + assert hasattr(estimator, "_callback_fit_context") + assert hasattr(callback_ctx, "callbacks") + + def test_propagate_callbacks(): """Sanity check for the `_propagate_callbacks` method.""" not_propagated_callback = TestingCallback() diff --git a/sklearn/callback/tests/test_computation_tree.py b/sklearn/callback/tests/test_computation_tree.py deleted file mode 100644 index de2442d4851a3..0000000000000 --- a/sklearn/callback/tests/test_computation_tree.py +++ /dev/null @@ -1,73 +0,0 @@ -# License: BSD 3 clause -# Authors: the scikit-learn developers - -import numpy as np - -from sklearn.callback import build_computation_tree - -TREE_STRUCTURE = [ - {"stage": "stage0", "n_children": 3}, - {"stage": "stage1", "n_children": 5}, - {"stage": "stage2", "n_children": 7}, - {"stage": "stage3", "n_children": None}, -] - - -def test_computation_tree(): - """Check the construction of the computation tree.""" - computation_tree = build_computation_tree( - estimator_name="estimator", tree_structure=TREE_STRUCTURE - ) - assert computation_tree.estimator_name == ("estimator",) - assert computation_tree.parent is None - assert computation_tree.idx == 0 - - assert len(computation_tree.children) == computation_tree.n_children == 3 - assert [node.idx for node in computation_tree.children] == list(range(3)) - - for node1 in computation_tree.children: - assert len(node1.children) == 5 - assert [n.idx for n in node1.children] == list(range(5)) - - for node2 in node1.children: - assert len(node2.children) == 7 - assert [n.idx for n in node2.children] == list(range(7)) - - for node3 in node2.children: - assert not node3.children - - -def test_n_nodes(): - """Check that the number of node in a computation tree corresponds to what we expect - from the level descriptions. - """ - computation_tree = build_computation_tree( - estimator_name="", tree_structure=TREE_STRUCTURE - ) - - n_children_per_level = [stage["n_children"] for stage in TREE_STRUCTURE[:-1]] - expected_n_nodes = 1 + np.sum(np.cumprod(n_children_per_level)) - - actual_n_nodes = sum(1 for _ in computation_tree) - - assert actual_n_nodes == expected_n_nodes - - -def test_path(): - """Check that the path from the root to a node is correct.""" - computation_tree = build_computation_tree( - estimator_name="", tree_structure=TREE_STRUCTURE - ) - - assert computation_tree.path == [computation_tree] - - node = computation_tree.children[1].children[2].children[3] - expected_path = [ - computation_tree, - computation_tree.children[1], - computation_tree.children[1].children[2], - node, - ] - assert node.path == expected_path - - assert all(node.depth == i for i, node in enumerate(expected_path)) diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py index a084ec1b32b2b..59ef911d94d71 100644 --- a/sklearn/callback/tests/test_progressbar.py +++ b/sklearn/callback/tests/test_progressbar.py @@ -9,7 +9,7 @@ from sklearn.utils._optional_dependencies import check_rich_support from sklearn.utils._testing import SkipTest -from ._utils import Estimator, WhileEstimator, MetaEstimator +from ._utils import Estimator, MetaEstimator, WhileEstimator @pytest.mark.parametrize("n_jobs", [1, 2]) diff --git a/sklearn/callback/tests/test_task_tree.py b/sklearn/callback/tests/test_task_tree.py new file mode 100644 index 0000000000000..5cb9a1df96bc0 --- /dev/null +++ b/sklearn/callback/tests/test_task_tree.py @@ -0,0 +1,60 @@ +# License: BSD 3 clause +# Authors: the scikit-learn developers + +import numpy as np + +from sklearn.callback import TaskNode + + +def _make_task_tree(n_children, n_grandchildren): + root = TaskNode( + estimator_name="estimator", name="root task", max_subtasks=n_children + ) + + for i in range(n_children): + child = root._add_child(name="child task", max_subtasks=n_grandchildren, idx=i) + + for j in range(n_grandchildren): + child._add_child(name="grandchild task", max_subtasks=0, idx=j) + + return root + + +def test_task_tree(): + root = _make_task_tree(n_children=3, n_grandchildren=5) + + assert root.max_subtasks == 3 + assert root.idx is None + assert root.parent is None + + assert len(root.children) == 3 + assert all(len(child.children) == 5 for child in root.children.values()) + + # 1 root, 3 children, 3 * 5 grandchildren + expected_n_nodes = np.sum(np.cumprod([1, 3, 5])) + actual_n_nodes = sum(1 for _ in root) + assert actual_n_nodes == expected_n_nodes + + +def test_path(): + root = _make_task_tree(n_children=3, n_grandchildren=5) + + assert root.path == [root] + + # pick a node + node = root.children[1].children[2] + expected_path = [root, root.children[1], node] + assert node.path == expected_path + + +def test_depth(): + root = _make_task_tree(n_children=3, n_grandchildren=5) + + assert root.depth == 0 + + assert all(child.depth == 1 for child in root.children.values()) + assert all( + grandchild.depth == 2 + for child in root.children.values() + for grandchild in child.children.values() + ) From a2d497559260e4cf19a60cf5080cccc7c5f3f6a0 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 4 Apr 2024 23:02:44 +0200 Subject: [PATCH 49/55] improve callback context API --- sklearn/base.py | 17 ++- sklearn/callback/_base.py | 7 +- sklearn/callback/_callback_context.py | 72 +++++++----- sklearn/callback/_progressbar.py | 33 +++--- sklearn/callback/_task_tree.py | 130 +++++++++++---------- sklearn/callback/tests/_utils.py | 53 +++++---- sklearn/callback/tests/test_progressbar.py | 7 +- sklearn/callback/tests/test_task_tree.py | 61 ++++++---- 8 files changed, 221 insertions(+), 159 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 3b2cb91c81693..18dcb7ef8eafe 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -697,17 +697,30 @@ def _set_callbacks(self, callbacks): return self - def _init_callback_context(self): + def _init_callback_context(self, task_name="fit"): """Initialize the callback context for the estimator. + Parameters + ---------- + task_name : str, default='fit' + The name of the root task. + Returns ------- callback_fit_ctx : CallbackContext The callback context for the estimator. """ + # We don't initialize the callback context during _set_callbacks but in fit + # because in the future we might want to have callbacks in predict/transform + # which would require their own context. self._callback_fit_ctx = CallbackContext( callbacks=getattr(self, "_skl_callbacks", []), + estimator_name=self.__class__.__name__, + task_name=task_name, + parent_estimator_task_node=getattr( + self, "_parent_estimator_task_node", None + ), ) return self._callback_fit_ctx @@ -1517,7 +1530,7 @@ def wrapper(estimator, *args, **kwargs): try: return fit_method(estimator, *args, **kwargs) finally: - estimator._callback_fit_ctx.eval_on_fit_end() + estimator._callback_fit_ctx.eval_on_fit_end(estimator=estimator) return wrapper diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index d8a2e9471cec1..5ded089d6b54e 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -14,7 +14,7 @@ def on_fit_begin(self, estimator, *, data): Parameters ---------- estimator : estimator instance - The estimator the callback is set on. + The estimator the callback is registered on. data : dict Dictionary containing the training and validation data. The possible @@ -70,11 +70,14 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) """ @abstractmethod - def on_fit_end(self, task_node): + def on_fit_end(self, estimator, task_node): """Method called at the end of the fit method of the estimator. Parameters ---------- + estimator : estimator instance + The estimator the callback is registered on. + task_node : TaskNode instance The task node corresponding to the whole `fit` task. This is usually the root of the task tree of the estimator but it can be an intermediate node diff --git a/sklearn/callback/_callback_context.py b/sklearn/callback/_callback_context.py index 6d96f2d7fd6b5..a5d1f1de1eb9b 100644 --- a/sklearn/callback/_callback_context.py +++ b/sklearn/callback/_callback_context.py @@ -5,56 +5,70 @@ class CallbackContext: - def __init__(self, callbacks): + def __init__( + self, + callbacks, + estimator_name="", + task_name="", + task_id=0, + max_tasks=1, + parent_task_node=None, + parent_estimator_task_node=None, + ): self.callbacks = callbacks + self.estimator_name = estimator_name - def eval_on_fit_begin(self, *, estimator, max_subtasks=0, data): self.task_node = TaskNode( - estimator_name=estimator.__class__.__name__, - name="fit", - max_subtasks=max_subtasks, + task_name=task_name, + task_id=task_id, + max_tasks=max_tasks, + estimator_name=self.estimator_name, ) - parent_task_node = getattr(estimator, "_parent_task_node", None) if parent_task_node is not None: - self.task_node._merge_with(parent_task_node) + # This task is a subtask of another task of a same estimator + parent_task_node._add_child(self.task_node) + elif parent_estimator_task_node is not None: + # This task is the root task of the estimator which itself corresponds to + # a leaf task of a meta-estimator. Both tasks actually represent the same + # task so we merge both task nodes into a single task node, attaching the + # task tree of the sub-estimator to the task tree of the meta-estimator on + # the way. + self.task_node._merge_with(parent_estimator_task_node) + + def subcontext(self, task_name="", task_id=0, max_tasks=1): + return CallbackContext( + callbacks=self.callbacks, + estimator_name=self.estimator_name, + task_name=task_name, + task_id=task_id, + max_tasks=max_tasks, + parent_task_node=self.task_node, + ) + def eval_on_fit_begin(self, estimator, *, data): for callback in self.callbacks: # Only call the on_fit_begin method of callbacks that are not # propagated from a meta-estimator. - if not (callback.auto_propagate and parent_task_node is not None): + if not (callback.auto_propagate and self.task_node.parent is not None): callback.on_fit_begin(estimator, data=data) return self - def eval_on_fit_iter_end(self, **kwargs): + def eval_on_fit_iter_end(self, estimator, **kwargs): return any( - callback.on_fit_iter_end(task_node=self.task_node, **kwargs) + callback.on_fit_iter_end(estimator, self.task_node, **kwargs) for callback in self.callbacks ) - def eval_on_fit_end(self): + def eval_on_fit_end(self, estimator): for callback in self.callbacks: # Only call the on_fit_end method of callbacks that are not # propagated from a meta-estimator. if not (callback.auto_propagate and self.task_node.parent is not None): - callback.on_fit_end(task_node=self.task_node) - - def subcontext(self, task="", max_subtasks=0, idx=0, sub_estimator=None): - sub_ctx = CallbackContext(callbacks=self.callbacks) - - sub_ctx.task_node = self.task_node._add_child( - name=task, - max_subtasks=max_subtasks, - idx=idx, - ) + callback.on_fit_end(estimator, task_node=self.task_node) - if sub_estimator is not None: - sub_ctx._propagate_callbacks(sub_estimator=sub_estimator) - - return sub_ctx - - def _propagate_callbacks(self, sub_estimator): + def propagate_callbacks(self, sub_estimator): bad_callbacks = [ callback.__class__.__name__ for callback in getattr(sub_estimator, "_skl_callbacks", []) @@ -76,8 +90,10 @@ def _propagate_callbacks(self, sub_estimator): if not callbacks_to_propagate: return - sub_estimator._parent_task_node = self.task_node + sub_estimator._parent_estimator_task_node = self.task_node sub_estimator._set_callbacks( getattr(sub_estimator, "_skl_callbacks", []) + callbacks_to_propagate ) + + return self diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index bba2c7ec9e3e8..58616055c7100 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -16,15 +16,15 @@ class ProgressBar(BaseCallback): def __init__(self): check_rich_support("Progressbar") - def on_fit_begin(self, estimator, data): + def on_fit_begin(self, estimator, *, data): self._queue = Manager().Queue() self.progress_monitor = _RichProgressMonitor(queue=self._queue) self.progress_monitor.start() - def on_fit_iter_end(self, *, task_node, **kwargs): + def on_fit_iter_end(self, estimator, task_node, **kwargs): self._queue.put(task_node) - def on_fit_end(self, *, task_node): + def on_fit_end(self, estimator, task_node): self._queue.put(task_node) self._queue.put(None) self.progress_monitor.join() @@ -92,7 +92,7 @@ def run(self): def _update_task_tree(self, task_node): """Update the tree of tasks from a new node.""" - curr_rich_task, parent_task = None, None + curr_rich_task, parent_rich_task = None, None for curr_node in task_node.path: if curr_node.parent is None: # root node @@ -101,14 +101,14 @@ def _update_task_tree(self, task_node): curr_node, progress_ctx=self.progress_ctx ) curr_rich_task = self.root_rich_task - elif curr_node.idx not in parent_task.children: + elif curr_node.task_id not in parent_rich_task.children: curr_rich_task = RichTaskNode( - curr_node, progress_ctx=self.progress_ctx, parent=parent_task + curr_node, progress_ctx=self.progress_ctx, parent=parent_rich_task ) - parent_task.children[curr_node.idx] = curr_rich_task + parent_rich_task.children[curr_node.task_id] = curr_rich_task else: # task already exists - curr_rich_task = parent_task.children[curr_node.idx] - parent_task = curr_rich_task + curr_rich_task = parent_rich_task.children[curr_node.task_id] + parent_rich_task = curr_rich_task # Mark the deepest task as finished (this is the one corresponding to the # computation node that we just get from the queue). @@ -170,7 +170,6 @@ class RichTaskNode: """ def __init__(self, task_node, progress_ctx, parent=None): - self.node_idx = task_node.idx self.parent = parent self.children = {} self.finished = False @@ -188,13 +187,15 @@ def _format_task_description(self, task_node): indent = f"{' ' * (task_node.depth)}" style = f"[{colors[(task_node.depth)%len(colors)]}]" - description = f"{task_node.estimator_name[0]} - {task_node.name[0]}" - if task_node.parent is not None: - description += f" #{task_node.idx}" - if len(task_node.estimator_name) == 2: - description += f" | {task_node.estimator_name[1]} - {task_node.name[1]}" + task_desc = f"{task_node.estimator_name} - {task_node.task_name}" + id_mark = f" #{task_node.task_id}" if task_node.parent is not None else "" + prev_task_desc = ( + f"{task_node.prev_estimator_name} - {task_node.prev_task_name} | " + if task_node.prev_estimator_name is not None + else "" + ) - return f"{style}{indent}{description}" + return f"{style}{indent}{prev_task_desc}{task_desc}{id_mark}" def __iter__(self): """Pre-order depth-first traversal, excluding leaves.""" diff --git a/sklearn/callback/_task_tree.py b/sklearn/callback/_task_tree.py index 72e1763ebeb4e..f3773a6fc2eae 100644 --- a/sklearn/callback/_task_tree.py +++ b/sklearn/callback/_task_tree.py @@ -7,89 +7,93 @@ class TaskNode: Parameters ---------- - estimator_name : str - The name of the estimator this task node belongs to. - - name : str, default=None + task_name : str The name of the task this node represents. - max_subtasks : int or None, default=0 - The maximum number of its children. 0 means it's a leaf. - None means the number of children is not known in advance. - - idx : int, default=None + task_id : int The index of this node among its siblings. None means this is the root. + An identifier for this task that distinguishes it from its siblings. - parent : TaskNode instance, default=None - The parent node. None means this is the root. + max_tasks : int or None + The maximum number of its siblings. 0 means it's a leaf. + None means the maximum number of siblings is not known in advance. - Note that the root task of an estimator can become an intermediate node - of a meta-estimator. + estimator_name : str + The name of the estimator this task node belongs to. Attributes ---------- - children : dict - A mapping from the index of a child to the child node `{idx: TaskNode}`. + parent : TaskNode instance or None + The parent node. None means this is the root. + + Note that it's dynamic since the root task of an estimator can become an + intermediate node of a meta-estimator. + + children_map : dict + A mapping from the task_id of a child to the child node `{task_id: TaskNode}`. For a leaf, it's an empty dictionary. + + max_subtasks : int or None + The maximum number of subtasks of this node. 0 means it's a leaf. None + means the maximum number of subtasks is not known in advance. + + prev_estimator_name : str or None + The estimator name of the node this node was merged with. None if was not + merged with another node. + + prev_task_name : str + The task name of the node this node was merged with. None if was not + merged with another node. """ - def __init__( - self, - estimator_name, - name="fit", - max_subtasks=0, - idx=None, - parent=None, - ): - # estimator_name and name are tuples because an estimator can be - # a sub-estimator of a meta-estimator. In that case, the root of the task - # tree of the sub-estimator and a leaf of the task tree of the - # meta-estimator correspond to the same computation step. Therefore, both - # nodes are merged into a single node, retaining the information of both. - self.estimator_name = (estimator_name,) - self.name = (name,) - - self.max_subtasks = max_subtasks - self.idx = idx - self.parent = parent - - # Children stored in a dict indexed by their idx for easy access because the - # order in which self.children is populated is not guaranteed to follow the - # order of the idx du to parallelism. - self.children = {} - - def _add_child(self, *, name, max_subtasks, idx): - if idx in self.children: + def __init__(self, *, task_name, task_id, max_tasks, estimator_name): + self.task_name = task_name + self.task_id = task_id + self.max_tasks = max_tasks + self.estimator_name = estimator_name + + self.parent = None + self.children_map = {} + self.max_subtasks = 0 + + # When an estimator is a sub-estimator of a meta-estimator, the root task of + # the estimator is merged with the corresponding leaf task of the + # meta-estimator because both correspond to the same computation step. + # The root task of the estimator takes the place of the leaf task of the + # meta-estimator in the task tree but we keep the information about the + # leaf task it was merged with to fully describe the merged node. + self.prev_estimator_name = None + self.prev_task_name = None + + def _add_child(self, task_node): + if task_node.task_id in self.children_map: raise ValueError( - f"Child of task node {self.name} of estimator {self.estimator_name} " - f"with index {idx} already exists." + f"Task node {self.task_name} of estimator {self.estimator_name} " + f"already has a child with task_id={task_node.task_id}." ) - if len(self.children) == self.max_subtasks: + if len(self.children_map) == task_node.max_tasks: raise ValueError( - f"Cannot add child to task node {self.name} of estimator " + f"Cannot add child to task node {self.task_name} of estimator " f"{self.estimator_name} because it already has its maximum " - f"number of children ({self.max_subtasks})." + f"number of children ({task_node.max_tasks})." ) - child = TaskNode( - estimator_name=self.estimator_name[-1], - name=name, - max_subtasks=max_subtasks, - idx=idx, - parent=self, - ) - self.children[idx] = child - - return child + self.children_map[task_node.task_id] = task_node + self.max_subtasks = task_node.max_tasks + task_node.parent = self def _merge_with(self, task_node): + # Set the parent of the sub-estimator's root task node to the parent + # of the meta-estimator's leaf task node self.parent = task_node.parent - self.idx = task_node.idx - task_node.parent.children[self.idx] = self + self.task_id = task_node.task_id + self.max_tasks = task_node.max_tasks + task_node.parent.children_map[self.task_id] = self - self.name = task_node.name + self.name - self.estimator_name = task_node.estimator_name + self.estimator_name + # Keep information about the node it was merged with + self.prev_task_name = task_node.task_name + self.prev_estimator_name = task_node.estimator_name @property def depth(self): @@ -104,5 +108,5 @@ def path(self): def __iter__(self): """Pre-order depth-first traversal""" yield self - for node in self.children.values(): - yield from node + for task_node in self.children_map.values(): + yield from task_node diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 8fae15ec1e63e..ea313866c7f79 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -39,21 +39,20 @@ def on_fit_iter_end(self, estimator, node, **kwargs): class Estimator(BaseEstimator): _parameter_constraints: dict = {} - def __init__(self, max_iter=20): + def __init__(self, max_iter=20, computation_intensity=0.001): self.max_iter = max_iter + self.computation_intensity = computation_intensity @_fit_context(prefer_skip_nested_validation=False) - def fit(self, X, y): + def fit(self, X=None, y=None): callback_ctx = self._init_callback_context().eval_on_fit_begin( - estimator=self, - max_subtasks=self.max_iter, - data={"X_train": X, "y_train": y}, + estimator=self, data={"X_train": X, "y_train": y} ) for i in range(self.max_iter): - subcontext = callback_ctx.subcontext(idx=i) + subcontext = callback_ctx.subcontext(task_id=i, max_tasks=self.max_iter) - time.sleep(0.001) # Computation intensive task + time.sleep(self.computation_intensity) # Computation intensive task if subcontext.eval_on_fit_iter_end( estimator=self, @@ -69,18 +68,20 @@ def fit(self, X, y): class WhileEstimator(BaseEstimator): _parameter_constraints: dict = {} + def __init__(self, computation_intensity=0.001): + self.computation_intensity = computation_intensity + @_fit_context(prefer_skip_nested_validation=False) - def fit(self, X, y): + def fit(self, X=None, y=None): callback_ctx = self._init_callback_context().eval_on_fit_begin( - estimator=self, max_subtasks=None, data={"X_train": X, "y_train": y} + estimator=self, data={"X_train": X, "y_train": y} ) i = 0 - while True: - subcontext = callback_ctx.subcontext(idx=i) + subcontext = callback_ctx.subcontext(task_id=i, max_tasks=None) - time.sleep(0.001) # Computation intensive task + time.sleep(self.computation_intensity) # Computation intensive task if subcontext.eval_on_fit_iter_end( estimator=self, @@ -109,27 +110,35 @@ def __init__( self.prefer = prefer @_fit_context(prefer_skip_nested_validation=False) - def fit(self, X, y): + def fit(self, X=None, y=None): callback_ctx = self._init_callback_context().eval_on_fit_begin( - estimator=self, max_subtasks=self.n_outer, data={"X_train": X, "y_train": y} + estimator=self, data={"X_train": X, "y_train": y} ) Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( - delayed(_func)(i, self, self.estimator, X, y, callback_ctx=callback_ctx) + delayed(_func)( + self, + self.estimator, + X, + y, + callback_ctx=callback_ctx.subcontext( + task_name="outer", task_id=i, max_tasks=self.n_outer + ), + ) for i in range(self.n_outer) ) return self -def _func(outer_idx, meta_estimator, inner_estimator, X, y, *, callback_ctx): - outer_ctx = callback_ctx.subcontext( - task="outer", max_subtasks=meta_estimator.n_inner, idx=outer_idx - ) - +def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx): for i in range(meta_estimator.n_inner): est = clone(inner_estimator) - inner_ctx = outer_ctx.subcontext(task="inner", idx=i, sub_estimator=est) + + inner_ctx = callback_ctx.subcontext( + task_name="inner", task_id=i, max_tasks=meta_estimator.n_inner + ).propagate_callbacks(sub_estimator=est) + est.fit(X, y) inner_ctx.eval_on_fit_iter_end( @@ -137,7 +146,7 @@ def _func(outer_idx, meta_estimator, inner_estimator, X, y, *, callback_ctx): data={"X_train": X, "y_train": y}, ) - outer_ctx.eval_on_fit_iter_end( + callback_ctx.eval_on_fit_iter_end( estimator=meta_estimator, data={"X_train": X, "y_train": y}, ) diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py index 59ef911d94d71..6fba1f5b4ae9b 100644 --- a/sklearn/callback/tests/test_progressbar.py +++ b/sklearn/callback/tests/test_progressbar.py @@ -22,7 +22,7 @@ def test_progressbar(n_jobs, prefer, InnerEstimator, capsys): est = InnerEstimator() meta_est = MetaEstimator(est, n_jobs=n_jobs, prefer=prefer) meta_est._set_callbacks(ProgressBar()) - meta_est.fit(None, None) + meta_est.fit() captured = capsys.readouterr() @@ -30,7 +30,10 @@ def test_progressbar(n_jobs, prefer, InnerEstimator, capsys): for i in range(4): assert re.search(rf"MetaEstimator - outer #{i}", captured.out) for i in range(3): - assert re.search(rf"MetaEstimator - inner #{i} | Estimator - fit", captured.out) + assert re.search( + rf"MetaEstimator - inner #{i} | {est.__class__.__name__} - fit", + captured.out, + ) # Check that all bars are 100% complete assert re.search(r"100%", captured.out) diff --git a/sklearn/callback/tests/test_task_tree.py b/sklearn/callback/tests/test_task_tree.py index 5cb9a1df96bc0..9ea9c3f7c012d 100644 --- a/sklearn/callback/tests/test_task_tree.py +++ b/sklearn/callback/tests/test_task_tree.py @@ -8,14 +8,26 @@ def _make_task_tree(n_children, n_grandchildren): root = TaskNode( - estimator_name="estimator", name="root task", max_subtasks=n_children + task_name="root task", task_id=0, max_tasks=1, estimator_name="estimator" ) for i in range(n_children): - child = root._add_child(name="child task", max_subtasks=n_grandchildren, idx=i) + child = TaskNode( + task_name="child task", + task_id=i, + max_tasks=n_children, + estimator_name="estimator", + ) + root._add_child(child) for j in range(n_grandchildren): - child._add_child(name="grandchild task", max_subtasks=0, idx=j) + grandchild = TaskNode( + task_name="grandchild task", + task_id=j, + max_tasks=n_grandchildren, + estimator_name="estimator", + ) + child._add_child(grandchild) return root @@ -23,38 +35,39 @@ def _make_task_tree(n_children, n_grandchildren): def test_task_tree(): root = _make_task_tree(n_children=3, n_grandchildren=5) - assert root.max_subtasks == 3 - assert root.idx is None assert root.parent is None + assert root.depth == 0 + assert len(root.children_map) == 3 + + for child in root.children_map.values(): + assert child.parent is root + assert child.depth == 1 + assert len(child.children_map) == 5 + assert root.max_subtasks == child.max_tasks - assert len(root.children) == 3 - assert all(len(child.children) == 5 for child in root.children.values()) + for grandchild in child.children_map.values(): + assert grandchild.parent is child + assert grandchild.depth == 2 + assert len(grandchild.children_map) == 0 + assert child.max_subtasks == grandchild.max_tasks - # 1 root, 3 children, 3 * 5 grandchildren + # 1 root + 1 * 3 children + 1 * 3 * 5 grandchildren expected_n_nodes = np.sum(np.cumprod([1, 3, 5])) actual_n_nodes = sum(1 for _ in root) assert actual_n_nodes == expected_n_nodes + # None of the nodes should have been merged with another node + assert all(node.prev_estimator_name is None for node in root) + assert all(node.prev_task_name is None for node in root) + def test_path(): root = _make_task_tree(n_children=3, n_grandchildren=5) assert root.path == [root] - # pick a node - node = root.children[1].children[2] - expected_path = [root, root.children[1], node] - assert node.path == expected_path - + # pick an arbitrary node + node = root.children_map[1].children_map[2] -def test_depth(): - root = _make_task_tree(n_children=3, n_grandchildren=5) - - assert root.depth == 0 - - assert all(child.depth == 1 for child in root.children.values()) - assert all( - grandchild.depth == 2 - for child in root.children.values() - for grandchild in child.children.values() - ) + expected_path = [root, root.children_map[1], node] + assert node.path == expected_path From 4f2f36990943e3eead01349832ec3b888b0e709e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Tue, 4 Jun 2024 15:44:57 +0200 Subject: [PATCH 50/55] merge --- build_tools/azure/debian_atlas_32bit_lock.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build_tools/azure/debian_atlas_32bit_lock.txt b/build_tools/azure/debian_atlas_32bit_lock.txt index 92b4b864fb9f5..ba957463e60bc 100644 --- a/build_tools/azure/debian_atlas_32bit_lock.txt +++ b/build_tools/azure/debian_atlas_32bit_lock.txt @@ -6,7 +6,7 @@ # attrs==23.2.0 # via pytest -coverage==7.5.2 +coverage==7.5.3 # via pytest-cov cython==3.0.10 # via -r build_tools/azure/debian_atlas_32bit_requirements.txt @@ -14,7 +14,7 @@ iniconfig==2.0.0 # via pytest joblib==1.2.0 # via -r build_tools/azure/debian_atlas_32bit_requirements.txt -meson==1.4.0 +meson==1.4.1 # via meson-python meson-python==0.16.0 # via -r build_tools/azure/debian_atlas_32bit_requirements.txt From 61cf977aaf2202a09859f3122b386afd7b41f3ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 7 Jun 2024 17:11:21 +0200 Subject: [PATCH 51/55] rework callback context internals + migrate tests --- sklearn/base.py | 55 +----- sklearn/callback/__init__.py | 7 +- sklearn/callback/_base.py | 30 ++-- sklearn/callback/_callback_context.py | 163 ++++++++++++++---- sklearn/callback/_mixin.py | 51 ++++++ sklearn/callback/_progressbar.py | 17 +- sklearn/callback/tests/_utils.py | 23 ++- .../callback/tests/test_callback_context.py | 63 +++---- sklearn/callback/tests/test_progressbar.py | 28 ++- 9 files changed, 263 insertions(+), 174 deletions(-) create mode 100644 sklearn/callback/_mixin.py diff --git a/sklearn/base.py b/sklearn/base.py index 4b407b977fed5..a325c73dfa52d 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -15,7 +15,6 @@ from . import __version__ from ._config import config_context, get_config -from .callback import BaseCallback, CallbackContext from .exceptions import InconsistentVersionWarning from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr from .utils._metadata_requests import _MetadataRequester, _routing_enabled @@ -674,57 +673,6 @@ class attribute, which is a dictionary `param_name: list of constraints`. See caller_name=self.__class__.__name__, ) - def _set_callbacks(self, callbacks): - """Set callbacks for the estimator. - - Parameters - ---------- - callbacks : callback or list of callbacks - the callbacks to set. - - Returns - ------- - self : estimator instance - The estimator instance itself. - """ - if not isinstance(callbacks, list): - callbacks = [callbacks] - - if not all(isinstance(callback, BaseCallback) for callback in callbacks): - raise TypeError("callbacks must be subclasses of BaseCallback.") - - self._skl_callbacks = callbacks - - return self - - def _init_callback_context(self, task_name="fit"): - """Initialize the callback context for the estimator. - - Parameters - ---------- - task_name : str, default='fit' - The name of the root task. - - Returns - ------- - callback_fit_ctx : CallbackContext - The callback context for the estimator. - """ - # We don't initialize the callback context during _set_callbacks but in fit - # because in the future we might want to have callbacks in predict/transform - # which would require their own context. - - self._callback_fit_ctx = CallbackContext( - callbacks=getattr(self, "_skl_callbacks", []), - estimator_name=self.__class__.__name__, - task_name=task_name, - parent_estimator_task_node=getattr( - self, "_parent_estimator_task_node", None - ), - ) - - return self._callback_fit_ctx - @property def _repr_html_(self): """HTML representation of estimator. @@ -1570,7 +1518,8 @@ def wrapper(estimator, *args, **kwargs): try: return fit_method(estimator, *args, **kwargs) finally: - estimator._callback_fit_ctx.eval_on_fit_end(estimator=estimator) + if hasattr(estimator, "_callback_fit_ctx"): + estimator._callback_fit_ctx.eval_on_fit_end(estimator=estimator) return wrapper diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py index 0e5b7ddc0999b..a493454bdbb22 100644 --- a/sklearn/callback/__init__.py +++ b/sklearn/callback/__init__.py @@ -6,14 +6,17 @@ # License: BSD 3 clause # Authors: the scikit-learn developers -from ._base import BaseCallback +from ._base import AutoPropagatedProtocol, CallbackProtocol from ._callback_context import CallbackContext +from ._mixin import CallbackSupportMixin from ._progressbar import ProgressBar from ._task_tree import TaskNode __all__ = [ - "BaseCallback", + "AutoPropagatedProtocol", + "CallbackProtocol", "CallbackContext", + "CallbackSupportMixin", "TaskNode", "ProgressBar", ] diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 5ded089d6b54e..1a9180a76042f 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -1,13 +1,13 @@ # License: BSD 3 clause # Authors: the scikit-learn developers -from abc import ABC, abstractmethod +from typing import Protocol, runtime_checkable -class BaseCallback(ABC): - """Abstract class for the callbacks""" +@runtime_checkable +class CallbackProtocol(Protocol): + """Protocol for the callbacks""" - @abstractmethod def on_fit_begin(self, estimator, *, data): """Method called at the beginning of the fit method of the estimator. @@ -22,7 +22,6 @@ def on_fit_begin(self, estimator, *, data): and "sample_weight_val". """ - @abstractmethod def on_fit_iter_end(self, estimator, task_node, **kwargs): """Method called at the end of each task of the estimator. @@ -69,7 +68,6 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) Whether or not to stop the current level of iterations at this node. """ - @abstractmethod def on_fit_end(self, estimator, task_node): """Method called at the end of the fit method of the estimator. @@ -84,14 +82,16 @@ def on_fit_end(self, estimator, task_node): if the estimator is a sub-estimator of a meta-estimator. """ + +@runtime_checkable +class AutoPropagatedProtocol(Protocol): + """Protocol for the auto-propagated callbacks""" + @property - def auto_propagate(self): - """Whether or not this callback should be propagated to sub-estimators. - - An auto-propagated callback (from a meta-estimator to its sub-estimators) must - be set on the meta-estimator. Its `on_fit_begin` and `on_fit_end` methods will - only be called at the beginning and end of the fit method of the meta-estimator, - while its `on_fit_iter_end` method will be called at each computation node of - the meta-estimator and its sub-estimators. + def max_estimator_depth(self): + """The maximum number of nested estimators at which the callback should be + propagated. + + If set to None, the callback is propagated to sub-estimators at all nesting + levels. """ - return False diff --git a/sklearn/callback/_callback_context.py b/sklearn/callback/_callback_context.py index a5d1f1de1eb9b..a5ebacb08d456 100644 --- a/sklearn/callback/_callback_context.py +++ b/sklearn/callback/_callback_context.py @@ -1,98 +1,187 @@ # License: BSD 3 clause # Authors: the scikit-learn developers +from . import AutoPropagatedProtocol from ._task_tree import TaskNode class CallbackContext: - def __init__( - self, - callbacks, - estimator_name="", - task_name="", - task_id=0, - max_tasks=1, - parent_task_node=None, - parent_estimator_task_node=None, - ): - self.callbacks = callbacks - self.estimator_name = estimator_name - - self.task_node = TaskNode( + """Task level context for the callbacks. + + This class is responsible for managing the callbacks and task tree of an estimator. + """ + + @classmethod + def _from_estimator(cls, estimator, *, task_name, task_id, max_tasks=1): + """Private constructor to create a root context. + + Parameters + ---------- + estimator : estimator instance + The estimator this context is responsible for. + + task_name : str + The name of the task this context is responsible for. + + task_id : int + The id of the task this context is responsible for. + + max_tasks : int, default=1 + The maximum number of tasks that can be siblings of the task this context is + responsible for. + """ + new_ctx = cls.__new__(cls) + + # We don't store the estimator in the context to avoid circular references + # because the estimator already holds a reference to the context. + new_ctx._callbacks = getattr(estimator, "_skl_callbacks", []) + new_ctx._estimator_name = estimator.__class__.__name__ + + new_ctx._task_node = TaskNode( task_name=task_name, task_id=task_id, max_tasks=max_tasks, - estimator_name=self.estimator_name, + estimator_name=new_ctx._estimator_name, ) - if parent_task_node is not None: - # This task is a subtask of another task of a same estimator - parent_task_node._add_child(self.task_node) - elif parent_estimator_task_node is not None: + if hasattr(estimator, "_parent_callback_ctx"): # This task is the root task of the estimator which itself corresponds to # a leaf task of a meta-estimator. Both tasks actually represent the same # task so we merge both task nodes into a single task node, attaching the # task tree of the sub-estimator to the task tree of the meta-estimator on # the way. - self.task_node._merge_with(parent_estimator_task_node) + parent_ctx = estimator._parent_callback_ctx + new_ctx._task_node._merge_with(parent_ctx._task_node) + new_ctx._estimator_depth = parent_ctx._estimator_depth + 1 + else: + new_ctx._estimator_depth = 0 + + return new_ctx + + @classmethod + def _from_parent(cls, parent_context, *, task_name, task_id, max_tasks=1): + """Private constructor to create a sub-context. + + Parameters + ---------- + parent_context : CallbackContext instance + The parent context of the new context. + + task_name : str + The name of the task this context is responsible for. + + task_id : int + The id of the task this context is responsible for. + + max_tasks : int, default=1 + The maximum number of tasks that can be siblings of the task this context is + responsible for. + """ + new_ctx = cls.__new__(cls) + + new_ctx._callbacks = parent_context._callbacks + new_ctx._estimator_name = parent_context._estimator_name + new_ctx._estimator_depth = parent_context._estimator_depth + + new_ctx._task_node = TaskNode( + task_name=task_name, + task_id=task_id, + max_tasks=max_tasks, + estimator_name=new_ctx._estimator_name, + ) + + # This task is a subtask of another task of a same estimator + parent_context._task_node._add_child(new_ctx._task_node) + + return new_ctx def subcontext(self, task_name="", task_id=0, max_tasks=1): - return CallbackContext( - callbacks=self.callbacks, - estimator_name=self.estimator_name, + """Create a context for a subtask of the current task. + + Parameters + ---------- + task_name : str, default="" + The name of the subtask. + + task_id : int, default=0 + An identifier of the subtask. Usually a number between 0 and + `max_tasks - 1`, but can be any identifier. + + max_tasks : int, default=1 + The maximum number of tasks that can be siblings of the subtask. + """ + return CallbackContext._from_parent( + parent_context=self, task_name=task_name, task_id=task_id, max_tasks=max_tasks, - parent_task_node=self.task_node, ) def eval_on_fit_begin(self, estimator, *, data): - for callback in self.callbacks: + """Evaluate the on_fit_begin method of the callbacks.""" + for callback in self._callbacks: # Only call the on_fit_begin method of callbacks that are not # propagated from a meta-estimator. - if not (callback.auto_propagate and self.task_node.parent is not None): + if not ( + isinstance(callback, AutoPropagatedProtocol) + and self._task_node.parent is not None + ): callback.on_fit_begin(estimator, data=data) return self def eval_on_fit_iter_end(self, estimator, **kwargs): + """Evaluate the on_fit_iter_end method of the callbacks.""" return any( - callback.on_fit_iter_end(estimator, self.task_node, **kwargs) - for callback in self.callbacks + callback.on_fit_iter_end(estimator, self._task_node, **kwargs) + for callback in self._callbacks ) def eval_on_fit_end(self, estimator): - for callback in self.callbacks: + """Evaluate the on_fit_end method of the callbacks.""" + for callback in self._callbacks: # Only call the on_fit_end method of callbacks that are not # propagated from a meta-estimator. - if not (callback.auto_propagate and self.task_node.parent is not None): - callback.on_fit_end(estimator, task_node=self.task_node) + if not ( + isinstance(callback, AutoPropagatedProtocol) + and self._task_node.parent is not None + ): + callback.on_fit_end(estimator, task_node=self._task_node) def propagate_callbacks(self, sub_estimator): + """Propagate the callbacks to a sub-estimator.""" bad_callbacks = [ callback.__class__.__name__ for callback in getattr(sub_estimator, "_skl_callbacks", []) - if callback.auto_propagate + if isinstance(callback, AutoPropagatedProtocol) ] if bad_callbacks: raise TypeError( f"The sub-estimator ({sub_estimator.__class__.__name__}) of a" - f" meta-estimator ({self.task_node.estimator_name}) can't have" + f" meta-estimator ({self._task_node.estimator_name}) can't have" f" auto-propagated callbacks ({bad_callbacks})." " Register them directly on the meta-estimator." ) callbacks_to_propagate = [ - callback for callback in self.callbacks if callback.auto_propagate + callback + for callback in self._callbacks + if isinstance(callback, AutoPropagatedProtocol) + and ( + callback.max_estimator_depth is None + or self._estimator_depth < callback.max_estimator_depth + ) ] if not callbacks_to_propagate: - return + return self - sub_estimator._parent_estimator_task_node = self.task_node + # We store the parent context in the sub-estimator to be able to merge the + # task trees of the sub-estimator and the meta-estimator. + sub_estimator._parent_callback_ctx = self - sub_estimator._set_callbacks( + sub_estimator.set_callbacks( getattr(sub_estimator, "_skl_callbacks", []) + callbacks_to_propagate ) diff --git a/sklearn/callback/_mixin.py b/sklearn/callback/_mixin.py new file mode 100644 index 0000000000000..4311e1331dda8 --- /dev/null +++ b/sklearn/callback/_mixin.py @@ -0,0 +1,51 @@ +from ._base import CallbackProtocol +from ._callback_context import CallbackContext + + +class CallbackSupportMixin: + """Mixin class to add callback support to an estimator.""" + + def set_callbacks(self, callbacks): + """Set callbacks for the estimator. + + Parameters + ---------- + callbacks : callback or list of callbacks + the callbacks to set. + + Returns + ------- + self : estimator instance + The estimator instance itself. + """ + if not isinstance(callbacks, list): + callbacks = [callbacks] + + if not all(isinstance(callback, CallbackProtocol) for callback in callbacks): + raise TypeError("callbacks must follow the CallbackProtocol protocol.") + + self._skl_callbacks = callbacks + + return self + + def init_callback_context(self, task_name="fit"): + """Initialize the callback context for the estimator. + + Parameters + ---------- + task_name : str, default='fit' + The name of the root task. + + Returns + ------- + callback_fit_ctx : CallbackContext + The callback context for the estimator. + """ + # We don't initialize the callback context during _set_callbacks but in fit + # because in the future we might want to have callbacks in predict/transform + # which would require their own context. + self._callback_fit_ctx = CallbackContext._from_estimator( + estimator=self, task_name=task_name, task_id=0, max_tasks=1 + ) + + return self._callback_fit_ctx diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index 58616055c7100..4ff1328e49f81 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -5,17 +5,24 @@ from threading import Thread from ..utils._optional_dependencies import check_rich_support -from . import BaseCallback -class ProgressBar(BaseCallback): - """Callback that displays progress bars for each iterative steps of an estimator.""" +class ProgressBar: + """Callback that displays progress bars for each iterative steps of an estimator. - auto_propagate = True + Parameters + ---------- + max_estimator_depth : int, default=1 + The maximum number of nested levels of estimators to display progress bars for. + By default, only the progress bars of the outermost estimator are displayed. + If set to None, all levels are displayed. + """ - def __init__(self): + def __init__(self, max_estimator_depth=1): check_rich_support("Progressbar") + self.max_estimator_depth = max_estimator_depth + def on_fit_begin(self, estimator, *, data): self._queue = Manager().Queue() self.progress_monitor = _RichProgressMonitor(queue=self._queue) diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index ea313866c7f79..3917030e263b1 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -4,11 +4,11 @@ import time from sklearn.base import BaseEstimator, _fit_context, clone -from sklearn.callback import BaseCallback +from sklearn.callback import CallbackSupportMixin from sklearn.utils.parallel import Parallel, delayed -class TestingCallback(BaseCallback): +class TestingCallback: def on_fit_begin(self, estimator, *, data): pass @@ -20,23 +20,20 @@ def on_fit_iter_end(self, estimator, node, **kwargs): class TestingAutoPropagatedCallback(TestingCallback): - auto_propagate = True + max_estimator_depth = None class NotValidCallback: - """Unvalid callback since it does not inherit from `BaseCallback`.""" + """Unvalid callback since it's missing a method from the protocol.'""" def on_fit_begin(self, estimator, *, data): pass # pragma: no cover - def on_fit_end(self): - pass # pragma: no cover - def on_fit_iter_end(self, estimator, node, **kwargs): pass # pragma: no cover -class Estimator(BaseEstimator): +class Estimator(CallbackSupportMixin, BaseEstimator): _parameter_constraints: dict = {} def __init__(self, max_iter=20, computation_intensity=0.001): @@ -45,7 +42,7 @@ def __init__(self, max_iter=20, computation_intensity=0.001): @_fit_context(prefer_skip_nested_validation=False) def fit(self, X=None, y=None): - callback_ctx = self._init_callback_context().eval_on_fit_begin( + callback_ctx = self.init_callback_context().eval_on_fit_begin( estimator=self, data={"X_train": X, "y_train": y} ) @@ -65,7 +62,7 @@ def fit(self, X=None, y=None): return self -class WhileEstimator(BaseEstimator): +class WhileEstimator(CallbackSupportMixin, BaseEstimator): _parameter_constraints: dict = {} def __init__(self, computation_intensity=0.001): @@ -73,7 +70,7 @@ def __init__(self, computation_intensity=0.001): @_fit_context(prefer_skip_nested_validation=False) def fit(self, X=None, y=None): - callback_ctx = self._init_callback_context().eval_on_fit_begin( + callback_ctx = self.init_callback_context().eval_on_fit_begin( estimator=self, data={"X_train": X, "y_train": y} ) @@ -97,7 +94,7 @@ def fit(self, X=None, y=None): return self -class MetaEstimator(BaseEstimator): +class MetaEstimator(CallbackSupportMixin, BaseEstimator): _parameter_constraints: dict = {} def __init__( @@ -111,7 +108,7 @@ def __init__( @_fit_context(prefer_skip_nested_validation=False) def fit(self, X=None, y=None): - callback_ctx = self._init_callback_context().eval_on_fit_begin( + callback_ctx = self.init_callback_context().eval_on_fit_begin( estimator=self, data={"X_train": X, "y_train": y} ) diff --git a/sklearn/callback/tests/test_callback_context.py b/sklearn/callback/tests/test_callback_context.py index 3dd4d2f99d919..b1e1a6baafeeb 100644 --- a/sklearn/callback/tests/test_callback_context.py +++ b/sklearn/callback/tests/test_callback_context.py @@ -21,10 +21,10 @@ ], ) def test_set_callbacks(callbacks): - """Sanity check for the `_set_callbacks` method.""" + """Sanity check for the `set_callbacks` method.""" estimator = Estimator() - set_callbacks_return = estimator._set_callbacks(callbacks) + set_callbacks_return = estimator.set_callbacks(callbacks) assert hasattr(estimator, "_skl_callbacks") expected_callbacks = [callbacks] if not isinstance(callbacks, list) else callbacks @@ -35,34 +35,37 @@ def test_set_callbacks(callbacks): @pytest.mark.parametrize("callbacks", [None, NotValidCallback()]) def test_set_callbacks_error(callbacks): - """Check the error message when not passing a valid callback to `_set_callbacks`.""" + """Check the error message when not passing a valid callback to `set_callbacks`.""" estimator = Estimator() - with pytest.raises(TypeError, match="callbacks must be subclasses of BaseCallback"): - estimator._set_callbacks(callbacks) + with pytest.raises( + TypeError, match="callbacks must follow the CallbackProtocol protocol." + ): + estimator.set_callbacks(callbacks) def test_init_callback_context(): - """Sanity check for the `_init_callback_context` method.""" + """Sanity check for the `init_callback_context` method.""" estimator = Estimator() - callback_ctx = estimator._init_callback_context() + callback_ctx = estimator.init_callback_context() - assert hasattr(estimator, "_callback_fit_context") - assert hasattr(callback_ctx, "callbacks") + assert hasattr(estimator, "_callback_fit_ctx") + assert hasattr(callback_ctx, "_callbacks") def test_propagate_callbacks(): - """Sanity check for the `_propagate_callbacks` method.""" + """Sanity check for the `propagate_callbacks` method.""" not_propagated_callback = TestingCallback() propagated_callback = TestingAutoPropagatedCallback() estimator = Estimator() metaestimator = MetaEstimator(estimator) - metaestimator._set_callbacks([not_propagated_callback, propagated_callback]) + metaestimator.set_callbacks([not_propagated_callback, propagated_callback]) - metaestimator._propagate_callbacks(estimator, parent_node=None) + callback_ctx = metaestimator.init_callback_context() + callback_ctx.propagate_callbacks(estimator) - assert hasattr(estimator, "_parent_node") + assert hasattr(estimator, "_parent_callback_ctx") assert not_propagated_callback not in estimator._skl_callbacks assert propagated_callback in estimator._skl_callbacks @@ -71,7 +74,11 @@ def test_propagate_callback_no_callback(): """Check that no callback is propagated if there's no callback.""" estimator = Estimator() metaestimator = MetaEstimator(estimator) - metaestimator._propagate_callbacks(estimator, parent_node=None) + + callback_ctx = metaestimator.init_callback_context() + assert len(callback_ctx._callbacks) == 0 + + callback_ctx.propagate_callbacks(estimator) assert not hasattr(metaestimator, "_skl_callbacks") assert not hasattr(estimator, "_skl_callbacks") @@ -82,35 +89,11 @@ def test_auto_propagated_callbacks(): sub-estimator of a meta-estimator. """ estimator = Estimator() - estimator._set_callbacks(TestingAutoPropagatedCallback()) - + estimator.set_callbacks(TestingAutoPropagatedCallback()) meta_estimator = MetaEstimator(estimator=estimator) match = ( - r"sub-estimators .*of a meta-estimator .*can't have auto-propagated callbacks" + r"sub-estimator .*of a meta-estimator .*can't have auto-propagated callbacks" ) with pytest.raises(TypeError, match=match): meta_estimator.fit(X=None, y=None) - - -def test_eval_callbacks_on_fit_begin(): - """Check that `_eval_callbacks_on_fit_begin` creates the computation tree.""" - estimator = Estimator()._set_callbacks(TestingCallback()) - assert not hasattr(estimator, "_computation_tree") - - tree_structure = [ - {"stage": "fit", "n_children": 10}, - {"stage": "iter", "n_children": None}, - ] - estimator._eval_callbacks_on_fit_begin(tree_structure=tree_structure, data={}) - assert hasattr(estimator, "_computation_tree") - - -def test_no_callback_early_stop(): - """Check that `eval_callbacks_on_fit_iter_end` doesn't trigger early stopping - when there's no callback. - """ - estimator = Estimator() - estimator.fit(X=None, y=None) - - assert estimator.n_iter_ == estimator.max_iter diff --git a/sklearn/callback/tests/test_progressbar.py b/sklearn/callback/tests/test_progressbar.py index 6fba1f5b4ae9b..9acd47e30d5ab 100644 --- a/sklearn/callback/tests/test_progressbar.py +++ b/sklearn/callback/tests/test_progressbar.py @@ -15,25 +15,35 @@ @pytest.mark.parametrize("n_jobs", [1, 2]) @pytest.mark.parametrize("prefer", ["threads", "processes"]) @pytest.mark.parametrize("InnerEstimator", [Estimator, WhileEstimator]) -def test_progressbar(n_jobs, prefer, InnerEstimator, capsys): +@pytest.mark.parametrize("max_estimator_depth", [1, 2, None]) +def test_progressbar(n_jobs, prefer, InnerEstimator, max_estimator_depth, capsys): """Check the output of the progress bars and their completion.""" pytest.importorskip("rich") + n_inner = 2 + n_outer = 3 + est = InnerEstimator() - meta_est = MetaEstimator(est, n_jobs=n_jobs, prefer=prefer) - meta_est._set_callbacks(ProgressBar()) + meta_est = MetaEstimator( + est, n_outer=n_outer, n_inner=n_inner, n_jobs=n_jobs, prefer=prefer + ) + meta_est.set_callbacks(ProgressBar(max_estimator_depth=max_estimator_depth)) meta_est.fit() captured = capsys.readouterr() assert re.search(r"MetaEstimator - fit", captured.out) - for i in range(4): + for i in range(n_outer): assert re.search(rf"MetaEstimator - outer #{i}", captured.out) - for i in range(3): - assert re.search( - rf"MetaEstimator - inner #{i} | {est.__class__.__name__} - fit", - captured.out, - ) + + # Progress bars of inner estimators are displayed only if max_estimator_depth > 1 + # (or None, which means all levels are displayed) + if max_estimator_depth is None or max_estimator_depth > 1: + for i in range(n_inner): + assert re.search( + rf"MetaEstimator - inner #{i} | {est.__class__.__name__} - fit", + captured.out, + ) # Check that all bars are 100% complete assert re.search(r"100%", captured.out) From 81824ce0bf94812229e4d7e26368915e6db651fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 7 Jun 2024 17:20:04 +0200 Subject: [PATCH 52/55] lint --- build_tools/update_environments_and_lock_files.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/build_tools/update_environments_and_lock_files.py b/build_tools/update_environments_and_lock_files.py index d12f20d737e56..5147b61a08455 100644 --- a/build_tools/update_environments_and_lock_files.py +++ b/build_tools/update_environments_and_lock_files.py @@ -256,7 +256,8 @@ def remove_from(alist, to_remove): "channel": "conda-forge", "conda_dependencies": remove_from( common_dependencies, ["pandas", "rich", "pyamg"] - ) + [ + ) + + [ "wheel", "pip", ], From 3fa6021075afb6e76772896e2319cf3002a1cba4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 7 Jun 2024 18:23:57 +0200 Subject: [PATCH 53/55] iter --- sklearn/callback/_task_tree.py | 12 ++--- sklearn/callback/tests/test_task_tree.py | 68 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/sklearn/callback/_task_tree.py b/sklearn/callback/_task_tree.py index f3773a6fc2eae..3f97b2ac282c5 100644 --- a/sklearn/callback/_task_tree.py +++ b/sklearn/callback/_task_tree.py @@ -11,12 +11,12 @@ class TaskNode: The name of the task this node represents. task_id : int - The index of this node among its siblings. None means this is the root. - An identifier for this task that distinguishes it from its siblings. + An identifier for this task that distinguishes it from its siblings. Usually + the index of this node among its siblings. max_tasks : int or None - The maximum number of its siblings. 0 means it's a leaf. - None means the maximum number of siblings is not known in advance. + The maximum number of its siblings. None means the maximum number of siblings + is not known in advance. estimator_name : str The name of the estimator this task node belongs to. @@ -38,11 +38,11 @@ class TaskNode: means the maximum number of subtasks is not known in advance. prev_estimator_name : str or None - The estimator name of the node this node was merged with. None if was not + The estimator name of the node this node was merged with. None if it was not merged with another node. prev_task_name : str - The task name of the node this node was merged with. None if was not + The task name of the node this node was merged with. None if it was not merged with another node. """ diff --git a/sklearn/callback/tests/test_task_tree.py b/sklearn/callback/tests/test_task_tree.py index 9ea9c3f7c012d..5085d50d06e83 100644 --- a/sklearn/callback/tests/test_task_tree.py +++ b/sklearn/callback/tests/test_task_tree.py @@ -2,6 +2,7 @@ # Authors: the scikit-learn developers import numpy as np +import pytest from sklearn.callback import TaskNode @@ -33,6 +34,7 @@ def _make_task_tree(n_children, n_grandchildren): def test_task_tree(): + """Check that the task tree is correctly built.""" root = _make_task_tree(n_children=3, n_grandchildren=5) assert root.parent is None @@ -62,6 +64,7 @@ def test_task_tree(): def test_path(): + """Sanity check for the path property.""" root = _make_task_tree(n_children=3, n_grandchildren=5) assert root.path == [root] @@ -71,3 +74,68 @@ def test_path(): expected_path = [root, root.children_map[1], node] assert node.path == expected_path + + +def test_add_task(): + """Check that informative error messages are raised when adding tasks.""" + root = TaskNode(task_name="root task", task_id=0, max_tasks=1, estimator_name="est") + + # Before adding new task, it's considered a leaf + assert root.max_subtasks == 0 + + root._add_child( + TaskNode(task_name="child task", task_id=0, max_tasks=2, estimator_name="est") + ) + assert root.max_subtasks == 2 + assert len(root.children_map) == 1 + + # root already has a child with id 0 + with pytest.raises( + ValueError, match=r"Task node .* already has a child with task_id=0" + ): + root._add_child( + TaskNode( + task_name="child task", task_id=0, max_tasks=2, estimator_name="est" + ) + ) + + root._add_child( + TaskNode(task_name="child task", task_id=1, max_tasks=2, estimator_name="est") + ) + assert len(root.children_map) == 2 + + # root can have at most 2 children + with pytest.raises(ValueError, match=r"Cannot add child to task node"): + root._add_child( + TaskNode( + task_name="child task", task_id=2, max_tasks=2, estimator_name="est" + ) + ) + + +def test_merge_with(): + outer_root = TaskNode( + task_name="root", task_id=0, max_tasks=1, estimator_name="outer" + ) + + # Add a child task within the same estimator + outer_child = TaskNode( + task_name="child", task_id="id", max_tasks=2, estimator_name="outer" + ) + outer_root._add_child(outer_child) + + # The root task of the inner estimator is merged with (and effectively replaces) + # a leaf of the outer estimator because they correspond to the same formal task. + inner_root = TaskNode( + task_name="root", task_id=0, max_tasks=1, estimator_name="inner" + ) + inner_root._merge_with(outer_child) + + assert inner_root.parent is outer_root + assert inner_root.task_id == outer_child.task_id + assert outer_child not in outer_root.children_map.values() + assert inner_root in outer_root.children_map.values() + + # The name and estimator name of the tasks it was merged with are stored + assert inner_root.prev_task_name == outer_child.task_name + assert inner_root.prev_estimator_name == outer_child.estimator_name From 82a9d131ff67e10352eb4eadaa6392c21fa0a03c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 7 Jun 2024 18:51:35 +0200 Subject: [PATCH 54/55] add rich to pyproject.toml --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4599d7f6302fc..cfda7c26268f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ docs = [ "matplotlib>=3.3.4", "scikit-image>=0.17.2", "pandas>=1.1.5", + "rich>=13.6.0", "seaborn>=0.9.0", "memory_profiler>=0.57.0", "sphinx>=7.3.7", @@ -71,6 +72,7 @@ examples = [ "matplotlib>=3.3.4", "scikit-image>=0.17.2", "pandas>=1.1.5", + "rich>=13.6.0", "seaborn>=0.9.0", "pooch>=1.6.0", "plotly>=5.14.0", @@ -79,6 +81,7 @@ tests = [ "matplotlib>=3.3.4", "scikit-image>=0.17.2", "pandas>=1.1.5", + "rich>=13.6.0", "pytest>=7.1.2", "pytest-cov>=2.9.0", "ruff>=0.2.1", From 7d41b317cc6886e37672d3e3bc5938e8dae2add9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Mon, 10 Jun 2024 14:44:26 +0200 Subject: [PATCH 55/55] fix docstrings and make callback hooks private for the end user --- sklearn/callback/_base.py | 18 +++---- sklearn/callback/_callback_context.py | 76 ++++++++++++++++++++++++--- sklearn/callback/_progressbar.py | 6 +-- sklearn/callback/tests/_utils.py | 10 ++-- 4 files changed, 86 insertions(+), 24 deletions(-) diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py index 1a9180a76042f..cdb434de08d2a 100644 --- a/sklearn/callback/_base.py +++ b/sklearn/callback/_base.py @@ -8,13 +8,13 @@ class CallbackProtocol(Protocol): """Protocol for the callbacks""" - def on_fit_begin(self, estimator, *, data): + def _on_fit_begin(self, estimator, *, data): """Method called at the beginning of the fit method of the estimator. Parameters ---------- estimator : estimator instance - The estimator the callback is registered on. + The estimator calling this callback hook. data : dict Dictionary containing the training and validation data. The possible @@ -22,14 +22,14 @@ def on_fit_begin(self, estimator, *, data): and "sample_weight_val". """ - def on_fit_iter_end(self, estimator, task_node, **kwargs): + def _on_fit_iter_end(self, estimator, task_node, **kwargs): """Method called at the end of each task of the estimator. Parameters ---------- estimator : estimator instance - The caller estimator. It might differ from the estimator passed to the - `on_fit_begin` method for auto-propagated callbacks. + The estimator calling this callback hook. It might differ from the estimator + passed to the `on_fit_begin` method for auto-propagated callbacks. task_node : TaskNode instance The caller task node. @@ -64,17 +64,17 @@ class (e.g. LogisticRegressionCV -> LogisticRegression) Returns ------- - stop : bool or None - Whether or not to stop the current level of iterations at this node. + stop : bool + Whether or not to stop the current level of iterations at this task node. """ - def on_fit_end(self, estimator, task_node): + def _on_fit_end(self, estimator, task_node): """Method called at the end of the fit method of the estimator. Parameters ---------- estimator : estimator instance - The estimator the callback is registered on. + The estimator calling this callback hook. task_node : TaskNode instance The task node corresponding to the whole `fit` task. This is usually the diff --git a/sklearn/callback/_callback_context.py b/sklearn/callback/_callback_context.py index a5ebacb08d456..c4ea17d95cd3c 100644 --- a/sklearn/callback/_callback_context.py +++ b/sklearn/callback/_callback_context.py @@ -118,7 +118,18 @@ def subcontext(self, task_name="", task_id=0, max_tasks=1): ) def eval_on_fit_begin(self, estimator, *, data): - """Evaluate the on_fit_begin method of the callbacks.""" + """Evaluate the _on_fit_begin method of the callbacks. + + Parameters + ---------- + estimator : estimator instance + The estimator calling this callback hook. + + data : dict + Dictionary containing the training and validation data. The possible + keys are "X_train", "y_train", "sample_weight_train", "X_val", "y_val" + and "sample_weight_val". + """ for callback in self._callbacks: # Only call the on_fit_begin method of callbacks that are not # propagated from a meta-estimator. @@ -126,19 +137,64 @@ def eval_on_fit_begin(self, estimator, *, data): isinstance(callback, AutoPropagatedProtocol) and self._task_node.parent is not None ): - callback.on_fit_begin(estimator, data=data) + callback._on_fit_begin(estimator, data=data) return self def eval_on_fit_iter_end(self, estimator, **kwargs): - """Evaluate the on_fit_iter_end method of the callbacks.""" + """Evaluate the _on_fit_iter_end method of the callbacks. + + Parameters + ---------- + estimator : estimator instance + The estimator calling this callback hook. + + **kwargs : dict + arguments passed to the callback. Possible keys are + + - data: dict + Dictionary containing the training and validation data. The keys are + "X_train", "y_train", "sample_weight_train", "X_val", "y_val", + "sample_weight_val". The values are the corresponding data. If a key is + missing, the corresponding value is None. + + - stopping_criterion: float + Usually iterations stop when `stopping_criterion <= tol`. + This is only provided at the innermost level of iterations. + + - tol: float + Tolerance for the stopping criterion. + This is only provided at the innermost level of iterations. + + - from_reconstruction_attributes: estimator instance + A ready to predict, transform, etc ... estimator as if the fit stopped + at this node. Usually it's a copy of the caller estimator with the + necessary attributes set but it can sometimes be an instance of another + class (e.g. LogisticRegressionCV -> LogisticRegression) + + - fit_state: dict + Model specific quantities updated during fit. This is not meant to be + used by generic callbacks but by a callback designed for a specific + estimator instead. + + Returns + ------- + stop : bool + Whether or not to stop the current level of iterations at this task node. + """ return any( - callback.on_fit_iter_end(estimator, self._task_node, **kwargs) + callback._on_fit_iter_end(estimator, self._task_node, **kwargs) for callback in self._callbacks ) def eval_on_fit_end(self, estimator): - """Evaluate the on_fit_end method of the callbacks.""" + """Evaluate the _on_fit_end method of the callbacks. + + Parameters + ---------- + estimator : estimator instance + The estimator calling this callback hook. + """ for callback in self._callbacks: # Only call the on_fit_end method of callbacks that are not # propagated from a meta-estimator. @@ -146,10 +202,16 @@ def eval_on_fit_end(self, estimator): isinstance(callback, AutoPropagatedProtocol) and self._task_node.parent is not None ): - callback.on_fit_end(estimator, task_node=self._task_node) + callback._on_fit_end(estimator, task_node=self._task_node) def propagate_callbacks(self, sub_estimator): - """Propagate the callbacks to a sub-estimator.""" + """Propagate the callbacks to a sub-estimator. + + Parameters + ---------- + sub_estimator : estimator instance + The estimator to which the callbacks should be propagated. + """ bad_callbacks = [ callback.__class__.__name__ for callback in getattr(sub_estimator, "_skl_callbacks", []) diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py index 4ff1328e49f81..2ed8f4e4e00b5 100644 --- a/sklearn/callback/_progressbar.py +++ b/sklearn/callback/_progressbar.py @@ -23,15 +23,15 @@ def __init__(self, max_estimator_depth=1): self.max_estimator_depth = max_estimator_depth - def on_fit_begin(self, estimator, *, data): + def _on_fit_begin(self, estimator, *, data): self._queue = Manager().Queue() self.progress_monitor = _RichProgressMonitor(queue=self._queue) self.progress_monitor.start() - def on_fit_iter_end(self, estimator, task_node, **kwargs): + def _on_fit_iter_end(self, estimator, task_node, **kwargs): self._queue.put(task_node) - def on_fit_end(self, estimator, task_node): + def _on_fit_end(self, estimator, task_node): self._queue.put(task_node) self._queue.put(None) self.progress_monitor.join() diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 3917030e263b1..8fbb25d9c09e8 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -9,13 +9,13 @@ class TestingCallback: - def on_fit_begin(self, estimator, *, data): + def _on_fit_begin(self, estimator, *, data): pass - def on_fit_end(self): + def _on_fit_end(self): pass - def on_fit_iter_end(self, estimator, node, **kwargs): + def _on_fit_iter_end(self, estimator, node, **kwargs): pass @@ -26,10 +26,10 @@ class TestingAutoPropagatedCallback(TestingCallback): class NotValidCallback: """Unvalid callback since it's missing a method from the protocol.'""" - def on_fit_begin(self, estimator, *, data): + def _on_fit_begin(self, estimator, *, data): pass # pragma: no cover - def on_fit_iter_end(self, estimator, node, **kwargs): + def _on_fit_iter_end(self, estimator, node, **kwargs): pass # pragma: no cover