Update app.py
Browse files
app.py
CHANGED
@@ -17,9 +17,9 @@ app = FastAPI(title="Hugging Face Model Transfer Service", version="1.0.0")
|
|
17 |
executor = ThreadPoolExecutor(max_workers=2)
|
18 |
|
19 |
# Get Hugging Face token from environment variable
|
20 |
-
HF_TOKEN = os.getenv("
|
21 |
if not HF_TOKEN:
|
22 |
-
raise ValueError("
|
23 |
|
24 |
# Hardcoded model repository ID
|
25 |
HARDCODED_MODEL_REPO_ID = "openai/gpt-oss-20b" # Your model
|
@@ -28,12 +28,13 @@ HARDCODED_MODEL_REPO_ID = "openai/gpt-oss-20b" # Your model
|
|
28 |
HARDCODED_DATASET_REPO_ID = "Fred808/helium_memory" # Your dataset
|
29 |
|
30 |
# Hardcoded path in repository
|
31 |
-
HARDCODED_PATH_IN_REPO = "model_data/
|
32 |
|
33 |
# Track transfer status
|
34 |
transfer_completed = False
|
35 |
transfer_error = None
|
36 |
transfer_started = False
|
|
|
37 |
|
38 |
def download_model_files() -> str:
|
39 |
"""Download all files from a model repository to a temporary directory."""
|
@@ -125,20 +126,10 @@ def upload_folder_to_dataset(folder_path: str):
|
|
125 |
print(f"Upload failed: {str(e)}")
|
126 |
raise
|
127 |
|
128 |
-
def cleanup_download(temp_dir: str):
|
129 |
-
"""Clean up downloaded files."""
|
130 |
-
try:
|
131 |
-
if os.path.exists(temp_dir):
|
132 |
-
shutil.rmtree(temp_dir)
|
133 |
-
print(f"Cleaned up temporary directory: {temp_dir}")
|
134 |
-
except Exception as e:
|
135 |
-
print(f"Cleanup failed: {str(e)}")
|
136 |
-
|
137 |
async def run_transfer():
|
138 |
"""Run the transfer process and update status."""
|
139 |
-
global transfer_completed, transfer_error, transfer_started
|
140 |
transfer_started = True
|
141 |
-
temp_dir = None
|
142 |
try:
|
143 |
# Download the model files
|
144 |
loop = asyncio.get_event_loop()
|
@@ -167,10 +158,6 @@ async def run_transfer():
|
|
167 |
error_msg = f"Transfer failed: {str(e)}"
|
168 |
print(error_msg)
|
169 |
transfer_error = error_msg
|
170 |
-
finally:
|
171 |
-
# Clean up downloaded files
|
172 |
-
if temp_dir:
|
173 |
-
cleanup_download(temp_dir)
|
174 |
|
175 |
@app.get("/")
|
176 |
async def root():
|
@@ -181,7 +168,8 @@ async def root():
|
|
181 |
"status": status,
|
182 |
"model": HARDCODED_MODEL_REPO_ID,
|
183 |
"dataset": HARDCODED_DATASET_REPO_ID,
|
184 |
-
"error": transfer_error
|
|
|
185 |
}
|
186 |
|
187 |
@app.get("/status")
|
@@ -193,9 +181,25 @@ async def get_status():
|
|
193 |
"model": HARDCODED_MODEL_REPO_ID,
|
194 |
"dataset": HARDCODED_DATASET_REPO_ID,
|
195 |
"path_in_repo": HARDCODED_PATH_IN_REPO,
|
196 |
-
"error": transfer_error
|
|
|
197 |
}
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
@app.on_event("startup")
|
200 |
async def startup_event():
|
201 |
"""Run the transfer process when the application starts."""
|
|
|
17 |
executor = ThreadPoolExecutor(max_workers=2)
|
18 |
|
19 |
# Get Hugging Face token from environment variable
|
20 |
+
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
21 |
if not HF_TOKEN:
|
22 |
+
raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is not set")
|
23 |
|
24 |
# Hardcoded model repository ID
|
25 |
HARDCODED_MODEL_REPO_ID = "openai/gpt-oss-20b" # Your model
|
|
|
28 |
HARDCODED_DATASET_REPO_ID = "Fred808/helium_memory" # Your dataset
|
29 |
|
30 |
# Hardcoded path in repository
|
31 |
+
HARDCODED_PATH_IN_REPO = "model_data/"
|
32 |
|
33 |
# Track transfer status
|
34 |
transfer_completed = False
|
35 |
transfer_error = None
|
36 |
transfer_started = False
|
37 |
+
temp_dir = None
|
38 |
|
39 |
def download_model_files() -> str:
|
40 |
"""Download all files from a model repository to a temporary directory."""
|
|
|
126 |
print(f"Upload failed: {str(e)}")
|
127 |
raise
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
async def run_transfer():
|
130 |
"""Run the transfer process and update status."""
|
131 |
+
global transfer_completed, transfer_error, transfer_started, temp_dir
|
132 |
transfer_started = True
|
|
|
133 |
try:
|
134 |
# Download the model files
|
135 |
loop = asyncio.get_event_loop()
|
|
|
158 |
error_msg = f"Transfer failed: {str(e)}"
|
159 |
print(error_msg)
|
160 |
transfer_error = error_msg
|
|
|
|
|
|
|
|
|
161 |
|
162 |
@app.get("/")
|
163 |
async def root():
|
|
|
168 |
"status": status,
|
169 |
"model": HARDCODED_MODEL_REPO_ID,
|
170 |
"dataset": HARDCODED_DATASET_REPO_ID,
|
171 |
+
"error": transfer_error,
|
172 |
+
"temp_dir": temp_dir if temp_dir else "Not created yet"
|
173 |
}
|
174 |
|
175 |
@app.get("/status")
|
|
|
181 |
"model": HARDCODED_MODEL_REPO_ID,
|
182 |
"dataset": HARDCODED_DATASET_REPO_ID,
|
183 |
"path_in_repo": HARDCODED_PATH_IN_REPO,
|
184 |
+
"error": transfer_error,
|
185 |
+
"temp_dir": temp_dir if temp_dir else "Not created yet"
|
186 |
}
|
187 |
|
188 |
+
@app.get("/cleanup")
|
189 |
+
async def cleanup():
|
190 |
+
"""Manual cleanup endpoint to remove downloaded files."""
|
191 |
+
global temp_dir
|
192 |
+
if temp_dir and os.path.exists(temp_dir):
|
193 |
+
try:
|
194 |
+
shutil.rmtree(temp_dir)
|
195 |
+
message = f"Cleaned up temporary directory: {temp_dir}"
|
196 |
+
temp_dir = None
|
197 |
+
return {"message": message}
|
198 |
+
except Exception as e:
|
199 |
+
raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
|
200 |
+
else:
|
201 |
+
return {"message": "No temporary directory to clean up"}
|
202 |
+
|
203 |
@app.on_event("startup")
|
204 |
async def startup_event():
|
205 |
"""Run the transfer process when the application starts."""
|