matteomarjanovic commited on
Commit
21e5fc0
·
1 Parent(s): a15d459

add prompt generation

Browse files
Files changed (2) hide show
  1. app.py +44 -3
  2. requirements.txt +2 -1
app.py CHANGED
@@ -6,9 +6,11 @@ import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
  import torch
8
  import subprocess
 
9
 
10
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
11
 
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model_repo_id = "black-forest-labs/FLUX.1-schnell" # Replace to the model you would like to use
14
  lora_path = "matteomarjanovic/flatsketcher"
@@ -26,6 +28,11 @@ pipe.load_lora_weights(lora_path, weight_name=weigths_file)
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 1024
28
 
 
 
 
 
 
29
 
30
  @spaces.GPU #[uncomment to use ZeroGPU]
31
  def infer(
@@ -56,6 +63,28 @@ def infer(
56
 
57
  return image, seed
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  examples = [
61
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
@@ -70,11 +99,15 @@ css = """
70
  }
71
  """
72
 
 
 
73
  with gr.Blocks(css=css) as demo:
74
  with gr.Row():
75
  with gr.Column(elem_id="col-input-image"):
76
  gr.Markdown(" # Drop your image here")
77
- gr.Image()
 
 
78
  with gr.Column(elem_id="col-container"):
79
  gr.Markdown(" # Text-to-Image Gradio Template")
80
 
@@ -148,8 +181,16 @@ with gr.Blocks(css=css) as demo:
148
  triggers=[run_button.click, prompt.submit],
149
  fn=infer,
150
  inputs=[
151
- prompt,
152
- negative_prompt,
 
 
 
 
 
 
 
 
153
  seed,
154
  randomize_seed,
155
  width,
 
6
  from diffusers import DiffusionPipeline
7
  import torch
8
  import subprocess
9
+ from transformers import IdeficsForVisionText2Text, AutoProcessor
10
 
11
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
12
 
13
+ # Load FLUX image generator
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  model_repo_id = "black-forest-labs/FLUX.1-schnell" # Replace to the model you would like to use
16
  lora_path = "matteomarjanovic/flatsketcher"
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  MAX_IMAGE_SIZE = 1024
30
 
31
+ # Load IDEFICS model for generate the prompt
32
+ checkpoint = "HuggingFaceM4/idefics-9b"
33
+ processor = AutoProcessor.from_pretrained(checkpoint)
34
+ idefics_model = IdeficsForVisionText2Text.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto")
35
+
36
 
37
  @spaces.GPU #[uncomment to use ZeroGPU]
38
  def infer(
 
63
 
64
  return image, seed
65
 
66
+ @spaces.GPU #[uncomment to use ZeroGPU]
67
+ def generate_description_fn(
68
+ image,
69
+ progress=gr.Progress(track_tqdm=True),
70
+ ):
71
+ if randomize_seed:
72
+ seed = random.randint(0, MAX_SEED)
73
+
74
+ prompt = [
75
+ "https://images.unsplash.com/photo-1583160247711-2191776b4b91?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=3542&q=80",
76
+ ]
77
+
78
+ generator = torch.Generator().manual_seed(seed)
79
+
80
+ inputs = processor(prompt, return_tensors="pt").to("cuda")
81
+ bad_words_ids = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
82
+
83
+ generated_ids = idefics_model.generate(**inputs, max_new_tokens=10, bad_words_ids=bad_words_ids)
84
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
85
+
86
+ return generated_text[0]
87
+
88
 
89
  examples = [
90
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
 
99
  }
100
  """
101
 
102
+ generated_prompt = ""
103
+
104
  with gr.Blocks(css=css) as demo:
105
  with gr.Row():
106
  with gr.Column(elem_id="col-input-image"):
107
  gr.Markdown(" # Drop your image here")
108
+ input_image = gr.Image()
109
+ generate_button = gr.Button("Generate", scale=0, variant="primary")
110
+ generated_prompt_md = gr.Markdown(generated_prompt)
111
  with gr.Column(elem_id="col-container"):
112
  gr.Markdown(" # Text-to-Image Gradio Template")
113
 
 
181
  triggers=[run_button.click, prompt.submit],
182
  fn=infer,
183
  inputs=[
184
+ input_image
185
+ ],
186
+ outputs=[generated_prompt],
187
+ )
188
+
189
+ gr.on(
190
+ triggers=[generate_button.click],
191
+ fn=generate_description_fn,
192
+ inputs=[
193
+ input_image,
194
  seed,
195
  randomize_seed,
196
  width,
requirements.txt CHANGED
@@ -5,4 +5,5 @@ torch
5
  transformers
6
  xformers
7
  sentencepiece
8
- peft
 
 
5
  transformers
6
  xformers
7
  sentencepiece
8
+ peft
9
+ bitsandbytes