thefreeham commited on
Commit
5d2696c
·
1 Parent(s): 3f614da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -67
app.py CHANGED
@@ -1,76 +1,20 @@
1
-
2
- import argparse
3
- import base64
4
- import os
5
- from pathlib import Path
6
- from io import BytesIO
7
- import time
8
-
9
- from flask import Flask, request, jsonify
10
- from flask_cors import CORS, cross_origin
11
- from consts import IMAGES_OUTPUT_DIR
12
- from utils import parse_arg_boolean, parse_arg_dalle_version
13
- from consts import ModelSize
14
-
15
-
16
  import gradio as gr
 
17
 
 
18
 
 
19
 
20
- app = Flask(__name__)
21
- CORS(app)
22
- print("--> Starting DALL-E Server. This might take up to two minutes.")
23
-
24
- from dalle_model import DalleModel
25
- dalle_model = None
26
-
27
- parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
28
- parser.add_argument("--port", type=int, default=8000, help = "backend port")
29
- parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
30
- parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
31
- args = parser.parse_args()
32
-
33
- @app.route("/dalle", methods=["POST"])
34
- @cross_origin()
35
- def generate_images_api():
36
- json_data = request.get_json(force=True)
37
- text_prompt = json_data["text"]
38
- num_images = json_data["num_images"]
39
- generated_imgs = dalle_model.generate_images(text_prompt, num_images)
40
-
41
- generated_images = []
42
- if args.save_to_disk:
43
- dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
44
- Path(dir_name).mkdir(parents=True, exist_ok=True)
45
-
46
- for idx, img in enumerate(generated_imgs):
47
- if args.save_to_disk:
48
- img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
49
-
50
- buffered = BytesIO()
51
- img.save(buffered, format="JPEG")
52
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
53
- generated_images.append(img_str)
54
-
55
- print(f"Created {num_images} images from text prompt [{text_prompt}]")
56
- return jsonify(generated_images)
57
-
58
-
59
- @app.route("/", methods=["GET"])
60
- @cross_origin()
61
- def health_check():
62
- return jsonify(success=True)
63
-
64
-
65
- with app.app_context():
66
- dalle_model = DalleModel(args.model_version)
67
- dalle_model.generate_images("warm-up", 1)
68
- print("--> DALL-E Server is up and running!")
69
- print(f"--> Model selected - DALL-E {args.model_version}")
70
 
 
 
 
 
71
 
72
- if __name__ == "__main__":
73
- app.run(host="0.0.0.0", port=args.port, debug=False)
74
 
75
  def greet(name):
76
  return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from diffusers import DiffusionPipeline
3
 
4
+ ldm = DiffusionPipeline.from_pretrained("fusing/latent-diffusion-text2im-large")
5
 
6
+ generator = torch.manual_seed(42)
7
 
8
+ prompt = "A painting of a squirrel eating a burger"
9
+ image = ldm([prompt], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ image_processed = image.cpu().permute(0, 2, 3, 1)
12
+ image_processed = image_processed * 255.
13
+ image_processed = image_processed.numpy().astype(np.uint8)
14
+ image_pil = PIL.Image.fromarray(image_processed[0])
15
 
16
+ # save image
17
+ image_pil.save("test.png")
18
 
19
  def greet(name):
20
  return "Hello " + name + "!!"