File size: 9,036 Bytes
b45bb89 28673b1 b45bb89 28673b1 b45bb89 cb144d5 b45bb89 28673b1 b45bb89 cb144d5 b45bb89 28673b1 b45bb89 28673b1 b45bb89 28673b1 b45bb89 28673b1 b45bb89 28673b1 b45bb89 28673b1 b45bb89 28673b1 b45bb89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
"""HTTP-based Supabase connector for OCR Arena votes.
This module provides a connection to Supabase using HTTP requests,
avoiding the dependency issues with the supabase client library.
"""
import logging
import requests
import json
import math
from typing import Dict, Any, List
from dotenv import load_dotenv
import os
load_dotenv()
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
logger = logging.getLogger(__name__)
# Supabase API configuration
API_BASE_URL = f"{SUPABASE_URL}/rest/v1"
HEADERS = {
"apikey": SUPABASE_KEY,
"Authorization": f"Bearer {SUPABASE_KEY}",
"Content-Type": "application/json"
}
def test_connection() -> bool:
"""Test the Supabase connection."""
try:
# Test connection by trying to access the ocr_votes table
table_url = f"{API_BASE_URL}/ocr_votes"
response = requests.get(table_url, headers=HEADERS)
if response.status_code in [200, 404]: # 200 = table exists, 404 = table doesn't exist but connection works
logger.info("✅ Supabase connection test successful")
return True
else:
logger.error(f"❌ Supabase connection failed: {response.status_code}")
return False
except Exception as e:
logger.error(f"❌ Supabase connection test failed: {e}")
return False
def test_table_exists(table_name: str = "ocr_votes") -> bool:
"""Test if a specific table exists in the database."""
try:
table_url = f"{API_BASE_URL}/{table_name}"
response = requests.get(table_url, headers=HEADERS)
if response.status_code == 200:
logger.info(f"✅ Table '{table_name}' exists and is accessible")
return True
else:
logger.warning(f"⚠️ Table '{table_name}' may not exist: {response.status_code}")
return False
except Exception as e:
logger.error(f"❌ Error testing table access: {e}")
return False
def add_vote(
username: str,
model_a: str,
model_b: str,
model_a_output: str,
model_b_output: str,
vote: str,
image_url: str
) -> Dict[str, Any]:
"""Add a vote to the ocr_votes table."""
try:
# Format timestamp in the desired format: YYYY-MM-DD HH:MM:SS
from datetime import datetime
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
data = {
"username": username,
"model_a": model_a,
"model_b": model_b,
"model_a_output": model_a_output,
"model_b_output": model_b_output,
"vote": vote,
"image_url": image_url,
"timestamp": timestamp
}
table_url = f"{API_BASE_URL}/ocr_votes"
response = requests.post(table_url, headers=HEADERS, json=data)
if response.status_code == 201:
logger.info("✅ Vote added successfully")
try:
return response.json()[0] if response.json() else data
except json.JSONDecodeError:
return data
else:
raise Exception(f"Insert failed with status {response.status_code}: {response.text}")
except Exception as e:
logger.error(f"❌ Error adding vote: {e}")
raise
def get_all_votes() -> List[Dict[str, Any]]:
"""Get all votes from the ocr_votes table."""
try:
table_url = f"{API_BASE_URL}/ocr_votes"
response = requests.get(table_url, headers=HEADERS)
if response.status_code == 200:
try:
return response.json()
except json.JSONDecodeError:
logger.warning("Could not parse JSON response")
return []
else:
logger.error(f"Failed to get votes: {response.status_code}")
return []
except Exception as e:
logger.error(f"❌ Error getting votes: {e}")
return []
def test_add_sample_vote() -> bool:
"""Test adding a sample vote to the database."""
try:
sample_vote = add_vote(
username="test_user",
model_a="gemini",
model_b="mistral",
model_a_output="# Test Gemini Output\n\nThis is a **test** markdown from Gemini.",
model_b_output="## Test Mistral Output\n\nThis is a *test* markdown from Mistral.",
vote="model_a",
image_url="data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAYEBQYFBAYGBQYHBwYIChAKCgkJChQODwwQFxQYGBcUFhYaHSUfGhsjHBYWICwgIyYnKSopGR8tMC0oMCUoKSj/2wBDAQcHBwoIChMKChMoGhYaKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCj/wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAv/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k="
)
logger.info(f"✅ Sample vote added: {sample_vote}")
return True
except Exception as e:
logger.error(f"❌ Error adding sample vote: {e}")
return False
def get_vote_statistics() -> Dict[str, Any]:
"""Get voting statistics."""
try:
votes = get_all_votes()
# Count votes for each model
gemini_votes = 0
mistral_votes = 0
openai_votes = 0
total_votes = len(votes)
for vote in votes:
vote_choice = vote.get('vote')
model_a = vote.get('model_a')
model_b = vote.get('model_b')
if vote_choice == 'model_a':
if model_a == 'gemini':
gemini_votes += 1
elif model_a == 'mistral':
mistral_votes += 1
elif model_a == 'openai':
openai_votes += 1
elif vote_choice == 'model_b':
if model_b == 'gemini':
gemini_votes += 1
elif model_b == 'mistral':
mistral_votes += 1
elif model_b == 'openai':
openai_votes += 1
return {
"total_votes": total_votes,
"gemini_votes": gemini_votes,
"mistral_votes": mistral_votes,
"openai_votes": openai_votes,
"gemini_percentage": (gemini_votes / total_votes * 100) if total_votes > 0 else 0,
"mistral_percentage": (mistral_votes / total_votes * 100) if total_votes > 0 else 0,
"openai_percentage": (openai_votes / total_votes * 100) if total_votes > 0 else 0
}
except Exception as e:
logger.error(f"❌ Error getting vote statistics: {e}")
return {}
def calculate_elo_rating(rating_a: float, rating_b: float, result_a: float, k_factor: int = 32) -> tuple[float, float]:
"""
Calculate new ELO ratings for two players after a match.
Args:
rating_a: Current ELO rating of player A
rating_b: Current ELO rating of player B
result_a: Result for player A (1 for win, 0.5 for draw, 0 for loss)
k_factor: K-factor determines how much a single result affects the rating
Returns:
tuple: (new_rating_a, new_rating_b)
"""
# Calculate expected scores
expected_a = 1 / (1 + 10 ** ((rating_b - rating_a) / 400))
expected_b = 1 / (1 + 10 ** ((rating_a - rating_b) / 400))
# Calculate new ratings
new_rating_a = rating_a + k_factor * (result_a - expected_a)
new_rating_b = rating_b + k_factor * ((1 - result_a) - expected_b)
return new_rating_a, new_rating_b
def calculate_elo_ratings_from_votes(votes: List[Dict[str, Any]]) -> Dict[str, float]:
"""
Calculate ELO ratings for all models based on vote history.
Args:
votes: List of vote dictionaries from database
Returns:
dict: Current ELO ratings for each model
"""
# Initialize ELO ratings (starting at 1500)
elo_ratings = {
"gemini": 1500,
"mistral": 1500,
"openai": 1500
}
# Process each vote to update ELO ratings
for vote in votes:
model_a = vote.get('model_a')
model_b = vote.get('model_b')
vote_choice = vote.get('vote')
if model_a and model_b and vote_choice:
# Determine result for model A
if vote_choice == 'model_a':
result_a = 1 # Model A wins
elif vote_choice == 'model_b':
result_a = 0 # Model A loses
else:
continue # Skip invalid votes
# Calculate new ELO ratings
new_rating_a, new_rating_b = calculate_elo_rating(
elo_ratings[model_a],
elo_ratings[model_b],
result_a
)
# Update ratings
elo_ratings[model_a] = new_rating_a
elo_ratings[model_b] = new_rating_b
return elo_ratings
if __name__ == "__main__":
print(test_connection())
print(test_add_sample_vote()) |