-
-
Notifications
You must be signed in to change notification settings - Fork 75
/
minichain.py
70 lines (55 loc) 路 2.47 KB
/
minichain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Any, Callable, Dict, Iterable, Optional
from spacy.util import SimpleFrozenDict
from ..compat import has_minichain, minichain
from ..registry import registry
def _check_installation() -> None:
"""Checks whether `minichain` is installed. Raises an error otherwise."""
if not has_minichain:
raise ValueError(
"The MiniChain backend requires `minichain` to be installed, which it is not. See "
"https://github.com/srush/MiniChain for installation instructions."
)
@registry.llm_queries("spacy.RunMiniChain.v1")
def query_minichain() -> (
Callable[["minichain.backend.Backend", Iterable[str]], Iterable[str]]
):
"""Returns query Callable for MiniChain.
RETURNS (Callable[["minichain.backend.Backend", Iterable[str]], Iterable[str]]): Callable executing simple prompts
on the specified MiniChain backend.
"""
def prompt(
backend: "minichain.backend.Backend", prompts: Iterable[str]
) -> Iterable[str]:
@minichain.prompt(backend)
def _prompt(model: "minichain.base.Prompt.Model", prompt_text: str) -> str:
return model(prompt_text)
return [_prompt(pr).run() for pr in prompts]
return prompt
@registry.llm_backends("spacy.MiniChain.v1")
def backend_minichain(
api: str,
query: Optional[
Callable[["minichain.backend.Backend", Iterable[str]], Iterable[str]]
] = None,
config: Dict[Any, Any] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns Callable using MiniChain backend to prompt specified API.
api (str): Name of any API/class in minichain.backend, e. g. "OpenAI".
query (Callable[["minichain.backend.Backend", Iterable[str]], Iterable[str]]): Callable implementing querying this
API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the minichain.Backend
instance.
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Callable querying the specified API using the
specified backend.
"""
_check_installation()
if hasattr(minichain.backend, api):
backend = getattr(minichain.backend, api)(**config)
query_fn = query_minichain() if query is None else query
def _query(prompts: Iterable[str]) -> Iterable[str]:
return query_fn(backend, prompts)
return _query
else:
raise KeyError(
f"The requested API {api} is not available in `minichain.backend`."
)