danhtran2mind commited on
Commit
3a6f563
·
verified ·
1 Parent(s): bcb6da1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -8
app.py CHANGED
@@ -191,19 +191,36 @@ if __name__ == "__main__":
191
 
192
  @dataclasses.dataclass
193
  class AppArgs:
194
- if local_model == True:
195
- model_name: str = "ghibli-fine-tuned-sd-2.1"
196
- else:
197
- model_name: str = "danhtran2mind/ghibli-fine-tuned-sd-2.1"
198
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
199
- port: int = 7860
200
- share: bool = False # Set to True for public sharing (Hugging Face Spaces)
 
 
 
 
 
 
 
 
 
 
201
 
202
  parser = HfArgumentParser([AppArgs])
203
  args_tuple = parser.parse_args_into_dataclasses()
204
  args = args_tuple[0]
205
 
 
 
 
 
 
 
 
206
  demo = create_demo(args.model_name, args.device)
207
  demo.launch(server_port=args.port, share=args.share)
208
 
209
- <<add option choose local_model when run app.py>>
 
191
 
192
  @dataclasses.dataclass
193
  class AppArgs:
194
+ local_model: bool = dataclasses.field(
195
+ default=False, metadata={"help": "Use local model path instead of Hugging Face model."}
196
+ )
197
+ model_name: str = dataclasses.field(
198
+ default="danhtran2mind/ghibli-fine-tuned-sd-2.1",
199
+ metadata={"help": "Model name or path for the fine-tuned Stable Diffusion model."}
200
+ )
201
+ device: str = dataclasses.field(
202
+ default="cuda" if torch.cuda.is_available() else "cpu",
203
+ metadata={"help": "Device to run the model on (e.g., 'cuda', 'cpu')."}
204
+ )
205
+ port: int = dataclasses.field(
206
+ default=7860, metadata={"help": "Port to run the Gradio app on."}
207
+ )
208
+ share: bool = dataclasses.field(
209
+ default=False, metadata={"help": "Set to True for public sharing (Hugging Face Spaces)."}
210
+ )
211
 
212
  parser = HfArgumentParser([AppArgs])
213
  args_tuple = parser.parse_args_into_dataclasses()
214
  args = args_tuple[0]
215
 
216
+ # Set model_name based on local_model flag
217
+ if args.local_model:
218
+ args.model_name = "ghibli-fine-tuned-sd-2.1"
219
+
220
+ demo = create_demo(args.model_name, args.device)
221
+ demo.launch(server_port=args.port, share=args.share)
222
+
223
  demo = create_demo(args.model_name, args.device)
224
  demo.launch(server_port=args.port, share=args.share)
225
 
226
+ <<add option choose local_model when run app.py>>