davanstrien's picture
davanstrien HF staff
refactor trending models fetching to improve parameter filtering and logging
0b973e8
import asyncio
import logging
import os
import sys
from contextlib import asynccontextmanager
from datetime import datetime
from typing import List, Optional
import chromadb
import dateutil.parser
import httpx
import polars as pl
import torch
from cashews import cache
from chromadb.utils import embedding_functions
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer
# Configuration constants
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
BATCH_SIZE = 2000
CACHE_TTL = "24h"
TRENDING_CACHE_TTL = "1h" # 15 minutes cache for trending data
if torch.cuda.is_available():
DEVICE = "cuda"
elif torch.backends.mps.is_available():
DEVICE = "mps"
else:
DEVICE = "cpu"
tokenizer = AutoTokenizer.from_pretrained(
"davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LOCAL = False
if sys.platform == "darwin":
LOCAL = True
DATA_DIR = "data" if LOCAL else "/data"
# Configure cache
cache.setup("mem://", size_limit="8gb")
# Initialize ChromaDB client
client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma")
# Initialize FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup
setup_database()
yield
# Cleanup
await cache.close()
app = FastAPI(lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://*.hf.space", # Allow all Hugging Face Spaces
"https://*.huggingface.co", # Allow all Hugging Face domains
# "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the embedding function at module level
def get_embedding_function():
logger.info(f"Using device: {DEVICE}")
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="nomic-ai/modernbert-embed-base", device=DEVICE
)
def setup_database():
try:
embedding_function = get_embedding_function()
dataset_collection = client.get_or_create_collection(
embedding_function=embedding_function,
name="dataset_cards",
metadata={"hnsw:space": "cosine"},
)
model_collection = client.get_or_create_collection(
embedding_function=embedding_function,
name="model_cards",
metadata={"hnsw:space": "cosine"},
)
# Load dataset data
df = pl.scan_parquet(
"hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
)
df = df.filter(
pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
)
df = df.filter(
pl.col("datasetId")
.str.contains_any(
["gemma-2-2B-it-thinking-function_calling-V0"]
) # course model that's not useful for retrieving
.not_()
)
# Get the most recent last_modified date from the collection
latest_update = None
if dataset_collection.count() > 0:
metadata = dataset_collection.get(include=["metadatas"]).get("metadatas")
logger.info(f"Found {len(metadata)} existing records in collection")
last_modifieds = [
dateutil.parser.parse(m.get("last_modified")) for m in metadata
]
latest_update = max(last_modifieds)
logger.info(f"Most recent record in DB from: {latest_update}")
logger.info(f"Oldest record in DB from: {min(last_modifieds)}")
# Filter and process only newer records
df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"])
# Log some stats about the incoming data
sample_dates = df.select("last_modified").limit(5).collect()
logger.info(f"Sample of incoming dates: {sample_dates}")
total_incoming = df.select(pl.len()).collect().item()
logger.info(f"Total incoming records: {total_incoming}")
if latest_update:
logger.info(f"Filtering records newer than {latest_update}")
df = df.filter(pl.col("last_modified") > latest_update)
filtered_count = df.select(pl.len()).collect().item()
logger.info(f"Found {filtered_count} records to update after filtering")
df = df.collect()
total_rows = len(df)
if total_rows > 0:
logger.info(f"Updating dataset collection with {total_rows} new records")
logger.info(
f"Date range of updates: {df['last_modified'].min()} to {df['last_modified'].max()}"
)
for i in range(0, total_rows, BATCH_SIZE):
batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
batch_size = len(batch_df)
logger.info(
f"Processing batch {i // BATCH_SIZE + 1}: {batch_size} records "
f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})"
)
dataset_collection.upsert(
ids=batch_df.select(["datasetId"]).to_series().to_list(),
documents=batch_df.select(["summary"]).to_series().to_list(),
metadatas=[
{
"likes": int(likes),
"downloads": int(downloads),
"last_modified": str(last_modified),
}
for likes, downloads, last_modified in zip(
batch_df.select(["likes"]).to_series().to_list(),
batch_df.select(["downloads"]).to_series().to_list(),
batch_df.select(["last_modified"]).to_series().to_list(),
)
],
)
logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records")
logger.info(
f"Database initialized with {dataset_collection.count():,} total rows"
)
# Load model data
model_df = pl.scan_parquet(
"hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
)
model_row_count = model_df.select(pl.len()).collect().item()
logger.info(f"Row count of new model data: {model_row_count}")
if model_collection.count() < model_row_count:
model_df = model_df.select(
[
"modelId",
"summary",
"likes",
"downloads",
"last_modified",
"param_count",
]
)
model_df = model_df.collect()
total_rows = len(model_df)
for i in range(0, total_rows, BATCH_SIZE):
batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i))
model_collection.upsert(
ids=batch_df.select(["modelId"]).to_series().to_list(),
documents=batch_df.select(["summary"]).to_series().to_list(),
metadatas=[
{
"likes": int(likes),
"downloads": int(downloads),
"last_modified": str(last_modified),
"param_count": int(param_count)
if param_count is not None
else 0,
}
for likes, downloads, last_modified, param_count in zip(
batch_df.select(["likes"]).to_series().to_list(),
batch_df.select(["downloads"]).to_series().to_list(),
batch_df.select(["last_modified"]).to_series().to_list(),
batch_df.select(["param_count"]).to_series().to_list(),
)
],
)
logger.info(
f"Processed {i + len(batch_df):,} / {total_rows:,} model rows"
)
logger.info(
f"Model database initialized with {model_collection.count():,} rows"
)
except Exception as e:
logger.error(f"Setup error: {e}")
# Run setup on startup
setup_database()
class QueryResult(BaseModel):
dataset_id: str
similarity: float
summary: str
likes: int
downloads: int
class QueryResponse(BaseModel):
results: List[QueryResult]
class ModelQueryResult(BaseModel):
model_id: str
similarity: float
summary: str
likes: int
downloads: int
param_count: Optional[int] = None
class ModelQueryResponse(BaseModel):
results: List[ModelQueryResult]
@app.get("/")
async def redirect_to_docs():
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/docs")
@app.get("/search/datasets", response_model=QueryResponse)
@cache(ttl=CACHE_TTL)
async def search_datasets(
query: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads", "trending"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
collection = client.get_collection(
name="dataset_cards", embedding_function=get_embedding_function()
)
results = collection.query(
query_texts=[f"search_query: {query}"],
n_results=k * 4 if sort_by != "similarity" else k,
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
query_results = await process_search_results(results, "dataset", k, sort_by)
return QueryResponse(results=query_results)
except Exception as e:
logger.error(f"Search error: {str(e)}")
raise HTTPException(status_code=500, detail="Search failed")
@app.get("/similarity/datasets", response_model=QueryResponse)
@cache(ttl=CACHE_TTL)
async def find_similar_datasets(
dataset_id: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads", "trending"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
collection = client.get_collection("dataset_cards")
results = collection.get(ids=[dataset_id], include=["embeddings"])
if not results["ids"]:
raise HTTPException(
status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
)
results = collection.query(
query_embeddings=[results["embeddings"][0]],
n_results=k * 4 if sort_by != "similarity" else k + 1,
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
query_results = await process_search_results(
results, "dataset", k, sort_by, dataset_id
)
return QueryResponse(results=query_results)
except HTTPException:
raise
except Exception as e:
logger.error(f"Similarity search error: {str(e)}")
raise HTTPException(status_code=500, detail="Similarity search failed")
@app.get("/search/models", response_model=ModelQueryResponse)
@cache(ttl=CACHE_TTL)
async def search_models(
query: str,
k: int = Query(default=5, ge=1, le=100, description="Number of results to return"),
sort_by: str = Query(
default="similarity",
enum=["similarity", "likes", "downloads", "trending"],
description="Sort method for results",
),
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
min_param_count: int = Query(
default=0,
ge=0,
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
),
max_param_count: Optional[int] = Query(
default=None,
ge=0,
description="Maximum parameter count (None means no upper limit)",
),
):
"""
Search for models based on a text query with optional filtering.
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
- param_count=0 indicates missing/unknown parameter count in the dataset
"""
try:
collection = client.get_collection(
name="model_cards", embedding_function=get_embedding_function()
)
where_conditions = []
if min_likes > 0:
where_conditions.append({"likes": {"$gte": min_likes}})
if min_downloads > 0:
where_conditions.append({"downloads": {"$gte": min_downloads}})
# Add parameter count filters
using_param_filters = min_param_count > 0 or max_param_count is not None
if using_param_filters:
# Always exclude zero param count when using any parameter filters
where_conditions.append({"param_count": {"$gt": 0}})
if min_param_count > 0:
where_conditions.append({"param_count": {"$gte": min_param_count}})
if max_param_count is not None:
where_conditions.append({"param_count": {"$lte": max_param_count}})
# Handle where clause creation based on number of conditions
where_clause = None
if len(where_conditions) > 1:
where_clause = {"$and": where_conditions}
elif len(where_conditions) == 1:
where_clause = where_conditions[0] # Single condition without $and
results = collection.query(
query_texts=[f"search_query: {query}"],
n_results=k * 4 if sort_by != "similarity" else k,
where=where_clause,
)
query_results = await process_search_results(results, "model", k, sort_by)
return ModelQueryResponse(results=query_results)
except Exception as e:
logger.error(f"Model search error: {str(e)}")
raise HTTPException(status_code=500, detail="Model search failed")
@app.get("/similarity/models", response_model=ModelQueryResponse)
@cache(ttl=CACHE_TTL)
async def find_similar_models(
model_id: str,
k: int = Query(default=5, ge=1, le=100, description="Number of results to return"),
sort_by: str = Query(
default="similarity",
enum=["similarity", "likes", "downloads", "trending"],
description="Sort method for results",
),
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
min_param_count: int = Query(
default=0,
ge=0,
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
),
max_param_count: Optional[int] = Query(
default=None,
ge=0,
description="Maximum parameter count (None means no upper limit)",
),
):
"""
Find similar models to a specified model with optional filtering.
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
- param_count=0 indicates missing/unknown parameter count in the dataset
"""
try:
collection = client.get_collection("model_cards")
results = collection.get(ids=[model_id], include=["embeddings"])
if not results["ids"]:
raise HTTPException(
status_code=404, detail=f"Model ID '{model_id}' not found"
)
where_conditions = []
if min_likes > 0:
where_conditions.append({"likes": {"$gte": min_likes}})
if min_downloads > 0:
where_conditions.append({"downloads": {"$gte": min_downloads}})
# Add parameter count filters
using_param_filters = min_param_count > 0 or max_param_count is not None
if using_param_filters:
# Always exclude zero param count when using any parameter filters
where_conditions.append({"param_count": {"$gt": 0}})
if min_param_count > 0:
where_conditions.append({"param_count": {"$gte": min_param_count}})
if max_param_count is not None:
where_conditions.append({"param_count": {"$lte": max_param_count}})
# Handle where clause creation based on number of conditions
where_clause = None
if len(where_conditions) > 1:
where_clause = {"$and": where_conditions}
elif len(where_conditions) == 1:
where_clause = where_conditions[0] # Single condition without $and
results = collection.query(
query_embeddings=[results["embeddings"][0]],
n_results=k * 4 if sort_by != "similarity" else k + 1,
where=where_clause,
)
query_results = await process_search_results(
results, "model", k, sort_by, model_id
)
return ModelQueryResponse(results=query_results)
except HTTPException:
raise
except Exception as e:
logger.error(f"Model similarity search error: {str(e)}")
raise HTTPException(status_code=500, detail="Model similarity search failed")
@cache(ttl="1h")
async def get_trending_score(item_id: str, item_type: str) -> float:
"""Fetch trending score for a model or dataset from HuggingFace API"""
try:
async with httpx.AsyncClient() as client:
endpoint = "models" if item_type == "model" else "datasets"
response = await client.get(
f"https://huggingface.co/api/{endpoint}/{item_id}?expand=trendingScore"
)
response.raise_for_status()
return response.json().get("trendingScore", 0)
except Exception as e:
logger.error(
f"Error fetching trending score for {item_type} {item_id}: {str(e)}"
)
return 0
async def process_search_results(results, id_field, k, sort_by, exclude_id=None):
"""Process search results into a standardized format."""
query_results = []
# Create base results
for i in range(len(results["ids"][0])):
current_id = results["ids"][0][i]
if exclude_id and current_id == exclude_id:
continue
result = {
f"{id_field}_id": current_id,
"similarity": float(results["distances"][0][i]),
"summary": results["documents"][0][i],
"likes": results["metadatas"][0][i]["likes"],
"downloads": results["metadatas"][0][i]["downloads"],
}
# Add param_count for models if it exists in metadata
if id_field == "model" and "param_count" in results["metadatas"][0][i]:
result["param_count"] = results["metadatas"][0][i]["param_count"]
if id_field == "dataset":
query_results.append(QueryResult(**result))
else:
query_results.append(ModelQueryResult(**result))
# Handle sorting
if sort_by == "trending":
# Fetch trending scores for all results
trending_scores = {}
async with httpx.AsyncClient() as client:
tasks = [
get_trending_score(
getattr(result, f"{id_field}_id"),
"model" if id_field == "model" else "dataset",
)
for result in query_results
]
scores = await asyncio.gather(*tasks)
trending_scores = {
getattr(result, f"{id_field}_id"): score
for result, score in zip(query_results, scores)
}
# Sort by trending score
query_results.sort(
key=lambda x: trending_scores.get(getattr(x, f"{id_field}_id"), 0),
reverse=True,
)
query_results = query_results[:k]
elif sort_by != "similarity":
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
query_results = query_results[:k]
elif exclude_id: # We fetched extra for similarity + exclude_id case
query_results = query_results[:k]
return query_results
async def fetch_trending_models():
"""Fetch trending models from HuggingFace API"""
async with httpx.AsyncClient() as client:
response = await client.get("https://huggingface.co/api/models")
response.raise_for_status()
return response.json()
@cache(ttl=TRENDING_CACHE_TTL)
async def get_trending_models_with_summaries(
limit: int = 10,
min_likes: int = 0,
min_downloads: int = 0,
min_param_count: int = 0,
max_param_count: Optional[int] = None,
) -> List[ModelQueryResult]:
"""Fetch trending models and combine with summaries from database"""
try:
# Fetch trending models
trending_models = await fetch_trending_models()
# Filter by minimum likes/downloads
trending_models = [
model
for model in trending_models
if model.get("likes", 0) >= min_likes
and model.get("downloads", 0) >= min_downloads
]
# Sort by trending score
trending_models = sorted(
trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True
)
# Fetch up to 3x the limit (buffer for filtering) or all available if fewer
# This ensures we have enough models to filter from
fetch_limit = min(len(trending_models), limit * 3)
trending_models = trending_models[:fetch_limit]
# Get model IDs
model_ids = [model["modelId"] for model in trending_models]
# Fetch summaries from ChromaDB
collection = client.get_collection("model_cards")
summaries = collection.get(ids=model_ids, include=["documents", "metadatas"])
# Create mapping of model_id to summary and metadata
id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
id_to_metadata = dict(zip(summaries["ids"], summaries["metadatas"]))
# Log parameters for debugging
print(
f"Filter params - min_param_count: {min_param_count}, max_param_count: {max_param_count}"
)
# Combine data - collect all results first
all_results = []
for model in trending_models:
if model["modelId"] in id_to_summary:
metadata = id_to_metadata.get(model["modelId"], {})
param_count = metadata.get("param_count", 0)
# Log model parameter counts
print(f"Model: {model['modelId']}, param_count: {param_count}")
result = ModelQueryResult(
model_id=model["modelId"],
similarity=1.0, # Not applicable for trending
summary=id_to_summary[model["modelId"]],
likes=model.get("likes", 0),
downloads=model.get("downloads", 0),
param_count=param_count,
)
all_results.append(result)
# Apply parameter filtering after collecting all results
filtered_results = all_results
# Check if any parameter filtering is being applied
using_param_filters = min_param_count > 0 or max_param_count is not None
# Only filter by params if we have specific parameter constraints
if using_param_filters:
filtered_results = []
for result in all_results:
should_include = True
# Always exclude models with param_count=0 when any parameter filtering is active
if result.param_count == 0:
print(
f"Filtering out {result.model_id} - has param_count=0 but parameter filtering is active"
)
should_include = False
# Apply min param filter if specified
elif min_param_count > 0 and result.param_count < min_param_count:
print(
f"Filtering out {result.model_id} - param_count {result.param_count} < min_param_count {min_param_count}"
)
should_include = False
# Apply max param filter if specified
elif (
max_param_count is not None and result.param_count > max_param_count
):
print(
f"Filtering out {result.model_id} - param_count {result.param_count} > max_param_count {max_param_count}"
)
should_include = False
if should_include:
filtered_results.append(result)
print(f"After filtering: {len(filtered_results)} models remain")
# Finally limit to the requested number
return filtered_results[:limit]
except Exception as e:
logger.error(f"Error fetching trending models: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to fetch trending models")
@app.get("/trending/models", response_model=ModelQueryResponse)
async def get_trending_models(
limit: int = Query(
default=10, ge=1, le=100, description="Number of results to return"
),
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
min_param_count: int = Query(
default=0,
ge=0,
description="Minimum parameter count (models with param_count=0 will be excluded if any parameter filter is used)",
),
max_param_count: Optional[int] = Query(
default=None,
ge=0,
description="Maximum parameter count (None means no upper limit)",
),
):
"""
Get trending models with their summaries and optional filtering.
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
- param_count=0 indicates missing/unknown parameter count in the dataset
"""
print(
f"Request for trending models with params: limit={limit}, min_likes={min_likes}, min_downloads={min_downloads}, min_param_count={min_param_count}, max_param_count={max_param_count}"
)
results = await get_trending_models_with_summaries(
limit=limit,
min_likes=min_likes,
min_downloads=min_downloads,
min_param_count=min_param_count,
max_param_count=max_param_count,
)
print(f"Returning {len(results)} trending model results")
return ModelQueryResponse(results=results)
async def fetch_trending_datasets():
"""Fetch trending datasets from HuggingFace API"""
async with httpx.AsyncClient() as client:
response = await client.get("https://huggingface.co/api/datasets")
response.raise_for_status()
return response.json()
@cache(ttl=TRENDING_CACHE_TTL)
async def get_trending_datasets_with_summaries(
limit: int = 10,
min_likes: int = 0,
min_downloads: int = 0,
) -> List[QueryResult]:
"""Fetch trending datasets and combine with summaries from database"""
try:
# Fetch trending datasets
trending_datasets = await fetch_trending_datasets()
# Filter by minimum likes/downloads
trending_datasets = [
dataset
for dataset in trending_datasets
if dataset.get("likes", 0) >= min_likes
and dataset.get("downloads", 0) >= min_downloads
]
# Sort by trending score and limit
trending_datasets = sorted(
trending_datasets, key=lambda x: x.get("trendingScore", 0), reverse=True
)[:limit]
# Get dataset IDs
dataset_ids = [dataset["id"] for dataset in trending_datasets]
# Fetch summaries from ChromaDB
collection = client.get_collection("dataset_cards")
summaries = collection.get(ids=dataset_ids, include=["documents"])
# Create mapping of dataset_id to summary
id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
# Combine data
results = []
for dataset in trending_datasets:
if dataset["id"] in id_to_summary:
result = QueryResult(
dataset_id=dataset["id"],
similarity=1.0, # Not applicable for trending
summary=id_to_summary[dataset["id"]],
likes=dataset.get("likes", 0),
downloads=dataset.get("downloads", 0),
)
results.append(result)
return results
except Exception as e:
logger.error(f"Error fetching trending datasets: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to fetch trending datasets")
@app.get("/trending/datasets", response_model=QueryResponse)
async def get_trending_datasets(
limit: int = Query(default=10, ge=1, le=100),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
"""Get trending datasets with their summaries"""
results = await get_trending_datasets_with_summaries(
limit=limit, min_likes=min_likes, min_downloads=min_downloads
)
return QueryResponse(results=results)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)