pandaphd commited on
Commit
5da9f4f
·
1 Parent(s): 94080d2

fix version

Browse files
Files changed (1) hide show
  1. app.py +10 -54
app.py CHANGED
@@ -7,63 +7,19 @@ import subprocess
7
  import gradio as gr
8
  from omegaconf import OmegaConf
9
 
10
- # ================== 新增的下载逻辑 ==================
11
- MODEL_REPO = "pandaphd/generative_photography"
12
- BRANCH = "main"
13
- LOCAL_DIR = "ckpts"
14
-
15
-
16
- def download_hf_folder():
17
- os.makedirs(LOCAL_DIR, exist_ok=True)
18
-
19
- def get_file_list(path=""):
20
- api_url = f"https://huggingface.co/{MODEL_REPO}/tree/{BRANCH}/{path}"
21
- try:
22
- response = requests.get(api_url, timeout=10)
23
- response.raise_for_status()
24
- return response.json()
25
- except Exception as e:
26
- raise RuntimeError(f"Failed to get file list: {str(e)}")
27
-
28
- def download_file(remote_path):
29
- url = f"https://huggingface.co/{MODEL_REPO}/tree/{BRANCH}/{remote_path}"
30
- local_path = os.path.join(LOCAL_DIR, remote_path)
31
- os.makedirs(os.path.dirname(local_path), exist_ok=True)
32
-
33
- # 使用 wget 下载(支持断点续传)
34
- cmd = [
35
- "wget", "-c", "-q", "--show-progress",
36
- "-O", local_path, url
37
- ]
38
- try:
39
- subprocess.run(cmd, check=True)
40
- print(f"Downloaded: {remote_path}")
41
- except subprocess.CalledProcessError:
42
- print(f"Failed to download: {remote_path}")
43
- raise
44
-
45
- print("Downloading models from Hugging Face...")
46
-
47
- # 递归下载函数
48
- def download_recursive(path=""):
49
- for item in get_file_list(path):
50
- item_path = os.path.join(path, item["path"])
51
- if item["type"] == "file":
52
- download_file(item_path)
53
- elif item["type"] == "directory":
54
- download_recursive(item_path)
55
 
56
- try:
57
- download_recursive()
58
- print("All files downloaded successfully!")
59
- except Exception as e:
60
- print(f"Critical error during download: {str(e)}")
61
- exit(1)
62
 
 
 
 
 
 
 
63
 
64
- # ================== 执行下载 ==================
65
- if not os.path.exists(os.path.join(LOCAL_DIR, "models")):
66
- download_hf_folder()
67
 
68
 
69
  torch.manual_seed(42)
 
7
  import gradio as gr
8
  from omegaconf import OmegaConf
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ model_path = "ckpts"
12
+ os.makedirs(model_path, exist_ok=True)
 
 
 
 
13
 
14
+ # Download files from Hugging Face repository using wget
15
+ def download_huggingface_files():
16
+ # URL to the Hugging Face repository
17
+ repo_url = "https://huggingface.co/pandaphd/generative_photography/tree/main/"
18
+ # Using wget to download all the files from the given Hugging Face repository
19
+ subprocess.run(["wget", "-r", "-np", "-nH", "--cut-dirs=3", "-P", model_path, repo_url])
20
 
21
+ print("Downloading models from Hugging Face...")
22
+ download_huggingface_files()
 
23
 
24
 
25
  torch.manual_seed(42)