from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel from huggingface_hub import snapshot_download, HfApi import os import shutil from typing import Optional import asyncio from concurrent.futures import ThreadPoolExecutor app = FastAPI(title="Hugging Face Model Transfer API", version="1.0.0") security = HTTPBearer() # Thread pool for running blocking operations executor = ThreadPoolExecutor(max_workers=2) # Local directory to save downloaded model data temporarily DOWNLOAD_DIR = "./downloaded_model_data" # Ensure the download directory exists os.makedirs(DOWNLOAD_DIR, exist_ok=True) class DownloadRequest(BaseModel): model_repo_id: str download_dir: Optional[str] = DOWNLOAD_DIR class UploadRequest(BaseModel): dataset_repo_id: str folder_path: str path_in_repo: str class TransferRequest(BaseModel): model_repo_id: str dataset_repo_id: str path_in_repo: Optional[str] = None def get_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: """Extract and validate Hugging Face token from Authorization header.""" return credentials.credentials def download_full_model(repo_id: str, download_dir: str) -> str: """Downloads an entire model from a Hugging Face model repository.""" print(f"Downloading full model from {repo_id}...") local_dir = snapshot_download(repo_id=repo_id, cache_dir=download_dir) print(f"Downloaded to: {local_dir}") return local_dir def upload_folder_to_dataset(dataset_repo_id: str, folder_path: str, path_in_repo: str, token: str): """Uploads a folder to a Hugging Face dataset repository.""" api = HfApi(token=token) print(f"Uploading {folder_path} to {dataset_repo_id} at {path_in_repo}...") api.upload_folder( folder_path=folder_path, path_in_repo=path_in_repo, repo_id=dataset_repo_id, repo_type="dataset", ) print("Upload complete!") @app.get("/") async def root(): """Health check endpoint.""" return {"message": "Hugging Face Model Transfer API is running"} @app.post("/download") async def download_model(request: DownloadRequest, token: str = Depends(get_token)): """Download a model from Hugging Face model repository.""" try: # Run the blocking download operation in a thread pool loop = asyncio.get_event_loop() local_dir = await loop.run_in_executor( executor, download_full_model, request.model_repo_id, request.download_dir ) return { "message": f"Model {request.model_repo_id} downloaded successfully", "local_path": local_dir } except Exception as e: raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") @app.post("/upload") async def upload_folder(request: UploadRequest, token: str = Depends(get_token)): """Upload a folder to a Hugging Face dataset repository.""" try: # Check if folder exists if not os.path.exists(request.folder_path): raise HTTPException(status_code=404, detail=f"Folder not found: {request.folder_path}") # Run the blocking upload operation in a thread pool loop = asyncio.get_event_loop() await loop.run_in_executor( executor, upload_folder_to_dataset, request.dataset_repo_id, request.folder_path, request.path_in_repo, token ) return { "message": f"Folder uploaded successfully to {request.dataset_repo_id}", "path_in_repo": request.path_in_repo } except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") @app.post("/transfer") async def transfer_model(request: TransferRequest, background_tasks: BackgroundTasks, token: str = Depends(get_token)): """Download a model and upload it to a dataset repository (combined operation).""" try: # Set default path in repo if not provided path_in_repo = request.path_in_repo or f"model_data/{request.model_repo_id}/" # Download the model loop = asyncio.get_event_loop() local_dir = await loop.run_in_executor( executor, download_full_model, request.model_repo_id, DOWNLOAD_DIR ) # Upload to dataset await loop.run_in_executor( executor, upload_folder_to_dataset, request.dataset_repo_id, local_dir, path_in_repo, token ) # Clean up downloaded files in background background_tasks.add_task(cleanup_download, local_dir) return { "message": f"Model {request.model_repo_id} transferred successfully to {request.dataset_repo_id}", "path_in_repo": path_in_repo } except Exception as e: raise HTTPException(status_code=500, detail=f"Transfer failed: {str(e)}") def cleanup_download(local_dir: str): """Clean up downloaded files.""" try: if os.path.exists(local_dir): shutil.rmtree(local_dir) print(f"Cleaned up: {local_dir}") except Exception as e: print(f"Cleanup failed: {str(e)}") @app.get("/status") async def get_status(): """Get server status and available disk space.""" try: disk_usage = shutil.disk_usage(DOWNLOAD_DIR) return { "status": "healthy", "download_dir": DOWNLOAD_DIR, "disk_space": { "total": disk_usage.total, "used": disk_usage.used, "free": disk_usage.free } } except Exception as e: raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860) live Jump to live