import json import logging import asyncio import re from typing import Tuple, Optional, Dict, Any import aiohttp from huggingface_hub import HfApi, ModelCard, hf_hub_download from huggingface_hub import hf_api from transformers import AutoConfig, AutoTokenizer from app.config.base import HF_TOKEN, API from app.utils.logging import LogFormatter logger = logging.getLogger(__name__) class ModelValidator: def __init__(self): self.token = HF_TOKEN self.api = HfApi(token=self.token) self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} async def check_model_card(self, model_id: str) -> Tuple[bool, str, Optional[Dict[str, Any]]]: """Check if model has a valid model card""" try: logger.info(LogFormatter.info(f"Checking model card for {model_id}")) # Get model card content using ModelCard.load try: model_card = await asyncio.to_thread( ModelCard.load, model_id ) logger.info(LogFormatter.success("Model card found")) except Exception as e: error_msg = "Please add a model card to your model to explain how you trained/fine-tuned it." logger.error(LogFormatter.error(error_msg, e)) return False, error_msg, None # Check license in model card data if model_card.data.license is None and not ("license_name" in model_card.data and "license_link" in model_card.data): error_msg = "License not found. Please add a license to your model card using the `license` metadata or a `license_name`/`license_link` pair." logger.warning(LogFormatter.warning(error_msg)) return False, error_msg, None # Enforce card content length if len(model_card.text) < 200: error_msg = "Please add a description to your model card, it is too short." logger.warning(LogFormatter.warning(error_msg)) return False, error_msg, None logger.info(LogFormatter.success("Model card validation passed")) return True, "", model_card except Exception as e: error_msg = "Failed to validate model card" logger.error(LogFormatter.error(error_msg, e)) return False, str(e), None async def get_safetensors_metadata(self, model_id: str, is_adapter: bool = False, revision: str = "main") -> Optional[Dict]: """Get metadata from a safetensors file""" try: if is_adapter: metadata = await asyncio.to_thread( hf_api.parse_safetensors_file_metadata, model_id, "adapter_model.safetensors", token=self.token, revision=revision, ) else: metadata = await asyncio.to_thread( hf_api.get_safetensors_metadata, repo_id=model_id, token=self.token, revision=revision, ) return metadata except Exception as e: logger.error(f"Failed to get safetensors metadata: {str(e)}") return None async def get_model_size( self, model_info: Any, precision: str, base_model: str, revision: str ) -> Tuple[Optional[float], Optional[str]]: """Get model size in billions of parameters""" try: logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}")) # Check if model is adapter is_adapter = any(s.rfilename == "adapter_config.json" for s in model_info.siblings if hasattr(s, 'rfilename')) # Try to get size from safetensors first model_size = None if is_adapter and base_model: # For adapters, we need both adapter and base model sizes adapter_meta = await self.get_safetensors_metadata(model_info.id, is_adapter=True, revision=revision) base_meta = await self.get_safetensors_metadata(base_model, revision="main") if adapter_meta and base_meta: adapter_size = sum(adapter_meta.parameter_count.values()) base_size = sum(base_meta.parameter_count.values()) model_size = (adapter_size + base_size) / (2 * 1e9) # Convert to billions, assuming float16 else: # For regular models, just get the model size meta = await self.get_safetensors_metadata(model_info.id, revision=revision) if meta: total_params = sum(meta.parameter_count.values()) model_size = total_params / (2 * 1e9) # Convert to billions, assuming float16 if model_size is None: # If model size could not be determined, return an error return None, "Model size could not be determined" # Adjust size for GPTQ models size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1 model_size = round(size_factor * model_size, 3) logger.info(LogFormatter.success(f"Model size: {model_size}B parameters")) return model_size, None except Exception as e: logger.error(LogFormatter.error(f"Error while determining model size: {e}")) return None, str(e) async def check_chat_template( self, model_id: str, revision: str ) -> Tuple[bool, Optional[str]]: """Check if model has a valid chat template""" try: logger.info(LogFormatter.info(f"Checking chat template for {model_id}")) try: config_file = await asyncio.to_thread( hf_hub_download, repo_id=model_id, filename="tokenizer_config.json", revision=revision, repo_type="model" ) with open(config_file, 'r') as f: tokenizer_config = json.load(f) if 'chat_template' not in tokenizer_config: error_msg = f"The model {model_id} doesn't have a chat_template in its tokenizer_config.json. Please add a chat_template before submitting or submit without it." logger.error(LogFormatter.error(error_msg)) return False, error_msg logger.info(LogFormatter.success("Valid chat template found")) return True, None except Exception as e: error_msg = f"Error checking chat_template: {str(e)}" logger.error(LogFormatter.error(error_msg)) return False, error_msg except Exception as e: error_msg = "Failed to check chat template" logger.error(LogFormatter.error(error_msg, e)) return False, str(e) async def is_model_on_hub( self, model_name: str, revision: str, test_tokenizer: bool = False, trust_remote_code: bool = False ) -> Tuple[bool, Optional[str], Optional[Any]]: """Check if model exists and is properly configured on the Hub""" try: config = await asyncio.to_thread( AutoConfig.from_pretrained, model_name, revision=revision, trust_remote_code=trust_remote_code, token=self.token, force_download=True ) if test_tokenizer: try: await asyncio.to_thread( AutoTokenizer.from_pretrained, model_name, revision=revision, trust_remote_code=trust_remote_code, token=self.token ) except ValueError as e: return False, f"uses a tokenizer which is not in a transformers release: {e}", None except Exception: return False, "'s tokenizer cannot be loaded. Is your tokenizer class in a stable transformers release, and correctly configured?", None return True, None, config except ValueError: return False, "needs to be launched with `trust_remote_code=True`. For safety reason, we do not allow these models to be automatically submitted to the leaderboard.", None except Exception as e: if "You are trying to access a gated repo." in str(e): return True, "uses a gated model.", None return False, f"was not found or misconfigured on the hub! Error raised was {e.args[0]}", None