patrickvonplaten commited on
Commit
8772541
·
1 Parent(s): dfaea05
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -21,7 +21,7 @@ def error_str(error, title="Error"):
21
 
22
  def inference(
23
  repo_id,
24
- pr,
25
  prompt,
26
  ):
27
 
@@ -35,12 +35,13 @@ def inference(
35
  dtype = torch.float16 if torch_device == "cuda" else torch.float32
36
 
37
  try:
38
- pipe = DiffusionPipeline.from_pretrained(repo_id, revision=pr, torch_dtype=dtype)
 
39
  pipe.to(torch_device)
40
 
41
  return pipe(prompt, generator=generator, num_inference_steps=25).images
42
  except Exception as e:
43
- url = f"https://huggingface.co/{repo_id}/discussions/{pr.split('/')[-1]}"
44
  message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n"
45
  return None, error_str(message + e)
46
 
@@ -64,11 +65,11 @@ with gr.Blocks(css="style.css") as demo:
64
  with gr.Group():
65
  repo_id = gr.Textbox(
66
  label="Repo id on Hub",
67
- placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4",
68
  )
69
- pr = gr.Textbox(
70
- label="PR branch",
71
- placeholder="PR branch that should be checked, e.g. refs/pr/171",
72
  )
73
  prompt = gr.Textbox(
74
  label="Prompt",
@@ -87,7 +88,7 @@ with gr.Blocks(css="style.css") as demo:
87
 
88
  inputs = [
89
  repo_id,
90
- pr,
91
  prompt,
92
  ]
93
  outputs = [gallery, error_output]
 
21
 
22
  def inference(
23
  repo_id,
24
+ discuss_nr,
25
  prompt,
26
  ):
27
 
 
35
  dtype = torch.float16 if torch_device == "cuda" else torch.float32
36
 
37
  try:
38
+ revision = f"refs/pr/{discuss_nr}"
39
+ pipe = DiffusionPipeline.from_pretrained(repo_id, revision=revision, torch_dtype=dtype)
40
  pipe.to(torch_device)
41
 
42
  return pipe(prompt, generator=generator, num_inference_steps=25).images
43
  except Exception as e:
44
+ url = f"https://huggingface.co/{repo_id}/discussions/{discuss_nr}"
45
  message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n"
46
  return None, error_str(message + e)
47
 
 
65
  with gr.Group():
66
  repo_id = gr.Textbox(
67
  label="Repo id on Hub",
68
+ placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4 for https://huggingface.co/CompVis/stable-diffusion-v1-4",
69
  )
70
+ discuss_nr = gr.Textbox(
71
+ label="Discussion number",
72
+ placeholder="Number of the discussion that should be checked, e.g. 171 for https://huggingface.co/CompVis/stable-diffusion-v1-4/discussions/171",
73
  )
74
  prompt = gr.Textbox(
75
  label="Prompt",
 
88
 
89
  inputs = [
90
  repo_id,
91
+ discuss_nr,
92
  prompt,
93
  ]
94
  outputs = [gallery, error_output]