deepakkarkala commited on
Commit
a90237d
Β·
1 Parent(s): 0e9898e

Loading model async

Browse files
Files changed (1) hide show
  1. app.py +79 -21
app.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import io
2
  import logging
3
  import os
 
4
  import uuid
5
 
6
  import streamlit as st
@@ -15,34 +17,68 @@ from transformers.image_utils import load_image
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  # Capture logs
18
- log_stream = io.StringIO()
19
- logging.basicConfig(stream=log_stream, level=logging.INFO)
20
-
 
21
 
22
  if "session_id" not in st.session_state:
23
  st.session_state["session_id"] = str(uuid.uuid4()) # Generate unique session ID
24
 
25
 
26
- @st.cache_resource # Streamlit Caching decorator
 
 
 
 
 
 
 
 
 
27
  def load_model_embedding():
28
- #docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colsmolvlm-alpha")
29
- #docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colqwen2-v1.0")
30
- docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
31
- return docs_retrieval_model
32
- model_embedding = load_model_embedding()
33
 
34
- @st.cache_resource # Streamlit Caching decorator
35
- def load_model_vlm():
 
 
 
 
 
 
 
 
 
 
36
  checkpoint = "HuggingFaceTB/SmolVLM-Instruct"
37
- processor = AutoProcessor.from_pretrained(checkpoint)
38
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
39
- model = AutoModelForVision2Seq.from_pretrained(
40
  checkpoint,
41
  #torch_dtype=torch.bfloat16,
42
  quantization_config=quantization_config,
43
  )
44
- return model, processor
45
- model_vlm, processor_vlm = load_model_vlm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
 
@@ -64,7 +100,7 @@ with st.sidebar:
64
  "[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)"
65
 
66
  st.title("πŸ“ Image Q&A with VLM")
67
- st.text_area("Logs:", log_stream.getvalue(), height=200)
68
 
69
  uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf"))
70
  query = st.text_input(
@@ -73,16 +109,34 @@ query = st.text_input(
73
  disabled=not uploaded_pdf,
74
  )
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  images = []
77
  images_folder = "data/" + st.session_state["session_id"] + "/"
78
  index_name = "index_" + st.session_state["session_id"]
79
 
80
 
81
- if uploaded_pdf and "is_index_complete" not in st.session_state:
 
82
  images = convert_from_bytes(uploaded_pdf.getvalue())
83
- save_images_to_local(images, output_folder=images_folder)
 
84
  # index documents using the document retrieval model
85
- model_embedding.index(
86
  input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True
87
  )
88
  logging.info(f"{len(images)} number of images extracted from PDF and indexed")
@@ -90,13 +144,17 @@ if uploaded_pdf and "is_index_complete" not in st.session_state:
90
 
91
 
92
 
93
- if uploaded_pdf and query:
94
- docs_retrieved = model_embedding.search(query, k=1)
95
  logging.info(f"{len(docs_retrieved)} number of images retrieved as relevant to query")
96
  image_id = docs_retrieved[0]["doc_id"]
97
  logging.info(f"Image id:{image_id} retrieved" )
98
  image_similar_to_query = images[image_id]
99
 
 
 
 
 
100
  # Create input messages
101
  system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context."
102
  chat_template = [
 
1
+ import asyncio
2
  import io
3
  import logging
4
  import os
5
+ import threading
6
  import uuid
7
 
8
  import streamlit as st
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
  # Capture logs
20
+ #log_stream = io.StringIO()
21
+ #logging.basicConfig(stream=log_stream, level=logging.INFO)
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
 
25
  if "session_id" not in st.session_state:
26
  st.session_state["session_id"] = str(uuid.uuid4()) # Generate unique session ID
27
 
28
 
29
+ # Async function to load the model
30
+ async def load_model_embedding_async():
31
+ st.session_state["loading_model_embedding"] = True # Show loading status
32
+ await asyncio.sleep(0.1) # Allow UI updates
33
+ model_embedding = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
34
+ st.session_state["model_embedding"] = model_embedding
35
+ st.session_state["loading_model_embedding"] = False # Model is ready
36
+
37
+
38
+ # Function to run async function in a separate thread
39
  def load_model_embedding():
40
+ loop = asyncio.new_event_loop()
41
+ asyncio.set_event_loop(loop)
42
+ loop.run_until_complete(load_model_embedding_async())
 
 
43
 
44
+ # Start model loading in a background thread
45
+ if "model_embedding" not in st.session_state:
46
+ with st.status("Loading embedding model... ⏳"):
47
+ threading.Thread(target=load_model_embedding, daemon=True).start()
48
+
49
+
50
+
51
+ # Async function to load the model
52
+ async def load_model_vlm_async():
53
+ st.session_state["loading_model_vlm"] = True # Show loading status
54
+ await asyncio.sleep(0.1) # Allow UI updates
55
+
56
  checkpoint = "HuggingFaceTB/SmolVLM-Instruct"
57
+ processor_vlm = AutoProcessor.from_pretrained(checkpoint)
58
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
59
+ model_vlm = AutoModelForVision2Seq.from_pretrained(
60
  checkpoint,
61
  #torch_dtype=torch.bfloat16,
62
  quantization_config=quantization_config,
63
  )
64
+
65
+ st.session_state["model_vlm"] = model_vlm
66
+ st.session_state["processor_vlm"] = processor_vlm
67
+ st.session_state["loading_model_vlm"] = False # Model is ready
68
+
69
+
70
+ # Function to run async function in a separate thread
71
+ def load_model_vlm():
72
+ loop = asyncio.new_event_loop()
73
+ asyncio.set_event_loop(loop)
74
+ loop.run_until_complete(load_model_vlm_async())
75
+
76
+
77
+ # Start model loading in a background thread
78
+ if "model_vlm" not in st.session_state:
79
+ with st.status("Loading VLM model... ⏳"):
80
+ threading.Thread(target=load_model_vlm, daemon=True).start()
81
+
82
 
83
 
84
 
 
100
  "[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)"
101
 
102
  st.title("πŸ“ Image Q&A with VLM")
103
+ #st.text_area("Logs:", log_stream.getvalue(), height=200)
104
 
105
  uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf"))
106
  query = st.text_input(
 
109
  disabled=not uploaded_pdf,
110
  )
111
 
112
+ if st.session_state.get("loading_model_embedding", True):
113
+ st.warning("Loading Embedding model....")
114
+ else:
115
+ st.success("Embedding Model loaded successfully! πŸŽ‰")
116
+
117
+ if st.session_state.get("loading_model_vlm", True):
118
+ st.warning("Loading VLM model....")
119
+ else:
120
+ st.success("VLM Model loaded successfully! πŸŽ‰")
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
  images = []
129
  images_folder = "data/" + st.session_state["session_id"] + "/"
130
  index_name = "index_" + st.session_state["session_id"]
131
 
132
 
133
+
134
+ if uploaded_pdf and "model_embedding" in st.session_state and "is_index_complete" not in st.session_state:
135
  images = convert_from_bytes(uploaded_pdf.getvalue())
136
+ save_images_to_local(images, output_folder=images_folder)
137
+
138
  # index documents using the document retrieval model
139
+ st.session_state["model_embedding"].index(
140
  input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True
141
  )
142
  logging.info(f"{len(images)} number of images extracted from PDF and indexed")
 
144
 
145
 
146
 
147
+ if uploaded_pdf and query and "model_embedding" in st.session_state and "model_vlm" in st.session_state:
148
+ docs_retrieved = st.session_state["model_embedding"].search(query, k=1)
149
  logging.info(f"{len(docs_retrieved)} number of images retrieved as relevant to query")
150
  image_id = docs_retrieved[0]["doc_id"]
151
  logging.info(f"Image id:{image_id} retrieved" )
152
  image_similar_to_query = images[image_id]
153
 
154
+
155
+ model_vlm, processor_vlm = st.session_state["model_vlm"], st.session_state["processor_vlm"]
156
+
157
+
158
  # Create input messages
159
  system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context."
160
  chat_template = [