import io import os import torch import streamlit as st from PyPDF2 import PdfReader from langchain.text_splitter import CharacterTextSplitter from langchain.chains.question_answering import load_qa_chain from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline from langchain_community.vectorstores import FAISS from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # Global variables are no longer needed, we will use session state # PDF 파일 로드 및 텍스트 추출 def load_pdf(pdf_file): pdf_reader = PdfReader(pdf_file) text = "".join(page.extract_text() for page in pdf_reader.pages) return text # 텍스트를 청크로 분할 def split_text(text): text_splitter = CharacterTextSplitter( separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len ) return text_splitter.split_text(text) # FAISS 벡터 저장소 생성 def create_knowledge_base(chunks): model_name = "sentence-transformers/all-mpnet-base-v2" # 임베딩 모델을 명시 embeddings = HuggingFaceEmbeddings(model_name=model_name) return FAISS.from_texts(chunks, embeddings) # Hugging Face 모델 로드 def load_model(): model_name = "google/gemma-2-2b" # Hugging Face 모델 ID access_token = os.getenv("HF_TOKEN") try: tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token, clean_up_tokenization_spaces=False) model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token) # 디바이스 설정 if torch.cuda.is_available(): device = 0 else: device = -1 # `do_sample`을 True로 설정 return pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1, do_sample=True, # 이 설정 추가 device=device ) except Exception as e: print(f"Error loading model: {e}") return None # 모델 응답 처리 def get_response_from_model(prompt): try: if "knowledge_base" not in st.session_state: return "No PDF has been uploaded yet." if "qa_chain" not in st.session_state: return "QA chain is not initialized." docs = st.session_state.knowledge_base.similarity_search(prompt) print("docs:", docs) # 이까진 됐는데 print("prompt:", prompt) # 이까진 됐는데 # Chain의 invoke() 메소드 사용 (input_documents로 전달) response = st.session_state.qa_chain.invoke({ "input_documents": docs, "question": prompt }) try: if "Helpful Answer:" in response: response = response.split("Helpful Answer:")[1].strip() except ValueError as e: print(f"ValueError occurred: {e}") return f"Error: Invalid response format - {e}" return response except Exception as e: return f"Error: {str(e)}" # 페이지 UI def main(): st.title("Welcome to GemmaPaperQA") # PDF 업로드 섹션 with st.expander("Upload Your Paper", expanded=True): paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden") if paper: st.write(f"Upload complete! File name: {paper.name}") # 파일 크기 확인 file_size = paper.size # 파일 크기를 파일 포인터 이동 없이 확인 if file_size > 10 * 1024 * 1024: # 10MB 제한 st.error("File is too large! Please upload a file smaller than 10MB.") return # PDF 텍스트 미리보기 with st.spinner('Processing PDF...'): try: paper.seek(0) contents = paper.read() pdf_file = io.BytesIO(contents) text = load_pdf(pdf_file) if len(text.strip()) == 0: st.error("The PDF appears to have no extractable text. Please check the file and try again.") return st.text_area("Preview of extracted text", text[:1000], height=200) st.write(f"Total characters extracted: {len(text)}") if st.button("Create Knowledge Base"): chunks = split_text(text) st.session_state.knowledge_base = create_knowledge_base(chunks) print("knowledge_base:", st.session_state.knowledge_base) if st.session_state.knowledge_base is None: st.error("Failed to create knowledge base.") return try: pipe = load_model() except Exception as e: st.error(f"Error loading model: {e}") return llm = HuggingFacePipeline(pipeline=pipe) st.session_state.qa_chain = load_qa_chain(llm, chain_type="map_rerank") st.success("Knowledge base created! You can now ask questions.") except Exception as e: st.error(f"Failed to process the PDF: {str(e)}") # 질문-응답 섹션 if "knowledge_base" in st.session_state and "qa_chain" in st.session_state: with st.expander("Ask Questions", expanded=True): prompt = st.text_input("Chat here!") if prompt: response = get_response_from_model(prompt) if response: st.write(f"**Assistant**: {response}") # 앱 실행 if __name__ == "__main__": main()