-
Notifications
You must be signed in to change notification settings - Fork 10
/
mpt-7b-instruct-quantized.py
111 lines (88 loc) · 4.13 KB
/
mpt-7b-instruct-quantized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Databricks notebook source
!pip install accelerate
!pip install bitsandbytes
dbutils.library.restartPython()
# COMMAND ----------
import pandas as pd
import numpy as np
import transformers
import accelerate
import mlflow
import torch
# COMMAND ----------
from huggingface_hub import snapshot_download
# Download the MPT model snapshot from huggingface
snapshot_location = snapshot_download(repo_id="mosaicml/mpt-7b-instruct")
# COMMAND ----------
class MPTQuantized(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""
This method initializes the tokenizer and language model
using the specified model repository.
"""
# Initialize tokenizer and language model
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
context.artifacts['repository'],
padding_side="left")
self.model = transformers.AutoModelForCausalLM.from_pretrained(
context.artifacts['repository'],
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
load_in_8bit=True,
trust_remote_code=True)
def predict(self, context, model_input):
"""
This method generates prediction for the given input.
"""
prompt = model_input["prompt"][0]
temperature = model_input.get("temperature", [1.0])[0]
max_tokens = model_input.get("max_tokens", [100])[0]
# Encode the input and generate prediction
encoded_input = self.tokenizer.encode(prompt, return_tensors='pt').to("cuda")
output = self.model.generate(encoded_input, do_sample=True, temperature=temperature, max_new_tokens=max_tokens)
# Decode the prediction to text
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
# Removing the prompt from the generated text
prompt_length = len(self.tokenizer.encode(prompt, return_tensors='pt')[0])
generated_response = self.tokenizer.decode(output[0][prompt_length:], skip_special_tokens=True)
return generated_response
# COMMAND ----------
from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, ColSpec
# Define input and output schema
input_schema = Schema([
ColSpec(DataType.string, "prompt"),
ColSpec(DataType.double, "temperature"),
ColSpec(DataType.long, "max_tokens")])
output_schema = Schema([ColSpec(DataType.string)])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
# Define input example
input_example=pd.DataFrame({
"prompt":["what is ML?"],
"temperature": [0.5],
"max_tokens": [100]})
# Log the model with its details such as artifacts, pip requirements and input example
with mlflow.start_run() as run:
mlflow.pyfunc.log_model(
"model",
python_model=MPTQuantized(),
artifacts={'repository' : snapshot_location},
pip_requirements=["torch", "transformers", "accelerate", "bitsandbytes", "einops", "sentencepiece"],
input_example=pd.DataFrame({"prompt":["what is ML?"], "temperature": [0.5],"max_tokens": [100]}),
signature=signature
)
# COMMAND ----------
# Register model in MLflow Model Registry
result = mlflow.register_model(
"runs:/"+run.info.run_id+"/model",
"mpt-7b-instruct-quantized"
)
# Note: Due to the large size of the model, the registration process might take longer than the default maximum wait time of 300 seconds. MLflow could throw an exception indicating that the max wait time has been exceeded. Don't worry if this happens - it's not necessarily an error. Instead, you can confirm the registration status of the model by directly checking the model registry. This exception is merely a time-out notification and does not necessarily imply a failure in the registration process.
# COMMAND ----------
# Load the logged model
loaded_model = mlflow.pyfunc.load_model(f"models:/mpt-7b-instruct-quantized/latest")
# COMMAND ----------
# Make a prediction using the loaded model
input_example=pd.DataFrame({"prompt":["what is ML?"], "temperature": [0.5],"max_tokens": [100]})
loaded_model.predict(input_example)