sunheycho commited on
Commit
51940dd
·
1 Parent(s): a1797f0

Refactor: Convert product comparison to use LangChain instead of TinyLlama

Browse files

- Replace TinyLlama direct model usage with LangChain ChatOpenAI
- Update all agents to use LangChain's invoke() method
- Add proper OpenAI API key environment variable handling
- Maintain graceful fallback when OPENAI_API_KEY is not available
- Re-enable LangChain dependencies in requirements.txt
- All agents now use gpt-3.5-turbo via LangChain for better performance

Files changed (2) hide show
  1. product_comparison.py +64 -46
  2. requirements.txt +2 -0
product_comparison.py CHANGED
@@ -156,46 +156,31 @@ class BaseAgent:
156
  def __init__(self, name, llm=None):
157
  self.name = name
158
 
159
- # Use TinyLlama as the default LLM if none is provided
160
  if llm is None:
161
  try:
162
- # Initialize TinyLlama model
163
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
164
- tokenizer = AutoTokenizer.from_pretrained(model_name)
165
- model = AutoModelForCausalLM.from_pretrained(
166
- model_name,
167
- torch_dtype=torch.float16,
168
- ).to("cuda" if torch.cuda.is_available() else "cpu")
169
-
170
- # Wrap the model for LangChain
171
- self.llm = self._create_llm_wrapper(model, tokenizer)
 
 
 
 
 
 
172
  except Exception as e:
173
- print(f"Error loading TinyLlama: {e}")
174
- # Fallback to a simple string output
175
  self.llm = None
176
  else:
177
  self.llm = llm
178
 
179
- def _create_llm_wrapper(self, model, tokenizer):
180
- """Create a simple wrapper for the LLM model"""
181
- # This is a simplified wrapper - in production, you'd use a proper LangChain integration
182
- def generate_text(prompt, max_tokens=512):
183
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
184
- with torch.no_grad():
185
- output = model.generate(
186
- **inputs,
187
- max_new_tokens=max_tokens,
188
- temperature=0.7,
189
- do_sample=True
190
- )
191
- response = tokenizer.decode(output[0], skip_special_tokens=True)
192
- # Remove the prompt from the response
193
- if response.startswith(prompt):
194
- response = response[len(prompt):].strip()
195
- return response
196
-
197
- return generate_text
198
-
199
  def log(self, session_id, message):
200
  """Log a message to the session"""
201
  return session_manager.add_message(session_id, message, agent_type=self.name)
@@ -331,17 +316,26 @@ class ImageProcessingAgent(BaseAgent):
331
  - model: Any model information that can be determined
332
  - color: The main color of the product
333
  - key_features: List of notable visual features
334
- """
 
335
 
336
  try:
337
- response = self.llm(prompt)
 
 
 
 
 
 
 
 
338
  # Try to parse as JSON
339
  try:
340
- extracted = json.loads(response)
341
  return extracted
342
  except json.JSONDecodeError:
343
  # If LLM output is not valid JSON, extract key information using simple parsing
344
- lines = response.split('\n')
345
  extracted = {}
346
  for line in lines:
347
  if ':' in line:
@@ -470,14 +464,22 @@ class FeatureExtractionAgent(BaseAgent):
470
  """
471
 
472
  try:
473
- response = self.llm(prompt)
 
 
 
 
 
 
 
 
474
  # Try to parse as JSON
475
  try:
476
- specs = json.loads(response)
477
  return specs
478
  except json.JSONDecodeError:
479
  # If LLM output is not valid JSON, extract key information using simple parsing
480
- lines = response.split('\n')
481
  specs = {}
482
  for line in lines:
483
  if ':' in line:
@@ -607,10 +609,18 @@ class ComparisonAgent(BaseAgent):
607
  """
608
 
609
  try:
610
- response = self.llm(prompt)
 
 
 
 
 
 
 
 
611
  # Try to parse as JSON
612
  try:
613
- comparison = json.loads(response)
614
  return comparison
615
  except json.JSONDecodeError:
616
  # If LLM output is not valid JSON, extract key sections using simple parsing
@@ -618,7 +628,7 @@ class ComparisonAgent(BaseAgent):
618
  current_section = None
619
  section_content = []
620
 
621
- lines = response.split('\n')
622
  for line in lines:
623
  line = line.strip()
624
  if not line:
@@ -807,14 +817,22 @@ class RecommendationAgent(BaseAgent):
807
  """
808
 
809
  try:
810
- response = self.llm(prompt)
 
 
 
 
 
 
 
 
811
  # Try to parse as JSON
812
  try:
813
- recommendation = json.loads(response)
814
  return recommendation
815
  except json.JSONDecodeError:
816
  # If LLM output is not valid JSON, extract key information using simple parsing
817
- lines = response.split('\n')
818
  recommendation = {}
819
 
820
  # Look for recommendation indicator
 
156
  def __init__(self, name, llm=None):
157
  self.name = name
158
 
159
+ # Use LangChain ChatOpenAI as the default LLM if none is provided
160
  if llm is None:
161
  try:
162
+ if LANGCHAIN_AVAILABLE and ChatOpenAI is not None:
163
+ # Initialize ChatOpenAI with environment variable for API key
164
+ api_key = os.environ.get('OPENAI_API_KEY')
165
+ if api_key:
166
+ self.llm = ChatOpenAI(
167
+ model="gpt-4o",
168
+ temperature=0.7,
169
+ api_key=api_key
170
+ )
171
+ print(f"Initialized {name} with ChatOpenAI (gpt-3.5-turbo)")
172
+ else:
173
+ print(f"Warning: OPENAI_API_KEY not found. {name} will use fallback mode.")
174
+ self.llm = None
175
+ else:
176
+ print(f"Warning: LangChain not available. {name} will use fallback mode.")
177
+ self.llm = None
178
  except Exception as e:
179
+ print(f"Error initializing ChatOpenAI for {name}: {e}")
 
180
  self.llm = None
181
  else:
182
  self.llm = llm
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def log(self, session_id, message):
185
  """Log a message to the session"""
186
  return session_manager.add_message(session_id, message, agent_type=self.name)
 
316
  - model: Any model information that can be determined
317
  - color: The main color of the product
318
  - key_features: List of notable visual features
319
+
320
+ Return only valid JSON format."""
321
 
322
  try:
323
+ # Use LangChain's invoke method
324
+ response = self.llm.invoke(prompt)
325
+
326
+ # Extract content from LangChain response
327
+ if hasattr(response, 'content'):
328
+ response_text = response.content
329
+ else:
330
+ response_text = str(response)
331
+
332
  # Try to parse as JSON
333
  try:
334
+ extracted = json.loads(response_text)
335
  return extracted
336
  except json.JSONDecodeError:
337
  # If LLM output is not valid JSON, extract key information using simple parsing
338
+ lines = response_text.split('\n')
339
  extracted = {}
340
  for line in lines:
341
  if ':' in line:
 
464
  """
465
 
466
  try:
467
+ # Use LangChain's invoke method
468
+ response = self.llm.invoke(prompt)
469
+
470
+ # Extract content from LangChain response
471
+ if hasattr(response, 'content'):
472
+ response_text = response.content
473
+ else:
474
+ response_text = str(response)
475
+
476
  # Try to parse as JSON
477
  try:
478
+ specs = json.loads(response_text)
479
  return specs
480
  except json.JSONDecodeError:
481
  # If LLM output is not valid JSON, extract key information using simple parsing
482
+ lines = response_text.split('\n')
483
  specs = {}
484
  for line in lines:
485
  if ':' in line:
 
609
  """
610
 
611
  try:
612
+ # Use LangChain's invoke method
613
+ response = self.llm.invoke(prompt)
614
+
615
+ # Extract content from LangChain response
616
+ if hasattr(response, 'content'):
617
+ response_text = response.content
618
+ else:
619
+ response_text = str(response)
620
+
621
  # Try to parse as JSON
622
  try:
623
+ comparison = json.loads(response_text)
624
  return comparison
625
  except json.JSONDecodeError:
626
  # If LLM output is not valid JSON, extract key sections using simple parsing
 
628
  current_section = None
629
  section_content = []
630
 
631
+ lines = response_text.split('\n')
632
  for line in lines:
633
  line = line.strip()
634
  if not line:
 
817
  """
818
 
819
  try:
820
+ # Use LangChain's invoke method
821
+ response = self.llm.invoke(prompt)
822
+
823
+ # Extract content from LangChain response
824
+ if hasattr(response, 'content'):
825
+ response_text = response.content
826
+ else:
827
+ response_text = str(response)
828
+
829
  # Try to parse as JSON
830
  try:
831
+ recommendation = json.loads(response_text)
832
  return recommendation
833
  except json.JSONDecodeError:
834
  # If LLM output is not valid JSON, extract key information using simple parsing
835
+ lines = response_text.split('\n')
836
  recommendation = {}
837
 
838
  # Look for recommendation indicator
requirements.txt CHANGED
@@ -47,3 +47,5 @@ openai>=1.30.0
47
  # LangChain (RAG pipeline)
48
  langchain>=0.2.6
49
  langchain-openai>=0.1.16
 
 
 
47
  # LangChain (RAG pipeline)
48
  langchain>=0.2.6
49
  langchain-openai>=0.1.16
50
+ langchain-community>=0.2.6
51
+ langchain-experimental>=0.0.60