Fred808 commited on
Commit
9583479
·
verified ·
1 Parent(s): 2cb5059

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -116
app.py CHANGED
@@ -1,15 +1,15 @@
1
- from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
2
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
- from pydantic import BaseModel
4
  from huggingface_hub import snapshot_download, HfApi
5
  import os
6
  import shutil
7
- from typing import Optional
8
  import asyncio
9
  from concurrent.futures import ThreadPoolExecutor
 
10
 
11
- app = FastAPI(title="Hugging Face Model Transfer API", version="1.0.0")
12
- security = HTTPBearer()
 
 
13
 
14
  # Thread pool for running blocking operations
15
  executor = ThreadPoolExecutor(max_workers=2)
@@ -20,140 +20,100 @@ DOWNLOAD_DIR = "./downloaded_model_data"
20
  # Ensure the download directory exists
21
  os.makedirs(DOWNLOAD_DIR, exist_ok=True)
22
 
23
- class DownloadRequest(BaseModel):
24
- model_repo_id: str
25
- download_dir: Optional[str] = DOWNLOAD_DIR
26
-
27
- class UploadRequest(BaseModel):
28
- dataset_repo_id: str
29
- folder_path: str
30
- path_in_repo: str
31
-
32
- class TransferRequest(BaseModel):
33
- model_repo_id: str
34
- dataset_repo_id: str
35
- path_in_repo: Optional[str] = None
36
-
37
- def get_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
38
- """Extract and validate Hugging Face token from Authorization header."""
39
- return credentials.credentials
40
-
41
- def download_full_model(repo_id: str, download_dir: str) -> str:
42
- """Downloads an entire model from a Hugging Face model repository."""
43
- print(f"Downloading full model from {repo_id}...")
44
- local_dir = snapshot_download(repo_id=repo_id, cache_dir=download_dir)
45
- print(f"Downloaded to: {local_dir}")
46
- return local_dir
47
-
48
- def upload_folder_to_dataset(dataset_repo_id: str, folder_path: str, path_in_repo: str, token: str):
49
- """Uploads a folder to a Hugging Face dataset repository."""
50
- api = HfApi(token=token)
51
- print(f"Uploading {folder_path} to {dataset_repo_id} at {path_in_repo}...")
52
- api.upload_folder(
53
- folder_path=folder_path,
54
- path_in_repo=path_in_repo,
55
- repo_id=dataset_repo_id,
56
- repo_type="dataset",
57
- )
58
- print("Upload complete!")
59
 
60
- @app.get("/")
61
- async def root():
62
- """Health check endpoint."""
63
- return {"message": "Hugging Face Model Transfer API is running"}
 
 
 
 
 
 
64
 
65
- @app.post("/download")
66
- async def download_model(request: DownloadRequest, token: str = Depends(get_token)):
67
- """Download a model from Hugging Face model repository."""
68
  try:
69
- # Run the blocking download operation in a thread pool
70
- loop = asyncio.get_event_loop()
71
- local_dir = await loop.run_in_executor(
72
- executor,
73
- download_full_model,
74
- request.model_repo_id,
75
- request.download_dir
76
  )
77
-
78
- return {
79
- "message": f"Model {request.model_repo_id} downloaded successfully",
80
- "local_path": local_dir
81
- }
82
  except Exception as e:
83
- raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
 
84
 
85
- @app.post("/upload")
86
- async def upload_folder(request: UploadRequest, token: str = Depends(get_token)):
87
- """Upload a folder to a Hugging Face dataset repository."""
 
88
  try:
89
- # Check if folder exists
90
- if not os.path.exists(request.folder_path):
91
- raise HTTPException(status_code=404, detail=f"Folder not found: {request.folder_path}")
92
-
93
- # Run the blocking upload operation in a thread pool
94
- loop = asyncio.get_event_loop()
95
- await loop.run_in_executor(
96
- executor,
97
- upload_folder_to_dataset,
98
- request.dataset_repo_id,
99
- request.folder_path,
100
- request.path_in_repo,
101
- token
102
  )
103
-
104
- return {
105
- "message": f"Folder uploaded successfully to {request.dataset_repo_id}",
106
- "path_in_repo": request.path_in_repo
107
- }
108
- except HTTPException:
109
  raise
 
 
 
 
 
 
 
110
  except Exception as e:
111
- raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
112
 
113
- @app.post("/transfer")
114
- async def transfer_model(request: TransferRequest, background_tasks: BackgroundTasks, token: str = Depends(get_token)):
115
- """Download a model and upload it to a dataset repository (combined operation)."""
116
  try:
117
- # Set default path in repo if not provided
118
- path_in_repo = request.path_in_repo or f"model_data/{request.model_repo_id}/"
119
-
120
  # Download the model
121
  loop = asyncio.get_event_loop()
122
  local_dir = await loop.run_in_executor(
123
  executor,
124
- download_full_model,
125
- request.model_repo_id,
126
- DOWNLOAD_DIR
127
  )
128
 
129
  # Upload to dataset
130
  await loop.run_in_executor(
131
  executor,
132
  upload_folder_to_dataset,
133
- request.dataset_repo_id,
134
- local_dir,
135
- path_in_repo,
136
- token
137
  )
138
 
139
- # Clean up downloaded files in background
140
- background_tasks.add_task(cleanup_download, local_dir)
 
 
141
 
142
- return {
143
- "message": f"Model {request.model_repo_id} transferred successfully to {request.dataset_repo_id}",
144
- "path_in_repo": path_in_repo
145
- }
146
  except Exception as e:
147
- raise HTTPException(status_code=500, detail=f"Transfer failed: {str(e)}")
 
148
 
149
- def cleanup_download(local_dir: str):
150
- """Clean up downloaded files."""
151
- try:
152
- if os.path.exists(local_dir):
153
- shutil.rmtree(local_dir)
154
- print(f"Cleaned up: {local_dir}")
155
- except Exception as e:
156
- print(f"Cleanup failed: {str(e)}")
 
 
 
 
 
 
157
 
158
  @app.get("/status")
159
  async def get_status():
@@ -167,12 +127,13 @@ async def get_status():
167
  "total": disk_usage.total,
168
  "used": disk_usage.used,
169
  "free": disk_usage.free
170
- }
 
 
171
  }
172
  except Exception as e:
173
  raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
174
 
175
  if __name__ == "__main__":
176
  import uvicorn
177
- uvicorn.run(app, host="0.0.0.0", port=7860)
178
-
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
 
 
2
  from huggingface_hub import snapshot_download, HfApi
3
  import os
4
  import shutil
 
5
  import asyncio
6
  from concurrent.futures import ThreadPoolExecutor
7
+ from dotenv import load_dotenv
8
 
9
+ # Load environment variables from .env file
10
+ load_dotenv()
11
+
12
+ app = FastAPI(title="Hugging Face Model Transfer Service", version="1.0.0")
13
 
14
  # Thread pool for running blocking operations
15
  executor = ThreadPoolExecutor(max_workers=2)
 
20
  # Ensure the download directory exists
21
  os.makedirs(DOWNLOAD_DIR, exist_ok=True)
22
 
23
+ # Hardcoded model repository ID
24
+ HARDCODED_MODEL_REPO_ID = "openai/gpt-oss-120b" # Change this to your desired model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Hardcoded dataset repository ID
27
+ HARDCODED_DATASET_REPO_ID = "Fred808/helium_memory" # Change this to your dataset
28
+
29
+ # Hardcoded path in repository
30
+ HARDCODED_PATH_IN_REPO = "model_data2/"
31
+
32
+ # Get Hugging Face token from environment variable
33
+ HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
34
+ if not HF_TOKEN:
35
+ raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is not set")
36
 
37
+ def download_full_model() -> str:
38
+ """Downloads the hardcoded model from Hugging Face model repository."""
39
+ print(f"Downloading hardcoded model {HARDCODED_MODEL_REPO_ID}...")
40
  try:
41
+ local_dir = snapshot_download(
42
+ repo_id=HARDCODED_MODEL_REPO_ID,
43
+ cache_dir=DOWNLOAD_DIR,
44
+ token=HF_TOKEN
 
 
 
45
  )
46
+ print(f"Downloaded to: {local_dir}")
47
+ return local_dir
 
 
 
48
  except Exception as e:
49
+ print(f"Download failed: {str(e)}")
50
+ raise
51
 
52
+ def upload_folder_to_dataset(folder_path: str):
53
+ """Uploads a folder to the hardcoded Hugging Face dataset repository."""
54
+ api = HfApi(token=HF_TOKEN)
55
+ print(f"Uploading {folder_path} to {HARDCODED_DATASET_REPO_ID} at {HARDCODED_PATH_IN_REPO}...")
56
  try:
57
+ api.upload_folder(
58
+ folder_path=folder_path,
59
+ path_in_repo=HARDCODED_PATH_IN_REPO,
60
+ repo_id=HARDCODED_DATASET_REPO_ID,
61
+ repo_type="dataset",
 
 
 
 
 
 
 
 
62
  )
63
+ print("Upload complete!")
64
+ except Exception as e:
65
+ print(f"Upload failed: {str(e)}")
 
 
 
66
  raise
67
+
68
+ def cleanup_download(local_dir: str):
69
+ """Clean up downloaded files."""
70
+ try:
71
+ if os.path.exists(local_dir):
72
+ shutil.rmtree(local_dir)
73
+ print(f"Cleaned up: {local_dir}")
74
  except Exception as e:
75
+ print(f"Cleanup failed: {str(e)}")
76
 
77
+ async def transfer_model():
78
+ """Download the hardcoded model and upload it to the hardcoded dataset repository."""
 
79
  try:
 
 
 
80
  # Download the model
81
  loop = asyncio.get_event_loop()
82
  local_dir = await loop.run_in_executor(
83
  executor,
84
+ download_full_model
 
 
85
  )
86
 
87
  # Upload to dataset
88
  await loop.run_in_executor(
89
  executor,
90
  upload_folder_to_dataset,
91
+ local_dir
 
 
 
92
  )
93
 
94
+ # Clean up downloaded files
95
+ cleanup_download(local_dir)
96
+
97
+ print(f"Model {HARDCODED_MODEL_REPO_ID} transferred successfully to {HARDCODED_DATASET_REPO_ID}")
98
 
 
 
 
 
99
  except Exception as e:
100
+ print(f"Transfer failed: {str(e)}")
101
+ raise
102
 
103
+ @app.get("/")
104
+ async def root():
105
+ """Health check endpoint."""
106
+ return {
107
+ "message": "Hugging Face Model Transfer Service is running",
108
+ "hardcoded_model": HARDCODED_MODEL_REPO_ID,
109
+ "hardcoded_dataset": HARDCODED_DATASET_REPO_ID
110
+ }
111
+
112
+ @app.on_event("startup")
113
+ async def startup_event():
114
+ """Run the transfer process when the application starts."""
115
+ print("Starting model transfer process...")
116
+ asyncio.create_task(transfer_model())
117
 
118
  @app.get("/status")
119
  async def get_status():
 
127
  "total": disk_usage.total,
128
  "used": disk_usage.used,
129
  "free": disk_usage.free
130
+ },
131
+ "model": HARDCODED_MODEL_REPO_ID,
132
+ "dataset": HARDCODED_DATASET_REPO_ID
133
  }
134
  except Exception as e:
135
  raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
136
 
137
  if __name__ == "__main__":
138
  import uvicorn
139
+ uvicorn.run(app, host="0.0.0.0", port=7860)