Update app.py
Browse files
app.py
CHANGED
|
@@ -4,13 +4,14 @@ import os
|
|
| 4 |
import requests
|
| 5 |
import subprocess
|
| 6 |
from subprocess import getoutput
|
| 7 |
-
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
| 8 |
|
| 9 |
hf_token = os.environ.get("HF_TOKEN_WITH_WRITE_PERMISSION")
|
| 10 |
|
| 11 |
is_shared_ui = True if "fffiloni/train-dreambooth-lora-sdxl" in os.environ['SPACE_ID'] else False
|
| 12 |
|
| 13 |
-
|
| 14 |
is_gpu_associated = torch.cuda.is_available()
|
| 15 |
if is_gpu_associated:
|
| 16 |
gpu_info = getoutput('nvidia-smi')
|
|
@@ -44,8 +45,7 @@ def get_sleep_time(hf_token):
|
|
| 44 |
return gcTimeout
|
| 45 |
|
| 46 |
def write_to_community(title, description,hf_token):
|
| 47 |
-
|
| 48 |
-
api = HfApi()
|
| 49 |
api.create_discussion(repo_id=os.environ['SPACE_ID'], title=title, description=description,repo_type="space", token=hf_token)
|
| 50 |
|
| 51 |
|
|
@@ -161,8 +161,9 @@ def main(dataset_id,
|
|
| 161 |
|
| 162 |
instance_data_dir = repo_parts[-1]
|
| 163 |
train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
|
| 164 |
-
|
| 165 |
-
|
|
|
|
| 166 |
|
| 167 |
css="""
|
| 168 |
#col-container {max-width: 780px; margin-left: auto; margin-right: auto;}
|
|
@@ -219,4 +220,4 @@ with gr.Blocks(css=css) as demo:
|
|
| 219 |
outputs = [status]
|
| 220 |
)
|
| 221 |
|
| 222 |
-
demo.queue().launch()
|
|
|
|
| 4 |
import requests
|
| 5 |
import subprocess
|
| 6 |
from subprocess import getoutput
|
| 7 |
+
from huggingface_hub import snapshot_download, HfApi
|
| 8 |
+
|
| 9 |
+
api = HfApi()
|
| 10 |
|
| 11 |
hf_token = os.environ.get("HF_TOKEN_WITH_WRITE_PERMISSION")
|
| 12 |
|
| 13 |
is_shared_ui = True if "fffiloni/train-dreambooth-lora-sdxl" in os.environ['SPACE_ID'] else False
|
| 14 |
|
|
|
|
| 15 |
is_gpu_associated = torch.cuda.is_available()
|
| 16 |
if is_gpu_associated:
|
| 17 |
gpu_info = getoutput('nvidia-smi')
|
|
|
|
| 45 |
return gcTimeout
|
| 46 |
|
| 47 |
def write_to_community(title, description,hf_token):
|
| 48 |
+
|
|
|
|
| 49 |
api.create_discussion(repo_id=os.environ['SPACE_ID'], title=title, description=description,repo_type="space", token=hf_token)
|
| 50 |
|
| 51 |
|
|
|
|
| 161 |
|
| 162 |
instance_data_dir = repo_parts[-1]
|
| 163 |
train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
|
| 164 |
+
|
| 165 |
+
your_username = api.whoami(token=hf_token)["name"]
|
| 166 |
+
return f"Done, your trained model has been stored in your models library: {your_username}/{lora_trained_xl_folder}"
|
| 167 |
|
| 168 |
css="""
|
| 169 |
#col-container {max-width: 780px; margin-left: auto; margin-right: auto;}
|
|
|
|
| 220 |
outputs = [status]
|
| 221 |
)
|
| 222 |
|
| 223 |
+
demo.queue(default_enabled=False).launch(debug=True)
|