-
Notifications
You must be signed in to change notification settings - Fork 0
/
learning-llm.py
155 lines (128 loc) · 4.82 KB
/
learning-llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import Any, List, Optional
from langchain.embeddings import LlamaCppEmbeddings
from langchain.vectorstores.chroma import Chroma
from langchain.chains import RetrievalQA
from langchain.chains.retrieval_qa.base import BaseRetrievalQA
from langchain.llms.llamacpp import LlamaCpp
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.language_models import LLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from sentence_type_classifier import SentenceTypeClassifier
from langchain.text_splitter import CharacterTextSplitter, SentenceTransformersTokenTextSplitter
from langchain.document_loaders import TextLoader
import os
from enum import Enum
import copy
class LocalModels(Enum):
MANTICORE = "manticore-f16_q4_0.bin"
LLAMA_13B = "llama-2-13b-q4_0.bin"
LLAMA_7B = "llama-2-7b-q4_0.bin"
parameter_sets = {
LocalModels.LLAMA_7B.value: {
'temperature': 0.7,
'max_tokens': 5000,
'n_batch': 512,
'n_ctx': 12000,
'n_gpu_layers': 1
},
LocalModels.LLAMA_13B.value: {
'temperature': 0.7,
'max_tokens': 5000,
'n_batch': 512,
'n_ctx': 4000,
'n_gpu_layers': 1,
'repeat_penalty': 1.3,
'verbose': False
}
}
class ChatAny(LLM):
llm: LlamaCpp
db: Chroma
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
result = []
print("Answer:")
for text in self.llm.stream(prompt):
if text.startswith('###'):
break
result.append(text)
print(text, end='')
return ' '.join(text)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "llama-2"
class Chatbot:
_model_location:str = None
_model_parameters:dict = None
_db:Chroma = None
_chain:BaseRetrievalQA = None
_classifier_path:str = None
_token_file:str = None
_classifier:SentenceTypeClassifier = None
def __init__(self, model_location:str, model_parameters:dict, classifier_path:str, token_file:str):
self._model_location = model_location
self._model_parameters = model_parameters
self._classifier_path = classifier_path
self._token_file = token_file
@property
def classifier(self):
if self._classifier is None:
self._classifier = SentenceTypeClassifier(self._classifier_path, self._token_file)
return self._classifier
@property
def chain(self):
if self._chain is None:
retriever = self.db.as_retriever(search_kwargs={'kuname': 5})
llama = LlamaCpp(model_path=self._model_location, top_k=4, **self._model_parameters)
llm = ChatAny(
llm=llama,
db=self.db,
sentence_classifier=SentenceTypeClassifier('/Users/davidmorton/Downloads/bert_sequence_classifier_question_statement_en_3/', '.token'))
retriever = VectorStoreRetriever(vectorstore=self.db)
self._chain = RetrievalQA.from_llm(
llm=llm,
retriever=retriever,
verbose=False
)
return self._chain
@property
def db(self):
if self._db is None:
parameters = copy.deepcopy(self._model_parameters)
allowed_tokens = list(LlamaCppEmbeddings.__fields__.keys())
for k in self._model_parameters:
if k not in allowed_tokens:
del parameters[k]
llama = LlamaCppEmbeddings(model_path=self._model_location, **parameters)
basename = os.path.basename(llama.model_path).split('.')[0]
chroma_path = f'./chromadb_learning/{basename}'
self._db = Chroma(persist_directory=chroma_path, embedding_function=llama, collection_metadata={"hnsw:space": "cosine"})
return self._db
def add_information(self, text):
self.db.add_texts([text])
def query(self, prompt:str):
classification_result = self.classifier.classify(prompt)
if classification_result == 1:
return self.chain.run(prompt)
self.add_information(prompt)
print('Okay. I\'ll remember that.')
return 'Okay. I\'ll remember that.'
def start_chat(self):
prompt = ''
while prompt.lower() not in ['bye','exit','quit','goodbye']:
if prompt.strip() != '':
self.query(prompt)
prompt = input('\n\n\nPrompt: ')
m = LocalModels.LLAMA_13B.value
chatbot:Chatbot = Chatbot(
os.path.join('/Users/davidmorton/source/models/', m),
parameter_sets[m],
"shahrukhx01/question-vs-statement-classifier",
'.token')
chatbot.start_chat()