OMol25DSF / main.py
amcgovern's picture
Update main.py
178e01a verified
import os
from typing import Optional, List
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from huggingface_hub import HfApi, hf_hub_url
# Configuration
APP_TITLE = "OMol25 API Service"
REPO_ID = "facebook/OMol25"
REPO_TYPE = "model"
# Get token from environment
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("WARNING: No HF_TOKEN found. Gated repo access will fail.")
# Initialize API
api = HfApi(token=HF_TOKEN)
app = FastAPI(
title=APP_TITLE,
servers=[{"url": "https://amcgovern-omol25dsf.hf.space"}]
)
# CORS setup for ChatGPT
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["GET"],
allow_headers=["*"],
)
# Response models
class FileListResponse(BaseModel):
count: int
files: List[str]
class FileURLResponse(BaseModel):
path: str
url: str
# Endpoints
@app.get("/")
def root():
return {"message": f"{APP_TITLE} is running", "endpoints": ["/health", "/files", "/url"]}
@app.get("/health")
def health():
return {"status": "ok", "repo": REPO_ID}
@app.get("/privacy")
def privacy():
return {"privacy": "No user data stored. Public repo access only."}
@app.get("/files", response_model=FileListResponse)
def list_files(q: Optional[str] = Query(None, description="Search filter")):
"""List files in OMol25 repository"""
try:
files = api.list_repo_files(REPO_ID, repo_type=REPO_TYPE)
if q:
files = [f for f in files if q.lower() in f.lower()]
return FileListResponse(count=len(files), files=files[:500])
except Exception as e:
raise HTTPException(status_code=503, detail=f"Cannot access OMol25: {str(e)}")
@app.get("/url", response_model=FileURLResponse)
def get_file_url(path: str = Query(..., description="File path in repo")):
"""Get direct download URL for a file"""
try:
url = hf_hub_url(repo_id=REPO_ID, filename=path, repo_type=REPO_TYPE)
return FileURLResponse(path=path, url=url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid path: {str(e)}")