File size: 7,839 Bytes
61fd549
 
abec798
970ed33
 
 
9583479
61fd549
a16d4ce
abec798
9583479
 
 
 
abec798
970ed33
 
abec798
61fd549
75e825d
61fd549
75e825d
abec798
9583479
82d0262
abec798
9583479
ffc809a
9583479
 
77e9073
61fd549
86e3cbf
 
 
a16d4ce
663422f
86e3cbf
61fd549
 
 
 
 
 
 
 
970ed33
61fd549
 
 
 
 
9583479
970ed33
61fd549
 
 
ffc809a
61fd549
 
 
 
 
 
 
 
 
 
ffc809a
61fd549
 
 
ffc809a
61fd549
 
 
 
ffc809a
 
 
 
 
a16d4ce
ffc809a
 
61fd549
 
ffc809a
61fd549
 
970ed33
61fd549
 
9583479
 
abec798
9583479
 
 
 
ffc809a
 
 
 
970ed33
ffc809a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a16d4ce
ffc809a
9583479
 
 
970ed33
9583479
86e3cbf
 
663422f
a16d4ce
970ed33
61fd549
970ed33
61fd549
970ed33
61fd549
970ed33
 
61fd549
 
 
 
ffc809a
61fd549
970ed33
 
 
 
61fd549
970ed33
 
9583479
86e3cbf
970ed33
 
86e3cbf
 
 
970ed33
9583479
 
 
a16d4ce
9583479
 
86e3cbf
 
 
663422f
 
86e3cbf
 
 
 
 
a16d4ce
86e3cbf
 
 
 
 
663422f
 
9583479
 
663422f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9583479
 
 
 
86e3cbf
 
970ed33
a16d4ce
 
 
 
 
 
970ed33
 
a16d4ce
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
from fastapi import FastAPI, HTTPException
from huggingface_hub import HfApi, hf_hub_download, list_repo_files
import os
import shutil
import asyncio
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
import tempfile
import time

# Load environment variables from .env file
load_dotenv()

app = FastAPI(title="Hugging Face Model Transfer Service", version="1.0.0")

# Thread pool for running blocking operations
executor = ThreadPoolExecutor(max_workers=2)

# Get Hugging Face token from environment variable
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
    raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")

# Hardcoded model repository ID
HARDCODED_MODEL_REPO_ID = "openai/gpt-oss-120b"  # Your model

# Hardcoded dataset repository ID
HARDCODED_DATASET_REPO_ID = "Fred808/helium_memory"  # Your dataset

# Hardcoded path in repository
HARDCODED_PATH_IN_REPO = "model_data/gpt-oss-20b"

# Track transfer status
transfer_completed = False
transfer_error = None
transfer_started = False
temp_dir = None

def download_model_files() -> str:
    """Download all files from a model repository to a temporary directory."""
    print(f"Downloading all files from model {HARDCODED_MODEL_REPO_ID}...")
    
    # Create a temporary directory for the downloaded files
    temp_dir = tempfile.mkdtemp(prefix="hf_model_")
    print(f"Created temporary directory: {temp_dir}")
    
    try:
        # Get list of all files in the repository
        api = HfApi(token=HF_TOKEN)
        files = list_repo_files(
            repo_id=HARDCODED_MODEL_REPO_ID,
            repo_type="model",
            token=HF_TOKEN
        )
        
        print(f"Found {len(files)} files to download")
        
        # Download each file directly without using cache
        for file_path in files:
            if file_path.endswith('/'):  # Skip directories
                continue
                
            print(f"Downloading: {file_path}")
            
            # Create subdirectories if needed
            file_dir = os.path.join(temp_dir, os.path.dirname(file_path))
            os.makedirs(file_dir, exist_ok=True)
            
            # Download the file directly to the target location
            local_path = hf_hub_download(
                repo_id=HARDCODED_MODEL_REPO_ID,
                filename=file_path,
                local_dir=temp_dir,
                force_download=True,
                token=HF_TOKEN
            )
            
            print(f"Downloaded to: {local_path}")
        
        # Remove any cache directories that might have been created
        for item in os.listdir(temp_dir):
            item_path = os.path.join(temp_dir, item)
            if item.startswith('.cache') and os.path.isdir(item_path):
                shutil.rmtree(item_path)
                print(f"Removed cache directory: {item_path}")
        
        print(f"All files downloaded to: {temp_dir}")
        print(f"Final directory contents: {os.listdir(temp_dir)}")
        return temp_dir
        
    except Exception as e:
        # Clean up on error
        shutil.rmtree(temp_dir, ignore_errors=True)
        print(f"Download failed: {str(e)}")
        raise

def upload_folder_to_dataset(folder_path: str):
    """Uploads a folder to the hardcoded Hugging Face dataset repository."""
    api = HfApi(token=HF_TOKEN)
    print(f"Uploading {folder_path} to {HARDCODED_DATASET_REPO_ID} at {HARDCODED_PATH_IN_REPO}...")
    
    # Check what we're actually trying to upload
    print(f"Folder contents: {os.listdir(folder_path)}")
    
    try:
        # Upload each file individually to avoid cache issues
        for root, dirs, files in os.walk(folder_path):
            for file in files:
                file_path = os.path.join(root, file)
                relative_path = os.path.relpath(file_path, folder_path)
                repo_path = os.path.join(HARDCODED_PATH_IN_REPO, relative_path).replace('\\', '/')
                
                print(f"Uploading {file_path} to {repo_path}")
                
                api.upload_file(
                    path_or_fileobj=file_path,
                    path_in_repo=repo_path,
                    repo_id=HARDCODED_DATASET_REPO_ID,
                    repo_type="dataset",
                    token=HF_TOKEN
                )
                print(f"Successfully uploaded: {file}")
        
        print("Upload complete!")
    except Exception as e:
        print(f"Upload failed: {str(e)}")
        raise

async def run_transfer():
    """Run the transfer process and update status."""
    global transfer_completed, transfer_error, transfer_started, temp_dir
    transfer_started = True
    try:
        # Download the model files
        loop = asyncio.get_event_loop()
        temp_dir = await loop.run_in_executor(
            executor,
            download_model_files
        )
        
        # Verify files were downloaded
        if not os.listdir(temp_dir):
            raise Exception("No files were downloaded")
        
        print(f"Final downloaded files: {os.listdir(temp_dir)}")
        
        # Upload to dataset
        await loop.run_in_executor(
            executor,
            upload_folder_to_dataset,
            temp_dir
        )
        
        print(f"Model {HARDCODED_MODEL_REPO_ID} transferred successfully to {HARDCODED_DATASET_REPO_ID}")
        transfer_completed = True
        
    except Exception as e:
        error_msg = f"Transfer failed: {str(e)}"
        print(error_msg)
        transfer_error = error_msg

@app.get("/")
async def root():
    """Health check endpoint."""
    status = "completed" if transfer_completed else "running" if transfer_started else "not_started"
    return {
        "message": "Hugging Face Model Transfer Service is running", 
        "status": status,
        "model": HARDCODED_MODEL_REPO_ID,
        "dataset": HARDCODED_DATASET_REPO_ID,
        "error": transfer_error,
        "temp_dir": temp_dir if temp_dir else "Not created yet"
    }

@app.get("/status")
async def get_status():
    """Get transfer status."""
    status = "completed" if transfer_completed else "running" if transfer_started else "not_started"
    return {
        "status": status,
        "model": HARDCODED_MODEL_REPO_ID,
        "dataset": HARDCODED_DATASET_REPO_ID,
        "path_in_repo": HARDCODED_PATH_IN_REPO,
        "error": transfer_error,
        "temp_dir": temp_dir if temp_dir else "Not created yet"
    }

@app.get("/cleanup")
async def cleanup():
    """Manual cleanup endpoint to remove downloaded files."""
    global temp_dir
    if temp_dir and os.path.exists(temp_dir):
        try:
            shutil.rmtree(temp_dir)
            message = f"Cleaned up temporary directory: {temp_dir}"
            temp_dir = None
            return {"message": message}
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
    else:
        return {"message": "No temporary directory to clean up"}

@app.on_event("startup")
async def startup_event():
    """Run the transfer process when the application starts."""
    print("Starting model transfer process...")
    # Run transfer in background without waiting for completion
    asyncio.create_task(run_transfer())

# Keep the server alive by preventing automatic shutdown
@app.middleware("http")
async def keep_alive_middleware(request, call_next):
    response = await call_next(request)
    return response

if __name__ == "__main__":
    import uvicorn
    # Run the server indefinitely with longer timeouts
    config = uvicorn.Config(
        app, 
        host="0.0.0.0", 
        port=7860, 
        timeout_keep_alive=600,
        timeout_graceful_shutdown=600
    )
    server = uvicorn.Server(config)
    server.run()