Spaces:
Sleeping
Sleeping
WilliamGazeley
commited on
Commit
·
9e2a95f
1
Parent(s):
1b19641
Migrate to Ollama
Browse files- src/app.py +27 -21
- src/config.py +8 -3
- src/functioncall.py +14 -42
src/app.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import os
|
|
|
2 |
import huggingface_hub
|
3 |
import streamlit as st
|
4 |
from config import config
|
5 |
from utils import get_assistant_message
|
6 |
from functioncall import ModelInference
|
7 |
-
from prompter import PromptManager
|
8 |
|
9 |
-
print("Why, hello there!", flush=True)
|
10 |
|
11 |
@st.cache_resource(show_spinner="Loading model..")
|
12 |
def init_llm():
|
@@ -14,40 +13,44 @@ def init_llm():
|
|
14 |
llm = ModelInference(chat_template=config.chat_template)
|
15 |
return llm
|
16 |
|
|
|
17 |
def get_response(prompt):
|
18 |
try:
|
19 |
return llm.generate_function_call(
|
20 |
-
prompt,
|
21 |
-
config.chat_template,
|
22 |
-
config.num_fewshot,
|
23 |
-
config.max_depth
|
24 |
)
|
25 |
except Exception as e:
|
26 |
return f"An error occurred: {str(e)}"
|
27 |
-
|
|
|
28 |
def get_output(context, user_input):
|
29 |
try:
|
30 |
config.status.update(label=":bulb: Preparing answer..")
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
convo = [
|
35 |
{"role": "system", "content": sys_prompt},
|
36 |
{"role": "user", "content": user_input},
|
37 |
]
|
38 |
response = llm.run_inference(convo)
|
39 |
-
return
|
40 |
except Exception as e:
|
41 |
return f"An error occurred: {str(e)}"
|
42 |
|
|
|
43 |
def main():
|
44 |
st.title("LLM-ADE 9B Demo")
|
45 |
-
|
46 |
input_text = st.text_area("Enter your text here:", value="", height=200)
|
47 |
-
|
48 |
if st.button("Generate"):
|
49 |
if input_text:
|
50 |
-
with st.status(
|
51 |
config.status = status
|
52 |
agent_resp = get_response(input_text)
|
53 |
st.write(get_output(agent_resp, input_text))
|
@@ -55,17 +58,20 @@ def main():
|
|
55 |
else:
|
56 |
st.warning("Please enter some text to generate a response.")
|
57 |
|
|
|
58 |
llm = init_llm()
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
|
66 |
if __name__ == "__main__":
|
67 |
-
print(f"Test env vars: {os.getenv('TEST_SECRET')}")
|
68 |
if config.headless:
|
69 |
-
|
|
|
70 |
else:
|
71 |
main()
|
|
|
1 |
import os
|
2 |
+
from time import time
|
3 |
import huggingface_hub
|
4 |
import streamlit as st
|
5 |
from config import config
|
6 |
from utils import get_assistant_message
|
7 |
from functioncall import ModelInference
|
|
|
8 |
|
|
|
9 |
|
10 |
@st.cache_resource(show_spinner="Loading model..")
|
11 |
def init_llm():
|
|
|
13 |
llm = ModelInference(chat_template=config.chat_template)
|
14 |
return llm
|
15 |
|
16 |
+
|
17 |
def get_response(prompt):
|
18 |
try:
|
19 |
return llm.generate_function_call(
|
20 |
+
prompt, config.chat_template, config.num_fewshot, config.max_depth
|
|
|
|
|
|
|
21 |
)
|
22 |
except Exception as e:
|
23 |
return f"An error occurred: {str(e)}"
|
24 |
+
|
25 |
+
|
26 |
def get_output(context, user_input):
|
27 |
try:
|
28 |
config.status.update(label=":bulb: Preparing answer..")
|
29 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
30 |
+
prompt_path = os.path.join(script_dir, 'prompt_assets', 'output_sys_prompt.yml')
|
31 |
+
prompt_schema = llm.prompter.read_yaml_file(prompt_path)
|
32 |
+
sys_prompt = (
|
33 |
+
llm.prompter.format_yaml_prompt(prompt_schema, dict())
|
34 |
+
+ f"Information:\n{context}"
|
35 |
+
)
|
36 |
convo = [
|
37 |
{"role": "system", "content": sys_prompt},
|
38 |
{"role": "user", "content": user_input},
|
39 |
]
|
40 |
response = llm.run_inference(convo)
|
41 |
+
return response
|
42 |
except Exception as e:
|
43 |
return f"An error occurred: {str(e)}"
|
44 |
|
45 |
+
|
46 |
def main():
|
47 |
st.title("LLM-ADE 9B Demo")
|
48 |
+
|
49 |
input_text = st.text_area("Enter your text here:", value="", height=200)
|
50 |
+
|
51 |
if st.button("Generate"):
|
52 |
if input_text:
|
53 |
+
with st.status("Generating response...") as status:
|
54 |
config.status = status
|
55 |
agent_resp = get_response(input_text)
|
56 |
st.write(get_output(agent_resp, input_text))
|
|
|
58 |
else:
|
59 |
st.warning("Please enter some text to generate a response.")
|
60 |
|
61 |
+
|
62 |
llm = init_llm()
|
63 |
|
64 |
+
|
65 |
+
def main_headless(prompt: str):
|
66 |
+
start = time()
|
67 |
+
agent_resp = get_response(prompt)
|
68 |
+
print("\033[94m" + get_output(agent_resp, prompt) + "\033[0m")
|
69 |
+
print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)
|
70 |
+
|
71 |
|
72 |
if __name__ == "__main__":
|
|
|
73 |
if config.headless:
|
74 |
+
import fire
|
75 |
+
fire.Fire(main_headless)
|
76 |
else:
|
77 |
main()
|
src/config.py
CHANGED
@@ -2,12 +2,18 @@ from pydantic import Field
|
|
2 |
from pydantic_settings import BaseSettings
|
3 |
from typing import Dict, Any
|
4 |
|
|
|
|
|
|
|
|
|
|
|
5 |
class Config(BaseSettings):
|
6 |
hf_token: str = Field(...)
|
7 |
-
hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
|
|
|
8 |
headless: bool = Field(False, description="Run in headless mode.")
|
9 |
|
10 |
-
status: Any =
|
11 |
|
12 |
az_search_endpoint: str = Field("https://analysis-bank.search.windows.net")
|
13 |
az_search_api_key: str = Field(...)
|
@@ -17,7 +23,6 @@ class Config(BaseSettings):
|
|
17 |
|
18 |
chat_template: str = Field("chatml", description="Chat template for prompt formatting")
|
19 |
num_fewshot: int | None = Field(None, description="Option to use json mode examples")
|
20 |
-
load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
|
21 |
max_depth: int = Field(3, description="Maximum number of recursive iteration")
|
22 |
|
23 |
config = Config(_env_file=".env")
|
|
|
2 |
from pydantic_settings import BaseSettings
|
3 |
from typing import Dict, Any
|
4 |
|
5 |
+
class MockStatus():
|
6 |
+
# Required for headless mode
|
7 |
+
def update(self, *args, **kwargs):
|
8 |
+
print("MockStatus update called with args: ", args, " and kwargs: ", kwargs)
|
9 |
+
|
10 |
class Config(BaseSettings):
|
11 |
hf_token: str = Field(...)
|
12 |
+
hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev") # We need this because I can't get the model template out of the ollama model
|
13 |
+
ollama_model: str = Field("llama3")
|
14 |
headless: bool = Field(False, description="Run in headless mode.")
|
15 |
|
16 |
+
status: Any = MockStatus()
|
17 |
|
18 |
az_search_endpoint: str = Field("https://analysis-bank.search.windows.net")
|
19 |
az_search_api_key: str = Field(...)
|
|
|
23 |
|
24 |
chat_template: str = Field("chatml", description="Chat template for prompt formatting")
|
25 |
num_fewshot: int | None = Field(None, description="Option to use json mode examples")
|
|
|
26 |
max_depth: int = Field(3, description="Maximum number of recursive iteration")
|
27 |
|
28 |
config = Config(_env_file=".env")
|
src/functioncall.py
CHANGED
@@ -13,6 +13,7 @@ from transformers import (
|
|
13 |
import functions
|
14 |
from prompter import PromptManager
|
15 |
from validator import validate_function_call_schema
|
|
|
16 |
|
17 |
from utils import (
|
18 |
inference_logger,
|
@@ -22,26 +23,12 @@ from utils import (
|
|
22 |
)
|
23 |
|
24 |
class ModelInference:
|
25 |
-
def __init__(self, chat_template: str
|
26 |
self.prompter = PromptManager()
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
load_in_4bit=True,
|
32 |
-
bnb_4bit_quant_type="nf4",
|
33 |
-
bnb_4bit_use_double_quant=True,
|
34 |
-
)
|
35 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
36 |
-
config.hf_model,
|
37 |
-
trust_remote_code=True,
|
38 |
-
return_dict=True,
|
39 |
-
quantization_config=self.bnb_config,
|
40 |
-
torch_dtype=torch.float16,
|
41 |
-
attn_implementation="flash_attention_2",
|
42 |
-
device_map="auto",
|
43 |
-
)
|
44 |
-
|
45 |
self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
|
46 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
47 |
self.tokenizer.padding_side = "left"
|
@@ -49,24 +36,18 @@ class ModelInference:
|
|
49 |
if self.tokenizer.chat_template is None:
|
50 |
print("No chat template defined, getting chat_template...")
|
51 |
self.tokenizer.chat_template = get_chat_template(chat_template)
|
52 |
-
|
53 |
-
inference_logger.info(self.model.config)
|
54 |
-
inference_logger.info(self.model.generation_config)
|
55 |
-
inference_logger.info(self.tokenizer.special_tokens_map)
|
56 |
|
57 |
-
def process_completion_and_validate(self, completion, chat_template):
|
58 |
-
|
59 |
-
assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
|
60 |
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
if validation:
|
65 |
inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
|
66 |
-
return tool_calls,
|
67 |
else:
|
68 |
tool_calls = None
|
69 |
-
return tool_calls,
|
70 |
else:
|
71 |
inference_logger.warning("Assistant message is None")
|
72 |
raise ValueError("Assistant message is None")
|
@@ -86,19 +67,10 @@ class ModelInference:
|
|
86 |
inputs = self.tokenizer.apply_chat_template(
|
87 |
prompt,
|
88 |
add_generation_prompt=True,
|
89 |
-
|
90 |
-
)
|
91 |
-
|
92 |
-
tokens = self.model.generate(
|
93 |
-
inputs.to(self.model.device),
|
94 |
-
max_new_tokens=1500,
|
95 |
-
temperature=0.8,
|
96 |
-
repetition_penalty=1.2,
|
97 |
-
do_sample=True,
|
98 |
-
eos_token_id=self.tokenizer.eos_token_id
|
99 |
)
|
100 |
-
completion = self.
|
101 |
-
return completion
|
102 |
|
103 |
def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
|
104 |
try:
|
|
|
13 |
import functions
|
14 |
from prompter import PromptManager
|
15 |
from validator import validate_function_call_schema
|
16 |
+
from langchain_community.chat_models import ChatOllama
|
17 |
|
18 |
from utils import (
|
19 |
inference_logger,
|
|
|
23 |
)
|
24 |
|
25 |
class ModelInference:
|
26 |
+
def __init__(self, chat_template: str):
|
27 |
self.prompter = PromptManager()
|
28 |
+
|
29 |
+
self.model = ChatOllama(model=config.ollama_model,
|
30 |
+
temperature=0.0, format='json')
|
31 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
|
33 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
34 |
self.tokenizer.padding_side = "left"
|
|
|
36 |
if self.tokenizer.chat_template is None:
|
37 |
print("No chat template defined, getting chat_template...")
|
38 |
self.tokenizer.chat_template = get_chat_template(chat_template)
|
|
|
|
|
|
|
|
|
39 |
|
|
|
|
|
|
|
40 |
|
41 |
+
def process_completion_and_validate(self, completion, chat_template):
|
42 |
+
if completion:
|
43 |
+
validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)
|
44 |
|
45 |
if validation:
|
46 |
inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
|
47 |
+
return tool_calls, completion, error_message
|
48 |
else:
|
49 |
tool_calls = None
|
50 |
+
return tool_calls, completion, error_message
|
51 |
else:
|
52 |
inference_logger.warning("Assistant message is None")
|
53 |
raise ValueError("Assistant message is None")
|
|
|
67 |
inputs = self.tokenizer.apply_chat_template(
|
68 |
prompt,
|
69 |
add_generation_prompt=True,
|
70 |
+
tokenize=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
)
|
72 |
+
completion = self.model.invoke(inputs, format='json')
|
73 |
+
return completion.content
|
74 |
|
75 |
def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
|
76 |
try:
|