INIC / app.py
Fred808's picture
Update app.py
82d0262 verified
from fastapi import FastAPI, HTTPException
from huggingface_hub import HfApi, hf_hub_download, list_repo_files
import os
import shutil
import asyncio
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
import tempfile
import time
# Load environment variables from .env file
load_dotenv()
app = FastAPI(title="Hugging Face Model Transfer Service", version="1.0.0")
# Thread pool for running blocking operations
executor = ThreadPoolExecutor(max_workers=2)
# Get Hugging Face token from environment variable
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
# Hardcoded model repository ID
HARDCODED_MODEL_REPO_ID = "openai/gpt-oss-120b" # Your model
# Hardcoded dataset repository ID
HARDCODED_DATASET_REPO_ID = "Fred808/helium_memory" # Your dataset
# Hardcoded path in repository
HARDCODED_PATH_IN_REPO = "model_data/gpt-oss-20b"
# Track transfer status
transfer_completed = False
transfer_error = None
transfer_started = False
temp_dir = None
def download_model_files() -> str:
"""Download all files from a model repository to a temporary directory."""
print(f"Downloading all files from model {HARDCODED_MODEL_REPO_ID}...")
# Create a temporary directory for the downloaded files
temp_dir = tempfile.mkdtemp(prefix="hf_model_")
print(f"Created temporary directory: {temp_dir}")
try:
# Get list of all files in the repository
api = HfApi(token=HF_TOKEN)
files = list_repo_files(
repo_id=HARDCODED_MODEL_REPO_ID,
repo_type="model",
token=HF_TOKEN
)
print(f"Found {len(files)} files to download")
# Download each file directly without using cache
for file_path in files:
if file_path.endswith('/'): # Skip directories
continue
print(f"Downloading: {file_path}")
# Create subdirectories if needed
file_dir = os.path.join(temp_dir, os.path.dirname(file_path))
os.makedirs(file_dir, exist_ok=True)
# Download the file directly to the target location
local_path = hf_hub_download(
repo_id=HARDCODED_MODEL_REPO_ID,
filename=file_path,
local_dir=temp_dir,
force_download=True,
token=HF_TOKEN
)
print(f"Downloaded to: {local_path}")
# Remove any cache directories that might have been created
for item in os.listdir(temp_dir):
item_path = os.path.join(temp_dir, item)
if item.startswith('.cache') and os.path.isdir(item_path):
shutil.rmtree(item_path)
print(f"Removed cache directory: {item_path}")
print(f"All files downloaded to: {temp_dir}")
print(f"Final directory contents: {os.listdir(temp_dir)}")
return temp_dir
except Exception as e:
# Clean up on error
shutil.rmtree(temp_dir, ignore_errors=True)
print(f"Download failed: {str(e)}")
raise
def upload_folder_to_dataset(folder_path: str):
"""Uploads a folder to the hardcoded Hugging Face dataset repository."""
api = HfApi(token=HF_TOKEN)
print(f"Uploading {folder_path} to {HARDCODED_DATASET_REPO_ID} at {HARDCODED_PATH_IN_REPO}...")
# Check what we're actually trying to upload
print(f"Folder contents: {os.listdir(folder_path)}")
try:
# Upload each file individually to avoid cache issues
for root, dirs, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, folder_path)
repo_path = os.path.join(HARDCODED_PATH_IN_REPO, relative_path).replace('\\', '/')
print(f"Uploading {file_path} to {repo_path}")
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=repo_path,
repo_id=HARDCODED_DATASET_REPO_ID,
repo_type="dataset",
token=HF_TOKEN
)
print(f"Successfully uploaded: {file}")
print("Upload complete!")
except Exception as e:
print(f"Upload failed: {str(e)}")
raise
async def run_transfer():
"""Run the transfer process and update status."""
global transfer_completed, transfer_error, transfer_started, temp_dir
transfer_started = True
try:
# Download the model files
loop = asyncio.get_event_loop()
temp_dir = await loop.run_in_executor(
executor,
download_model_files
)
# Verify files were downloaded
if not os.listdir(temp_dir):
raise Exception("No files were downloaded")
print(f"Final downloaded files: {os.listdir(temp_dir)}")
# Upload to dataset
await loop.run_in_executor(
executor,
upload_folder_to_dataset,
temp_dir
)
print(f"Model {HARDCODED_MODEL_REPO_ID} transferred successfully to {HARDCODED_DATASET_REPO_ID}")
transfer_completed = True
except Exception as e:
error_msg = f"Transfer failed: {str(e)}"
print(error_msg)
transfer_error = error_msg
@app.get("/")
async def root():
"""Health check endpoint."""
status = "completed" if transfer_completed else "running" if transfer_started else "not_started"
return {
"message": "Hugging Face Model Transfer Service is running",
"status": status,
"model": HARDCODED_MODEL_REPO_ID,
"dataset": HARDCODED_DATASET_REPO_ID,
"error": transfer_error,
"temp_dir": temp_dir if temp_dir else "Not created yet"
}
@app.get("/status")
async def get_status():
"""Get transfer status."""
status = "completed" if transfer_completed else "running" if transfer_started else "not_started"
return {
"status": status,
"model": HARDCODED_MODEL_REPO_ID,
"dataset": HARDCODED_DATASET_REPO_ID,
"path_in_repo": HARDCODED_PATH_IN_REPO,
"error": transfer_error,
"temp_dir": temp_dir if temp_dir else "Not created yet"
}
@app.get("/cleanup")
async def cleanup():
"""Manual cleanup endpoint to remove downloaded files."""
global temp_dir
if temp_dir and os.path.exists(temp_dir):
try:
shutil.rmtree(temp_dir)
message = f"Cleaned up temporary directory: {temp_dir}"
temp_dir = None
return {"message": message}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
else:
return {"message": "No temporary directory to clean up"}
@app.on_event("startup")
async def startup_event():
"""Run the transfer process when the application starts."""
print("Starting model transfer process...")
# Run transfer in background without waiting for completion
asyncio.create_task(run_transfer())
# Keep the server alive by preventing automatic shutdown
@app.middleware("http")
async def keep_alive_middleware(request, call_next):
response = await call_next(request)
return response
if __name__ == "__main__":
import uvicorn
# Run the server indefinitely with longer timeouts
config = uvicorn.Config(
app,
host="0.0.0.0",
port=7860,
timeout_keep_alive=600,
timeout_graceful_shutdown=600
)
server = uvicorn.Server(config)
server.run()