File size: 6,146 Bytes
970ed33 abec798 970ed33 abec798 970ed33 abec798 970ed33 abec798 970ed33 abec798 970ed33 abec798 970ed33 abec798 970ed33 abec798 970ed33 abec798 970ed33 abec798 970ed33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from huggingface_hub import snapshot_download, HfApi
import os
import shutil
from typing import Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
app = FastAPI(title="Hugging Face Model Transfer API", version="1.0.0")
security = HTTPBearer()
# Thread pool for running blocking operations
executor = ThreadPoolExecutor(max_workers=2)
# Local directory to save downloaded model data temporarily
DOWNLOAD_DIR = "./downloaded_model_data"
# Ensure the download directory exists
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
class DownloadRequest(BaseModel):
model_repo_id: str
download_dir: Optional[str] = DOWNLOAD_DIR
class UploadRequest(BaseModel):
dataset_repo_id: str
folder_path: str
path_in_repo: str
class TransferRequest(BaseModel):
model_repo_id: str
dataset_repo_id: str
path_in_repo: Optional[str] = None
def get_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
"""Extract and validate Hugging Face token from Authorization header."""
return credentials.credentials
def download_full_model(repo_id: str, download_dir: str) -> str:
"""Downloads an entire model from a Hugging Face model repository."""
print(f"Downloading full model from {repo_id}...")
local_dir = snapshot_download(repo_id=repo_id, cache_dir=download_dir)
print(f"Downloaded to: {local_dir}")
return local_dir
def upload_folder_to_dataset(dataset_repo_id: str, folder_path: str, path_in_repo: str, token: str):
"""Uploads a folder to a Hugging Face dataset repository."""
api = HfApi(token=token)
print(f"Uploading {folder_path} to {dataset_repo_id} at {path_in_repo}...")
api.upload_folder(
folder_path=folder_path,
path_in_repo=path_in_repo,
repo_id=dataset_repo_id,
repo_type="dataset",
)
print("Upload complete!")
@app.get("/")
async def root():
"""Health check endpoint."""
return {"message": "Hugging Face Model Transfer API is running"}
@app.post("/download")
async def download_model(request: DownloadRequest, token: str = Depends(get_token)):
"""Download a model from Hugging Face model repository."""
try:
# Run the blocking download operation in a thread pool
loop = asyncio.get_event_loop()
local_dir = await loop.run_in_executor(
executor,
download_full_model,
request.model_repo_id,
request.download_dir
)
return {
"message": f"Model {request.model_repo_id} downloaded successfully",
"local_path": local_dir
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
@app.post("/upload")
async def upload_folder(request: UploadRequest, token: str = Depends(get_token)):
"""Upload a folder to a Hugging Face dataset repository."""
try:
# Check if folder exists
if not os.path.exists(request.folder_path):
raise HTTPException(status_code=404, detail=f"Folder not found: {request.folder_path}")
# Run the blocking upload operation in a thread pool
loop = asyncio.get_event_loop()
await loop.run_in_executor(
executor,
upload_folder_to_dataset,
request.dataset_repo_id,
request.folder_path,
request.path_in_repo,
token
)
return {
"message": f"Folder uploaded successfully to {request.dataset_repo_id}",
"path_in_repo": request.path_in_repo
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
@app.post("/transfer")
async def transfer_model(request: TransferRequest, background_tasks: BackgroundTasks, token: str = Depends(get_token)):
"""Download a model and upload it to a dataset repository (combined operation)."""
try:
# Set default path in repo if not provided
path_in_repo = request.path_in_repo or f"model_data/{request.model_repo_id}/"
# Download the model
loop = asyncio.get_event_loop()
local_dir = await loop.run_in_executor(
executor,
download_full_model,
request.model_repo_id,
DOWNLOAD_DIR
)
# Upload to dataset
await loop.run_in_executor(
executor,
upload_folder_to_dataset,
request.dataset_repo_id,
local_dir,
path_in_repo,
token
)
# Clean up downloaded files in background
background_tasks.add_task(cleanup_download, local_dir)
return {
"message": f"Model {request.model_repo_id} transferred successfully to {request.dataset_repo_id}",
"path_in_repo": path_in_repo
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Transfer failed: {str(e)}")
def cleanup_download(local_dir: str):
"""Clean up downloaded files."""
try:
if os.path.exists(local_dir):
shutil.rmtree(local_dir)
print(f"Cleaned up: {local_dir}")
except Exception as e:
print(f"Cleanup failed: {str(e)}")
@app.get("/status")
async def get_status():
"""Get server status and available disk space."""
try:
disk_usage = shutil.disk_usage(DOWNLOAD_DIR)
return {
"status": "healthy",
"download_dir": DOWNLOAD_DIR,
"disk_space": {
"total": disk_usage.total,
"used": disk_usage.used,
"free": disk_usage.free
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
live
Jump to live
|