Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade

davanstrien
HF staff
refactor trending models fetching to improve parameter filtering and logging
0b973e8
import asyncio | |
import logging | |
import os | |
import sys | |
from contextlib import asynccontextmanager | |
from datetime import datetime | |
from typing import List, Optional | |
import chromadb | |
import dateutil.parser | |
import httpx | |
import polars as pl | |
import torch | |
from cashews import cache | |
from chromadb.utils import embedding_functions | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer | |
# Configuration constants | |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" | |
EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base" | |
BATCH_SIZE = 2000 | |
CACHE_TTL = "24h" | |
TRENDING_CACHE_TTL = "1h" # 15 minutes cache for trending data | |
if torch.cuda.is_available(): | |
DEVICE = "cuda" | |
elif torch.backends.mps.is_available(): | |
DEVICE = "mps" | |
else: | |
DEVICE = "cpu" | |
tokenizer = AutoTokenizer.from_pretrained( | |
"davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" | |
) | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
LOCAL = False | |
if sys.platform == "darwin": | |
LOCAL = True | |
DATA_DIR = "data" if LOCAL else "/data" | |
# Configure cache | |
cache.setup("mem://", size_limit="8gb") | |
# Initialize ChromaDB client | |
client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma") | |
# Initialize FastAPI app | |
async def lifespan(app: FastAPI): | |
# Setup | |
setup_database() | |
yield | |
# Cleanup | |
await cache.close() | |
app = FastAPI(lifespan=lifespan) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=[ | |
"https://*.hf.space", # Allow all Hugging Face Spaces | |
"https://*.huggingface.co", # Allow all Hugging Face domains | |
# "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod | |
], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define the embedding function at module level | |
def get_embedding_function(): | |
logger.info(f"Using device: {DEVICE}") | |
return embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="nomic-ai/modernbert-embed-base", device=DEVICE | |
) | |
def setup_database(): | |
try: | |
embedding_function = get_embedding_function() | |
dataset_collection = client.get_or_create_collection( | |
embedding_function=embedding_function, | |
name="dataset_cards", | |
metadata={"hnsw:space": "cosine"}, | |
) | |
model_collection = client.get_or_create_collection( | |
embedding_function=embedding_function, | |
name="model_cards", | |
metadata={"hnsw:space": "cosine"}, | |
) | |
# Load dataset data | |
df = pl.scan_parquet( | |
"hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet" | |
) | |
df = df.filter( | |
pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() | |
) | |
df = df.filter( | |
pl.col("datasetId") | |
.str.contains_any( | |
["gemma-2-2B-it-thinking-function_calling-V0"] | |
) # course model that's not useful for retrieving | |
.not_() | |
) | |
# Get the most recent last_modified date from the collection | |
latest_update = None | |
if dataset_collection.count() > 0: | |
metadata = dataset_collection.get(include=["metadatas"]).get("metadatas") | |
logger.info(f"Found {len(metadata)} existing records in collection") | |
last_modifieds = [ | |
dateutil.parser.parse(m.get("last_modified")) for m in metadata | |
] | |
latest_update = max(last_modifieds) | |
logger.info(f"Most recent record in DB from: {latest_update}") | |
logger.info(f"Oldest record in DB from: {min(last_modifieds)}") | |
# Filter and process only newer records | |
df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"]) | |
# Log some stats about the incoming data | |
sample_dates = df.select("last_modified").limit(5).collect() | |
logger.info(f"Sample of incoming dates: {sample_dates}") | |
total_incoming = df.select(pl.len()).collect().item() | |
logger.info(f"Total incoming records: {total_incoming}") | |
if latest_update: | |
logger.info(f"Filtering records newer than {latest_update}") | |
df = df.filter(pl.col("last_modified") > latest_update) | |
filtered_count = df.select(pl.len()).collect().item() | |
logger.info(f"Found {filtered_count} records to update after filtering") | |
df = df.collect() | |
total_rows = len(df) | |
if total_rows > 0: | |
logger.info(f"Updating dataset collection with {total_rows} new records") | |
logger.info( | |
f"Date range of updates: {df['last_modified'].min()} to {df['last_modified'].max()}" | |
) | |
for i in range(0, total_rows, BATCH_SIZE): | |
batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
batch_size = len(batch_df) | |
logger.info( | |
f"Processing batch {i // BATCH_SIZE + 1}: {batch_size} records " | |
f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})" | |
) | |
dataset_collection.upsert( | |
ids=batch_df.select(["datasetId"]).to_series().to_list(), | |
documents=batch_df.select(["summary"]).to_series().to_list(), | |
metadatas=[ | |
{ | |
"likes": int(likes), | |
"downloads": int(downloads), | |
"last_modified": str(last_modified), | |
} | |
for likes, downloads, last_modified in zip( | |
batch_df.select(["likes"]).to_series().to_list(), | |
batch_df.select(["downloads"]).to_series().to_list(), | |
batch_df.select(["last_modified"]).to_series().to_list(), | |
) | |
], | |
) | |
logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records") | |
logger.info( | |
f"Database initialized with {dataset_collection.count():,} total rows" | |
) | |
# Load model data | |
model_df = pl.scan_parquet( | |
"hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet" | |
) | |
model_row_count = model_df.select(pl.len()).collect().item() | |
logger.info(f"Row count of new model data: {model_row_count}") | |
if model_collection.count() < model_row_count: | |
model_df = model_df.select( | |
[ | |
"modelId", | |
"summary", | |
"likes", | |
"downloads", | |
"last_modified", | |
"param_count", | |
] | |
) | |
model_df = model_df.collect() | |
total_rows = len(model_df) | |
for i in range(0, total_rows, BATCH_SIZE): | |
batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
model_collection.upsert( | |
ids=batch_df.select(["modelId"]).to_series().to_list(), | |
documents=batch_df.select(["summary"]).to_series().to_list(), | |
metadatas=[ | |
{ | |
"likes": int(likes), | |
"downloads": int(downloads), | |
"last_modified": str(last_modified), | |
"param_count": int(param_count) | |
if param_count is not None | |
else 0, | |
} | |
for likes, downloads, last_modified, param_count in zip( | |
batch_df.select(["likes"]).to_series().to_list(), | |
batch_df.select(["downloads"]).to_series().to_list(), | |
batch_df.select(["last_modified"]).to_series().to_list(), | |
batch_df.select(["param_count"]).to_series().to_list(), | |
) | |
], | |
) | |
logger.info( | |
f"Processed {i + len(batch_df):,} / {total_rows:,} model rows" | |
) | |
logger.info( | |
f"Model database initialized with {model_collection.count():,} rows" | |
) | |
except Exception as e: | |
logger.error(f"Setup error: {e}") | |
# Run setup on startup | |
setup_database() | |
class QueryResult(BaseModel): | |
dataset_id: str | |
similarity: float | |
summary: str | |
likes: int | |
downloads: int | |
class QueryResponse(BaseModel): | |
results: List[QueryResult] | |
class ModelQueryResult(BaseModel): | |
model_id: str | |
similarity: float | |
summary: str | |
likes: int | |
downloads: int | |
param_count: Optional[int] = None | |
class ModelQueryResponse(BaseModel): | |
results: List[ModelQueryResult] | |
async def redirect_to_docs(): | |
from fastapi.responses import RedirectResponse | |
return RedirectResponse(url="/docs") | |
async def search_datasets( | |
query: str, | |
k: int = Query(default=5, ge=1, le=100), | |
sort_by: str = Query( | |
default="similarity", enum=["similarity", "likes", "downloads", "trending"] | |
), | |
min_likes: int = Query(default=0, ge=0), | |
min_downloads: int = Query(default=0, ge=0), | |
): | |
try: | |
collection = client.get_collection( | |
name="dataset_cards", embedding_function=get_embedding_function() | |
) | |
results = collection.query( | |
query_texts=[f"search_query: {query}"], | |
n_results=k * 4 if sort_by != "similarity" else k, | |
where={ | |
"$and": [ | |
{"likes": {"$gte": min_likes}}, | |
{"downloads": {"$gte": min_downloads}}, | |
] | |
} | |
if min_likes > 0 or min_downloads > 0 | |
else None, | |
) | |
query_results = await process_search_results(results, "dataset", k, sort_by) | |
return QueryResponse(results=query_results) | |
except Exception as e: | |
logger.error(f"Search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Search failed") | |
async def find_similar_datasets( | |
dataset_id: str, | |
k: int = Query(default=5, ge=1, le=100), | |
sort_by: str = Query( | |
default="similarity", enum=["similarity", "likes", "downloads", "trending"] | |
), | |
min_likes: int = Query(default=0, ge=0), | |
min_downloads: int = Query(default=0, ge=0), | |
): | |
try: | |
collection = client.get_collection("dataset_cards") | |
results = collection.get(ids=[dataset_id], include=["embeddings"]) | |
if not results["ids"]: | |
raise HTTPException( | |
status_code=404, detail=f"Dataset ID '{dataset_id}' not found" | |
) | |
results = collection.query( | |
query_embeddings=[results["embeddings"][0]], | |
n_results=k * 4 if sort_by != "similarity" else k + 1, | |
where={ | |
"$and": [ | |
{"likes": {"$gte": min_likes}}, | |
{"downloads": {"$gte": min_downloads}}, | |
] | |
} | |
if min_likes > 0 or min_downloads > 0 | |
else None, | |
) | |
query_results = await process_search_results( | |
results, "dataset", k, sort_by, dataset_id | |
) | |
return QueryResponse(results=query_results) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Similarity search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Similarity search failed") | |
async def search_models( | |
query: str, | |
k: int = Query(default=5, ge=1, le=100, description="Number of results to return"), | |
sort_by: str = Query( | |
default="similarity", | |
enum=["similarity", "likes", "downloads", "trending"], | |
description="Sort method for results", | |
), | |
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), | |
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), | |
min_param_count: int = Query( | |
default=0, | |
ge=0, | |
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)", | |
), | |
max_param_count: Optional[int] = Query( | |
default=None, | |
ge=0, | |
description="Maximum parameter count (None means no upper limit)", | |
), | |
): | |
""" | |
Search for models based on a text query with optional filtering. | |
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded | |
- param_count=0 indicates missing/unknown parameter count in the dataset | |
""" | |
try: | |
collection = client.get_collection( | |
name="model_cards", embedding_function=get_embedding_function() | |
) | |
where_conditions = [] | |
if min_likes > 0: | |
where_conditions.append({"likes": {"$gte": min_likes}}) | |
if min_downloads > 0: | |
where_conditions.append({"downloads": {"$gte": min_downloads}}) | |
# Add parameter count filters | |
using_param_filters = min_param_count > 0 or max_param_count is not None | |
if using_param_filters: | |
# Always exclude zero param count when using any parameter filters | |
where_conditions.append({"param_count": {"$gt": 0}}) | |
if min_param_count > 0: | |
where_conditions.append({"param_count": {"$gte": min_param_count}}) | |
if max_param_count is not None: | |
where_conditions.append({"param_count": {"$lte": max_param_count}}) | |
# Handle where clause creation based on number of conditions | |
where_clause = None | |
if len(where_conditions) > 1: | |
where_clause = {"$and": where_conditions} | |
elif len(where_conditions) == 1: | |
where_clause = where_conditions[0] # Single condition without $and | |
results = collection.query( | |
query_texts=[f"search_query: {query}"], | |
n_results=k * 4 if sort_by != "similarity" else k, | |
where=where_clause, | |
) | |
query_results = await process_search_results(results, "model", k, sort_by) | |
return ModelQueryResponse(results=query_results) | |
except Exception as e: | |
logger.error(f"Model search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Model search failed") | |
async def find_similar_models( | |
model_id: str, | |
k: int = Query(default=5, ge=1, le=100, description="Number of results to return"), | |
sort_by: str = Query( | |
default="similarity", | |
enum=["similarity", "likes", "downloads", "trending"], | |
description="Sort method for results", | |
), | |
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), | |
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), | |
min_param_count: int = Query( | |
default=0, | |
ge=0, | |
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)", | |
), | |
max_param_count: Optional[int] = Query( | |
default=None, | |
ge=0, | |
description="Maximum parameter count (None means no upper limit)", | |
), | |
): | |
""" | |
Find similar models to a specified model with optional filtering. | |
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded | |
- param_count=0 indicates missing/unknown parameter count in the dataset | |
""" | |
try: | |
collection = client.get_collection("model_cards") | |
results = collection.get(ids=[model_id], include=["embeddings"]) | |
if not results["ids"]: | |
raise HTTPException( | |
status_code=404, detail=f"Model ID '{model_id}' not found" | |
) | |
where_conditions = [] | |
if min_likes > 0: | |
where_conditions.append({"likes": {"$gte": min_likes}}) | |
if min_downloads > 0: | |
where_conditions.append({"downloads": {"$gte": min_downloads}}) | |
# Add parameter count filters | |
using_param_filters = min_param_count > 0 or max_param_count is not None | |
if using_param_filters: | |
# Always exclude zero param count when using any parameter filters | |
where_conditions.append({"param_count": {"$gt": 0}}) | |
if min_param_count > 0: | |
where_conditions.append({"param_count": {"$gte": min_param_count}}) | |
if max_param_count is not None: | |
where_conditions.append({"param_count": {"$lte": max_param_count}}) | |
# Handle where clause creation based on number of conditions | |
where_clause = None | |
if len(where_conditions) > 1: | |
where_clause = {"$and": where_conditions} | |
elif len(where_conditions) == 1: | |
where_clause = where_conditions[0] # Single condition without $and | |
results = collection.query( | |
query_embeddings=[results["embeddings"][0]], | |
n_results=k * 4 if sort_by != "similarity" else k + 1, | |
where=where_clause, | |
) | |
query_results = await process_search_results( | |
results, "model", k, sort_by, model_id | |
) | |
return ModelQueryResponse(results=query_results) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Model similarity search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Model similarity search failed") | |
async def get_trending_score(item_id: str, item_type: str) -> float: | |
"""Fetch trending score for a model or dataset from HuggingFace API""" | |
try: | |
async with httpx.AsyncClient() as client: | |
endpoint = "models" if item_type == "model" else "datasets" | |
response = await client.get( | |
f"https://huggingface.co/api/{endpoint}/{item_id}?expand=trendingScore" | |
) | |
response.raise_for_status() | |
return response.json().get("trendingScore", 0) | |
except Exception as e: | |
logger.error( | |
f"Error fetching trending score for {item_type} {item_id}: {str(e)}" | |
) | |
return 0 | |
async def process_search_results(results, id_field, k, sort_by, exclude_id=None): | |
"""Process search results into a standardized format.""" | |
query_results = [] | |
# Create base results | |
for i in range(len(results["ids"][0])): | |
current_id = results["ids"][0][i] | |
if exclude_id and current_id == exclude_id: | |
continue | |
result = { | |
f"{id_field}_id": current_id, | |
"similarity": float(results["distances"][0][i]), | |
"summary": results["documents"][0][i], | |
"likes": results["metadatas"][0][i]["likes"], | |
"downloads": results["metadatas"][0][i]["downloads"], | |
} | |
# Add param_count for models if it exists in metadata | |
if id_field == "model" and "param_count" in results["metadatas"][0][i]: | |
result["param_count"] = results["metadatas"][0][i]["param_count"] | |
if id_field == "dataset": | |
query_results.append(QueryResult(**result)) | |
else: | |
query_results.append(ModelQueryResult(**result)) | |
# Handle sorting | |
if sort_by == "trending": | |
# Fetch trending scores for all results | |
trending_scores = {} | |
async with httpx.AsyncClient() as client: | |
tasks = [ | |
get_trending_score( | |
getattr(result, f"{id_field}_id"), | |
"model" if id_field == "model" else "dataset", | |
) | |
for result in query_results | |
] | |
scores = await asyncio.gather(*tasks) | |
trending_scores = { | |
getattr(result, f"{id_field}_id"): score | |
for result, score in zip(query_results, scores) | |
} | |
# Sort by trending score | |
query_results.sort( | |
key=lambda x: trending_scores.get(getattr(x, f"{id_field}_id"), 0), | |
reverse=True, | |
) | |
query_results = query_results[:k] | |
elif sort_by != "similarity": | |
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) | |
query_results = query_results[:k] | |
elif exclude_id: # We fetched extra for similarity + exclude_id case | |
query_results = query_results[:k] | |
return query_results | |
async def fetch_trending_models(): | |
"""Fetch trending models from HuggingFace API""" | |
async with httpx.AsyncClient() as client: | |
response = await client.get("https://huggingface.co/api/models") | |
response.raise_for_status() | |
return response.json() | |
async def get_trending_models_with_summaries( | |
limit: int = 10, | |
min_likes: int = 0, | |
min_downloads: int = 0, | |
min_param_count: int = 0, | |
max_param_count: Optional[int] = None, | |
) -> List[ModelQueryResult]: | |
"""Fetch trending models and combine with summaries from database""" | |
try: | |
# Fetch trending models | |
trending_models = await fetch_trending_models() | |
# Filter by minimum likes/downloads | |
trending_models = [ | |
model | |
for model in trending_models | |
if model.get("likes", 0) >= min_likes | |
and model.get("downloads", 0) >= min_downloads | |
] | |
# Sort by trending score | |
trending_models = sorted( | |
trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True | |
) | |
# Fetch up to 3x the limit (buffer for filtering) or all available if fewer | |
# This ensures we have enough models to filter from | |
fetch_limit = min(len(trending_models), limit * 3) | |
trending_models = trending_models[:fetch_limit] | |
# Get model IDs | |
model_ids = [model["modelId"] for model in trending_models] | |
# Fetch summaries from ChromaDB | |
collection = client.get_collection("model_cards") | |
summaries = collection.get(ids=model_ids, include=["documents", "metadatas"]) | |
# Create mapping of model_id to summary and metadata | |
id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) | |
id_to_metadata = dict(zip(summaries["ids"], summaries["metadatas"])) | |
# Log parameters for debugging | |
print( | |
f"Filter params - min_param_count: {min_param_count}, max_param_count: {max_param_count}" | |
) | |
# Combine data - collect all results first | |
all_results = [] | |
for model in trending_models: | |
if model["modelId"] in id_to_summary: | |
metadata = id_to_metadata.get(model["modelId"], {}) | |
param_count = metadata.get("param_count", 0) | |
# Log model parameter counts | |
print(f"Model: {model['modelId']}, param_count: {param_count}") | |
result = ModelQueryResult( | |
model_id=model["modelId"], | |
similarity=1.0, # Not applicable for trending | |
summary=id_to_summary[model["modelId"]], | |
likes=model.get("likes", 0), | |
downloads=model.get("downloads", 0), | |
param_count=param_count, | |
) | |
all_results.append(result) | |
# Apply parameter filtering after collecting all results | |
filtered_results = all_results | |
# Check if any parameter filtering is being applied | |
using_param_filters = min_param_count > 0 or max_param_count is not None | |
# Only filter by params if we have specific parameter constraints | |
if using_param_filters: | |
filtered_results = [] | |
for result in all_results: | |
should_include = True | |
# Always exclude models with param_count=0 when any parameter filtering is active | |
if result.param_count == 0: | |
print( | |
f"Filtering out {result.model_id} - has param_count=0 but parameter filtering is active" | |
) | |
should_include = False | |
# Apply min param filter if specified | |
elif min_param_count > 0 and result.param_count < min_param_count: | |
print( | |
f"Filtering out {result.model_id} - param_count {result.param_count} < min_param_count {min_param_count}" | |
) | |
should_include = False | |
# Apply max param filter if specified | |
elif ( | |
max_param_count is not None and result.param_count > max_param_count | |
): | |
print( | |
f"Filtering out {result.model_id} - param_count {result.param_count} > max_param_count {max_param_count}" | |
) | |
should_include = False | |
if should_include: | |
filtered_results.append(result) | |
print(f"After filtering: {len(filtered_results)} models remain") | |
# Finally limit to the requested number | |
return filtered_results[:limit] | |
except Exception as e: | |
logger.error(f"Error fetching trending models: {str(e)}") | |
raise HTTPException(status_code=500, detail="Failed to fetch trending models") | |
async def get_trending_models( | |
limit: int = Query( | |
default=10, ge=1, le=100, description="Number of results to return" | |
), | |
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), | |
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), | |
min_param_count: int = Query( | |
default=0, | |
ge=0, | |
description="Minimum parameter count (models with param_count=0 will be excluded if any parameter filter is used)", | |
), | |
max_param_count: Optional[int] = Query( | |
default=None, | |
ge=0, | |
description="Maximum parameter count (None means no upper limit)", | |
), | |
): | |
""" | |
Get trending models with their summaries and optional filtering. | |
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded | |
- param_count=0 indicates missing/unknown parameter count in the dataset | |
""" | |
print( | |
f"Request for trending models with params: limit={limit}, min_likes={min_likes}, min_downloads={min_downloads}, min_param_count={min_param_count}, max_param_count={max_param_count}" | |
) | |
results = await get_trending_models_with_summaries( | |
limit=limit, | |
min_likes=min_likes, | |
min_downloads=min_downloads, | |
min_param_count=min_param_count, | |
max_param_count=max_param_count, | |
) | |
print(f"Returning {len(results)} trending model results") | |
return ModelQueryResponse(results=results) | |
async def fetch_trending_datasets(): | |
"""Fetch trending datasets from HuggingFace API""" | |
async with httpx.AsyncClient() as client: | |
response = await client.get("https://huggingface.co/api/datasets") | |
response.raise_for_status() | |
return response.json() | |
async def get_trending_datasets_with_summaries( | |
limit: int = 10, | |
min_likes: int = 0, | |
min_downloads: int = 0, | |
) -> List[QueryResult]: | |
"""Fetch trending datasets and combine with summaries from database""" | |
try: | |
# Fetch trending datasets | |
trending_datasets = await fetch_trending_datasets() | |
# Filter by minimum likes/downloads | |
trending_datasets = [ | |
dataset | |
for dataset in trending_datasets | |
if dataset.get("likes", 0) >= min_likes | |
and dataset.get("downloads", 0) >= min_downloads | |
] | |
# Sort by trending score and limit | |
trending_datasets = sorted( | |
trending_datasets, key=lambda x: x.get("trendingScore", 0), reverse=True | |
)[:limit] | |
# Get dataset IDs | |
dataset_ids = [dataset["id"] for dataset in trending_datasets] | |
# Fetch summaries from ChromaDB | |
collection = client.get_collection("dataset_cards") | |
summaries = collection.get(ids=dataset_ids, include=["documents"]) | |
# Create mapping of dataset_id to summary | |
id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) | |
# Combine data | |
results = [] | |
for dataset in trending_datasets: | |
if dataset["id"] in id_to_summary: | |
result = QueryResult( | |
dataset_id=dataset["id"], | |
similarity=1.0, # Not applicable for trending | |
summary=id_to_summary[dataset["id"]], | |
likes=dataset.get("likes", 0), | |
downloads=dataset.get("downloads", 0), | |
) | |
results.append(result) | |
return results | |
except Exception as e: | |
logger.error(f"Error fetching trending datasets: {str(e)}") | |
raise HTTPException(status_code=500, detail="Failed to fetch trending datasets") | |
async def get_trending_datasets( | |
limit: int = Query(default=10, ge=1, le=100), | |
min_likes: int = Query(default=0, ge=0), | |
min_downloads: int = Query(default=0, ge=0), | |
): | |
"""Get trending datasets with their summaries""" | |
results = await get_trending_datasets_with_summaries( | |
limit=limit, min_likes=min_likes, min_downloads=min_downloads | |
) | |
return QueryResponse(results=results) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |