ft-dpo-gradio / app.py
ivanvmoreno
wip: prompt static variables
1273fb2
import json
import os
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List
import gradio as gr
from dotenv import load_dotenv
from huggingface_hub import Repository
from langchain import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
from langchain.prompts import PromptTemplate
from utils import force_git_push
def replace_template(template: str, data: dict) -> str:
"""Replace template variables with data."""
for key, value in data.items():
template = template.replace(f"{{{key}}}", value)
return template
def json_to_dict(json_file: str) -> dict:
with open(json_file, "r") as f:
json_data = json.load(f)
return json_data
def generate_response(chatbot: ConversationChain, input: str, count=1) -> List[str]:
"""Generates responses for a `langchain` chatbot."""
return [chatbot.predict(input=input) for _ in range(count)]
def generate_responses(chatbots: List[ConversationChain], inputs: List[str]) -> List[str]:
"""Generates parallel responses for a list of `langchain` chatbots."""
results = []
with ThreadPoolExecutor(max_workers=100) as executor:
for result in executor.map(
generate_response,
chatbots,
inputs,
[NUM_RESPONSES] * len(inputs),
):
results += result
return results
if Path(".env").is_file():
load_dotenv(".env")
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
FORCE_PUSH = os.getenv("FORCE_PUSH")
HF_TOKEN = os.getenv("HF_TOKEN")
PROMPT_TEMPLATES = Path("prompt_templates")
NUM_RESPONSES = 3 # Number of responses to generate per interaction
DATA_FILENAME = "data.jsonl"
DATA_FILE = os.path.join("data", DATA_FILENAME)
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN)
TOTAL_CNT = 3 # How many user inputs to collect
PUSH_FREQUENCY = 60
def asynchronous_push(f_stop):
if repo.is_repo_clean():
print("Repo currently clean. Ignoring push_to_hub")
else:
repo.git_add(auto_lfs_track=True)
repo.git_commit("Auto commit by space")
if FORCE_PUSH == "yes":
force_git_push(repo)
else:
repo.git_push()
if not f_stop.is_set():
# call again in 60 seconds
threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start()
f_stop = threading.Event()
asynchronous_push(f_stop)
[input_vars, prompt_tpl] = json_to_dict(PROMPT_TEMPLATES / "prompt_01.json").values()
prompt_data = json_to_dict(PROMPT_TEMPLATES / "data_01.json")
prompt_tpl = replace_template(prompt_tpl, prompt_data)
prompt = PromptTemplate(template=prompt_tpl, input_variables=input_vars)
chatbot = ConversationChain(
llm=HuggingFaceHub(
repo_id="Open-Orca/Mistral-7B-OpenOrca",
model_kwargs={"temperature": 1},
huggingfacehub_api_token=HF_TOKEN,
),
prompt=prompt,
verbose=False,
memory=ConversationBufferMemory(ai_prefix="Assistant"),
)
demo = gr.Blocks()
with demo:
# We keep track of state as a JSON
state_dict = {
"conversation_id": str(uuid.uuid4()),
"cnt": 0,
"data": [],
"past_user_inputs": [],
"generated_responses": [],
}
state = gr.JSON(state_dict, visible=False)
gr.Markdown("# Talk to the assistant")
state_display = gr.Markdown(f"Your messages: 0/{TOTAL_CNT}")
# Generate model prediction
def _predict(txt, state):
start = time.time()
responses = generate_response(chatbot, txt, count=NUM_RESPONSES)
print(f"Time taken to generate {len(responses)} responses : {time.time() - start:.2f} seconds")
state["cnt"] += 1
metadata = {"cnt": state["cnt"], "text": txt}
for idx, response in enumerate(responses):
metadata[f"response_{idx + 1}"] = response
state["data"].append(metadata)
state["past_user_inputs"].append(txt)
past_conversation_string = "<br />".join(
[
"<br />".join(["Human 😃: " + user_input, "Assistant 🤖: " + model_response])
for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])
]
)
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True, choices=responses, interactive=True, value=responses[0]),
gr.update(value=past_conversation_string),
state,
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
def _select_response(selected_response, state):
done = state["cnt"] == TOTAL_CNT
state["generated_responses"].append(selected_response)
state["data"][-1]["selected_response"] = selected_response
if state["cnt"] == TOTAL_CNT:
with open(DATA_FILE, "a") as jsonlfile:
json_data_with_assignment_id = [
json.dumps(
dict(
{
"assignmentId": state["assignmentId"],
"conversation_id": state["conversation_id"],
},
**datum,
)
)
for datum in state["data"]
]
jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
toggle_example_submit = gr.update(visible=not done)
past_conversation_string = "<br />".join(
[
"<br />".join(["😃: " + user_input, "🤖: " + model_response])
for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"])
]
)
toggle_final_submit = gr.update(visible=False)
if done:
# Wipe the memory
chatbot.memory = ConversationBufferMemory(ai_prefix="Assistant")
else:
# Sync model's memory with the conversation path that
# was actually taken.
chatbot.memory = state["data"][-1][selected_response].memory
text_input = gr.update(visible=False) if done else gr.update(visible=True)
return (
gr.update(visible=False),
gr.update(visible=True),
text_input,
gr.update(visible=False),
state,
gr.update(value=past_conversation_string),
toggle_example_submit,
toggle_final_submit,
)
# Input fields
past_conversation = gr.Markdown()
text_input = gr.Textbox(placeholder="Enter a statement", show_label=False)
select_response = gr.Radio(
choices=[None, None],
visible=False,
label="Choose the most helpful and honest response",
)
select_response_button = gr.Button("Select Response", visible=False)
with gr.Column() as example_submit:
submit_ex_button = gr.Button("Submit")
with gr.Column(visible=False) as final_submit:
submit_hit_button = gr.Button("Submit HIT")
select_response_button.click(
_select_response,
inputs=[select_response, state],
outputs=[
select_response,
example_submit,
text_input,
select_response_button,
state,
past_conversation,
example_submit,
final_submit,
],
)
submit_ex_button.click(
_predict,
inputs=[text_input, state],
outputs=[
text_input,
select_response_button,
select_response,
past_conversation,
state,
example_submit,
final_submit,
state_display,
],
)
submit_hit_button.click(
lambda state: state,
inputs=[state],
outputs=[state],
)
demo.launch()