Vaishu16's picture
Update app.py
67cac15 verified
raw
history blame
2.44 kB
import streamlit as st
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
# Load the fine-tuned model and tokenizer
@st.cache_resource # Cache model to avoid reloading
def load_model():
model_directory = "Vaishu16/QC_fine_tuned_model"
model = AutoModelForCausalLM.from_pretrained(model_directory)
tokenizer = AutoTokenizer.from_pretrained(model_directory)
return model, tokenizer
# Load model and tokenizer
model, tokenizer = load_model()
# Create a pipeline
question_completion_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=-1
)
# Suggested partial questions
suggested_questions = [
"What is the impact of",
"How does the company report",
"What are the financial risks of",
"Explain the corporate governance of",
"What was the revenue growth of"
]
# Streamlit UI
st.title("Question Completion Model")
st.write("Enter a partial question related to financial statements, corporate governance, or company reports, and the app will intelligently complete it based on learned patterns!")
# Session state for input question
if "partial_question" not in st.session_state:
st.session_state.partial_question = ""
# Function to update input when a suggestion is clicked
def update_input(selected_question):
st.session_state.partial_question = selected_question
st.rerun()
# Display suggested partial questions
st.write("### Suggested Partial Questions:")
cols = st.columns(5)
for i, question in enumerate(suggested_questions):
if cols[i % 5].button(question, key=f"suggested_{i}"): # Place buttons in columns
update_input(question)
# Text input box
partial_question = st.text_input("Enter a partial question:", st.session_state.partial_question)
# Button to generate 3 completed questions
if st.button("Complete Question"):
if partial_question.strip():
outputs = question_completion_pipeline(
partial_question,
max_length=60,
num_return_sequences=3, # Generate 3 different completions
do_sample=True,
truncation=True
)
st.write("### Completed Questions:")
for output in outputs:
completed_question = output["generated_text"]
st.markdown(f"**{completed_question}**") # Display non-clickable text
else:
st.warning("Please enter a partial question.")