Fiqa commited on
Commit
6d0a4ef
·
verified ·
1 Parent(s): 229f93a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -83
app.py CHANGED
@@ -1,83 +1,87 @@
1
- import streamlit as st
2
- import PyPDF2
3
- from langchain.llms import HuggingFaceHub
4
- import pptx
5
- import os
6
- from langchain.vectorstores.cassandra import Cassandra
7
- from langchain.indexes.vectorstore import VectorStoreIndexWrapper
8
- from langchain.embeddings import OpenAIEmbeddings
9
- import cassio
10
- from langchain.text_splitter import CharacterTextSplitter
11
-
12
-
13
-
14
-
15
- # Secure API keys (replace with environment variables in deployment)
16
- ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
17
- ASTRA_DB_ID = os.getenv("ASTRA_DB_ID")
18
- HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
19
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
-
21
-
22
- # Initialize Astra DB connection
23
- cassio.init(token=ASTRA_DB_APPLICATION_TOKEN, database_id=ASTRA_DB_ID)
24
-
25
- # Initialize LLM & Embeddings
26
- hf_llm = HuggingFaceHub(repo_id="google/flan-t5-large", model_kwargs={"temperature": 0, "max_length": 64})
27
- embedding =OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
28
-
29
- # Initialize vector store
30
- astra_vector_store = Cassandra(embedding=embedding, table_name="qa_mini_demo")
31
-
32
- def extract_text_from_pdf(uploaded_file):
33
- """Extract text from a PDF file."""
34
- text = ""
35
- pdf_reader = PyPDF2.PdfReader(uploaded_file)
36
- for page in pdf_reader.pages:
37
- page_text = page.extract_text()
38
- if page_text: # Avoid NoneType error
39
- text += page_text + "\n"
40
- return text
41
-
42
- def extract_text_from_ppt(uploaded_file):
43
- """Extract text from a PowerPoint file."""
44
- text = ""
45
- presentation = pptx.Presentation(uploaded_file)
46
- for slide in presentation.slides:
47
- for shape in slide.shapes:
48
- if hasattr(shape, "text"):
49
- text += shape.text + "\n"
50
- return text
51
-
52
- def main():
53
- st.title("Chat with Documents")
54
-
55
- uploaded_file = st.file_uploader("Upload a PDF or PPT file", type=["pdf", "pptx"])
56
- extract_button = st.button("Extract Text")
57
-
58
- extracted_text = ""
59
- if extract_button and uploaded_file is not None:
60
- if uploaded_file.name.endswith(".pdf"):
61
- extracted_text = extract_text_from_pdf(uploaded_file)
62
- elif uploaded_file.name.endswith(".pptx"):
63
- extracted_text = extract_text_from_ppt(uploaded_file)
64
-
65
- if extracted_text:
66
- text_splitter = CharacterTextSplitter(separator="\n", chunk_size=800, chunk_overlap=200, length_function=len)
67
- texts = text_splitter.split_text(extracted_text)
68
- astra_vector_store.add_texts(texts)
69
-
70
- # Ensure the vector store index is initialized properly
71
- astra_vector_index = VectorStoreIndexWrapper(vectorstore=astra_vector_store)
72
-
73
- query = st.text_input("Enter your query")
74
- submit_query = st.button("Submit Query")
75
- if submit_query:
76
-
77
-
78
- value = astra_vector_index.query(query, llm=hf_llm)
79
-
80
- st.write(f"Response: {value}")
81
-
82
- if __name__ == "__main__":
83
- main()
 
 
 
 
 
1
+ import streamlit as st
2
+ import PyPDF2
3
+ from langchain.llms import HuggingFaceHub
4
+ import pptx
5
+ import os
6
+ from langchain.vectorstores.cassandra import Cassandra
7
+ from langchain.indexes.vectorstore import VectorStoreIndexWrapper
8
+ from langchain.embeddings import OpenAIEmbeddings
9
+ import cassio
10
+ from langchain.text_splitter import CharacterTextSplitter
11
+ from huggingface_hub import login
12
+
13
+
14
+
15
+
16
+
17
+
18
+
19
+ # Secure API keys (replace with environment variables in deployment)
20
+ ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
21
+ ASTRA_DB_ID = os.getenv("ASTRA_DB_ID")
22
+ HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
23
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
24
+ login(token=HUGGINGFACE_API_KEY)
25
+
26
+ # Initialize Astra DB connection
27
+ cassio.init(token=ASTRA_DB_APPLICATION_TOKEN, database_id=ASTRA_DB_ID)
28
+
29
+ # Initialize LLM & Embeddings
30
+ hf_llm = HuggingFaceHub(repo_id="google/flan-t5-large", model_kwargs={"temperature": 0, "max_length": 64})
31
+ embedding =OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
32
+
33
+ # Initialize vector store
34
+ astra_vector_store = Cassandra(embedding=embedding, table_name="qa_mini_demo")
35
+
36
+ def extract_text_from_pdf(uploaded_file):
37
+ """Extract text from a PDF file."""
38
+ text = ""
39
+ pdf_reader = PyPDF2.PdfReader(uploaded_file)
40
+ for page in pdf_reader.pages:
41
+ page_text = page.extract_text()
42
+ if page_text: # Avoid NoneType error
43
+ text += page_text + "\n"
44
+ return text
45
+
46
+ def extract_text_from_ppt(uploaded_file):
47
+ """Extract text from a PowerPoint file."""
48
+ text = ""
49
+ presentation = pptx.Presentation(uploaded_file)
50
+ for slide in presentation.slides:
51
+ for shape in slide.shapes:
52
+ if hasattr(shape, "text"):
53
+ text += shape.text + "\n"
54
+ return text
55
+
56
+ def main():
57
+ st.title("Chat with Documents")
58
+
59
+ uploaded_file = st.file_uploader("Upload a PDF or PPT file", type=["pdf", "pptx"])
60
+ extract_button = st.button("Extract Text")
61
+
62
+ extracted_text = ""
63
+ if extract_button and uploaded_file is not None:
64
+ if uploaded_file.name.endswith(".pdf"):
65
+ extracted_text = extract_text_from_pdf(uploaded_file)
66
+ elif uploaded_file.name.endswith(".pptx"):
67
+ extracted_text = extract_text_from_ppt(uploaded_file)
68
+
69
+ if extracted_text:
70
+ text_splitter = CharacterTextSplitter(separator="\n", chunk_size=800, chunk_overlap=200, length_function=len)
71
+ texts = text_splitter.split_text(extracted_text)
72
+ astra_vector_store.add_texts(texts)
73
+
74
+ # Ensure the vector store index is initialized properly
75
+ astra_vector_index = VectorStoreIndexWrapper(vectorstore=astra_vector_store)
76
+
77
+ query = st.text_input("Enter your query")
78
+ submit_query = st.button("Submit Query")
79
+ if submit_query:
80
+
81
+
82
+ value = astra_vector_index.query(query, llm=hf_llm)
83
+
84
+ st.write(f"Response: {value}")
85
+
86
+ if __name__ == "__main__":
87
+ main()