forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
__init__.py
1759 lines (1493 loc) · 76.8 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
The ``mlflow.sklearn`` module provides an API for logging and loading scikit-learn models. This
module exports scikit-learn models with the following flavors:
Python (native) `pickle <https://scikit-learn.org/stable/modules/model_persistence.html>`_ format
This is the main flavor that can be loaded back into scikit-learn.
:py:mod:`mlflow.pyfunc`
Produced for use by generic pyfunc-based deployment tools and batch inference.
NOTE: The `mlflow.pyfunc` flavor is only added for scikit-learn models that define `predict()`,
since `predict()` is required for pyfunc model inference.
"""
import inspect
import functools
import os
import logging
import numpy as np
import pickle
import yaml
import warnings
import weakref
from collections import defaultdict, OrderedDict
from packaging.version import Version
import mlflow
from mlflow import pyfunc
from mlflow.exceptions import MlflowException
from mlflow.models import Model
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.signature import ModelSignature
from mlflow.models.utils import ModelInputExample, _save_example
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, INTERNAL_ERROR
from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils import _inspect_original_var_name
from mlflow.utils.autologging_utils import get_instance_method_first_arg_value
from mlflow.utils.environment import (
_mlflow_conda_env,
_validate_env_arguments,
_process_pip_requirements,
_process_conda_env,
_CONDA_ENV_FILE_NAME,
_REQUIREMENTS_FILE_NAME,
_CONSTRAINTS_FILE_NAME,
)
from mlflow.utils import gorilla
from mlflow.utils.requirements_utils import _get_pinned_requirement
from mlflow.utils.file_utils import write_to
from mlflow.utils.docstring_utils import format_docstring, LOG_MODEL_PARAM_DOCS
from mlflow.utils.mlflow_tags import MLFLOW_AUTOLOGGING
from mlflow.utils.model_utils import _get_flavor_configuration
from mlflow.utils.autologging_utils import (
autologging_integration,
safe_patch,
INPUT_EXAMPLE_SAMPLE_ROWS,
resolve_input_example_and_signature,
_get_new_training_session_class,
MlflowAutologgingQueueingClient,
disable_autologging,
update_wrapper_extended,
)
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
FLAVOR_NAME = "sklearn"
SERIALIZATION_FORMAT_PICKLE = "pickle"
SERIALIZATION_FORMAT_CLOUDPICKLE = "cloudpickle"
SUPPORTED_SERIALIZATION_FORMATS = [SERIALIZATION_FORMAT_PICKLE, SERIALIZATION_FORMAT_CLOUDPICKLE]
_logger = logging.getLogger(__name__)
_SklearnTrainingSession = _get_new_training_session_class()
def _gen_estimators_to_patch():
from mlflow.sklearn.utils import (
_all_estimators,
_get_meta_estimators_for_autologging,
)
_, estimators_to_patch = zip(*_all_estimators())
# Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected
# for patching if they are not already included in the output of `all_estimators()`
estimators_to_patch = set(estimators_to_patch).union(
set(_get_meta_estimators_for_autologging())
)
# Exclude certain preprocessing & feature manipulation estimators from patching. These
# estimators represent data manipulation routines (e.g., normalization, label encoding)
# rather than ML algorithms. Accordingly, we should not create MLflow runs and log
# parameters / metrics for these routines, unless they are captured as part of an ML pipeline
# (via `sklearn.pipeline.Pipeline`)
excluded_module_names = [
"sklearn.preprocessing",
"sklearn.impute",
"sklearn.feature_extraction",
"sklearn.feature_selection",
]
excluded_class_names = [
"sklearn.compose._column_transformer.ColumnTransformer",
]
return [
estimator
for estimator in estimators_to_patch
if not any(
estimator.__module__.startswith(excluded_module_name)
or (estimator.__module__ + "." + estimator.__name__) in excluded_class_names
for excluded_module_name in excluded_module_names
)
]
def get_default_pip_requirements(include_cloudpickle=False):
"""
:return: A list of default pip requirements for MLflow Models produced by this flavor.
Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
that, at minimum, contains these requirements.
"""
pip_deps = [_get_pinned_requirement("scikit-learn", module="sklearn")]
if include_cloudpickle:
pip_deps += [_get_pinned_requirement("cloudpickle")]
return pip_deps
def get_default_conda_env(include_cloudpickle=False):
"""
:return: The default Conda environment for MLflow Models produced by calls to
:func:`save_model()` and :func:`log_model()`.
"""
return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements(include_cloudpickle))
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
def save_model(
sk_model,
path,
conda_env=None,
mlflow_model=None,
serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
signature: ModelSignature = None,
input_example: ModelInputExample = None,
pip_requirements=None,
extra_pip_requirements=None,
):
"""
Save a scikit-learn model to a path on the local file system. Produces an MLflow Model
containing the following flavors:
- :py:mod:`mlflow.sklearn`
- :py:mod:`mlflow.pyfunc`. NOTE: This flavor is only included for scikit-learn models
that define `predict()`, since `predict()` is required for pyfunc model inference.
:param sk_model: scikit-learn model to be saved.
:param path: Local path where the model is to be saved.
:param conda_env: {{ conda_env }}
:param mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
:param serialization_format: The format in which to serialize the model. This should be one of
the formats listed in
``mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS``. The Cloudpickle
format, ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``,
provides better cross-system compatibility by identifying and
packaging code dependencies with the serialized model.
:param signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
from datasets with valid model input (e.g. the training dataset with target
column omitted) and valid model output (e.g. model predictions generated on
the training dataset), for example:
.. code-block:: python
from mlflow.models.signature import infer_signature
train = df.drop_column("target_label")
predictions = ... # compute model predictions
signature = infer_signature(train, predictions)
:param input_example: Input example provides one or several instances of valid
model input. The example can be used as a hint of what data to feed the
model. The given example will be converted to a Pandas DataFrame and then
serialized to json using the Pandas split-oriented format. Bytes are
base64-encoded.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
.. code-block:: python
:caption: Example
import mlflow.sklearn
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
sk_model = tree.DecisionTreeClassifier()
sk_model = sk_model.fit(iris.data, iris.target)
# Save the model in cloudpickle format
# set path to location for persistence
sk_path_dir_1 = ...
mlflow.sklearn.save_model(
sk_model, sk_path_dir_1,
serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE)
# save the model in pickle format
# set path to location for persistence
sk_path_dir_2 = ...
mlflow.sklearn.save_model(sk_model, sk_path_dir_2,
serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE)
"""
import sklearn
_validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS:
raise MlflowException(
message=(
"Unrecognized serialization format: {serialization_format}. Please specify one"
" of the following supported formats: {supported_formats}.".format(
serialization_format=serialization_format,
supported_formats=SUPPORTED_SERIALIZATION_FORMATS,
)
),
error_code=INVALID_PARAMETER_VALUE,
)
if os.path.exists(path):
raise MlflowException(
message="Path '{}' already exists".format(path), error_code=RESOURCE_ALREADY_EXISTS
)
os.makedirs(path)
if mlflow_model is None:
mlflow_model = Model()
if signature is not None:
mlflow_model.signature = signature
if input_example is not None:
_save_example(mlflow_model, input_example, path)
model_data_subpath = "model.pkl"
model_data_path = os.path.join(path, model_data_subpath)
_save_model(
sk_model=sk_model,
output_path=model_data_path,
serialization_format=serialization_format,
)
# `PyFuncModel` only works for sklearn models that define `predict()`.
if hasattr(sk_model, "predict"):
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.sklearn",
model_path=model_data_subpath,
env=_CONDA_ENV_FILE_NAME,
)
mlflow_model.add_flavor(
FLAVOR_NAME,
pickled_model=model_data_subpath,
sklearn_version=sklearn.__version__,
serialization_format=serialization_format,
)
mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
if conda_env is None:
if pip_requirements is None:
include_cloudpickle = serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE
default_reqs = get_default_pip_requirements(include_cloudpickle)
# To ensure `_load_pyfunc` can successfully load the model during the dependency
# inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
inferred_reqs = mlflow.models.infer_pip_requirements(
model_data_path,
FLAVOR_NAME,
fallback=default_reqs,
)
default_reqs = sorted(set(inferred_reqs).union(default_reqs))
else:
default_reqs = None
conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
default_reqs,
pip_requirements,
extra_pip_requirements,
)
else:
conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
# Save `constraints.txt` if necessary
if pip_constraints:
write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
# Save `requirements.txt`
write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
def log_model(
sk_model,
artifact_path,
conda_env=None,
serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
registered_model_name=None,
signature: ModelSignature = None,
input_example: ModelInputExample = None,
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
pip_requirements=None,
extra_pip_requirements=None,
):
"""
Log a scikit-learn model as an MLflow artifact for the current run. Produces an MLflow Model
containing the following flavors:
- :py:mod:`mlflow.sklearn`
- :py:mod:`mlflow.pyfunc`. NOTE: This flavor is only included for scikit-learn models
that define `predict()`, since `predict()` is required for pyfunc model inference.
:param sk_model: scikit-learn model to be saved.
:param artifact_path: Run-relative artifact path.
:param conda_env: {{ conda_env }}
:param serialization_format: The format in which to serialize the model. This should be one of
the formats listed in
``mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS``. The Cloudpickle
format, ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``,
provides better cross-system compatibility by identifying and
packaging code dependencies with the serialized model.
:param registered_model_name: If given, create a model version under
``registered_model_name``, also creating a registered model if one
with the given name does not exist.
:param signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
from datasets with valid model input (e.g. the training dataset with target
column omitted) and valid model output (e.g. model predictions generated on
the training dataset), for example:
.. code-block:: python
from mlflow.models.signature import infer_signature
train = df.drop_column("target_label")
predictions = ... # compute model predictions
signature = infer_signature(train, predictions)
:param input_example: Input example provides one or several instances of valid
model input. The example can be used as a hint of what data to feed the
model. The given example will be converted to a Pandas DataFrame and then
serialized to json using the Pandas split-oriented format. Bytes are
base64-encoded.
:param await_registration_for: Number of seconds to wait for the model version to finish
being created and is in ``READY`` status. By default, the function
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
:caption: Example
import mlflow
import mlflow.sklearn
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
sk_model = tree.DecisionTreeClassifier()
sk_model = sk_model.fit(iris.data, iris.target)
# set the artifact_path to location where experiment artifacts will be saved
#log model params
mlflow.log_param("criterion", sk_model.criterion)
mlflow.log_param("splitter", sk_model.splitter)
# log model
mlflow.sklearn.log_model(sk_model, "sk_models")
"""
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.sklearn,
sk_model=sk_model,
conda_env=conda_env,
serialization_format=serialization_format,
registered_model_name=registered_model_name,
signature=signature,
input_example=input_example,
await_registration_for=await_registration_for,
pip_requirements=pip_requirements,
extra_pip_requirements=extra_pip_requirements,
)
def _load_model_from_local_file(path, serialization_format):
"""Load a scikit-learn model saved as an MLflow artifact on the local file system.
:param path: Local filesystem path to the MLflow Model saved with the ``sklearn`` flavor
:param serialization_format: The format in which the model was serialized. This should be one of
the following: ``mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE`` or
``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``.
"""
# TODO: we could validate the scikit-learn version here
if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS:
raise MlflowException(
message=(
"Unrecognized serialization format: {serialization_format}. Please specify one"
" of the following supported formats: {supported_formats}.".format(
serialization_format=serialization_format,
supported_formats=SUPPORTED_SERIALIZATION_FORMATS,
)
),
error_code=INVALID_PARAMETER_VALUE,
)
with open(path, "rb") as f:
# Models serialized with Cloudpickle cannot necessarily be deserialized using Pickle;
# That's why we check the serialization format of the model before deserializing
if serialization_format == SERIALIZATION_FORMAT_PICKLE:
return pickle.load(f)
elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE:
import cloudpickle
return cloudpickle.load(f)
def _load_pyfunc(path):
"""
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
:param path: Local filesystem path to the MLflow Model with the ``sklearn`` flavor.
"""
if os.path.isfile(path):
# Scikit-learn models saved in older versions of MLflow (<= 1.9.1) specify the ``data``
# field within the pyfunc flavor configuration. For these older models, the ``path``
# parameter of ``_load_pyfunc()`` refers directly to a serialized scikit-learn model
# object. In this case, we assume that the serialization format is ``pickle``, since
# the model loading procedure in older versions of MLflow used ``pickle.load()``.
serialization_format = SERIALIZATION_FORMAT_PICKLE
else:
# In contrast, scikit-learn models saved in versions of MLflow > 1.9.1 do not
# specify the ``data`` field within the pyfunc flavor configuration. For these newer
# models, the ``path`` parameter of ``load_pyfunc()`` refers to the top-level MLflow
# Model directory. In this case, we parse the model path from the MLmodel's pyfunc
# flavor configuration and attempt to fetch the serialization format from the
# scikit-learn flavor configuration
try:
sklearn_flavor_conf = _get_flavor_configuration(
model_path=path, flavor_name=FLAVOR_NAME
)
serialization_format = sklearn_flavor_conf.get(
"serialization_format", SERIALIZATION_FORMAT_PICKLE
)
except MlflowException:
_logger.warning(
"Could not find scikit-learn flavor configuration during model loading process."
" Assuming 'pickle' serialization format."
)
serialization_format = SERIALIZATION_FORMAT_PICKLE
pyfunc_flavor_conf = _get_flavor_configuration(
model_path=path, flavor_name=pyfunc.FLAVOR_NAME
)
path = os.path.join(path, pyfunc_flavor_conf["model_path"])
return _load_model_from_local_file(path=path, serialization_format=serialization_format)
class _SklearnCustomModelPicklingError(pickle.PicklingError):
"""
Exception for describing error raised during pickling custom sklearn estimator
"""
def __init__(self, sk_model, original_exception):
"""
:param sk_model: The custom sklearn model to be pickled
:param original_exception: The original exception raised
"""
super(_SklearnCustomModelPicklingError, self).__init__(
f"Pickling custom sklearn model {sk_model.__class__.__name__} failed "
f"when saving model: {str(original_exception)}"
)
self.original_exception = original_exception
def _dump_model(pickle_lib, sk_model, out):
try:
pickle_lib.dump(sk_model, out)
except (pickle.PicklingError, TypeError, AttributeError) as e:
if sk_model.__class__ not in _gen_estimators_to_patch():
raise _SklearnCustomModelPicklingError(sk_model, e)
else:
raise
def _save_model(sk_model, output_path, serialization_format):
"""
:param sk_model: The scikit-learn model to serialize.
:param output_path: The file path to which to write the serialized model.
:param serialization_format: The format in which to serialize the model. This should be one of
the following: ``mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE`` or
``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``.
"""
with open(output_path, "wb") as out:
if serialization_format == SERIALIZATION_FORMAT_PICKLE:
_dump_model(pickle, sk_model, out)
elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE:
import cloudpickle
_dump_model(cloudpickle, sk_model, out)
else:
raise MlflowException(
message="Unrecognized serialization format: {serialization_format}".format(
serialization_format=serialization_format
),
error_code=INTERNAL_ERROR,
)
def load_model(model_uri, dst_path=None):
"""
Load a scikit-learn model from a local file or a run.
:param model_uri: The location, in URI format, of the MLflow model, for example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
- ``models:/<model_name>/<model_version>``
- ``models:/<model_name>/<stage>``
For more information about supported URI schemes, see
`Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
artifact-locations>`_.
:param dst_path: The local filesystem path to which to download the model artifact.
This directory must already exist. If unspecified, a local output
path will be created.
:return: A scikit-learn model.
.. code-block:: python
:caption: Example
import mlflow.sklearn
sk_model = mlflow.sklearn.load_model("runs:/96771d893a5e46159d9f3b49bf9013e2/sk_models")
# use Pandas DataFrame to make predictions
pandas_df = ...
predictions = sk_model.predict(pandas_df)
"""
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
sklearn_model_artifacts_path = os.path.join(local_model_path, flavor_conf["pickled_model"])
serialization_format = flavor_conf.get("serialization_format", SERIALIZATION_FORMAT_PICKLE)
return _load_model_from_local_file(
path=sklearn_model_artifacts_path, serialization_format=serialization_format
)
# The `_apis_autologging_disabled` contains APIs which is incompatible with autologging,
# when user call these APIs, autolog is temporarily disabled.
_apis_autologging_disabled = [
"cross_validate",
"cross_val_predict",
"cross_val_score",
"learning_curve",
"permutation_test_score",
"validation_curve",
]
class _AutologgingMetricsManager:
"""
This class is designed for holding information which is used by autologging metrics
It will hold information of:
(1) a map of "prediction result object id" to a tuple of dataset name(the dataset is
the one which generate the prediction result) and run_id.
Note: We need this map instead of setting the run_id into the "prediction result object"
because the object maybe a numpy array which does not support additional attribute
assignment.
(2) _log_post_training_metrics_enabled flag, in the following method scope:
`model.fit`, `eval_and_log_metrics`, `model.score`,
in order to avoid nested/duplicated autologging metric, when run into these scopes,
we need temporarily disable the metric autologging.
(3) _eval_dataset_info_map, it is a double level map:
`_eval_dataset_info_map[run_id][eval_dataset_var_name]` will get a list, each
element in the list is an id of "eval_dataset" instance.
This data structure is used for:
* generating unique dataset name key when autologging metric. For each eval dataset object,
if they have the same eval_dataset_var_name, but object ids are different,
then they will be assigned different name (via appending index to the
eval_dataset_var_name) when autologging.
(4) _metric_api_call_info, it is a double level map:
`_metric_api_call_info[run_id][metric_name]` wil get a list of tuples, each tuple is:
(logged_metric_key, metric_call_command_string)
each call command string is like `metric_fn(arg1, arg2, ...)`
This data structure is used for:
* storing the call arguments dict for each metric call, we need log them into metric_info
artifact file.
Note: this class is not thread-safe.
Design rule for this class:
Because this class instance is a global instance, in order to prevent memory leak, it should
only holds IDs and other small objects references. This class internal data structure should
avoid reference to user dataset variables or model variables.
"""
def __init__(self):
self._pred_result_id_to_dataset_name_and_run_id = {}
self._eval_dataset_info_map = defaultdict(lambda: defaultdict(list))
self._metric_api_call_info = defaultdict(lambda: defaultdict(list))
self._log_post_training_metrics_enabled = True
self._metric_info_artifact_need_update = defaultdict(lambda: False)
def should_log_post_training_metrics(self):
"""
Check whether we should run patching code for autologging post training metrics.
This checking should surround the whole patched code due to the safe guard checking,
See following note.
Note: It includes checking `_SklearnTrainingSession.is_active()`, This is a safe guarding
for meta-estimator (e.g. GridSearchCV) case:
running GridSearchCV.fit, the nested `estimator.fit` will be called in parallel,
but, the _autolog_training_status is a global status without thread-safe lock protecting.
This safe guarding will prevent code run into this case.
"""
return not _SklearnTrainingSession.is_active() and self._log_post_training_metrics_enabled
def disable_log_post_training_metrics(self):
class LogPostTrainingMetricsDisabledScope:
def __enter__(inner_self): # pylint: disable=no-self-argument
# pylint: disable=attribute-defined-outside-init
inner_self.old_status = self._log_post_training_metrics_enabled
self._log_post_training_metrics_enabled = False
# pylint: disable=no-self-argument
def __exit__(inner_self, exc_type, exc_val, exc_tb):
self._log_post_training_metrics_enabled = inner_self.old_status
return LogPostTrainingMetricsDisabledScope()
@staticmethod
def get_run_id_for_model(model):
return getattr(model, "_mlflow_run_id", None)
@staticmethod
def is_metric_value_loggable(metric_value):
"""
check whether the specified `metric_value` is a numeric value which can be logged
as an MLflow metric.
"""
return isinstance(metric_value, (int, float, np.number)) and not isinstance(
metric_value, (bool, np.bool)
)
def register_model(self, model, run_id):
"""
In `patched_fit`, we need register the model with the run_id used in `patched_fit`
So that in following metric autologging, the metric will be logged into the registered
run_id
"""
model._mlflow_run_id = run_id
@staticmethod
def gen_name_with_index(name, index):
assert index >= 0
if index == 0:
return name
else:
# Use '-' as the separator between name and index,
# The '-' is not valid character in python var name
# so it can prevent name conflicts after appending index.
return f"{name}-{index + 1}"
def register_prediction_input_dataset(self, model, eval_dataset):
"""
Register prediction input dataset into eval_dataset_info_map, it will do:
1. inspect eval dataset var name.
2. check whether eval_dataset_info_map already registered this eval dataset.
will check by object id.
3. register eval dataset with id.
4. return eval dataset name with index.
Note: this method include inspecting argument variable name.
So should be called directly from the "patched method", to ensure it capture
correct argument variable name.
"""
eval_dataset_name = _inspect_original_var_name(
eval_dataset, fallback_name="unknown_dataset"
)
eval_dataset_id = id(eval_dataset)
run_id = self.get_run_id_for_model(model)
registered_dataset_list = self._eval_dataset_info_map[run_id][eval_dataset_name]
for i, id_i in enumerate(registered_dataset_list):
if eval_dataset_id == id_i:
index = i
break
else:
index = len(registered_dataset_list)
if index == len(registered_dataset_list):
# register new eval dataset
registered_dataset_list.append(eval_dataset_id)
return self.gen_name_with_index(eval_dataset_name, index)
def register_prediction_result(self, run_id, eval_dataset_name, predict_result):
"""
Register the relationship
id(prediction_result) --> (eval_dataset_name, run_id)
into map `_pred_result_id_to_dataset_name_and_run_id`
"""
value = (eval_dataset_name, run_id)
prediction_result_id = id(predict_result)
self._pred_result_id_to_dataset_name_and_run_id[prediction_result_id] = value
def clean_id(id_):
_AUTOLOGGING_METRICS_MANAGER._pred_result_id_to_dataset_name_and_run_id.pop(id_, None)
# When the `predict_result` object being GCed, its ID may be reused, so register a finalizer
# to clear the ID from the dict for preventing wrong ID mapping.
weakref.finalize(predict_result, clean_id, prediction_result_id)
@staticmethod
def gen_metric_call_command(self_obj, metric_fn, *call_pos_args, **call_kwargs):
"""
Generate metric function call command string like `metric_fn(arg1, arg2, ...)`
Note: this method include inspecting argument variable name.
So should be called directly from the "patched method", to ensure it capture
correct argument variable name.
:param self_obj: If the metric_fn is a method of an instance (e.g. `model.score`),
the `self_obj` represent the instance.
:param metric_fn: metric function.
:param call_pos_args: the positional arguments of the metric function call. If `metric_fn`
is instance method, then the `call_pos_args` should exclude the first `self` argument.
:param call_kwargs: the keyword arguments ofthe metric function call.
"""
arg_list = []
def arg_to_str(arg):
if arg is None or np.isscalar(arg):
if isinstance(arg, str) and len(arg) > 32:
# truncate too long string
return repr(arg[:32] + "...")
return repr(arg)
else:
# dataset arguments or other non-scalar type argument
return _inspect_original_var_name(arg, fallback_name=f"<{arg.__class__.__name__}>")
param_sig = inspect.signature(metric_fn).parameters
arg_names = list(param_sig.keys())
if self_obj is not None:
# If metric_fn is a method of an instance, e.g. `model.score`,
# then the first argument is `self` which we need exclude it.
arg_names.pop(0)
if self_obj is not None:
call_fn_name = f"{self_obj.__class__.__name__}.{metric_fn.__name__}"
else:
call_fn_name = metric_fn.__name__
# Attach param signature key for positinal param values
for arg_name, arg in zip(arg_names, call_pos_args):
arg_list.append(f"{arg_name}={arg_to_str(arg)}")
for arg_name, arg in call_kwargs.items():
arg_list.append(f"{arg_name}={arg_to_str(arg)}")
arg_list_str = ", ".join(arg_list)
return f"{call_fn_name}({arg_list_str})"
def register_metric_api_call(self, run_id, metric_name, dataset_name, call_command):
"""
This method will do:
(1) Generate and return metric key, format is:
{metric_name}[-{call_index}]_{eval_dataset_name}
metric_name is generated by metric function name, if multiple calls on the same
metric API happen, the following calls will be assigned with an increasing "call index".
(2) Register the metric key with the "call command" information into
`_AUTOLOGGING_METRICS_MANAGER`. See doc of `gen_metric_call_command` method for
details of "call command".
"""
call_cmd_list = self._metric_api_call_info[run_id][metric_name]
index = len(call_cmd_list)
metric_name_with_index = self.gen_name_with_index(metric_name, index)
metric_key = f"{metric_name_with_index}_{dataset_name}"
call_cmd_list.append((metric_key, call_command))
# Set the flag to true, represent the metric info in this run need update.
# Later when `log_eval_metric` called, it will generate a new metric_info artifact
# and overwrite the old artifact.
self._metric_info_artifact_need_update[run_id] = True
return metric_key
def get_run_id_and_dataset_name_for_metric_api_call(self, call_pos_args, call_kwargs):
"""
Given a metric api call (include the called metric function, and call arguments)
Register the call information (arguments dict) into the `metric_api_call_arg_dict_list_map`
and return a tuple of (run_id, eval_dataset_name)
"""
call_arg_list = list(call_pos_args) + list(call_kwargs.values())
dataset_id_list = self._pred_result_id_to_dataset_name_and_run_id.keys()
# Note: some metric API the arguments is not like `y_true`, `y_pred`
# e.g.
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.silhouette_score.html#sklearn.metrics.silhouette_score
for arg in call_arg_list:
if arg is not None and not np.isscalar(arg) and id(arg) in dataset_id_list:
dataset_name, run_id = self._pred_result_id_to_dataset_name_and_run_id[id(arg)]
break
else:
return None, None
return run_id, dataset_name
def log_post_training_metric(self, run_id, key, value):
"""
Log the metric into the specified mlflow run.
and it will also update the metric_info artifact if needed.
"""
# Note: if the case log the same metric key multiple times,
# newer value will overwrite old value
client = mlflow.tracking.MlflowClient()
client.log_metric(run_id=run_id, key=key, value=value)
if self._metric_info_artifact_need_update[run_id]:
call_commands_list = []
for v in self._metric_api_call_info[run_id].values():
call_commands_list.extend(v)
call_commands_list.sort(key=lambda x: x[0])
dict_to_log = OrderedDict(call_commands_list)
client.log_dict(run_id=run_id, dictionary=dict_to_log, artifact_file="metric_info.json")
self._metric_info_artifact_need_update[run_id] = False
# The global `_AutologgingMetricsManager` instance which holds information used in
# post-training metric autologging. See doc of class `_AutologgingMetricsManager` for details.
_AUTOLOGGING_METRICS_MANAGER = _AutologgingMetricsManager()
_metric_api_excluding_list = ["check_scoring", "get_scorer", "make_scorer"]
def _get_metric_name_list():
"""
Return metric function name list in `sklearn.metrics` module
"""
from sklearn import metrics
metric_list = []
for metric_method_name in metrics.__all__:
# excludes plot_* methods
# exclude class (e.g. metrics.ConfusionMatrixDisplay)
metric_method = getattr(metrics, metric_method_name)
if (
metric_method_name not in _metric_api_excluding_list
and not inspect.isclass(metric_method)
and callable(metric_method)
and not metric_method_name.startswith("plot_")
):
metric_list.append(metric_method_name)
return metric_list
def _patch_estimator_method_if_available(flavor_name, class_def, func_name, patched_fn, manage_run):
if not hasattr(class_def, func_name):
return
original = gorilla.get_original_attribute(
class_def, func_name, bypass_descriptor_protocol=False
)
# Retrieve raw attribute while bypassing the descriptor protocol
raw_original_obj = gorilla.get_original_attribute(
class_def, func_name, bypass_descriptor_protocol=True
)
if raw_original_obj == original and (callable(original) or isinstance(original, property)):
# normal method or property decorated method
safe_patch(flavor_name, class_def, func_name, patched_fn, manage_run=manage_run)
elif hasattr(raw_original_obj, "delegate_names") or hasattr(raw_original_obj, "check"):
# sklearn delegated method
safe_patch(flavor_name, raw_original_obj, "fn", patched_fn, manage_run=manage_run)
else:
# unsupported method type. skip patching
pass
@autologging_integration(FLAVOR_NAME)
def autolog(
log_input_examples=False,
log_model_signatures=True,
log_models=True,
disable=False,
exclusive=False,
disable_for_unsupported_versions=False,
silent=False,
max_tuning_runs=5,
log_post_training_metrics=True,
): # pylint: disable=unused-argument
"""
Enables (or disables) and configures autologging for scikit-learn estimators.
**When is autologging performed?**
Autologging is performed when you call:
- ``estimator.fit()``
- ``estimator.fit_predict()``
- ``estimator.fit_transform()``
**Logged information**
**Parameters**
- Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params``
is called with ``deep=True``. This means when you fit a meta estimator that chains
a series of estimators, the parameters of these child estimators are also logged.
**Training metrics**
- A training score obtained by ``estimator.score``. Note that the training score is
computed using parameters given to ``fit()``.
- Common metrics for classifier:
- `precision score`_
.. _precision score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html
- `recall score`_
.. _recall score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html
- `f1 score`_
.. _f1 score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
- `accuracy score`_
.. _accuracy score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html
If the classifier has method ``predict_proba``, we additionally log:
- `log loss`_
.. _log loss:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html
- `roc auc score`_
.. _roc auc score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
- Common metrics for regressor:
- `mean squared error`_
.. _mean squared error:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html
- root mean squared error
- `mean absolute error`_
.. _mean absolute error:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html
- `r2 score`_
.. _r2 score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html
.. _post training metrics:
**Post training metrics**
When users call metric APIs after model training, MLflow tries to capture the metric API
results and log them as MLflow metrics to the Run associated with the model. The following
types of scikit-learn metric APIs are supported:
- model.score
- metric APIs defined in the `sklearn.metrics` module
For post training metrics autologging, the metric key format is:
"{metric_name}[-{call_index}]_{dataset_name}"
- If the metric function is from `sklearn.metrics`, the MLflow "metric_name" is the
metric function name. If the metric function is `model.score`, then "metric_name" is
"{model_class_name}_score".
- If multiple calls are made to the same scikit-learn metric API, each subsequent call
adds a "call_index" (starting from 2) to the metric key.
- MLflow uses the prediction input dataset variable name as the "dataset_name" in the
metric key. The "prediction input dataset variable" refers to the variable which was
used as the first argument of the associated `model.predict` or `model.score` call.
Note: MLflow captures the "prediction input dataset" instance in the outermost call
frame and fetches the variable name in the outermost call frame. If the "prediction