thefreeham commited on
Commit
68040ae
·
1 Parent(s): 5e22df1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -50
app.py CHANGED
@@ -15,66 +15,64 @@ from consts import ModelSize
15
  import gradio as gr
16
 
17
  def greet(name):
 
 
 
18
  app = Flask(__name__)
19
  CORS(app)
20
  print("--> Starting DALL-E Server. This might take up to two minutes.")
21
- return "Hello " + name + "!!"
22
 
23
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
24
- iface.launch()
25
-
26
-
27
- app = Flask(__name__)
28
- CORS(app)
29
- print("--> Starting DALL-E Server. This might take up to two minutes.")
30
-
31
- from dalle_model import DalleModel
32
- dalle_model = None
33
-
34
- parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
35
- parser.add_argument("--port", type=int, default=8000, help = "backend port")
36
- parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
37
- parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
38
- args = parser.parse_args()
39
-
40
- @app.route("/dalle", methods=["POST"])
41
- @cross_origin()
42
- def generate_images_api():
43
- json_data = request.get_json(force=True)
44
- text_prompt = json_data["text"]
45
- num_images = json_data["num_images"]
46
- generated_imgs = dalle_model.generate_images(text_prompt, num_images)
47
-
48
- generated_images = []
49
- if args.save_to_disk:
50
- dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
51
- Path(dir_name).mkdir(parents=True, exist_ok=True)
52
-
53
- for idx, img in enumerate(generated_imgs):
54
  if args.save_to_disk:
55
- img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
 
 
 
 
 
56
 
57
- buffered = BytesIO()
58
- img.save(buffered, format="JPEG")
59
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
60
- generated_images.append(img_str)
61
 
62
- print(f"Created {num_images} images from text prompt [{text_prompt}]")
63
- return jsonify(generated_images)
64
 
65
 
66
- @app.route("/", methods=["GET"])
67
- @cross_origin()
68
- def health_check():
69
- return jsonify(success=True)
70
 
71
 
72
- with app.app_context():
73
- dalle_model = DalleModel(args.model_version)
74
- dalle_model.generate_images("warm-up", 1)
75
- print("--> DALL-E Server is up and running!")
76
- print(f"--> Model selected - DALL-E {args.model_version}")
77
 
78
 
79
- if __name__ == "__main__":
80
- app.run(host="0.0.0.0", port=args.port, debug=False)
 
 
 
 
 
15
  import gradio as gr
16
 
17
  def greet(name):
18
+
19
+ return "Hello " + name + "!!"
20
+
21
  app = Flask(__name__)
22
  CORS(app)
23
  print("--> Starting DALL-E Server. This might take up to two minutes.")
 
24
 
25
+ from dalle_model import DalleModel
26
+ dalle_model = None
27
+
28
+ parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
29
+ parser.add_argument("--port", type=int, default=8000, help = "backend port")
30
+ parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
31
+ parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
32
+ args = parser.parse_args()
33
+
34
+ @app.route("/dalle", methods=["POST"])
35
+ @cross_origin()
36
+ def generate_images_api():
37
+ json_data = request.get_json(force=True)
38
+ text_prompt = json_data["text"]
39
+ num_images = json_data["num_images"]
40
+ generated_imgs = dalle_model.generate_images(text_prompt, num_images)
41
+
42
+ generated_images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if args.save_to_disk:
44
+ dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
45
+ Path(dir_name).mkdir(parents=True, exist_ok=True)
46
+
47
+ for idx, img in enumerate(generated_imgs):
48
+ if args.save_to_disk:
49
+ img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
50
 
51
+ buffered = BytesIO()
52
+ img.save(buffered, format="JPEG")
53
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
54
+ generated_images.append(img_str)
55
 
56
+ print(f"Created {num_images} images from text prompt [{text_prompt}]")
57
+ return jsonify(generated_images)
58
 
59
 
60
+ @app.route("/", methods=["GET"])
61
+ @cross_origin()
62
+ def health_check():
63
+ return jsonify(success=True)
64
 
65
 
66
+ with app.app_context():
67
+ dalle_model = DalleModel(args.model_version)
68
+ dalle_model.generate_images("warm-up", 1)
69
+ print("--> DALL-E Server is up and running!")
70
+ print(f"--> Model selected - DALL-E {args.model_version}")
71
 
72
 
73
+ if __name__ == "__main__":
74
+ app.run(host="0.0.0.0", port=args.port, debug=False)
75
+
76
+
77
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
78
+ iface.launch()