Spaces:
Running
Running
import streamlit as st | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
# Load the fine-tuned model and tokenizer | |
# Cache model to avoid reloading | |
def load_model(): | |
model_directory = "Vaishu16/QC_fine_tuned_model" | |
model = AutoModelForCausalLM.from_pretrained(model_directory) | |
tokenizer = AutoTokenizer.from_pretrained(model_directory) | |
return model, tokenizer | |
# Load model and tokenizer | |
model, tokenizer = load_model() | |
# Create a pipeline | |
question_completion_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=-1 | |
) | |
# Suggested partial questions | |
suggested_questions = [ | |
"What is the impact of", | |
"How does the company report", | |
"What are the financial risks of", | |
"Explain the corporate governance of", | |
"What was the revenue growth of" | |
] | |
# Sidebar Settings | |
st.sidebar.header("π§ Settings") | |
max_length = st.sidebar.slider("Max Length", 30, 100, 60, step=5) | |
num_return_sequences = st.sidebar.slider("Number of Completions", 1, 5, 3) | |
sampling = st.sidebar.checkbox("Enable Sampling", value=True) | |
# Main UI | |
st.title("π€ Question Completion Model") | |
st.write("Enter a partial question related to **financial statements, corporate governance, or company reports**, and the model will intelligently complete it!") | |
# Session state for input question | |
if "partial_question" not in st.session_state: | |
st.session_state.partial_question = "" | |
# Function to update input when a suggestion is clicked | |
def update_input(selected_question): | |
st.session_state.partial_question = selected_question | |
st.rerun() | |
# Display suggested partial questions in sidebar | |
st.sidebar.subheader("π‘ Suggested Partial Questions") | |
for i, question in enumerate(suggested_questions): | |
if st.sidebar.button(question, key=f"suggested_{i}"): | |
update_input(question) | |
# Text input box | |
st.subheader("βοΈ Enter Partial Question") | |
partial_question = st.text_input("Enter a partial question:", st.session_state.partial_question) | |
# Generate button | |
if st.button("π Complete Question"): | |
if partial_question.strip(): | |
with st.spinner("β³ Generating completed questions... Please wait!"): | |
outputs = question_completion_pipeline( | |
partial_question, | |
max_length=max_length, | |
num_return_sequences=num_return_sequences, | |
do_sample=sampling, | |
truncation=True | |
) | |
st.subheader("β Completed Questions") | |
completed_texts = [output["generated_text"] for output in outputs] | |
# Display completed questions inside an expander | |
with st.expander("π View Generated Questions"): | |
for idx, completed_question in enumerate(completed_texts): | |
st.markdown(f"**{idx+1}. {completed_question}**") | |
# Copy option | |
st.text_area("π Copy Completed Questions:", "\n".join(completed_texts), height=150) | |
else: | |
st.warning("β οΈ Please enter a partial question.") | |