Spaces:
Paused
Paused
import io | |
import os | |
import torch | |
import streamlit as st | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline | |
from langchain_community.vectorstores import FAISS | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
# Global variables are no longer needed, we will use session state | |
# PDF ํ์ผ ๋ก๋ ๋ฐ ํ ์คํธ ์ถ์ถ | |
def load_pdf(pdf_file): | |
pdf_reader = PdfReader(pdf_file) | |
text = "".join(page.extract_text() for page in pdf_reader.pages) | |
return text | |
# ํ ์คํธ๋ฅผ ์ฒญํฌ๋ก ๋ถํ | |
def split_text(text): | |
text_splitter = CharacterTextSplitter( | |
separator="\n", | |
chunk_size=1000, | |
chunk_overlap=200, | |
length_function=len | |
) | |
return text_splitter.split_text(text) | |
# FAISS ๋ฒกํฐ ์ ์ฅ์ ์์ฑ | |
def create_knowledge_base(chunks): | |
model_name = "sentence-transformers/all-mpnet-base-v2" # ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ๋ช ์ | |
embeddings = HuggingFaceEmbeddings(model_name=model_name) | |
return FAISS.from_texts(chunks, embeddings) | |
# Hugging Face ๋ชจ๋ธ ๋ก๋ | |
def load_model(): | |
model_name = "google/gemma-2-2b" # Hugging Face ๋ชจ๋ธ ID | |
access_token = os.getenv("HF_TOKEN") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token, clean_up_tokenization_spaces=False) | |
model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token) | |
# ๋๋ฐ์ด์ค ์ค์ | |
if torch.cuda.is_available(): | |
device = 0 | |
else: | |
device = -1 | |
# `do_sample`์ True๋ก ์ค์ | |
return pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=150, | |
temperature=0.1, | |
do_sample=True, # ์ด ์ค์ ์ถ๊ฐ | |
device=device | |
) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return None | |
# ๋ชจ๋ธ ์๋ต ์ฒ๋ฆฌ | |
def get_response_from_model(prompt): | |
try: | |
if "knowledge_base" not in st.session_state: | |
return "No PDF has been uploaded yet." | |
if "qa_chain" not in st.session_state: | |
return "QA chain is not initialized." | |
docs = st.session_state.knowledge_base.similarity_search(prompt) | |
print("docs:", docs) # ์ด๊น์ง ๋๋๋ฐ | |
print("prompt:", prompt) # ์ด๊น์ง ๋๋๋ฐ | |
# Chain์ invoke() ๋ฉ์๋ ์ฌ์ฉ (input_documents๋ก ์ ๋ฌ) | |
response = st.session_state.qa_chain.invoke({ | |
"input_documents": docs, | |
"question": prompt | |
}) | |
try: | |
if "Helpful Answer:" in response: | |
response = response.split("Helpful Answer:")[1].strip() | |
except ValueError as e: | |
print(f"ValueError occurred: {e}") | |
return f"Error: Invalid response format - {e}" | |
return response | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# ํ์ด์ง UI | |
def main(): | |
st.title("Welcome to GemmaPaperQA") | |
# PDF ์ ๋ก๋ ์น์ | |
with st.expander("Upload Your Paper", expanded=True): | |
paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden") | |
if paper: | |
st.write(f"Upload complete! File name: {paper.name}") | |
# ํ์ผ ํฌ๊ธฐ ํ์ธ | |
file_size = paper.size # ํ์ผ ํฌ๊ธฐ๋ฅผ ํ์ผ ํฌ์ธํฐ ์ด๋ ์์ด ํ์ธ | |
if file_size > 10 * 1024 * 1024: # 10MB ์ ํ | |
st.error("File is too large! Please upload a file smaller than 10MB.") | |
return | |
# PDF ํ ์คํธ ๋ฏธ๋ฆฌ๋ณด๊ธฐ | |
with st.spinner('Processing PDF...'): | |
try: | |
paper.seek(0) | |
contents = paper.read() | |
pdf_file = io.BytesIO(contents) | |
text = load_pdf(pdf_file) | |
if len(text.strip()) == 0: | |
st.error("The PDF appears to have no extractable text. Please check the file and try again.") | |
return | |
st.text_area("Preview of extracted text", text[:1000], height=200) | |
st.write(f"Total characters extracted: {len(text)}") | |
if st.button("Create Knowledge Base"): | |
chunks = split_text(text) | |
st.session_state.knowledge_base = create_knowledge_base(chunks) | |
print("knowledge_base:", st.session_state.knowledge_base) | |
if st.session_state.knowledge_base is None: | |
st.error("Failed to create knowledge base.") | |
return | |
try: | |
pipe = load_model() | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
return | |
llm = HuggingFacePipeline(pipeline=pipe) | |
st.session_state.qa_chain = load_qa_chain(llm, chain_type="map_rerank") | |
st.success("Knowledge base created! You can now ask questions.") | |
except Exception as e: | |
st.error(f"Failed to process the PDF: {str(e)}") | |
# ์ง๋ฌธ-์๋ต ์น์ | |
if "knowledge_base" in st.session_state and "qa_chain" in st.session_state: | |
with st.expander("Ask Questions", expanded=True): | |
prompt = st.text_input("Chat here!") | |
if prompt: | |
response = get_response_from_model(prompt) | |
if response: | |
st.write(f"**Assistant**: {response}") | |
# ์ฑ ์คํ | |
if __name__ == "__main__": | |
main() | |