import streamlit as st from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer # Load the fine-tuned model and tokenizer @st.cache_resource # Cache model to avoid reloading def load_model(): model_directory = "Vaishu16/QC_fine_tuned_model" model = AutoModelForCausalLM.from_pretrained(model_directory) tokenizer = AutoTokenizer.from_pretrained(model_directory) return model, tokenizer # Load model and tokenizer model, tokenizer = load_model() # Create a pipeline question_completion_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=-1 ) # Suggested partial questions suggested_questions = [ "What is the impact of", "How does the company report", "What are the financial risks of", "Explain the corporate governance of", "What was the revenue growth of" ] # Sidebar Settings st.sidebar.header("🔧 Settings") max_length = st.sidebar.slider("Max Length", 30, 100, 60, step=5) num_return_sequences = st.sidebar.slider("Number of Completions", 1, 5, 3) sampling = st.sidebar.checkbox("Enable Sampling", value=True) # Main UI st.title("🤖 Question Completion Model") st.write("Enter a partial question related to **financial statements, corporate governance, or company reports**, and the model will intelligently complete it!") # Session state for input question if "partial_question" not in st.session_state: st.session_state.partial_question = "" # Function to update input when a suggestion is clicked def update_input(selected_question): st.session_state.partial_question = selected_question st.rerun() # Display suggested partial questions in sidebar st.sidebar.subheader("💡 Suggested Partial Questions") for i, question in enumerate(suggested_questions): if st.sidebar.button(question, key=f"suggested_{i}"): update_input(question) # Text input box st.subheader("✏️ Enter Partial Question") partial_question = st.text_input("Enter a partial question:", st.session_state.partial_question) # Generate button if st.button("🚀 Complete Question"): if partial_question.strip(): with st.spinner("⏳ Generating completed questions... Please wait!"): outputs = question_completion_pipeline( partial_question, max_length=max_length, num_return_sequences=num_return_sequences, do_sample=sampling, truncation=True ) st.subheader("✅ Completed Questions") completed_texts = [output["generated_text"] for output in outputs] # Display completed questions inside an expander with st.expander("🔍 View Generated Questions"): for idx, completed_question in enumerate(completed_texts): st.markdown(f"**{idx+1}. {completed_question}**") # Copy option st.text_area("📋 Copy Completed Questions:", "\n".join(completed_texts), height=150) else: st.warning("⚠️ Please enter a partial question.")