gemma_sprint / app.py
halyn's picture
find int error
840d9a5
import io
import os
import torch
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.question_answering import load_qa_chain
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Global variables are no longer needed, we will use session state
# PDF ํŒŒ์ผ ๋กœ๋“œ ๋ฐ ํ…์ŠคํŠธ ์ถ”์ถœ
def load_pdf(pdf_file):
pdf_reader = PdfReader(pdf_file)
text = "".join(page.extract_text() for page in pdf_reader.pages)
return text
# ํ…์ŠคํŠธ๋ฅผ ์ฒญํฌ๋กœ ๋ถ„ํ• 
def split_text(text):
text_splitter = CharacterTextSplitter(
separator="\n",
chunk_size=1000,
chunk_overlap=200,
length_function=len
)
return text_splitter.split_text(text)
# FAISS ๋ฒกํ„ฐ ์ €์žฅ์†Œ ์ƒ์„ฑ
def create_knowledge_base(chunks):
model_name = "sentence-transformers/all-mpnet-base-v2" # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ๋ช…์‹œ
embeddings = HuggingFaceEmbeddings(model_name=model_name)
return FAISS.from_texts(chunks, embeddings)
# Hugging Face ๋ชจ๋ธ ๋กœ๋“œ
def load_model():
model_name = "google/gemma-2-2b" # Hugging Face ๋ชจ๋ธ ID
access_token = os.getenv("HF_TOKEN")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token, clean_up_tokenization_spaces=False)
model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)
# ๋””๋ฐ”์ด์Šค ์„ค์ •
if torch.cuda.is_available():
device = 0
else:
device = -1
# `do_sample`์„ True๋กœ ์„ค์ •
return pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=150,
temperature=0.1,
do_sample=True, # ์ด ์„ค์ • ์ถ”๊ฐ€
device=device
)
except Exception as e:
print(f"Error loading model: {e}")
return None
# ๋ชจ๋ธ ์‘๋‹ต ์ฒ˜๋ฆฌ
def get_response_from_model(prompt):
try:
if "knowledge_base" not in st.session_state:
return "No PDF has been uploaded yet."
if "qa_chain" not in st.session_state:
return "QA chain is not initialized."
docs = st.session_state.knowledge_base.similarity_search(prompt)
print("docs:", docs) # ์ด๊นŒ์ง„ ๋๋Š”๋ฐ
print("prompt:", prompt) # ์ด๊นŒ์ง„ ๋๋Š”๋ฐ
# Chain์˜ invoke() ๋ฉ”์†Œ๋“œ ์‚ฌ์šฉ (input_documents๋กœ ์ „๋‹ฌ)
response = st.session_state.qa_chain.invoke({
"input_documents": docs,
"question": prompt
})
try:
if "Helpful Answer:" in response:
response = response.split("Helpful Answer:")[1].strip()
except ValueError as e:
print(f"ValueError occurred: {e}")
return f"Error: Invalid response format - {e}"
return response
except Exception as e:
return f"Error: {str(e)}"
# ํŽ˜์ด์ง€ UI
def main():
st.title("Welcome to GemmaPaperQA")
# PDF ์—…๋กœ๋“œ ์„น์…˜
with st.expander("Upload Your Paper", expanded=True):
paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden")
if paper:
st.write(f"Upload complete! File name: {paper.name}")
# ํŒŒ์ผ ํฌ๊ธฐ ํ™•์ธ
file_size = paper.size # ํŒŒ์ผ ํฌ๊ธฐ๋ฅผ ํŒŒ์ผ ํฌ์ธํ„ฐ ์ด๋™ ์—†์ด ํ™•์ธ
if file_size > 10 * 1024 * 1024: # 10MB ์ œํ•œ
st.error("File is too large! Please upload a file smaller than 10MB.")
return
# PDF ํ…์ŠคํŠธ ๋ฏธ๋ฆฌ๋ณด๊ธฐ
with st.spinner('Processing PDF...'):
try:
paper.seek(0)
contents = paper.read()
pdf_file = io.BytesIO(contents)
text = load_pdf(pdf_file)
if len(text.strip()) == 0:
st.error("The PDF appears to have no extractable text. Please check the file and try again.")
return
st.text_area("Preview of extracted text", text[:1000], height=200)
st.write(f"Total characters extracted: {len(text)}")
if st.button("Create Knowledge Base"):
chunks = split_text(text)
st.session_state.knowledge_base = create_knowledge_base(chunks)
print("knowledge_base:", st.session_state.knowledge_base)
if st.session_state.knowledge_base is None:
st.error("Failed to create knowledge base.")
return
try:
pipe = load_model()
except Exception as e:
st.error(f"Error loading model: {e}")
return
llm = HuggingFacePipeline(pipeline=pipe)
st.session_state.qa_chain = load_qa_chain(llm, chain_type="map_rerank")
st.success("Knowledge base created! You can now ask questions.")
except Exception as e:
st.error(f"Failed to process the PDF: {str(e)}")
# ์งˆ๋ฌธ-์‘๋‹ต ์„น์…˜
if "knowledge_base" in st.session_state and "qa_chain" in st.session_state:
with st.expander("Ask Questions", expanded=True):
prompt = st.text_input("Chat here!")
if prompt:
response = get_response_from_model(prompt)
if response:
st.write(f"**Assistant**: {response}")
# ์•ฑ ์‹คํ–‰
if __name__ == "__main__":
main()