Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify | |
from gradio_client import Client, handle_file | |
import base64 | |
import io | |
import requests | |
from PIL import Image | |
app = Flask(__name__) | |
# client = Client("multimodalart/flux-style-shaping") | |
def get_random_api_key(): | |
keys = os.getenv("KEYS", "").split(",") | |
if keys and keys[0]: # Check if KEYS is set and not empty | |
return random.choice(keys).strip() | |
else: | |
raise ValueError("API keys not found. Please set the KEYS environment variable.") | |
def translate_to_english(prompt): | |
language = detect(prompt) | |
if language != 'en': | |
prompt = GoogleTranslator(source=language, target='en').translate(prompt) | |
return prompt | |
def handle_image_input(image_data): | |
"""Функция для обработки разных форматов входных изображений.""" | |
if image_data.startswith("http://") or image_data.startswith("https://"): | |
try: | |
response = requests.get(image_data, stream=True) | |
response.raise_for_status() # Проверяем, что запрос успешен | |
image = Image.open(io.BytesIO(response.content)) | |
image_bytes = io.BytesIO() | |
image.save(image_bytes, format="PNG") # Сохраняем в PNG для единообразия | |
image_bytes = image_bytes.getvalue() | |
return handle_file(image_bytes) | |
except requests.exceptions.RequestException as e: | |
print(f"Ошибка при загрузке изображения по URL: {e}") | |
return None | |
elif image_data.startswith("data:image"): # Base64 | |
try: | |
header, encoded = image_data.split(',', 1) | |
image_bytes = base64.b64decode(encoded) | |
return handle_file(image_bytes) | |
except Exception as e: | |
print(f"Ошибка при декодировании base64: {e}") | |
return None | |
else: # Предполгаем, что это просто путь к файлу | |
try: | |
with open(image_data, "rb") as f: | |
image_bytes = f.read() | |
return handle_file(image_bytes) | |
except Exception as e: | |
print(f"Ошибка при открытии файла: {e}") | |
return None | |
def generate_image(): | |
prompt = request.args.get('prompt') | |
image1_data = request.args.get('image1') | |
image2_data = request.args.get('image2') | |
depth_strength = request.args.get('depth_strength', default=15, type=float) | |
style_strength = request.args.get('style_strength', default=0.5, type=float) | |
if not prompt or not image1_data or not image2_data: | |
return jsonify({"error": "Missing required parameters: prompt, image1, and image2."}), 400 | |
structure_image = handle_image_input(image1_data) | |
style_image = handle_image_input(image2_data) | |
if not structure_image or not style_image: | |
return jsonify({"error": "Failed to process one or both of the input images."}), 400 | |
prompt = translate_to_english(prompt) if prompt else "" | |
try: | |
client = Client("multimodalart/flux-style-shaping", hf_token=get_random_api_key()) | |
result = client.predict( | |
prompt=prompt, | |
structure_image=structure_image, | |
style_image=style_image, | |
depth_strength=depth_strength, | |
style_strength=style_strength, | |
api_name="/generate_image" | |
) | |
# Преобразуем результат в base64 | |
with open(result["path"], "rb") as image_file: | |
encoded_result = base64.b64encode(image_file.read()).decode('utf-8') | |
return jsonify({"generated_image": encoded_result}) | |
except Exception as e: | |
return jsonify({"error": f"Error during image generation: {str(e)}"}), 500 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860, debug=True) | |