Fred808 commited on
Commit
61fd549
·
verified ·
1 Parent(s): 9583479

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -44
app.py CHANGED
@@ -1,10 +1,11 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks
2
- from huggingface_hub import snapshot_download, HfApi
3
  import os
4
  import shutil
5
  import asyncio
6
  from concurrent.futures import ThreadPoolExecutor
7
  from dotenv import load_dotenv
 
8
 
9
  # Load environment variables from .env file
10
  load_dotenv()
@@ -14,38 +15,71 @@ app = FastAPI(title="Hugging Face Model Transfer Service", version="1.0.0")
14
  # Thread pool for running blocking operations
15
  executor = ThreadPoolExecutor(max_workers=2)
16
 
17
- # Local directory to save downloaded model data temporarily
18
- DOWNLOAD_DIR = "./downloaded_model_data"
19
-
20
- # Ensure the download directory exists
21
- os.makedirs(DOWNLOAD_DIR, exist_ok=True)
22
 
23
  # Hardcoded model repository ID
24
- HARDCODED_MODEL_REPO_ID = "openai/gpt-oss-120b" # Change this to your desired model
25
 
26
  # Hardcoded dataset repository ID
27
- HARDCODED_DATASET_REPO_ID = "Fred808/helium_memory" # Change this to your dataset
28
 
29
  # Hardcoded path in repository
30
- HARDCODED_PATH_IN_REPO = "model_data2/"
31
-
32
- # Get Hugging Face token from environment variable
33
- HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
34
- if not HF_TOKEN:
35
- raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is not set")
36
-
37
- def download_full_model() -> str:
38
- """Downloads the hardcoded model from Hugging Face model repository."""
39
- print(f"Downloading hardcoded model {HARDCODED_MODEL_REPO_ID}...")
40
  try:
41
- local_dir = snapshot_download(
42
- repo_id=HARDCODED_MODEL_REPO_ID,
43
- cache_dir=DOWNLOAD_DIR,
 
 
44
  token=HF_TOKEN
45
  )
46
- print(f"Downloaded to: {local_dir}")
47
- return local_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  except Exception as e:
 
 
49
  print(f"Download failed: {str(e)}")
50
  raise
51
 
@@ -59,46 +93,55 @@ def upload_folder_to_dataset(folder_path: str):
59
  path_in_repo=HARDCODED_PATH_IN_REPO,
60
  repo_id=HARDCODED_DATASET_REPO_ID,
61
  repo_type="dataset",
 
62
  )
63
  print("Upload complete!")
64
  except Exception as e:
65
  print(f"Upload failed: {str(e)}")
66
  raise
67
 
68
- def cleanup_download(local_dir: str):
69
  """Clean up downloaded files."""
70
  try:
71
- if os.path.exists(local_dir):
72
- shutil.rmtree(local_dir)
73
- print(f"Cleaned up: {local_dir}")
74
  except Exception as e:
75
  print(f"Cleanup failed: {str(e)}")
76
 
77
  async def transfer_model():
78
  """Download the hardcoded model and upload it to the hardcoded dataset repository."""
 
79
  try:
80
- # Download the model
81
  loop = asyncio.get_event_loop()
82
- local_dir = await loop.run_in_executor(
83
  executor,
84
- download_full_model
85
  )
86
 
 
 
 
 
 
 
87
  # Upload to dataset
88
  await loop.run_in_executor(
89
  executor,
90
  upload_folder_to_dataset,
91
- local_dir
92
  )
93
 
94
- # Clean up downloaded files
95
- cleanup_download(local_dir)
96
-
97
  print(f"Model {HARDCODED_MODEL_REPO_ID} transferred successfully to {HARDCODED_DATASET_REPO_ID}")
98
 
99
  except Exception as e:
100
  print(f"Transfer failed: {str(e)}")
101
  raise
 
 
 
 
102
 
103
  @app.get("/")
104
  async def root():
@@ -117,19 +160,13 @@ async def startup_event():
117
 
118
  @app.get("/status")
119
  async def get_status():
120
- """Get server status and available disk space."""
121
  try:
122
- disk_usage = shutil.disk_usage(DOWNLOAD_DIR)
123
  return {
124
  "status": "healthy",
125
- "download_dir": DOWNLOAD_DIR,
126
- "disk_space": {
127
- "total": disk_usage.total,
128
- "used": disk_usage.used,
129
- "free": disk_usage.free
130
- },
131
  "model": HARDCODED_MODEL_REPO_ID,
132
- "dataset": HARDCODED_DATASET_REPO_ID
 
133
  }
134
  except Exception as e:
135
  raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files
3
  import os
4
  import shutil
5
  import asyncio
6
  from concurrent.futures import ThreadPoolExecutor
7
  from dotenv import load_dotenv
8
+ import tempfile
9
 
10
  # Load environment variables from .env file
11
  load_dotenv()
 
15
  # Thread pool for running blocking operations
16
  executor = ThreadPoolExecutor(max_workers=2)
17
 
18
+ # Get Hugging Face token from environment variable
19
+ HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
20
+ if not HF_TOKEN:
21
+ raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is not set")
 
22
 
23
  # Hardcoded model repository ID
24
+ HARDCODED_MODEL_REPO_ID = "bert-base-uncased" # Change this to your desired model
25
 
26
  # Hardcoded dataset repository ID
27
+ HARDCODED_DATASET_REPO_ID = "your-username/your-dataset-name" # Change this to your dataset
28
 
29
  # Hardcoded path in repository
30
+ HARDCODED_PATH_IN_REPO = "model_data/"
31
+
32
+ def download_model_files() -> str:
33
+ """Download all files from a model repository to a temporary directory."""
34
+ print(f"Downloading all files from model {HARDCODED_MODEL_REPO_ID}...")
35
+
36
+ # Create a temporary directory for the downloaded files
37
+ temp_dir = tempfile.mkdtemp(prefix="hf_model_")
38
+ print(f"Created temporary directory: {temp_dir}")
39
+
40
  try:
41
+ # Get list of all files in the repository
42
+ api = HfApi(token=HF_TOKEN)
43
+ files = list_repo_files(
44
+ repo_id=HARDCODED_MODEL_REPO_ID,
45
+ repo_type="model",
46
  token=HF_TOKEN
47
  )
48
+
49
+ print(f"Found {len(files)} files to download")
50
+
51
+ # Download each file
52
+ for file_path in files:
53
+ if file_path.endswith('/'): # Skip directories
54
+ continue
55
+
56
+ print(f"Downloading: {file_path}")
57
+
58
+ # Create subdirectories if needed
59
+ file_dir = os.path.join(temp_dir, os.path.dirname(file_path))
60
+ os.makedirs(file_dir, exist_ok=True)
61
+
62
+ # Download the file
63
+ local_path = hf_hub_download(
64
+ repo_id=HARDCODED_MODEL_REPO_ID,
65
+ filename=file_path,
66
+ cache_dir=temp_dir,
67
+ force_download=True,
68
+ resume_download=False,
69
+ token=HF_TOKEN
70
+ )
71
+
72
+ # Move from cache to the desired location
73
+ final_path = os.path.join(temp_dir, file_path)
74
+ if local_path != final_path:
75
+ shutil.move(local_path, final_path)
76
+
77
+ print(f"All files downloaded to: {temp_dir}")
78
+ return temp_dir
79
+
80
  except Exception as e:
81
+ # Clean up on error
82
+ shutil.rmtree(temp_dir, ignore_errors=True)
83
  print(f"Download failed: {str(e)}")
84
  raise
85
 
 
93
  path_in_repo=HARDCODED_PATH_IN_REPO,
94
  repo_id=HARDCODED_DATASET_REPO_ID,
95
  repo_type="dataset",
96
+ token=HF_TOKEN
97
  )
98
  print("Upload complete!")
99
  except Exception as e:
100
  print(f"Upload failed: {str(e)}")
101
  raise
102
 
103
+ def cleanup_download(temp_dir: str):
104
  """Clean up downloaded files."""
105
  try:
106
+ if os.path.exists(temp_dir):
107
+ shutil.rmtree(temp_dir)
108
+ print(f"Cleaned up temporary directory: {temp_dir}")
109
  except Exception as e:
110
  print(f"Cleanup failed: {str(e)}")
111
 
112
  async def transfer_model():
113
  """Download the hardcoded model and upload it to the hardcoded dataset repository."""
114
+ temp_dir = None
115
  try:
116
+ # Download the model files
117
  loop = asyncio.get_event_loop()
118
+ temp_dir = await loop.run_in_executor(
119
  executor,
120
+ download_model_files
121
  )
122
 
123
+ # Verify files were downloaded
124
+ if not os.listdir(temp_dir):
125
+ raise Exception("No files were downloaded")
126
+
127
+ print(f"Downloaded files: {os.listdir(temp_dir)}")
128
+
129
  # Upload to dataset
130
  await loop.run_in_executor(
131
  executor,
132
  upload_folder_to_dataset,
133
+ temp_dir
134
  )
135
 
 
 
 
136
  print(f"Model {HARDCODED_MODEL_REPO_ID} transferred successfully to {HARDCODED_DATASET_REPO_ID}")
137
 
138
  except Exception as e:
139
  print(f"Transfer failed: {str(e)}")
140
  raise
141
+ finally:
142
+ # Clean up downloaded files
143
+ if temp_dir:
144
+ cleanup_download(temp_dir)
145
 
146
  @app.get("/")
147
  async def root():
 
160
 
161
  @app.get("/status")
162
  async def get_status():
163
+ """Get server status."""
164
  try:
 
165
  return {
166
  "status": "healthy",
 
 
 
 
 
 
167
  "model": HARDCODED_MODEL_REPO_ID,
168
+ "dataset": HARDCODED_DATASET_REPO_ID,
169
+ "path_in_repo": HARDCODED_PATH_IN_REPO
170
  }
171
  except Exception as e:
172
  raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")