|
from __future__ import annotations |
|
|
|
import os |
|
|
|
from gpt4all import GPT4All |
|
from .models import get_models |
|
from ..typing import Messages |
|
|
|
MODEL_LIST: dict[str, dict] = None |
|
|
|
def find_model_dir(model_file: str) -> str: |
|
local_dir = os.path.dirname(os.path.abspath(__file__)) |
|
project_dir = os.path.dirname(os.path.dirname(local_dir)) |
|
|
|
new_model_dir = os.path.join(project_dir, "models") |
|
new_model_file = os.path.join(new_model_dir, model_file) |
|
if os.path.isfile(new_model_file): |
|
return new_model_dir |
|
|
|
old_model_dir = os.path.join(local_dir, "models") |
|
old_model_file = os.path.join(old_model_dir, model_file) |
|
if os.path.isfile(old_model_file): |
|
return old_model_dir |
|
|
|
working_dir = "./" |
|
for root, dirs, files in os.walk(working_dir): |
|
if model_file in files: |
|
return root |
|
|
|
return new_model_dir |
|
|
|
class LocalProvider: |
|
@staticmethod |
|
def create_completion(model: str, messages: Messages, stream: bool = False, **kwargs): |
|
global MODEL_LIST |
|
if MODEL_LIST is None: |
|
MODEL_LIST = get_models() |
|
if model not in MODEL_LIST: |
|
raise ValueError(f'Model "{model}" not found / not yet implemented') |
|
|
|
model = MODEL_LIST[model] |
|
model_file = model["path"] |
|
model_dir = find_model_dir(model_file) |
|
if not os.path.isfile(os.path.join(model_dir, model_file)): |
|
print(f'Model file "models/{model_file}" not found.') |
|
download = input(f"Do you want to download {model_file}? [y/n]: ") |
|
if download in ["y", "Y"]: |
|
GPT4All.download_model(model_file, model_dir) |
|
else: |
|
raise ValueError(f'Model "{model_file}" not found.') |
|
|
|
model = GPT4All(model_name=model_file, |
|
|
|
verbose=False, |
|
allow_download=False, |
|
model_path=model_dir) |
|
|
|
system_message = "\n".join(message["content"] for message in messages if message["role"] == "system") |
|
if system_message: |
|
system_message = "A chat between a curious user and an artificial intelligence assistant." |
|
|
|
prompt_template = "USER: {0}\nASSISTANT: " |
|
conversation = "\n" . join( |
|
f"{message['role'].upper()}: {message['content']}" |
|
for message in messages |
|
if message["role"] != "system" |
|
) + "\nASSISTANT: " |
|
|
|
def should_not_stop(token_id: int, token: str): |
|
return "USER" not in token |
|
|
|
with model.chat_session(system_message, prompt_template): |
|
if stream: |
|
for token in model.generate(conversation, streaming=True, callback=should_not_stop): |
|
yield token |
|
else: |
|
yield model.generate(conversation, callback=should_not_stop) |