Spaces:
Running
Running
| # aiclient.py | |
| import os | |
| import time | |
| import json | |
| from typing import List, Dict, Optional, Union, AsyncGenerator | |
| from openai import AsyncOpenAI | |
| from starlette.responses import StreamingResponse | |
| from observability import log_execution ,LLMObservabilityManager | |
| import psycopg2 | |
| import requests | |
| from functools import lru_cache | |
| import logging | |
| import pandas as pd | |
| logger = logging.getLogger(__name__) | |
| def get_model_info(): | |
| try: | |
| model_info_dict = requests.get( | |
| 'https://openrouter.ai/api/v1/models', | |
| headers={'accept': 'application/json'} | |
| ).json()["data"] | |
| # Save the model info to a JSON file | |
| with open('model_info.json', 'w') as json_file: | |
| json.dump(model_info_dict, json_file, indent=4) | |
| except Exception as e: | |
| logger.error(f"Failed to fetch model info: {e}. Loading from file.") | |
| if os.path.exists('model_info.json'): | |
| with open('model_info.json', 'r') as json_file: | |
| model_info_dict = json.load(json_file) | |
| model_info = pd.DataFrame(model_info_dict) | |
| return model_info | |
| else: | |
| logger.error("No model info file found") | |
| return None | |
| model_info = pd.DataFrame(model_info_dict) | |
| return model_info | |
| class AIClient: | |
| def __init__(self): | |
| self.client = AsyncOpenAI( | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=os.environ['OPENROUTER_API_KEY'] | |
| ) | |
| self.observability_manager = LLMObservabilityManager() | |
| self.model_info = get_model_info() | |
| #@log_execution | |
| async def generate_response( | |
| self, | |
| messages: List[Dict[str, str]], | |
| model: str = "openai/gpt-4o-mini", | |
| max_tokens: int = 32000, | |
| conversation_id: Optional[str] = None, | |
| user: str = "anonymous" | |
| ) -> AsyncGenerator[str, None]: | |
| if not messages: | |
| return | |
| start_time = time.time() | |
| full_response = "" | |
| usage = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0} | |
| status = "success" | |
| try: | |
| response = await self.client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| stream_options={"include_usage": True} | |
| ) | |
| end_time = time.time() | |
| latency = end_time - start_time | |
| async for chunk in response: | |
| if chunk.choices[0].delta.content: | |
| yield chunk.choices[0].delta.content | |
| full_response += chunk.choices[0].delta.content | |
| if chunk.usage: | |
| model = chunk.model | |
| usage["completion_tokens"] = chunk.usage.completion_tokens | |
| usage["prompt_tokens"] = chunk.usage.prompt_tokens | |
| usage["total_tokens"] = chunk.usage.total_tokens | |
| print(usage) | |
| print(model) | |
| except Exception as e: | |
| status = "error" | |
| full_response = str(e) | |
| latency = time.time() - start_time | |
| print(f"Error in generate_response: {e}") | |
| finally: | |
| # Log the observation | |
| try: | |
| pricing_data = self.model_info[self.model_info.id == model]["pricing"].values[0] | |
| cost = float(pricing_data["completion"]) * float(usage["completion_tokens"]) + float(pricing_data["prompt"]) * float(usage["prompt_tokens"]) | |
| self.observability_manager.insert_observation( | |
| response=full_response, | |
| model=model, | |
| completion_tokens=usage["completion_tokens"], | |
| prompt_tokens=usage["prompt_tokens"], | |
| total_tokens=usage["total_tokens"], | |
| cost=cost, | |
| conversation_id=conversation_id or "default", | |
| status=status, | |
| request=json.dumps([msg for msg in messages if msg.get('role') != 'system']), | |
| latency=latency, | |
| user=user | |
| ) | |
| except Exception as obs_error: | |
| print(f"Error logging observation: {obs_error}") | |
| class DatabaseManager: | |
| """Manages database operations.""" | |
| def __init__(self): | |
| self.db_params = { | |
| "dbname": "postgres", | |
| "user": os.environ['SUPABASE_USER'], | |
| "password": os.environ['SUPABASE_PASSWORD'], | |
| "host": "aws-0-us-west-1.pooler.supabase.com", | |
| "port": "5432" | |
| } | |
| def update_database(self, user_id: str, user_query: str, response: str) -> None: | |
| with psycopg2.connect(**self.db_params) as conn: | |
| with conn.cursor() as cur: | |
| insert_query = """ | |
| INSERT INTO ai_document_generator (user_id, user_query, response) | |
| VALUES (%s, %s, %s); | |
| """ | |
| cur.execute(insert_query, (user_id, user_query, response)) |