|
"""HTTP-based Supabase connector for OCR Arena votes. |
|
|
|
This module provides a connection to Supabase using HTTP requests, |
|
avoiding the dependency issues with the supabase client library. |
|
""" |
|
import logging |
|
import requests |
|
import json |
|
import math |
|
from typing import Dict, Any, List |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
load_dotenv() |
|
|
|
SUPABASE_URL = os.getenv("SUPABASE_URL") |
|
SUPABASE_KEY = os.getenv("SUPABASE_KEY") |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
API_BASE_URL = f"{SUPABASE_URL}/rest/v1" |
|
HEADERS = { |
|
"apikey": SUPABASE_KEY, |
|
"Authorization": f"Bearer {SUPABASE_KEY}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
def test_connection() -> bool: |
|
"""Test the Supabase connection.""" |
|
try: |
|
|
|
table_url = f"{API_BASE_URL}/ocr_votes" |
|
response = requests.get(table_url, headers=HEADERS) |
|
if response.status_code in [200, 404]: |
|
logger.info("✅ Supabase connection test successful") |
|
return True |
|
else: |
|
logger.error(f"❌ Supabase connection failed: {response.status_code}") |
|
return False |
|
except Exception as e: |
|
logger.error(f"❌ Supabase connection test failed: {e}") |
|
return False |
|
|
|
def test_table_exists(table_name: str = "ocr_votes") -> bool: |
|
"""Test if a specific table exists in the database.""" |
|
try: |
|
table_url = f"{API_BASE_URL}/{table_name}" |
|
response = requests.get(table_url, headers=HEADERS) |
|
if response.status_code == 200: |
|
logger.info(f"✅ Table '{table_name}' exists and is accessible") |
|
return True |
|
else: |
|
logger.warning(f"⚠️ Table '{table_name}' may not exist: {response.status_code}") |
|
return False |
|
except Exception as e: |
|
logger.error(f"❌ Error testing table access: {e}") |
|
return False |
|
|
|
def add_vote( |
|
username: str, |
|
model_a: str, |
|
model_b: str, |
|
model_a_output: str, |
|
model_b_output: str, |
|
vote: str, |
|
image_url: str |
|
) -> Dict[str, Any]: |
|
"""Add a vote to the ocr_votes table.""" |
|
try: |
|
|
|
from datetime import datetime |
|
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
|
|
data = { |
|
"username": username, |
|
"model_a": model_a, |
|
"model_b": model_b, |
|
"model_a_output": model_a_output, |
|
"model_b_output": model_b_output, |
|
"vote": vote, |
|
"image_url": image_url, |
|
"timestamp": timestamp |
|
} |
|
|
|
table_url = f"{API_BASE_URL}/ocr_votes" |
|
response = requests.post(table_url, headers=HEADERS, json=data) |
|
|
|
if response.status_code == 201: |
|
logger.info("✅ Vote added successfully") |
|
try: |
|
return response.json()[0] if response.json() else data |
|
except json.JSONDecodeError: |
|
return data |
|
else: |
|
raise Exception(f"Insert failed with status {response.status_code}: {response.text}") |
|
|
|
except Exception as e: |
|
logger.error(f"❌ Error adding vote: {e}") |
|
raise |
|
|
|
def get_all_votes() -> List[Dict[str, Any]]: |
|
"""Get all votes from the ocr_votes table.""" |
|
try: |
|
table_url = f"{API_BASE_URL}/ocr_votes" |
|
response = requests.get(table_url, headers=HEADERS) |
|
|
|
if response.status_code == 200: |
|
try: |
|
return response.json() |
|
except json.JSONDecodeError: |
|
logger.warning("Could not parse JSON response") |
|
return [] |
|
else: |
|
logger.error(f"Failed to get votes: {response.status_code}") |
|
return [] |
|
except Exception as e: |
|
logger.error(f"❌ Error getting votes: {e}") |
|
return [] |
|
|
|
def test_add_sample_vote() -> bool: |
|
"""Test adding a sample vote to the database.""" |
|
try: |
|
sample_vote = add_vote( |
|
username="test_user", |
|
model_a="gemini", |
|
model_b="mistral", |
|
model_a_output="# Test Gemini Output\n\nThis is a **test** markdown from Gemini.", |
|
model_b_output="## Test Mistral Output\n\nThis is a *test* markdown from Mistral.", |
|
vote="model_a", |
|
image_url="data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAYEBQYFBAYGBQYHBwYIChAKCgkJChQODwwQFxQYGBcUFhYaHSUfGhsjHBYWICwgIyYnKSopGR8tMC0oMCUoKSj/2wBDAQcHBwoIChMKChMoGhYaKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCj/wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAv/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k=" |
|
) |
|
logger.info(f"✅ Sample vote added: {sample_vote}") |
|
return True |
|
except Exception as e: |
|
logger.error(f"❌ Error adding sample vote: {e}") |
|
return False |
|
|
|
def get_vote_statistics() -> Dict[str, Any]: |
|
"""Get voting statistics.""" |
|
try: |
|
votes = get_all_votes() |
|
|
|
|
|
gemini_votes = 0 |
|
mistral_votes = 0 |
|
openai_votes = 0 |
|
total_votes = len(votes) |
|
|
|
for vote in votes: |
|
vote_choice = vote.get('vote') |
|
model_a = vote.get('model_a') |
|
model_b = vote.get('model_b') |
|
|
|
if vote_choice == 'model_a': |
|
if model_a == 'gemini': |
|
gemini_votes += 1 |
|
elif model_a == 'mistral': |
|
mistral_votes += 1 |
|
elif model_a == 'openai': |
|
openai_votes += 1 |
|
elif vote_choice == 'model_b': |
|
if model_b == 'gemini': |
|
gemini_votes += 1 |
|
elif model_b == 'mistral': |
|
mistral_votes += 1 |
|
elif model_b == 'openai': |
|
openai_votes += 1 |
|
|
|
return { |
|
"total_votes": total_votes, |
|
"gemini_votes": gemini_votes, |
|
"mistral_votes": mistral_votes, |
|
"openai_votes": openai_votes, |
|
"gemini_percentage": (gemini_votes / total_votes * 100) if total_votes > 0 else 0, |
|
"mistral_percentage": (mistral_votes / total_votes * 100) if total_votes > 0 else 0, |
|
"openai_percentage": (openai_votes / total_votes * 100) if total_votes > 0 else 0 |
|
} |
|
except Exception as e: |
|
logger.error(f"❌ Error getting vote statistics: {e}") |
|
return {} |
|
|
|
def calculate_elo_rating(rating_a: float, rating_b: float, result_a: float, k_factor: int = 32) -> tuple[float, float]: |
|
""" |
|
Calculate new ELO ratings for two players after a match. |
|
|
|
Args: |
|
rating_a: Current ELO rating of player A |
|
rating_b: Current ELO rating of player B |
|
result_a: Result for player A (1 for win, 0.5 for draw, 0 for loss) |
|
k_factor: K-factor determines how much a single result affects the rating |
|
|
|
Returns: |
|
tuple: (new_rating_a, new_rating_b) |
|
""" |
|
|
|
expected_a = 1 / (1 + 10 ** ((rating_b - rating_a) / 400)) |
|
expected_b = 1 / (1 + 10 ** ((rating_a - rating_b) / 400)) |
|
|
|
|
|
new_rating_a = rating_a + k_factor * (result_a - expected_a) |
|
new_rating_b = rating_b + k_factor * ((1 - result_a) - expected_b) |
|
|
|
return new_rating_a, new_rating_b |
|
|
|
def calculate_elo_ratings_from_votes(votes: List[Dict[str, Any]]) -> Dict[str, float]: |
|
""" |
|
Calculate ELO ratings for all models based on vote history. |
|
|
|
Args: |
|
votes: List of vote dictionaries from database |
|
|
|
Returns: |
|
dict: Current ELO ratings for each model |
|
""" |
|
|
|
elo_ratings = { |
|
"gemini": 1500, |
|
"mistral": 1500, |
|
"openai": 1500 |
|
} |
|
|
|
|
|
for vote in votes: |
|
model_a = vote.get('model_a') |
|
model_b = vote.get('model_b') |
|
vote_choice = vote.get('vote') |
|
|
|
if model_a and model_b and vote_choice: |
|
|
|
if vote_choice == 'model_a': |
|
result_a = 1 |
|
elif vote_choice == 'model_b': |
|
result_a = 0 |
|
else: |
|
continue |
|
|
|
|
|
new_rating_a, new_rating_b = calculate_elo_rating( |
|
elo_ratings[model_a], |
|
elo_ratings[model_b], |
|
result_a |
|
) |
|
|
|
|
|
elo_ratings[model_a] = new_rating_a |
|
elo_ratings[model_b] = new_rating_b |
|
|
|
return elo_ratings |
|
|
|
if __name__ == "__main__": |
|
print(test_connection()) |
|
print(test_add_sample_vote()) |