import streamlit as st from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer # Load the fine-tuned model and tokenizer @st.cache_resource # Cache model to avoid reloading def load_model(): model_directory = "Vaishu16/QC_fine_tuned_model" model = AutoModelForCausalLM.from_pretrained(model_directory) tokenizer = AutoTokenizer.from_pretrained(model_directory) return model, tokenizer # Load model and tokenizer model, tokenizer = load_model() # Create a pipeline question_completion_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=-1 ) # Suggested partial questions suggested_questions = [ "What is the impact of", "How does the company report", "What are the financial risks of", "Explain the corporate governance of", "What was the revenue growth of" ] # Streamlit UI st.title("Question Completion Model") st.write("Enter a partial question related to financial statements, corporate governance, or company reports, and the app will intelligently complete it based on learned patterns!") # Session state for input question if "partial_question" not in st.session_state: st.session_state.partial_question = "" # Function to update input when a suggestion is clicked def update_input(selected_question): st.session_state.partial_question = selected_question st.rerun() # Display suggested partial questions st.write("### Suggested Partial Questions:") cols = st.columns(5) for i, question in enumerate(suggested_questions): if cols[i % 5].button(question, key=f"suggested_{i}"): # Place buttons in columns update_input(question) # Text input box partial_question = st.text_input("Enter a partial question:", st.session_state.partial_question) # Button to generate 3 completed questions if st.button("Complete Question"): if partial_question.strip(): outputs = question_completion_pipeline( partial_question, max_length=60, num_return_sequences=3, # Generate 3 different completions do_sample=True, truncation=True ) st.write("### Completed Questions:") for output in outputs: completed_question = output["generated_text"] st.markdown(f"**{completed_question}**") # Display non-clickable text else: st.warning("Please enter a partial question.")