Spaces:
Running
Running
import streamlit as st | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
# Load the fine-tuned model and tokenizer | |
# Cache model to avoid reloading | |
def load_model(): | |
model_directory = "Vaishu16/QC_fine_tuned_model" | |
model = AutoModelForCausalLM.from_pretrained(model_directory) | |
tokenizer = AutoTokenizer.from_pretrained(model_directory) | |
return model, tokenizer | |
# Load model and tokenizer | |
model, tokenizer = load_model() | |
# Create a pipeline | |
question_completion_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=-1 | |
) | |
# Suggested partial questions | |
suggested_questions = [ | |
"What is the impact of", | |
"How does the company report", | |
"What are the financial risks of", | |
"Explain the corporate governance of", | |
"What was the revenue growth of" | |
] | |
# Streamlit UI | |
st.title("Question Completion Model") | |
st.write("Enter a partial question related to financial statements, corporate governance, or company reports, and the app will intelligently complete it based on learned patterns!") | |
# Session state for input question | |
if "partial_question" not in st.session_state: | |
st.session_state.partial_question = "" | |
# Function to update input when a suggestion is clicked | |
def update_input(selected_question): | |
st.session_state.partial_question = selected_question | |
st.rerun() | |
# Display suggested partial questions | |
st.write("### Suggested Partial Questions:") | |
cols = st.columns(5) | |
for i, question in enumerate(suggested_questions): | |
if cols[i % 5].button(question, key=f"suggested_{i}"): # Place buttons in columns | |
update_input(question) | |
# Text input box | |
partial_question = st.text_input("Enter a partial question:", st.session_state.partial_question) | |
# Button to generate 3 completed questions | |
if st.button("Complete Question"): | |
if partial_question.strip(): | |
outputs = question_completion_pipeline( | |
partial_question, | |
max_length=60, | |
num_return_sequences=3, # Generate 3 different completions | |
do_sample=True, | |
truncation=True | |
) | |
st.write("### Completed Questions:") | |
for output in outputs: | |
completed_question = output["generated_text"] | |
st.markdown(f"**{completed_question}**") # Display non-clickable text | |
else: | |
st.warning("Please enter a partial question.") | |