Skip to content

Commit

Permalink
Added two special cases to the Model Registry Workflow API (#4225)
Browse files Browse the repository at this point in the history
* Added first case registering previously saved modelsin native format

Signed-off-by: Jules Damji <dmatrix@comcast.net>

* Removed context-manager statement and added print statement for delineating output

Signed-off-by: Jules Damji <dmatrix@comcast.net>

* Added PyFuncModel example as the last case scenario

Signed-off-by: Jules Damji <dmatrix@comcast.net>

* Fixed minor typos

Signed-off-by: Jules Damji <dmatrix@comcast.net>

* Added versions to conda_env dynamicallly

Signed-off-by: Jules Damji <dmatrix@comcast.net>

* Minor sentence restructure

Signed-off-by: Jules Damji <dmatrix@comcast.net>

* Incoroporated some style suggestions

Signed-off-by: Jules Damji <dmatrix@comcast.net>

* Fixed a minor missing subject

Signed-off-by: Jules Damji <dmatrix@comcast.net>
  • Loading branch information
dmatrix committed Apr 8, 2021
1 parent 83b665a commit c635d1a
Showing 1 changed file with 252 additions and 0 deletions.
252 changes: 252 additions & 0 deletions docs/source/model-registry.rst
Expand Up @@ -349,3 +349,255 @@ You can either delete specific versions of a registered model or you can delete
# Delete a registered model along with all its versions
client.delete_registered_model(name="sk-learn-random-forest-reg-model")
While the above workflow API demonstrates interactions with the Model Registry, two exceptional cases require attention.
One is when you have existing ML models saved from training without the use of MLflow. Serialized and persisted on disk
in sklearn's pickled format, you want to register this model with the Model Registry. The second is when you use
an ML framework without a built-in MLflow model flavor support, for instance, `vaderSentiment,` and want to register the model.


Registering a Saved Model
^^^^^^^^^^^^^^^^^^^^^^^^^
Not everyone will start their model training with MLflow. So you may have some models trained before the use of MLflow.
Instead of retraining the models, all you want to do is register your saved models with the Model Registry.

This code snippet creates a sklearn model, which we assume that you had created and saved in native pickle format.


.. note::
The sklearn library and pickle versions with which the model was saved should be compatible with the
current MLflow supported built-in sklearn model flavor.

.. code-block:: py
import numpy as np
import pickle
from sklearn import datasets, linear_model
from sklearn.metrics import mean_squared_error, r2_score
# source: https://scikit-learn.org/stable/auto_examples/linear_model/plot_ols.html
# Load the diabetes dataset
diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True)
# Use only one feature
diabetes_X = diabetes_X[:, np.newaxis, 2]
# Split the data into training/testing sets
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]
# Split the targets into training/testing sets
diabetes_y_train = diabetes_y[:-20]
diabetes_y_test = diabetes_y[-20:]
def print_predictions(m, y_pred):
# The coefficients
print('Coefficients: \n', m.coef_)
# The mean squared error
print('Mean squared error: %.2f'
% mean_squared_error(diabetes_y_test, y_pred))
# The coefficient of determination: 1 is perfect prediction
print('Coefficient of determination: %.2f'
% r2_score(diabetes_y_test, y_pred))
# Create linear regression object
lr_model = linear_model.LinearRegression()
# Train the model using the training sets
lr_model.fit(diabetes_X_train, diabetes_y_train)
# Make predictions using the testing set
diabetes_y_pred = lr_model.predict(diabetes_X_test)
print_predictions(lr_model, diabetes_y_pred)
# save the model in the native sklearn format
filename = 'lr_model.pkl'
pickle.dump(lr_model, open(filename, 'wb'))
.. code-block:: text
Coefficients:
[938.23786125]
Mean squared error: 2548.07
Coefficient of determination: 0.47
Once saved in pickled format, we can load the sklearn model into memory using pickle API and
register the loaded model with the Model Registry.

.. code-block:: py
import mlflow
# load the model into memory
loaded_model = pickle.load(open(filename, 'rb'))
# log and register the model using MLflow scikit-learn API
mlflow.set_tracking_uri("sqlite:///mlruns.db")
reg_model_name = "SklearnLinearRegression"
print("--")
mlflow.sklearn.log_model(loaded_model, "sk_learn",
serialization_format="cloudpickle",
registered_model_name=reg_model_name)
.. code-block:: text
--
Successfully registered model 'SklearnLinearRegression'.
2021/04/02 16:30:57 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.
Model name: SklearnLinearRegression, version 1
Created version '1' of model 'SklearnLinearRegression'.
Now, using MLflow fluent APIs, we reload the model from the Model Registry and score.

.. code-block:: py
# load the model from the Model Registry and score
model_uri = f"models:/{reg_model_name}/1"
loaded_model = mlflow.sklearn.load_model(model_uri)
print("--")
# Make predictions using the testing set
diabetes_y_pred = loaded_model.predict(diabetes_X_test)
print_predictions(loaded_model, diabetes_y_pred)
.. code-block:: text
--
Coefficients:
[938.23786125]
Mean squared error: 2548.07
Coefficient of determination: 0.47
Registering an Unsupported Machine Learning Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In some cases, you might use a machine learning framework without its built-in MLflow Model flavor support.
For instance, the `vaderSentiment` library is a standard Natural Language Processing (NLP) library used
for sentiment analysis. Since it lacks a built-in MLflow Model flavor, you cannot log or register the model
using MLflow Model fluent APIs.

To work around this problem, you can create an instance of a :py:mod:`mlflow.pyfunc` model flavor and embed your NLP model
inside it, allowing you to save, log or register the model. Once registered, load the model from the Model Registry
and score using the :py:func:`predict <mlflow.pyfunc.PyFuncModel.predict>` function.

The code sections below demonstrate how to create a ``PythonFuncModel`` class with a ``vaderSentiment`` model embedded in it,
save, log, register, and load from the Model Registry and score.

.. note::
To use this example, you will need to ``pip install vaderSentiment``.

.. code-block:: py
from sys import version_info
import cloudpickle
import pandas as pd
import mlflow.pyfunc
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
#
# Good and readable paper from the authors of this package
# http://comp.social.gatech.edu/papers/icwsm14.vader.hutto.pdf
#
INPUT_TEXTS = [{'text': "This is a bad movie. You don't want to see it! :-)"},
{'text': "Ricky Gervais is smart, witty, and creative!!!!!! :D"},
{'text': "LOL, this guy fell off a chair while sleeping and snoring in a meeting"},
{'text': "Men shoots himself while trying to steal a dog, OMG"},
{'text': "Yay!! Another good phone interview. I nailed it!!"},
{'text': "This is INSANE! I can't believe it. How could you do such a horrible thing?"}]
PYTHON_VERSION = "{major}.{minor}.{micro}".format(major=version_info.major,
minor=version_info.minor,
micro=version_info.micro)
def score_model(model):
# Use inference to predict output from the customized PyFunc model
for i, text in enumerate(INPUT_TEXTS):
text = INPUT_TEXTS[i]['text']
m_input = pd.DataFrame([text])
scores = loaded_model.predict(m_input)
print(f"<{text}> -- {str(scores[0])}")
# Define a class and extend from PythonModel
class SocialMediaAnalyserModel(mlflow.pyfunc.PythonModel):
def __init__(self):
super().__init__()
# embed your vader model instance
self._analyser = SentimentIntensityAnalyzer()
# preprocess the input with prediction from the vader sentiment model
def _score(self, txt):
prediction_scores = self._analyser.polarity_scores(txt)
return prediction_scores
def predict(self, context, model_input):
# Apply the preprocess function from the vader model to score
model_output = model_input.apply(lambda col: self._score(col))
return model_output
model_path = "vader"
reg_model_name = "PyFuncVaderSentiments"
vader_model = SocialMediaAnalyserModel()
# Set the tracking URI to use local SQLAlchemy db file and start the run
# Log MLflow entities and save the model
mlflow.set_tracking_uri("sqlite:///mlruns.db")
# Save the conda environment for this model.
conda_env = {
'channels': ['defaults', 'conda-forge'],
'dependencies': [
'python={}'.format(PYTHON_VERSION),
'pip'],
'pip': [
'mlflow',
'cloudpickle=={}'.format(cloudpickle.__version__),
'vaderSentiment==3.3.2'
],
'name': 'mlflow-env'
}
# Save the model
with mlflow.start_run(run_name="Vader Sentiment Analysis") as run:
model_path = f"{model_path}-{run.info.run_uuid}"
mlflow.log_param("algorithm", "VADER")
mlflow.log_param("total_sentiments", len(INPUT_TEXTS))
mlflow.pyfunc.save_model(path=model_path, python_model=vader_model, conda_env=conda_env)
# Use the saved model path to log and register into the model registry
mlflow.pyfunc.log_model(artifact_path=model_path,
python_model=vader_model,
registered_model_name=reg_model_name,
conda_env=conda_env)
# Load the model from the model registry and score
model_uri = f"models:/{reg_model_name}/1"
loaded_model = mlflow.pyfunc.load_model(model_uri)
score_model(loaded_model)
.. code-block:: text
Successfully registered model 'PyFuncVaderSentiments'.
2021/04/05 10:34:15 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.
Created version '1' of model 'PyFuncVaderSentiments'.
<This is a bad movie. You don't want to see it! :-)> -- {'neg': 0.307, 'neu': 0.552, 'pos': 0.141, 'compound': -0.4047}
<Ricky Gervais is smart, witty, and creative!!!!!! :D> -- {'neg': 0.0, 'neu': 0.316, 'pos': 0.684, 'compound': 0.8957}
<LOL, this guy fell off a chair while sleeping and snoring in a meeting> -- {'neg': 0.0, 'neu': 0.786, 'pos': 0.214, 'compound': 0.5473}
<Men shoots himself while trying to steal a dog, OMG> -- {'neg': 0.262, 'neu': 0.738, 'pos': 0.0, 'compound': -0.4939}
<Yay!! Another good phone interview. I nailed it!!> -- {'neg': 0.0, 'neu': 0.446, 'pos': 0.554, 'compound': 0.816}
<This is INSANE! I can't believe it. How could you do such a horrible thing?> -- {'neg': 0.357, 'neu': 0.643, 'pos': 0.0, 'compound': -0.8034}

0 comments on commit c635d1a

Please sign in to comment.