John6666 commited on
Commit
ab9eb81
·
verified ·
1 Parent(s): edbad06

Upload convert_repo_to_safetensors_gr.py

Browse files
Files changed (1) hide show
  1. convert_repo_to_safetensors_gr.py +9 -9
convert_repo_to_safetensors_gr.py CHANGED
@@ -289,8 +289,7 @@ def convert_openai_text_enc_state_dict(text_enc_dict):
289
  return text_enc_dict
290
 
291
 
292
- def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
293
- progress(0, desc="Start converting...")
294
  # Path for safetensors
295
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
296
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
@@ -351,12 +350,10 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16",
351
  elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
352
 
353
  save_file(state_dict, checkpoint_path)
354
- progress(1, desc="Converted.")
355
 
356
 
357
- def download_repo(repo_id, dir_path, progress=gr.Progress(track_tqdm=True)):
358
  try:
359
- progress(0, desc="Start downloading...")
360
  snapshot_download(repo_id=repo_id, local_dir=dir_path, token=get_token())
361
  except Exception as e:
362
  print(f"Error: Failed to download {repo_id}. {e}")
@@ -383,8 +380,11 @@ def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progres
383
  def convert_repo_to_safetensors(repo_id, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
384
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
385
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
386
- download_repo(repo_id, download_dir, progress)
387
- convert_diffusers_to_safetensors(download_dir, output_filename, dtype, progress)
 
 
 
388
  return output_filename
389
 
390
 
@@ -394,11 +394,11 @@ def convert_repo_to_safetensors_multi(repo_id, hf_token, files, urls, dtype="fp1
394
  if hf_token: HfFolder.save_token(hf_token)
395
  else: HfFolder.save_token(os.environ.get("HF_TOKEN"))
396
  if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
397
- file = convert_repo_to_safetensors(repo_id, dtype, progress)
398
  if not urls: urls = []
399
  url = ""
400
  if is_upload:
401
- url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private, progress)
402
  if url: urls.append(url)
403
  md = ""
404
  for u in urls:
 
289
  return text_enc_dict
290
 
291
 
292
+ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16"):
 
293
  # Path for safetensors
294
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
295
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
 
350
  elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
351
 
352
  save_file(state_dict, checkpoint_path)
 
353
 
354
 
355
+ def download_repo(repo_id, dir_path):
356
  try:
 
357
  snapshot_download(repo_id=repo_id, local_dir=dir_path, token=get_token())
358
  except Exception as e:
359
  print(f"Error: Failed to download {repo_id}. {e}")
 
380
  def convert_repo_to_safetensors(repo_id, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
381
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
382
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
383
+ progress(0, desc="Start downloading...")
384
+ download_repo(repo_id, download_dir)
385
+ progress(0, desc="Start converting...")
386
+ convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
387
+ progress(1, desc="Converted.")
388
  return output_filename
389
 
390
 
 
394
  if hf_token: HfFolder.save_token(hf_token)
395
  else: HfFolder.save_token(os.environ.get("HF_TOKEN"))
396
  if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
397
+ file = convert_repo_to_safetensors(repo_id, dtype)
398
  if not urls: urls = []
399
  url = ""
400
  if is_upload:
401
+ url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
402
  if url: urls.append(url)
403
  md = ""
404
  for u in urls: