Update app.py
Browse files
app.py
CHANGED
@@ -191,19 +191,36 @@ if __name__ == "__main__":
|
|
191 |
|
192 |
@dataclasses.dataclass
|
193 |
class AppArgs:
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
parser = HfArgumentParser([AppArgs])
|
203 |
args_tuple = parser.parse_args_into_dataclasses()
|
204 |
args = args_tuple[0]
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
demo = create_demo(args.model_name, args.device)
|
207 |
demo.launch(server_port=args.port, share=args.share)
|
208 |
|
209 |
-
<<add option choose local_model when run app.py>>
|
|
|
191 |
|
192 |
@dataclasses.dataclass
|
193 |
class AppArgs:
|
194 |
+
local_model: bool = dataclasses.field(
|
195 |
+
default=False, metadata={"help": "Use local model path instead of Hugging Face model."}
|
196 |
+
)
|
197 |
+
model_name: str = dataclasses.field(
|
198 |
+
default="danhtran2mind/ghibli-fine-tuned-sd-2.1",
|
199 |
+
metadata={"help": "Model name or path for the fine-tuned Stable Diffusion model."}
|
200 |
+
)
|
201 |
+
device: str = dataclasses.field(
|
202 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
203 |
+
metadata={"help": "Device to run the model on (e.g., 'cuda', 'cpu')."}
|
204 |
+
)
|
205 |
+
port: int = dataclasses.field(
|
206 |
+
default=7860, metadata={"help": "Port to run the Gradio app on."}
|
207 |
+
)
|
208 |
+
share: bool = dataclasses.field(
|
209 |
+
default=False, metadata={"help": "Set to True for public sharing (Hugging Face Spaces)."}
|
210 |
+
)
|
211 |
|
212 |
parser = HfArgumentParser([AppArgs])
|
213 |
args_tuple = parser.parse_args_into_dataclasses()
|
214 |
args = args_tuple[0]
|
215 |
|
216 |
+
# Set model_name based on local_model flag
|
217 |
+
if args.local_model:
|
218 |
+
args.model_name = "ghibli-fine-tuned-sd-2.1"
|
219 |
+
|
220 |
+
demo = create_demo(args.model_name, args.device)
|
221 |
+
demo.launch(server_port=args.port, share=args.share)
|
222 |
+
|
223 |
demo = create_demo(args.model_name, args.device)
|
224 |
demo.launch(server_port=args.port, share=args.share)
|
225 |
|
226 |
+
<<add option choose local_model when run app.py>>
|