Spaces:
Running
Running
| # Copyright (2024) Bytedance Ltd. and/or its affiliates | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from PIL import Image, ImageFilter | |
| from loguru import logger | |
| import requests | |
| import base64 | |
| import json | |
| import io | |
| # 接口 URL | |
| t2i_url = 'https://magicarena.bytedance.com/api/evaluate/v1/algo/process' | |
| #注意:正式上线环境需要不设置header | |
| headers = { | |
| 'X-TT-ENV': 'ppe_general_20', | |
| 'X-USE-PPE': '1' | |
| } | |
| class SeedT2ICaller(): | |
| def __init__(self, cfg, *args, **kwargs): | |
| self.cfg = cfg | |
| def generate(self, text, *args, **kwargs): | |
| try: | |
| logger.info("Generate images ...") | |
| req_json = json.dumps({ | |
| "prompt": str(text), | |
| "use_sr": True, | |
| "model_version": "general_v2.0_L", | |
| "req_schedule_conf": "general_v20_9B_pe" | |
| # "width": 64, | |
| # "height": 64 | |
| }) | |
| logger.info(f"{req_json}") | |
| # 请求发送 | |
| response = requests.post( | |
| t2i_url, | |
| headers=headers, | |
| data={ | |
| 'AlgoType': 1, | |
| 'ReqJson': req_json, | |
| } | |
| ) | |
| logger.info(f"header: {response.headers}") | |
| if response.status_code != 200: | |
| return None, False | |
| resp = response.json() | |
| if resp.get('code',{}) != 0: | |
| logger.info(f"response error {resp}") | |
| return None, False | |
| binary_data1 = resp.get('data', {}).get('BinaryData') | |
| binary_data = binary_data1[0] | |
| #logger.info(f"binary_data: {binary_data}") | |
| image = Image.open(io.BytesIO(base64.b64decode(binary_data))) | |
| #image.save('./t2i_image.png') | |
| image = image.resize((self.cfg['resolution'], self.cfg['resolution'])) | |
| return image, True | |
| except Exception as e: | |
| logger.exception("An error occurred during image generation.") | |
| return None, False | |
| class SeedEditCaller(): | |
| def __init__(self, cfg, *args, **kwargs): | |
| self.cfg = cfg | |
| def edit(self, image, edit, cfg_scale=0.5, *args, **kwargs): | |
| try: | |
| image_bytes = io.BytesIO() | |
| image.save(image_bytes, format='JPEG') # 或 format='PNG' | |
| logger.info("Edit images ...") | |
| req_json = json.dumps({ | |
| "prompt": str(edit), | |
| "model_version": "byteedit_v2.0", | |
| "scale": cfg_scale, | |
| }) | |
| logger.info(f"{req_json}") | |
| binary =base64.b64encode(image_bytes.getvalue()).decode('utf-8') | |
| # 请求发送 | |
| response = requests.post( | |
| t2i_url, | |
| headers=headers, | |
| data=json.dumps({ | |
| 'AlgoType': 2, | |
| 'ReqJson': req_json, | |
| 'BinaryData': [binary] | |
| # 'Base': base | |
| }) | |
| ) | |
| logger.info(f"header: {response.headers}") | |
| if response.status_code != 200: | |
| return None, False | |
| resp = response.json() | |
| if resp.get('code',{}) != 0: | |
| logger.info(f"response error {resp}") | |
| return None, False | |
| binary_data = resp.get('data', {}).get('BinaryData') | |
| image = Image.open(io.BytesIO(base64.b64decode(binary_data[0]))) | |
| return image, True | |
| except Exception as e: | |
| logger.exception("An error occurred during image generation.") | |
| return None, False | |
| if __name__ == "__main__": | |
| cfg_t2i = { | |
| "resolution": 611 | |
| } | |
| model_t2i = SeedT2ICaller(cfg_t2i) | |
| model_t2i.generate("a beautiful girl") | |
| image_path = "./t2i_image.png" | |
| with open(image_path, 'rb') as image: | |
| image_bytes = image.read() | |
| model_edit = SeedEditCaller(cfg_t2i) | |
| model_edit.edit(image=image_bytes,edit="please edit to a good man") |