Upload convert_repo_to_safetensors_gr.py
Browse files
convert_repo_to_safetensors_gr.py
CHANGED
@@ -289,8 +289,7 @@ def convert_openai_text_enc_state_dict(text_enc_dict):
|
|
289 |
return text_enc_dict
|
290 |
|
291 |
|
292 |
-
def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16"
|
293 |
-
progress(0, desc="Start converting...")
|
294 |
# Path for safetensors
|
295 |
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
296 |
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
@@ -351,12 +350,10 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16",
|
|
351 |
elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
|
352 |
|
353 |
save_file(state_dict, checkpoint_path)
|
354 |
-
progress(1, desc="Converted.")
|
355 |
|
356 |
|
357 |
-
def download_repo(repo_id, dir_path
|
358 |
try:
|
359 |
-
progress(0, desc="Start downloading...")
|
360 |
snapshot_download(repo_id=repo_id, local_dir=dir_path, token=get_token())
|
361 |
except Exception as e:
|
362 |
print(f"Error: Failed to download {repo_id}. {e}")
|
@@ -383,8 +380,11 @@ def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progres
|
|
383 |
def convert_repo_to_safetensors(repo_id, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
|
384 |
download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
|
385 |
output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
|
386 |
-
|
387 |
-
|
|
|
|
|
|
|
388 |
return output_filename
|
389 |
|
390 |
|
@@ -394,11 +394,11 @@ def convert_repo_to_safetensors_multi(repo_id, hf_token, files, urls, dtype="fp1
|
|
394 |
if hf_token: HfFolder.save_token(hf_token)
|
395 |
else: HfFolder.save_token(os.environ.get("HF_TOKEN"))
|
396 |
if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
|
397 |
-
file = convert_repo_to_safetensors(repo_id, dtype
|
398 |
if not urls: urls = []
|
399 |
url = ""
|
400 |
if is_upload:
|
401 |
-
url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private
|
402 |
if url: urls.append(url)
|
403 |
md = ""
|
404 |
for u in urls:
|
|
|
289 |
return text_enc_dict
|
290 |
|
291 |
|
292 |
+
def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16"):
|
|
|
293 |
# Path for safetensors
|
294 |
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
295 |
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
|
|
350 |
elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
|
351 |
|
352 |
save_file(state_dict, checkpoint_path)
|
|
|
353 |
|
354 |
|
355 |
+
def download_repo(repo_id, dir_path):
|
356 |
try:
|
|
|
357 |
snapshot_download(repo_id=repo_id, local_dir=dir_path, token=get_token())
|
358 |
except Exception as e:
|
359 |
print(f"Error: Failed to download {repo_id}. {e}")
|
|
|
380 |
def convert_repo_to_safetensors(repo_id, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
|
381 |
download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
|
382 |
output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
|
383 |
+
progress(0, desc="Start downloading...")
|
384 |
+
download_repo(repo_id, download_dir)
|
385 |
+
progress(0, desc="Start converting...")
|
386 |
+
convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
|
387 |
+
progress(1, desc="Converted.")
|
388 |
return output_filename
|
389 |
|
390 |
|
|
|
394 |
if hf_token: HfFolder.save_token(hf_token)
|
395 |
else: HfFolder.save_token(os.environ.get("HF_TOKEN"))
|
396 |
if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
|
397 |
+
file = convert_repo_to_safetensors(repo_id, dtype)
|
398 |
if not urls: urls = []
|
399 |
url = ""
|
400 |
if is_upload:
|
401 |
+
url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
|
402 |
if url: urls.append(url)
|
403 |
md = ""
|
404 |
for u in urls:
|