|
from fastapi import FastAPI, HTTPException |
|
from huggingface_hub import HfApi, hf_hub_download, list_repo_files |
|
import os |
|
import shutil |
|
import asyncio |
|
from concurrent.futures import ThreadPoolExecutor |
|
from dotenv import load_dotenv |
|
import tempfile |
|
import time |
|
|
|
|
|
load_dotenv() |
|
|
|
app = FastAPI(title="Hugging Face Model Transfer Service", version="1.0.0") |
|
|
|
|
|
executor = ThreadPoolExecutor(max_workers=2) |
|
|
|
|
|
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") |
|
if not HF_TOKEN: |
|
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") |
|
|
|
|
|
HARDCODED_MODEL_REPO_ID = "openai/gpt-oss-120b" |
|
|
|
|
|
HARDCODED_DATASET_REPO_ID = "Fred808/helium_memory" |
|
|
|
|
|
HARDCODED_PATH_IN_REPO = "model_data/gpt-oss-20b" |
|
|
|
|
|
transfer_completed = False |
|
transfer_error = None |
|
transfer_started = False |
|
temp_dir = None |
|
|
|
def download_model_files() -> str: |
|
"""Download all files from a model repository to a temporary directory.""" |
|
print(f"Downloading all files from model {HARDCODED_MODEL_REPO_ID}...") |
|
|
|
|
|
temp_dir = tempfile.mkdtemp(prefix="hf_model_") |
|
print(f"Created temporary directory: {temp_dir}") |
|
|
|
try: |
|
|
|
api = HfApi(token=HF_TOKEN) |
|
files = list_repo_files( |
|
repo_id=HARDCODED_MODEL_REPO_ID, |
|
repo_type="model", |
|
token=HF_TOKEN |
|
) |
|
|
|
print(f"Found {len(files)} files to download") |
|
|
|
|
|
for file_path in files: |
|
if file_path.endswith('/'): |
|
continue |
|
|
|
print(f"Downloading: {file_path}") |
|
|
|
|
|
file_dir = os.path.join(temp_dir, os.path.dirname(file_path)) |
|
os.makedirs(file_dir, exist_ok=True) |
|
|
|
|
|
local_path = hf_hub_download( |
|
repo_id=HARDCODED_MODEL_REPO_ID, |
|
filename=file_path, |
|
local_dir=temp_dir, |
|
force_download=True, |
|
token=HF_TOKEN |
|
) |
|
|
|
print(f"Downloaded to: {local_path}") |
|
|
|
|
|
for item in os.listdir(temp_dir): |
|
item_path = os.path.join(temp_dir, item) |
|
if item.startswith('.cache') and os.path.isdir(item_path): |
|
shutil.rmtree(item_path) |
|
print(f"Removed cache directory: {item_path}") |
|
|
|
print(f"All files downloaded to: {temp_dir}") |
|
print(f"Final directory contents: {os.listdir(temp_dir)}") |
|
return temp_dir |
|
|
|
except Exception as e: |
|
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
print(f"Download failed: {str(e)}") |
|
raise |
|
|
|
def upload_folder_to_dataset(folder_path: str): |
|
"""Uploads a folder to the hardcoded Hugging Face dataset repository.""" |
|
api = HfApi(token=HF_TOKEN) |
|
print(f"Uploading {folder_path} to {HARDCODED_DATASET_REPO_ID} at {HARDCODED_PATH_IN_REPO}...") |
|
|
|
|
|
print(f"Folder contents: {os.listdir(folder_path)}") |
|
|
|
try: |
|
|
|
for root, dirs, files in os.walk(folder_path): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
relative_path = os.path.relpath(file_path, folder_path) |
|
repo_path = os.path.join(HARDCODED_PATH_IN_REPO, relative_path).replace('\\', '/') |
|
|
|
print(f"Uploading {file_path} to {repo_path}") |
|
|
|
api.upload_file( |
|
path_or_fileobj=file_path, |
|
path_in_repo=repo_path, |
|
repo_id=HARDCODED_DATASET_REPO_ID, |
|
repo_type="dataset", |
|
token=HF_TOKEN |
|
) |
|
print(f"Successfully uploaded: {file}") |
|
|
|
print("Upload complete!") |
|
except Exception as e: |
|
print(f"Upload failed: {str(e)}") |
|
raise |
|
|
|
async def run_transfer(): |
|
"""Run the transfer process and update status.""" |
|
global transfer_completed, transfer_error, transfer_started, temp_dir |
|
transfer_started = True |
|
try: |
|
|
|
loop = asyncio.get_event_loop() |
|
temp_dir = await loop.run_in_executor( |
|
executor, |
|
download_model_files |
|
) |
|
|
|
|
|
if not os.listdir(temp_dir): |
|
raise Exception("No files were downloaded") |
|
|
|
print(f"Final downloaded files: {os.listdir(temp_dir)}") |
|
|
|
|
|
await loop.run_in_executor( |
|
executor, |
|
upload_folder_to_dataset, |
|
temp_dir |
|
) |
|
|
|
print(f"Model {HARDCODED_MODEL_REPO_ID} transferred successfully to {HARDCODED_DATASET_REPO_ID}") |
|
transfer_completed = True |
|
|
|
except Exception as e: |
|
error_msg = f"Transfer failed: {str(e)}" |
|
print(error_msg) |
|
transfer_error = error_msg |
|
|
|
@app.get("/") |
|
async def root(): |
|
"""Health check endpoint.""" |
|
status = "completed" if transfer_completed else "running" if transfer_started else "not_started" |
|
return { |
|
"message": "Hugging Face Model Transfer Service is running", |
|
"status": status, |
|
"model": HARDCODED_MODEL_REPO_ID, |
|
"dataset": HARDCODED_DATASET_REPO_ID, |
|
"error": transfer_error, |
|
"temp_dir": temp_dir if temp_dir else "Not created yet" |
|
} |
|
|
|
@app.get("/status") |
|
async def get_status(): |
|
"""Get transfer status.""" |
|
status = "completed" if transfer_completed else "running" if transfer_started else "not_started" |
|
return { |
|
"status": status, |
|
"model": HARDCODED_MODEL_REPO_ID, |
|
"dataset": HARDCODED_DATASET_REPO_ID, |
|
"path_in_repo": HARDCODED_PATH_IN_REPO, |
|
"error": transfer_error, |
|
"temp_dir": temp_dir if temp_dir else "Not created yet" |
|
} |
|
|
|
@app.get("/cleanup") |
|
async def cleanup(): |
|
"""Manual cleanup endpoint to remove downloaded files.""" |
|
global temp_dir |
|
if temp_dir and os.path.exists(temp_dir): |
|
try: |
|
shutil.rmtree(temp_dir) |
|
message = f"Cleaned up temporary directory: {temp_dir}" |
|
temp_dir = None |
|
return {"message": message} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}") |
|
else: |
|
return {"message": "No temporary directory to clean up"} |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Run the transfer process when the application starts.""" |
|
print("Starting model transfer process...") |
|
|
|
asyncio.create_task(run_transfer()) |
|
|
|
|
|
@app.middleware("http") |
|
async def keep_alive_middleware(request, call_next): |
|
response = await call_next(request) |
|
return response |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
config = uvicorn.Config( |
|
app, |
|
host="0.0.0.0", |
|
port=7860, |
|
timeout_keep_alive=600, |
|
timeout_graceful_shutdown=600 |
|
) |
|
server = uvicorn.Server(config) |
|
server.run() |