David Ko commited on
Commit
cbaf1c3
ยท
1 Parent(s): 0524412

Replace Llama model with OpenAI API for question answering

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. api.py +38 -52
  3. requirements.txt +4 -1
README.md CHANGED
@@ -83,6 +83,7 @@ This project follows a phased development approach:
83
  - **YOLOv8**: Fast and accurate object detection
84
  - **DETR**: DEtection TRansformer for object detection
85
  - **ViT**: Vision Transformer for image classification
 
86
 
87
  ## API Endpoints
88
 
 
83
  - **YOLOv8**: Fast and accurate object detection
84
  - **DETR**: DEtection TRansformer for object detection
85
  - **ViT**: Vision Transformer for image classification
86
+ - **OpenAI API**: For natural language processing and question answering about detected objects
87
 
88
  ## API Endpoints
89
 
api.py CHANGED
@@ -161,33 +161,34 @@ except Exception as e:
161
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
162
  print(f"Using device: {device}")
163
 
164
- # LLM model (using an open-access model instead of Llama 4 which requires authentication)
165
- llm_model = None
166
- llm_tokenizer = None
 
 
 
 
 
 
 
167
  try:
168
- from transformers import AutoModelForCausalLM, AutoTokenizer
169
-
170
- print("Loading LLM model... This may take a moment.")
171
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Using TinyLlama as an open-access alternative
172
 
173
- llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
174
- llm_model = AutoModelForCausalLM.from_pretrained(
175
- model_name,
176
- torch_dtype=torch.float16,
177
- # Removing options that require accelerate package
178
- # device_map="auto",
179
- # load_in_8bit=True
180
- ).to(device)
181
- print("LLM model loaded successfully")
182
  except Exception as e:
183
- print(f"Error loading LLM model: {e}")
184
- llm_model = None
185
- llm_tokenizer = None
186
 
187
  def process_llm_query(vision_results, user_query):
188
- """Process a query with the LLM model using vision results and user text"""
189
- if llm_model is None or llm_tokenizer is None:
190
- return {"error": "LLM model not available"}
191
 
192
  # ๊ฒฐ๊ณผ ๋ฐ์ดํ„ฐ ์š”์•ฝ (ํ† ํฐ ๊ธธ์ด ์ œํ•œ์„ ์œ„ํ•ด)
193
  summarized_results = []
@@ -205,52 +206,37 @@ def process_llm_query(vision_results, user_query):
205
  summarized_results.append(summary)
206
 
207
  # Create a prompt combining vision results and user query
208
- prompt = f"""You are an AI assistant analyzing image detection results.
209
- Here are the objects detected in the image: {json.dumps(summarized_results, indent=2)}
210
 
211
  User question: {user_query}
212
 
213
  Please provide a detailed analysis based on the detected objects and the user's question.
214
  """
215
 
216
- # Tokenize and generate response
217
  try:
218
  start_time = time.time()
219
 
220
- # ํ† ํฐ ๊ธธ์ด ํ™•์ธ ๋ฐ ์ œํ•œ
221
- tokens = llm_tokenizer.encode(prompt)
222
- if len(tokens) > 1500: # ์•ˆ์ „ ๋งˆ์ง„ ์„ค์ •
223
- prompt = f"""You are an AI assistant analyzing image detection results.
224
- The image contains {len(summarized_results)} detected objects.
225
-
226
- User question: {user_query}
227
-
228
- Please provide a general analysis based on the user's question.
229
- """
230
-
231
- inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
232
- with torch.no_grad():
233
- output = llm_model.generate(
234
- **inputs,
235
- max_new_tokens=512,
236
- temperature=0.7,
237
- top_p=0.9,
238
- do_sample=True
239
- )
240
-
241
- response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
242
-
243
- # Remove the prompt from the response
244
- if response_text.startswith(prompt):
245
- response_text = response_text[len(prompt):].strip()
246
 
 
247
  inference_time = time.time() - start_time
248
 
249
  return {
250
  "response": response_text,
251
  "performance": {
252
  "inference_time": round(inference_time, 3),
253
- "device": "GPU" if torch.cuda.is_available() else "CPU"
254
  }
255
  }
256
  except Exception as e:
 
161
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
162
  print(f"Using device: {device}")
163
 
164
+ # OpenAI API ์„ค์ • (Llama ๋ชจ๋ธ ๋Œ€์‹  ์‚ฌ์šฉ)
165
+ import os
166
+ import openai
167
+
168
+ # OpenAI API ํ‚ค ์„ค์ •
169
+ openai_api_key = os.environ.get("OPENAI_API_KEY", "")
170
+ if not openai_api_key:
171
+ print("Warning: OPENAI_API_KEY environment variable not set")
172
+
173
+ # OpenAI ํด๋ผ์ด์–ธํŠธ ์„ค์ •
174
  try:
175
+ from openai import OpenAI
 
 
 
176
 
177
+ print("Setting up OpenAI client...")
178
+ if openai_api_key:
179
+ openai_client = OpenAI(api_key=openai_api_key)
180
+ print("OpenAI client initialized successfully")
181
+ else:
182
+ openai_client = None
183
+ print("OpenAI client not initialized due to missing API key")
 
 
184
  except Exception as e:
185
+ print(f"Error setting up OpenAI client: {e}")
186
+ openai_client = None
 
187
 
188
  def process_llm_query(vision_results, user_query):
189
+ """Process a query with OpenAI API using vision results and user text"""
190
+ if openai_client is None:
191
+ return {"error": "OpenAI API not available. Please set OPENAI_API_KEY environment variable."}
192
 
193
  # ๊ฒฐ๊ณผ ๋ฐ์ดํ„ฐ ์š”์•ฝ (ํ† ํฐ ๊ธธ์ด ์ œํ•œ์„ ์œ„ํ•ด)
194
  summarized_results = []
 
206
  summarized_results.append(summary)
207
 
208
  # Create a prompt combining vision results and user query
209
+ system_message = "You are an AI assistant analyzing image detection results."
210
+ user_message = f"""Here are the objects detected in the image: {json.dumps(summarized_results, indent=2)}
211
 
212
  User question: {user_query}
213
 
214
  Please provide a detailed analysis based on the detected objects and the user's question.
215
  """
216
 
217
+ # OpenAI API ํ˜ธ์ถœ
218
  try:
219
  start_time = time.time()
220
 
221
+ response = openai_client.chat.completions.create(
222
+ model="gpt-4", # ๋˜๋Š” "gpt-3.5-turbo" ๋“ฑ ์›ํ•˜๋Š” ๋ชจ๋ธ
223
+ messages=[
224
+ {"role": "system", "content": system_message},
225
+ {"role": "user", "content": user_message}
226
+ ],
227
+ max_tokens=500,
228
+ temperature=0.7,
229
+ top_p=0.9
230
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ response_text = response.choices[0].message.content
233
  inference_time = time.time() - start_time
234
 
235
  return {
236
  "response": response_text,
237
  "performance": {
238
  "inference_time": round(inference_time, 3),
239
+ "model": "OpenAI API"
240
  }
241
  }
242
  except Exception as e:
requirements.txt CHANGED
@@ -19,7 +19,10 @@ fastapi>=0.100.0
19
  uvicorn[standard]>=0.22.0
20
  python-multipart>=0.0.5
21
 
22
- # Llama 4 integration
 
 
 
23
  accelerator>=0.20.0
24
  bitsandbytes>=0.41.0
25
  sentencepiece>=0.1.99
 
19
  uvicorn[standard]>=0.22.0
20
  python-multipart>=0.0.5
21
 
22
+ # OpenAI API integration (replacing Llama)
23
+ openai>=1.0.0
24
+
25
+ # Llama 4 integration (legacy)
26
  accelerator>=0.20.0
27
  bitsandbytes>=0.41.0
28
  sentencepiece>=0.1.99