from fastapi import FastAPI, HTTPException from huggingface_hub import HfApi, hf_hub_download, list_repo_files import os import shutil import asyncio from concurrent.futures import ThreadPoolExecutor from dotenv import load_dotenv import tempfile import time # Load environment variables from .env file load_dotenv() app = FastAPI(title="Hugging Face Model Transfer Service", version="1.0.0") # Thread pool for running blocking operations executor = ThreadPoolExecutor(max_workers=2) # Get Hugging Face token from environment variable HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") if not HF_TOKEN: raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") # Hardcoded model repository ID HARDCODED_MODEL_REPO_ID = "openai/gpt-oss-120b" # Your model # Hardcoded dataset repository ID HARDCODED_DATASET_REPO_ID = "Fred808/helium_memory" # Your dataset # Hardcoded path in repository HARDCODED_PATH_IN_REPO = "model_data/gpt-oss-20b" # Track transfer status transfer_completed = False transfer_error = None transfer_started = False temp_dir = None def download_model_files() -> str: """Download all files from a model repository to a temporary directory.""" print(f"Downloading all files from model {HARDCODED_MODEL_REPO_ID}...") # Create a temporary directory for the downloaded files temp_dir = tempfile.mkdtemp(prefix="hf_model_") print(f"Created temporary directory: {temp_dir}") try: # Get list of all files in the repository api = HfApi(token=HF_TOKEN) files = list_repo_files( repo_id=HARDCODED_MODEL_REPO_ID, repo_type="model", token=HF_TOKEN ) print(f"Found {len(files)} files to download") # Download each file directly without using cache for file_path in files: if file_path.endswith('/'): # Skip directories continue print(f"Downloading: {file_path}") # Create subdirectories if needed file_dir = os.path.join(temp_dir, os.path.dirname(file_path)) os.makedirs(file_dir, exist_ok=True) # Download the file directly to the target location local_path = hf_hub_download( repo_id=HARDCODED_MODEL_REPO_ID, filename=file_path, local_dir=temp_dir, force_download=True, token=HF_TOKEN ) print(f"Downloaded to: {local_path}") # Remove any cache directories that might have been created for item in os.listdir(temp_dir): item_path = os.path.join(temp_dir, item) if item.startswith('.cache') and os.path.isdir(item_path): shutil.rmtree(item_path) print(f"Removed cache directory: {item_path}") print(f"All files downloaded to: {temp_dir}") print(f"Final directory contents: {os.listdir(temp_dir)}") return temp_dir except Exception as e: # Clean up on error shutil.rmtree(temp_dir, ignore_errors=True) print(f"Download failed: {str(e)}") raise def upload_folder_to_dataset(folder_path: str): """Uploads a folder to the hardcoded Hugging Face dataset repository.""" api = HfApi(token=HF_TOKEN) print(f"Uploading {folder_path} to {HARDCODED_DATASET_REPO_ID} at {HARDCODED_PATH_IN_REPO}...") # Check what we're actually trying to upload print(f"Folder contents: {os.listdir(folder_path)}") try: # Upload each file individually to avoid cache issues for root, dirs, files in os.walk(folder_path): for file in files: file_path = os.path.join(root, file) relative_path = os.path.relpath(file_path, folder_path) repo_path = os.path.join(HARDCODED_PATH_IN_REPO, relative_path).replace('\\', '/') print(f"Uploading {file_path} to {repo_path}") api.upload_file( path_or_fileobj=file_path, path_in_repo=repo_path, repo_id=HARDCODED_DATASET_REPO_ID, repo_type="dataset", token=HF_TOKEN ) print(f"Successfully uploaded: {file}") print("Upload complete!") except Exception as e: print(f"Upload failed: {str(e)}") raise async def run_transfer(): """Run the transfer process and update status.""" global transfer_completed, transfer_error, transfer_started, temp_dir transfer_started = True try: # Download the model files loop = asyncio.get_event_loop() temp_dir = await loop.run_in_executor( executor, download_model_files ) # Verify files were downloaded if not os.listdir(temp_dir): raise Exception("No files were downloaded") print(f"Final downloaded files: {os.listdir(temp_dir)}") # Upload to dataset await loop.run_in_executor( executor, upload_folder_to_dataset, temp_dir ) print(f"Model {HARDCODED_MODEL_REPO_ID} transferred successfully to {HARDCODED_DATASET_REPO_ID}") transfer_completed = True except Exception as e: error_msg = f"Transfer failed: {str(e)}" print(error_msg) transfer_error = error_msg @app.get("/") async def root(): """Health check endpoint.""" status = "completed" if transfer_completed else "running" if transfer_started else "not_started" return { "message": "Hugging Face Model Transfer Service is running", "status": status, "model": HARDCODED_MODEL_REPO_ID, "dataset": HARDCODED_DATASET_REPO_ID, "error": transfer_error, "temp_dir": temp_dir if temp_dir else "Not created yet" } @app.get("/status") async def get_status(): """Get transfer status.""" status = "completed" if transfer_completed else "running" if transfer_started else "not_started" return { "status": status, "model": HARDCODED_MODEL_REPO_ID, "dataset": HARDCODED_DATASET_REPO_ID, "path_in_repo": HARDCODED_PATH_IN_REPO, "error": transfer_error, "temp_dir": temp_dir if temp_dir else "Not created yet" } @app.get("/cleanup") async def cleanup(): """Manual cleanup endpoint to remove downloaded files.""" global temp_dir if temp_dir and os.path.exists(temp_dir): try: shutil.rmtree(temp_dir) message = f"Cleaned up temporary directory: {temp_dir}" temp_dir = None return {"message": message} except Exception as e: raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}") else: return {"message": "No temporary directory to clean up"} @app.on_event("startup") async def startup_event(): """Run the transfer process when the application starts.""" print("Starting model transfer process...") # Run transfer in background without waiting for completion asyncio.create_task(run_transfer()) # Keep the server alive by preventing automatic shutdown @app.middleware("http") async def keep_alive_middleware(request, call_next): response = await call_next(request) return response if __name__ == "__main__": import uvicorn # Run the server indefinitely with longer timeouts config = uvicorn.Config( app, host="0.0.0.0", port=7860, timeout_keep_alive=600, timeout_graceful_shutdown=600 ) server = uvicorn.Server(config) server.run()