|
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks |
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
from pydantic import BaseModel |
|
from huggingface_hub import snapshot_download, HfApi |
|
import os |
|
import shutil |
|
from typing import Optional |
|
import asyncio |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
app = FastAPI(title="Hugging Face Model Transfer API", version="1.0.0") |
|
security = HTTPBearer() |
|
|
|
|
|
executor = ThreadPoolExecutor(max_workers=2) |
|
|
|
|
|
DOWNLOAD_DIR = "./downloaded_model_data" |
|
|
|
|
|
os.makedirs(DOWNLOAD_DIR, exist_ok=True) |
|
|
|
class DownloadRequest(BaseModel): |
|
model_repo_id: str |
|
download_dir: Optional[str] = DOWNLOAD_DIR |
|
|
|
class UploadRequest(BaseModel): |
|
dataset_repo_id: str |
|
folder_path: str |
|
path_in_repo: str |
|
|
|
class TransferRequest(BaseModel): |
|
model_repo_id: str |
|
dataset_repo_id: str |
|
path_in_repo: Optional[str] = None |
|
|
|
def get_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: |
|
"""Extract and validate Hugging Face token from Authorization header.""" |
|
return credentials.credentials |
|
|
|
def download_full_model(repo_id: str, download_dir: str) -> str: |
|
"""Downloads an entire model from a Hugging Face model repository.""" |
|
print(f"Downloading full model from {repo_id}...") |
|
local_dir = snapshot_download(repo_id=repo_id, cache_dir=download_dir) |
|
print(f"Downloaded to: {local_dir}") |
|
return local_dir |
|
|
|
def upload_folder_to_dataset(dataset_repo_id: str, folder_path: str, path_in_repo: str, token: str): |
|
"""Uploads a folder to a Hugging Face dataset repository.""" |
|
api = HfApi(token=token) |
|
print(f"Uploading {folder_path} to {dataset_repo_id} at {path_in_repo}...") |
|
api.upload_folder( |
|
folder_path=folder_path, |
|
path_in_repo=path_in_repo, |
|
repo_id=dataset_repo_id, |
|
repo_type="dataset", |
|
) |
|
print("Upload complete!") |
|
|
|
@app.get("/") |
|
async def root(): |
|
"""Health check endpoint.""" |
|
return {"message": "Hugging Face Model Transfer API is running"} |
|
|
|
@app.post("/download") |
|
async def download_model(request: DownloadRequest, token: str = Depends(get_token)): |
|
"""Download a model from Hugging Face model repository.""" |
|
try: |
|
|
|
loop = asyncio.get_event_loop() |
|
local_dir = await loop.run_in_executor( |
|
executor, |
|
download_full_model, |
|
request.model_repo_id, |
|
request.download_dir |
|
) |
|
|
|
return { |
|
"message": f"Model {request.model_repo_id} downloaded successfully", |
|
"local_path": local_dir |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") |
|
|
|
@app.post("/upload") |
|
async def upload_folder(request: UploadRequest, token: str = Depends(get_token)): |
|
"""Upload a folder to a Hugging Face dataset repository.""" |
|
try: |
|
|
|
if not os.path.exists(request.folder_path): |
|
raise HTTPException(status_code=404, detail=f"Folder not found: {request.folder_path}") |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
await loop.run_in_executor( |
|
executor, |
|
upload_folder_to_dataset, |
|
request.dataset_repo_id, |
|
request.folder_path, |
|
request.path_in_repo, |
|
token |
|
) |
|
|
|
return { |
|
"message": f"Folder uploaded successfully to {request.dataset_repo_id}", |
|
"path_in_repo": request.path_in_repo |
|
} |
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") |
|
|
|
@app.post("/transfer") |
|
async def transfer_model(request: TransferRequest, background_tasks: BackgroundTasks, token: str = Depends(get_token)): |
|
"""Download a model and upload it to a dataset repository (combined operation).""" |
|
try: |
|
|
|
path_in_repo = request.path_in_repo or f"model_data/{request.model_repo_id}/" |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
local_dir = await loop.run_in_executor( |
|
executor, |
|
download_full_model, |
|
request.model_repo_id, |
|
DOWNLOAD_DIR |
|
) |
|
|
|
|
|
await loop.run_in_executor( |
|
executor, |
|
upload_folder_to_dataset, |
|
request.dataset_repo_id, |
|
local_dir, |
|
path_in_repo, |
|
token |
|
) |
|
|
|
|
|
background_tasks.add_task(cleanup_download, local_dir) |
|
|
|
return { |
|
"message": f"Model {request.model_repo_id} transferred successfully to {request.dataset_repo_id}", |
|
"path_in_repo": path_in_repo |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Transfer failed: {str(e)}") |
|
|
|
def cleanup_download(local_dir: str): |
|
"""Clean up downloaded files.""" |
|
try: |
|
if os.path.exists(local_dir): |
|
shutil.rmtree(local_dir) |
|
print(f"Cleaned up: {local_dir}") |
|
except Exception as e: |
|
print(f"Cleanup failed: {str(e)}") |
|
|
|
@app.get("/status") |
|
async def get_status(): |
|
"""Get server status and available disk space.""" |
|
try: |
|
disk_usage = shutil.disk_usage(DOWNLOAD_DIR) |
|
return { |
|
"status": "healthy", |
|
"download_dir": DOWNLOAD_DIR, |
|
"disk_space": { |
|
"total": disk_usage.total, |
|
"used": disk_usage.used, |
|
"free": disk_usage.free |
|
} |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|
|
|
|
live |
|
|
|
Jump to live |
|
|