Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import re | |
import faiss | |
import numpy as np | |
from dotenv import load_dotenv | |
import httpx | |
from langdetect import detect | |
from deep_translator import GoogleTranslator | |
try: | |
import pymssql | |
PYMSSQL_AVAILABLE = True | |
except ImportError: | |
PYMSSQL_AVAILABLE = False | |
logging.warning("pymssql not available - database features will be limited") | |
import pickle | |
import json | |
from sentence_transformers import SentenceTransformer, util | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
from typing import Dict, List, Any, Optional | |
from datetime import datetime | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
# Load environment variables | |
load_dotenv() | |
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") | |
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions" | |
# Database connection parameters | |
DB_SERVER = os.getenv("DB_SERVER") | |
DB_DATABASE = os.getenv("DB_DATABASE") | |
DB_USER = os.getenv("DB_USER") | |
DB_PASSWORD = os.getenv("DB_PASSWORD") | |
EMBEDDINGS_PATH = "hockey_embeddings.npy" | |
METADATA_PATH = "hockey_metadata.json" | |
INDEX_PATH = "hockey_faiss_index.index" | |
if not OPENROUTER_API_KEY: | |
logging.warning("OPENROUTER_API_KEY not set in environment - API calls will fail") | |
# Don't raise error, let it fail gracefully during API calls | |
if not all([DB_SERVER, DB_DATABASE, DB_USER, DB_PASSWORD]): | |
logging.warning("Database connection parameters missing in .env file - running without database") | |
DB_AVAILABLE = False | |
else: | |
DB_AVAILABLE = PYMSSQL_AVAILABLE | |
# In-memory conversation history | |
conversation_histories = {} | |
# Lazy-loaded SentenceTransformer and FAISS index | |
sentence_model = None | |
faiss_index = None | |
embeddings_np = None | |
metadata = [] | |
class HockeyFoodDBConnector: | |
def __init__(self): | |
self.connection = None | |
def connect(self): | |
"""Connect to HockeyFood database using pymssql""" | |
if not DB_AVAILABLE: | |
logging.info("Database not available - using preloaded embeddings only") | |
return False | |
try: | |
self.connection = pymssql.connect( | |
server=DB_SERVER, | |
user=DB_USER, | |
password=DB_PASSWORD, | |
database=DB_DATABASE, | |
timeout=30, | |
as_dict=True | |
) | |
logging.info(f"Successfully connected to database: {DB_DATABASE}") | |
return True | |
except Exception as e: | |
logging.error(f"Database connection failed: {str(e)}") | |
return False | |
def disconnect(self): | |
"""Close database connection""" | |
if self.connection: | |
self.connection.close() | |
logging.info("Database connection closed") | |
def execute_query(self, query: str, params: tuple = None): | |
"""Execute a query and return results""" | |
try: | |
cursor = self.connection.cursor() | |
cursor.execute(query, params or ()) | |
return cursor.fetchall() | |
except Exception as e: | |
logging.error(f"Query execution failed: {str(e)}") | |
return [] | |
def get_exercise_data(self): | |
"""Get Exercise table data: Title -> Text""" | |
query = """ | |
SELECT Id, Title, Text, InternalTitle, Organisation, Rules | |
FROM [Main].[Exercise] | |
WHERE DeletedAt IS NULL AND Title IS NOT NULL AND Text IS NOT NULL | |
""" | |
return self.execute_query(query) | |
def get_serie_data(self): | |
"""Get Serie table data: Title -> Description""" | |
query = """ | |
SELECT Id, Title, Description | |
FROM [Main].[Serie] | |
WHERE DeletedAt IS NULL AND Title IS NOT NULL AND Description IS NOT NULL | |
""" | |
return self.execute_query(query) | |
def get_multimedia_data(self): | |
"""Get Multimedia table data: Title -> URL""" | |
query = """ | |
SELECT Id, Title, Url, Description | |
FROM [Media].[Multimedia] | |
WHERE Title IS NOT NULL AND Url IS NOT NULL | |
""" | |
return self.execute_query(query) | |
def get_all_tables(self): | |
"""Get list of all tables in the database to debug""" | |
query = """ | |
SELECT TABLE_SCHEMA, TABLE_NAME | |
FROM INFORMATION_SCHEMA.TABLES | |
WHERE TABLE_TYPE = 'BASE TABLE' | |
ORDER BY TABLE_NAME | |
""" | |
return self.execute_query(query) | |
def load_resources(): | |
global sentence_model, faiss_index, embeddings_np, metadata | |
# Check if running on HuggingFace and adjust behavior | |
is_huggingface = os.getenv("SPACE_ID") is not None | |
if sentence_model is None: | |
try: | |
sentence_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2") | |
logging.info("Loaded SentenceTransformer model.") | |
except Exception as e: | |
if is_huggingface: | |
logging.warning(f"Failed to load SentenceTransformer on HuggingFace: {e}") | |
sentence_model = None | |
return # Exit gracefully for HuggingFace | |
else: | |
logging.error(f"Failed to load SentenceTransformer: {e}") | |
raise | |
if faiss_index is None or embeddings_np is None or not metadata: | |
if not (os.path.exists(EMBEDDINGS_PATH) and os.path.exists(METADATA_PATH) and os.path.exists(INDEX_PATH)): | |
if DB_AVAILABLE: | |
logging.info("Generating embeddings from HockeyFood database...") | |
generate_embeddings_from_db() | |
else: | |
logging.warning("No preloaded embeddings found and database not available - running without content recommendations") | |
embeddings_np = None | |
metadata = [] | |
faiss_index = None | |
else: | |
# Load existing embeddings | |
embeddings_np = np.load(EMBEDDINGS_PATH) | |
with open(METADATA_PATH, "r") as f: | |
metadata = json.load(f) | |
try: | |
faiss_index = faiss.read_index(INDEX_PATH) | |
except Exception as e: | |
logging.warning(f"Failed to load FAISS index: {e}. Regenerating...") | |
dimension = embeddings_np.shape[1] | |
faiss_index = faiss.IndexFlatIP(dimension) | |
faiss.normalize_L2(embeddings_np) | |
faiss_index.add(embeddings_np) | |
faiss.write_index(faiss_index, INDEX_PATH) | |
logging.info(f"Loaded {embeddings_np.shape[0]} embeddings") | |
def generate_embeddings_from_db(): | |
"""Generate embeddings from HockeyFood database tables""" | |
global embeddings_np, metadata, faiss_index | |
db_connector = HockeyFoodDBConnector() | |
if not db_connector.connect(): | |
raise RuntimeError("Could not connect to HockeyFood database") | |
try: | |
# First, let's see what tables actually exist | |
all_tables = db_connector.get_all_tables() | |
table_names = [f"{t.get('TABLE_SCHEMA', '')}.{t.get('TABLE_NAME', '')}" for t in all_tables] | |
logging.info(f"Available tables in database: {table_names}") | |
embeddings = [] | |
metadata = [] | |
# Process Exercise table (Title -> Text) | |
logging.info("Processing Exercise table...") | |
exercise_data = db_connector.get_exercise_data() | |
for row in exercise_data: | |
content = f"{row['Title']}: {row['Text']}" | |
if row['Organisation']: | |
content += f" Organisation: {row['Organisation']}" | |
if row['Rules']: | |
content += f" Rules: {row['Rules']}" | |
embedding = sentence_model.encode(content, convert_to_tensor=False) | |
embeddings.append(embedding) | |
metadata.append({ | |
"id": f"exercise_{row['Id']}", | |
"type": "exercise", | |
"title": row['Title'][:100], | |
"content": row['Text'][:200] + "..." if len(row['Text']) > 200 else row['Text'], | |
"source_table": "Exercise" | |
}) | |
# Process Serie table (Title -> Description) | |
logging.info("Processing Serie table...") | |
serie_data = db_connector.get_serie_data() | |
for row in serie_data: | |
content = f"{row['Title']}: {row['Description']}" | |
embedding = sentence_model.encode(content, convert_to_tensor=False) | |
embeddings.append(embedding) | |
metadata.append({ | |
"id": f"serie_{row['Id']}", | |
"type": "serie", | |
"title": row['Title'][:100], | |
"content": row['Description'][:200] + "..." if len(row['Description']) > 200 else row['Description'], | |
"source_table": "Serie" | |
}) | |
# Process Multimedia table (Title -> URL) | |
logging.info("Processing Multimedia table...") | |
multimedia_data = db_connector.get_multimedia_data() | |
for row in multimedia_data: | |
content = f"{row['Title']}" | |
if row.get('Description'): | |
content += f": {row['Description']}" | |
embedding = sentence_model.encode(content, convert_to_tensor=False) | |
embeddings.append(embedding) | |
metadata.append({ | |
"id": f"multimedia_{row['Id']}", | |
"type": "multimedia", | |
"title": row['Title'][:100], | |
"url": row['Url'], | |
"source_table": "Multimedia" | |
}) | |
if embeddings: | |
embeddings_np = np.array(embeddings, dtype=np.float32) | |
dimension = embeddings_np.shape[1] | |
faiss_index = faiss.IndexFlatIP(dimension) | |
faiss.normalize_L2(embeddings_np) | |
faiss_index.add(embeddings_np) | |
# Save embeddings and metadata | |
np.save(EMBEDDINGS_PATH, embeddings_np) | |
with open(METADATA_PATH, "w") as f: | |
json.dump(metadata, f, indent=2) | |
faiss.write_index(faiss_index, INDEX_PATH) | |
logging.info(f"Generated and saved {len(embeddings)} embeddings from HockeyFood database") | |
else: | |
logging.error("No valid data found in database tables") | |
raise RuntimeError("No valid data found in database tables") | |
finally: | |
db_connector.disconnect() | |
# Hockey-specific translation dictionary | |
hockey_translation_dict = { | |
"schiettips": "shooting tips", | |
"schieten": "shooting", | |
"backhand": "backhand", | |
"backhandschoten": "backhand shooting", | |
"achterhand": "backhand", | |
"veldhockey": "field hockey", | |
"strafcorner": "penalty corner", | |
"sleepflick": "drag flick", | |
"doelman": "goalkeeper", | |
"aanvaller": "forward", | |
"verdediger": "defender", | |
"middenvelder": "midfielder", | |
"stickbeheersing": "stick handling", | |
"balbeheersing": "ball control", | |
"hockeyoefeningen": "hockey drills", | |
"oefeningen": "drills", | |
"kinderen": "kids", | |
"verbeteren": "improve" | |
} | |
# Hockey keywords for domain detection | |
hockey_keywords = [ | |
"hockey", "field hockey", "veldhockey", "match", "wedstrijd", "game", "spel", "goal", "doelpunt", | |
"score", "scoren", "ball", "bal", "stick", "hockeystick", "field", "veld", "turf", "kunstgras", | |
"shooting", "schieten", "schiet", "backhand shooting", "backhandschoten", "passing", "passen", | |
"backhand", "achterhand", "forehand", "voorhand", "drag flick", "sleeppush", "push pass", | |
"training", "oefening", "exercise", "oefenen", "drill", "oefensessie", "practice", "praktijk", | |
"coach", "trainer", "goalkeeper", "doelman", "keeper", "goalie", "defender", "verdediger", | |
"midfielder", "middenvelder", "forward", "aanvaller", "striker", "spits" | |
] | |
# Greetings for detection | |
greetings = [ | |
"hey", "hello", "hi", "hiya", "yo", "what's up", "sup", "good morning", "good afternoon", | |
"good evening", "good night", "howdy", "greetings", "morning", "evening", "hallo", "hoi", | |
"goedemorgen", "goedemiddag", "goedenavond", "goedennacht", "hé", "joe", "moi", "dag", | |
"goedendag" | |
] | |
def preprocess_prompt(prompt: str, user_lang: str) -> tuple[str, str]: | |
"""Preprocess prompt and return both translated and original prompt""" | |
if not prompt or not isinstance(prompt, str): | |
return prompt, prompt | |
prompt_lower = prompt.lower().strip() | |
if user_lang == "nl": | |
# Apply hockey-specific translations | |
for dutch_term, english_term in hockey_translation_dict.items(): | |
prompt_lower = re.sub(rf'\b{re.escape(dutch_term)}\b', english_term, prompt_lower) | |
try: | |
translated = GoogleTranslator(source="nl", target="en").translate(prompt_lower) | |
return translated if translated else prompt_lower, prompt | |
except Exception as e: | |
logging.error(f"Translation error: {str(e)}") | |
return prompt_lower, prompt | |
return prompt_lower, prompt | |
def is_in_domain(prompt: str) -> bool: | |
"""Check if prompt is hockey-related""" | |
if not prompt or not isinstance(prompt, str): | |
return False | |
prompt_lower = prompt.lower().strip() | |
has_hockey_keywords = any( | |
re.search(rf'\b{re.escape(word)}\b|\b{re.escape(word[:-1])}\w*\b', prompt_lower) | |
for word in hockey_keywords | |
) | |
if sentence_model is not None: | |
try: | |
prompt_embedding = sentence_model.encode(prompt_lower, convert_to_tensor=True) | |
hockey_reference = "Field hockey training, drills, strategies, rules, techniques, or tutorials" | |
hockey_embedding = sentence_model.encode(hockey_reference, convert_to_tensor=True) | |
similarity = util.cos_sim(prompt_embedding, hockey_embedding).item() | |
return has_hockey_keywords or similarity > 0.3 | |
except Exception as e: | |
logging.warning(f"Semantic similarity check failed: {e}") | |
pass | |
return has_hockey_keywords | |
def is_greeting_or_vague(prompt: str, user_lang: str) -> bool: | |
"""Check if prompt is a greeting or too vague""" | |
if not prompt or not isinstance(prompt, str): | |
return True | |
prompt_lower = prompt.lower().strip() | |
is_greeting = any(greeting in prompt_lower for greeting in greetings) | |
has_hockey_keywords = any( | |
re.search(rf'\b{re.escape(word)}\b|\b{re.escape(word[:-1])}\w*\b', prompt_lower) | |
for word in hockey_keywords | |
) | |
return is_greeting and not has_hockey_keywords | |
def search_hockey_content(english_query: str, dutch_query: str) -> list: | |
"""Search HockeyFood database content using semantic similarity""" | |
if not is_in_domain(english_query): | |
logging.info("Query is out of domain, skipping database search.") | |
return [] | |
if sentence_model is None or faiss_index is None or not metadata: | |
logging.info("Search resources not available, skipping content search.") | |
return [] | |
try: | |
# Encode query | |
english_embedding = sentence_model.encode(english_query, convert_to_tensor=False) | |
english_embedding = np.array(english_embedding).astype("float32").reshape(1, -1) | |
faiss.normalize_L2(english_embedding) | |
# Search FAISS index | |
distances, indices = faiss_index.search(english_embedding, 5) # Top 5 results | |
results = [] | |
for idx, sim in zip(indices[0], distances[0]): | |
if idx < len(metadata) and sim > 0.3: # Similarity threshold | |
item = metadata[idx] | |
result = { | |
"title": item["title"], | |
"type": item["type"], | |
"source_table": item["source_table"], | |
"similarity": float(sim) | |
} | |
# Add URL for multimedia items | |
if item["type"] == "multimedia" and "url" in item: | |
result["url"] = item["url"] | |
else: | |
result["content"] = item.get("content", "") | |
results.append(result) | |
logging.info(f"Found {len(results)} relevant content items") | |
return results | |
except Exception as e: | |
logging.error(f"Content search error: {e}") | |
return [] | |
def get_conversation_history(user_role: str, user_team: str) -> str: | |
"""Get conversation history for user session""" | |
session_key = f"{user_role}|{user_team}" | |
history = conversation_histories.get(session_key, []) | |
formatted_history = "\n".join([f"User: {q}\nCoach: {a}" for q, a in history[-3:]]) | |
return formatted_history | |
def update_conversation_history(user_role: str, user_team: str, question: str, answer: str): | |
"""Update conversation history for user session""" | |
session_key = f"{user_role}|{user_team}" | |
history = conversation_histories.get(session_key, []) | |
history.append((question, answer)) | |
conversation_histories[session_key] = history[-3:] | |
def translate_text(text: str, source_lang: str, target_lang: str) -> str: | |
"""Translate text between languages""" | |
if not text or not isinstance(text, str) or source_lang == target_lang: | |
return text | |
try: | |
translated = GoogleTranslator(source=source_lang, target=target_lang).translate(text) | |
return translated | |
except Exception as e: | |
logging.error(f"Translation error: {str(e)}") | |
return text | |
async def agentic_hockey_chat(user_active_role: str, user_team: str, user_prompt: str) -> dict: | |
"""Main chat function with HockeyFood database integration""" | |
logging.info(f"Processing question: {user_prompt}, role: {user_active_role}, team: {user_team}") | |
# Sanitize user prompt | |
if not user_prompt or not isinstance(user_prompt, str): | |
logging.error("Invalid or empty user_prompt.") | |
return {"ai_response": "Question cannot be empty.", "recommended_content_details": []} | |
user_prompt = re.sub(r'\s+', ' ', user_prompt.strip()) | |
try: | |
user_lang = detect(user_prompt) | |
if user_lang not in ["en", "nl"]: | |
user_lang = "en" | |
except Exception: | |
user_lang = "en" | |
# Get both translated and original prompts | |
processing_prompt, original_prompt = preprocess_prompt(user_prompt, user_lang) | |
logging.info(f"Processing prompt: {processing_prompt}") | |
# Handle greetings | |
if is_greeting_or_vague(user_prompt, user_lang): | |
answer = "Hello! How can I assist you with hockey, training, or other topics?" if user_lang == "en" else "Hallo! Waarmee kan ik je helpen met betrekking tot hockey, training of andere onderwerpen?" | |
update_conversation_history(user_active_role, user_team, user_prompt, answer) | |
return {"ai_response": answer, "recommended_content_details": []} | |
# Check domain | |
if not is_in_domain(processing_prompt): | |
answer = "Sorry, I can only assist with questions about hockey, such as training, drills, strategies, rules, and tutorials. Please ask a hockey-related question!" if user_lang == "en" else "Sorry, ik kan alleen helpen met vragen over hockey, zoals training, oefeningen, strategieën, regels en tutorials. Stel me een hockeygerelateerde vraag!" | |
update_conversation_history(user_active_role, user_team, user_prompt, answer) | |
return {"ai_response": answer, "recommended_content_details": []} | |
history = get_conversation_history(user_active_role, user_team) | |
system_prompt = ( | |
f"You are an AI Assistant Bot specialized in field hockey, including training, drills, strategies, rules, and more. " | |
f"You communicate with a {user_active_role} from the team {user_team}. " | |
f"Provide concise, practical, and specific answers tailored to the user's role and team. " | |
f"Focus on field hockey-related topics such as training, drills, strategies, rules, and tutorials.\n\n" | |
f"Recent conversation:\n{history or 'No previous conversations.'}\n\n" | |
f"Answer the following question in English:\n{processing_prompt}" | |
) | |
# Check if running on HuggingFace - use reasonable token limits | |
is_huggingface = os.getenv("SPACE_ID") is not None | |
max_tokens = 150 if is_huggingface else 200 | |
payload = { | |
"model": "openai/gpt-4o", | |
"messages": [ | |
{"role": "system", "content": system_prompt} | |
], | |
"max_tokens": max_tokens, | |
"temperature": 0.3, | |
"top_p": 0.9 | |
} | |
headers = { | |
"Authorization": f"Bearer {OPENROUTER_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
try: | |
if not OPENROUTER_API_KEY: | |
return {"ai_response": "OpenRouter API key not configured. Please set OPENROUTER_API_KEY environment variable.", "recommended_content_details": []} | |
logging.info("Making OpenRouter API call...") | |
async with httpx.AsyncClient(timeout=60) as client: # Increased timeout | |
response = await client.post(OPENROUTER_API_URL, json=payload, headers=headers) | |
response.raise_for_status() | |
data = response.json() | |
answer = data.get("choices", [{}])[0].get("message", {}).get("content", "").strip() | |
if not answer: | |
logging.error("No answer received from OpenRouter API.") | |
return {"ai_response": "No answer received from the API.", "recommended_content_details": []} | |
# Remove URLs from answer and translate | |
answer = re.sub(r'https?://\S+', '', answer).strip() | |
answer = translate_text(answer, "en", user_lang) | |
# Search for recommended content from HockeyFood database (if available) | |
recommended_content = [] | |
if sentence_model is not None and faiss_index is not None and metadata: | |
logging.info("Searching HockeyFood database for relevant content...") | |
recommended_content = search_hockey_content(processing_prompt, original_prompt if user_lang == "nl" else "") | |
else: | |
logging.info("Embeddings not available - running without content recommendations") | |
# Format recommended content details with URLs from Multimedia table | |
recommended_content_details = [] | |
for item in recommended_content: | |
content_detail = { | |
"title": item["title"], | |
"type": item["type"], | |
"source": item["source_table"] | |
} | |
# Add URL for multimedia items, content for others | |
if item["type"] == "multimedia" and "url" in item: | |
content_detail["url"] = item["url"] | |
else: | |
content_detail["content"] = item.get("content", "") | |
recommended_content_details.append(content_detail) | |
update_conversation_history(user_active_role, user_team, user_prompt, answer) | |
return {"ai_response": answer, "recommended_content_details": recommended_content_details} | |
except httpx.HTTPStatusError as e: | |
logging.error(f"OpenRouter API error: Status {e.response.status_code}") | |
return {"ai_response": f"API error: {e.response.status_code} - {e.response.text}", "recommended_content_details": []} | |
except httpx.TimeoutException: | |
logging.error("OpenRouter API timeout") | |
return {"ai_response": "Request timed out. Please try again.", "recommended_content_details": []} | |
except httpx.NetworkError as e: | |
logging.error(f"Network error: {str(e)}") | |
return {"ai_response": "Network error occurred. Please check your connection and try again.", "recommended_content_details": []} | |
except Exception as e: | |
logging.error(f"Internal error: {str(e)}") | |
return {"ai_response": f"Internal error: {str(e)}", "recommended_content_details": []} | |
# Initialize resources on import - graceful fallback for HuggingFace | |
try: | |
load_resources() | |
logging.info("Successfully initialized Original_OpenAPI_DB with HockeyFood database integration") | |
except Exception as e: | |
logging.warning(f"Failed to initialize full resources: {e}") | |
logging.info("Running in limited mode - basic hockey advice available without database features") | |
# Set safe defaults for HuggingFace deployment | |
sentence_model = None | |
faiss_index = None | |
embeddings_np = None | |
metadata = [] |