Vaishu16's picture
Update app.py
9e160b4 verified
import streamlit as st
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
# Load the fine-tuned model and tokenizer
@st.cache_resource # Cache model to avoid reloading
def load_model():
model_directory = "Vaishu16/QC_fine_tuned_model"
model = AutoModelForCausalLM.from_pretrained(model_directory)
tokenizer = AutoTokenizer.from_pretrained(model_directory)
return model, tokenizer
# Load model and tokenizer
model, tokenizer = load_model()
# Create a pipeline
question_completion_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=-1
)
# Suggested partial questions
suggested_questions = [
"What is the impact of",
"How does the company report",
"What are the financial risks of",
"Explain the corporate governance of",
"What was the revenue growth of"
]
# Sidebar Settings
st.sidebar.header("πŸ”§ Settings")
max_length = st.sidebar.slider("Max Length", 30, 100, 60, step=5)
num_return_sequences = st.sidebar.slider("Number of Completions", 1, 5, 3)
sampling = st.sidebar.checkbox("Enable Sampling", value=True)
# Main UI
st.title("πŸ€– Question Completion Model")
st.write("Enter a partial question related to **financial statements, corporate governance, or company reports**, and the model will intelligently complete it!")
# Session state for input question
if "partial_question" not in st.session_state:
st.session_state.partial_question = ""
# Function to update input when a suggestion is clicked
def update_input(selected_question):
st.session_state.partial_question = selected_question
st.rerun()
# Display suggested partial questions in sidebar
st.sidebar.subheader("πŸ’‘ Suggested Partial Questions")
for i, question in enumerate(suggested_questions):
if st.sidebar.button(question, key=f"suggested_{i}"):
update_input(question)
# Text input box
st.subheader("✏️ Enter Partial Question")
partial_question = st.text_input("Enter a partial question:", st.session_state.partial_question)
# Generate button
if st.button("πŸš€ Complete Question"):
if partial_question.strip():
with st.spinner("⏳ Generating completed questions... Please wait!"):
outputs = question_completion_pipeline(
partial_question,
max_length=max_length,
num_return_sequences=num_return_sequences,
do_sample=sampling,
truncation=True
)
st.subheader("βœ… Completed Questions")
completed_texts = [output["generated_text"] for output in outputs]
# Display completed questions inside an expander
with st.expander("πŸ” View Generated Questions"):
for idx, completed_question in enumerate(completed_texts):
st.markdown(f"**{idx+1}. {completed_question}**")
# Copy option
st.text_area("πŸ“‹ Copy Completed Questions:", "\n".join(completed_texts), height=150)
else:
st.warning("⚠️ Please enter a partial question.")