INIC / app.py
Fred808's picture
Update app.py
970ed33 verified
raw
history blame
6.15 kB
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from huggingface_hub import snapshot_download, HfApi
import os
import shutil
from typing import Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
app = FastAPI(title="Hugging Face Model Transfer API", version="1.0.0")
security = HTTPBearer()
# Thread pool for running blocking operations
executor = ThreadPoolExecutor(max_workers=2)
# Local directory to save downloaded model data temporarily
DOWNLOAD_DIR = "./downloaded_model_data"
# Ensure the download directory exists
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
class DownloadRequest(BaseModel):
model_repo_id: str
download_dir: Optional[str] = DOWNLOAD_DIR
class UploadRequest(BaseModel):
dataset_repo_id: str
folder_path: str
path_in_repo: str
class TransferRequest(BaseModel):
model_repo_id: str
dataset_repo_id: str
path_in_repo: Optional[str] = None
def get_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
"""Extract and validate Hugging Face token from Authorization header."""
return credentials.credentials
def download_full_model(repo_id: str, download_dir: str) -> str:
"""Downloads an entire model from a Hugging Face model repository."""
print(f"Downloading full model from {repo_id}...")
local_dir = snapshot_download(repo_id=repo_id, cache_dir=download_dir)
print(f"Downloaded to: {local_dir}")
return local_dir
def upload_folder_to_dataset(dataset_repo_id: str, folder_path: str, path_in_repo: str, token: str):
"""Uploads a folder to a Hugging Face dataset repository."""
api = HfApi(token=token)
print(f"Uploading {folder_path} to {dataset_repo_id} at {path_in_repo}...")
api.upload_folder(
folder_path=folder_path,
path_in_repo=path_in_repo,
repo_id=dataset_repo_id,
repo_type="dataset",
)
print("Upload complete!")
@app.get("/")
async def root():
"""Health check endpoint."""
return {"message": "Hugging Face Model Transfer API is running"}
@app.post("/download")
async def download_model(request: DownloadRequest, token: str = Depends(get_token)):
"""Download a model from Hugging Face model repository."""
try:
# Run the blocking download operation in a thread pool
loop = asyncio.get_event_loop()
local_dir = await loop.run_in_executor(
executor,
download_full_model,
request.model_repo_id,
request.download_dir
)
return {
"message": f"Model {request.model_repo_id} downloaded successfully",
"local_path": local_dir
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
@app.post("/upload")
async def upload_folder(request: UploadRequest, token: str = Depends(get_token)):
"""Upload a folder to a Hugging Face dataset repository."""
try:
# Check if folder exists
if not os.path.exists(request.folder_path):
raise HTTPException(status_code=404, detail=f"Folder not found: {request.folder_path}")
# Run the blocking upload operation in a thread pool
loop = asyncio.get_event_loop()
await loop.run_in_executor(
executor,
upload_folder_to_dataset,
request.dataset_repo_id,
request.folder_path,
request.path_in_repo,
token
)
return {
"message": f"Folder uploaded successfully to {request.dataset_repo_id}",
"path_in_repo": request.path_in_repo
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
@app.post("/transfer")
async def transfer_model(request: TransferRequest, background_tasks: BackgroundTasks, token: str = Depends(get_token)):
"""Download a model and upload it to a dataset repository (combined operation)."""
try:
# Set default path in repo if not provided
path_in_repo = request.path_in_repo or f"model_data/{request.model_repo_id}/"
# Download the model
loop = asyncio.get_event_loop()
local_dir = await loop.run_in_executor(
executor,
download_full_model,
request.model_repo_id,
DOWNLOAD_DIR
)
# Upload to dataset
await loop.run_in_executor(
executor,
upload_folder_to_dataset,
request.dataset_repo_id,
local_dir,
path_in_repo,
token
)
# Clean up downloaded files in background
background_tasks.add_task(cleanup_download, local_dir)
return {
"message": f"Model {request.model_repo_id} transferred successfully to {request.dataset_repo_id}",
"path_in_repo": path_in_repo
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Transfer failed: {str(e)}")
def cleanup_download(local_dir: str):
"""Clean up downloaded files."""
try:
if os.path.exists(local_dir):
shutil.rmtree(local_dir)
print(f"Cleaned up: {local_dir}")
except Exception as e:
print(f"Cleanup failed: {str(e)}")
@app.get("/status")
async def get_status():
"""Get server status and available disk space."""
try:
disk_usage = shutil.disk_usage(DOWNLOAD_DIR)
return {
"status": "healthy",
"download_dir": DOWNLOAD_DIR,
"disk_space": {
"total": disk_usage.total,
"used": disk_usage.used,
"free": disk_usage.free
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
live
Jump to live