Spaces:
Sleeping
Sleeping
import os | |
from time import time | |
import huggingface_hub | |
import streamlit as st | |
from config import config | |
from functioncall import ModelInference | |
def init_llm(): | |
huggingface_hub.login(token=config.hf_token, new_session=False) | |
llm = ModelInference(chat_template=config.chat_template) | |
return llm | |
def function_agent(prompt): | |
try: | |
return llm.generate_function_call( | |
prompt, config.chat_template, config.num_fewshot, config.max_depth | |
) | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
def output_agent(context, user_input): | |
"""Takes the output of the RAG and generates a final response.""" | |
try: | |
config.status.update(label=":bulb: Preparing answer..") | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml") | |
prompt_schema = llm.prompter.read_yaml_file(prompt_path) | |
sys_prompt = ( | |
llm.prompter.format_yaml_prompt(prompt_schema, dict()) | |
+ f"Information:\n{context}" | |
) | |
convo = [ | |
{"role": "system", "content": sys_prompt}, | |
{"role": "user", "content": user_input}, | |
] | |
response = llm.run_inference(convo) | |
return response | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
def query_agent(prompt): | |
"""Modifies the prompt and runs inference on it.""" | |
try: | |
config.status.update(label=":brain: Starting inference..") | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml") | |
prompt_schema = llm.prompter.read_yaml_file(prompt_path) | |
sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) | |
convo = [ | |
{"role": "system", "content": sys_prompt}, | |
{"role": "user", "content": prompt}, | |
] | |
response = llm.run_inference(convo) | |
return response | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
def get_response(input_text: str): | |
"""This is the main function that generates the final response.""" | |
agent_resp = function_agent(input_text) | |
output = output_agent(agent_resp, input_text) | |
return output | |
def main(): | |
st.title("LLM-ADE 9B Demo") | |
input_text = st.text_area("Enter your text here:", value="", height=200) | |
if st.button("Generate"): | |
if input_text: | |
with st.status("Generating response...") as status: | |
config.status = status | |
st.write(get_response(input_text)) | |
config.status.update(label="Finished!", state="complete", expanded=True) | |
else: | |
st.warning("Please enter some text to generate a response.") | |
def main_headless(prompt: str): | |
start = time() | |
print("\033[94m" + get_response(prompt) + "\033[0m") | |
print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20) | |
llm = init_llm() | |
if __name__ == "__main__": | |
if config.headless: | |
import fire | |
fire.Fire(main_headless) | |
else: | |
main() | |