thefreeham commited on
Commit
d10fb3e
·
1 Parent(s): bfe9080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -44
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import argparse
2
  import base64
3
  import os
@@ -14,66 +15,65 @@ from consts import ModelSize
14
 
15
  import gradio as gr
16
 
17
- def greet(name):
18
-
19
 
20
 
21
- app = Flask(__name__)
22
- CORS(app)
23
- print("--> Starting DALL-E Server. This might take up to two minutes.")
24
 
25
- from dalle_model import DalleModel
26
- dalle_model = None
27
 
28
- parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
29
- parser.add_argument("--port", type=int, default=8000, help = "backend port")
30
- parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
31
- parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
32
- args = parser.parse_args()
33
 
34
- @app.route("/dalle", methods=["POST"])
35
- @cross_origin()
36
- def generate_images_api():
37
- json_data = request.get_json(force=True)
38
- text_prompt = json_data["text"]
39
- num_images = json_data["num_images"]
40
- generated_imgs = dalle_model.generate_images(text_prompt, num_images)
41
 
42
- generated_images = []
43
- if args.save_to_disk:
44
- dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
45
- Path(dir_name).mkdir(parents=True, exist_ok=True)
46
 
47
- for idx, img in enumerate(generated_imgs):
48
- if args.save_to_disk:
49
- img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
50
 
51
- buffered = BytesIO()
52
- img.save(buffered, format="JPEG")
53
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
54
- generated_images.append(img_str)
55
 
56
- print(f"Created {num_images} images from text prompt [{text_prompt}]")
57
- return jsonify(generated_images)
58
 
59
 
60
- @app.route("/", methods=["GET"])
61
- @cross_origin()
62
- def health_check():
63
- return jsonify(success=True)
64
 
65
 
66
- with app.app_context():
67
- dalle_model = DalleModel(args.model_version)
68
- dalle_model.generate_images("warm-up", 1)
69
- print("--> DALL-E Server is up and running!")
70
- print(f"--> Model selected - DALL-E {args.model_version}")
71
 
72
-
73
 
74
- if __name__ == "__main__":
75
- app.run(host="0.0.0.0", port=args.port, debug=False)
76
 
 
77
  return "Hello " + name + "!!"
 
78
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
79
  iface.launch()
 
1
+
2
  import argparse
3
  import base64
4
  import os
 
15
 
16
  import gradio as gr
17
 
 
 
18
 
19
 
20
+ app = Flask(__name__)
21
+ CORS(app)
22
+ print("--> Starting DALL-E Server. This might take up to two minutes.")
23
 
24
+ from dalle_model import DalleModel
25
+ dalle_model = None
26
 
27
+ parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
28
+ parser.add_argument("--port", type=int, default=8000, help = "backend port")
29
+ parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
30
+ parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
31
+ args = parser.parse_args()
32
 
33
+ @app.route("/dalle", methods=["POST"])
34
+ @cross_origin()
35
+ def generate_images_api():
36
+ json_data = request.get_json(force=True)
37
+ text_prompt = json_data["text"]
38
+ num_images = json_data["num_images"]
39
+ generated_imgs = dalle_model.generate_images(text_prompt, num_images)
40
 
41
+ generated_images = []
42
+ if args.save_to_disk:
43
+ dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
44
+ Path(dir_name).mkdir(parents=True, exist_ok=True)
45
 
46
+ for idx, img in enumerate(generated_imgs):
47
+ if args.save_to_disk:
48
+ img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
49
 
50
+ buffered = BytesIO()
51
+ img.save(buffered, format="JPEG")
52
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
53
+ generated_images.append(img_str)
54
 
55
+ print(f"Created {num_images} images from text prompt [{text_prompt}]")
56
+ return jsonify(generated_images)
57
 
58
 
59
+ @app.route("/", methods=["GET"])
60
+ @cross_origin()
61
+ def health_check():
62
+ return jsonify(success=True)
63
 
64
 
65
+ with app.app_context():
66
+ dalle_model = DalleModel(args.model_version)
67
+ dalle_model.generate_images("warm-up", 1)
68
+ print("--> DALL-E Server is up and running!")
69
+ print(f"--> Model selected - DALL-E {args.model_version}")
70
 
 
71
 
72
+ if __name__ == "__main__":
73
+ app.run(host="0.0.0.0", port=args.port, debug=False)
74
 
75
+ def greet(name):
76
  return "Hello " + name + "!!"
77
+
78
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
79
  iface.launch()