from flask import Flask, request, send_file, abort
import requests
import io
from PIL import Image
from cachetools import TTLCache, cached
import random
import os
import urllib.parse
import hashlib
from deep_translator import GoogleTranslator
from langdetect import detect

app = Flask(__name__)

# Максимальные значения для ширины и высоты
MAX_WIDTH = 1384
MAX_HEIGHT = 1384

# Кэш на 10 минут
cache = TTLCache(maxsize=100, ttl=600)

# Получаем ключи из переменной окружения
keys = os.getenv("keys", "").split(',')
if not keys:
    raise ValueError("Environment variable 'keys' must be set with a comma-separated list of API keys.")

def get_random_key():
    return random.choice(keys)

def generate_cache_key(prompt, width, height, seed, model_name):
    # Создаем уникальный ключ на основе всех параметров
    return hashlib.md5(f"{prompt}_{width}_{height}_{seed}_{model_name}".encode()).hexdigest()


def scale_dimensions(width, height, max_width, max_height):
    """Масштабирует размеры изображения, сохраняя соотношение сторон, и округляет до чисел, кратных 8."""
    aspect_ratio = width / height
    if width > max_width or height > max_height:
        if width / max_width > height / max_height:
            width = max_width
            height = int(width / aspect_ratio)
        else:
            height = max_height
            width = int(height * aspect_ratio)
    
    # Округляем до ближайших чисел, кратных 8
    width = (width + 3) // 8 * 8
    height = (height + 3) // 8 * 8
    return width, height

@cached(cache, key=lambda prompt, width, height, seed, model_name: generate_cache_key(prompt, width, height, seed, model_name))
def generate_cached_image(prompt, width, height, seed, model_name, api_key):
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    data = {
        "inputs": prompt,
        "parameters": {
            "width": width,
            "height": height,
            "seed": seed
        }
    }

    try:
        response = requests.post(
            f"https://api-inference.huggingface.co/models/{model_name}",
            headers=headers,
            json=data,
            timeout=1550  # Таймаут 3 минуты
        )
        response.raise_for_status()
        image_data = response.content
        image = Image.open(io.BytesIO(image_data))
        return image
    except requests.exceptions.HTTPError as http_err:
        app.logger.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
        return None
    except requests.exceptions.Timeout as timeout_err:
        app.logger.error(f"Timeout error occurred: {timeout_err}")
        return None
    except requests.exceptions.RequestException as req_err:
        app.logger.error(f"Request error occurred: {req_err}")
        return None

@app.route('/prompt/<path:prompt>')
def get_image(prompt):
    width = request.args.get('width', type=int, default=512)
    height = request.args.get('height', type=int, default=512)
    seed = request.args.get('seed', type=int, default=25)
    model_name = request.args.get('model', default="black-forest-labs/FLUX.1-schnell").replace('+', '/')
    api_key = request.args.get('key', default=None)

    # Декодируем URL-кодированный prompt
    prompt = urllib.parse.unquote(prompt)

    # Определяем язык промпта
    try:
        language = detect(prompt)
    except Exception as e:
        app.logger.error(f"Error detecting language: {e}")
        return send_error_image()

    # Переводим промпт, если он не на английском языке
    if language != 'en':
        try:
            translator = GoogleTranslator(source=language, target='en')
            prompt = translator.translate(prompt)
        except Exception as e:
            app.logger.error(f"Error translating prompt: {e}")
            return send_error_image()

    # Масштабируем размеры изображения, если они превышают максимальные значения, и округляем до чисел, кратных 8
    width, height = scale_dimensions(width, height, MAX_WIDTH, MAX_HEIGHT)

    # Используем указанный ключ, если он предоставлен, иначе выбираем случайный ключ
    if api_key is None:
        api_key = get_random_key()

    try:
        image = generate_cached_image(prompt, width, height, seed, model_name, api_key)
        if image is None:
            return send_error_image()
    except Exception as e:
        app.logger.error(f"Error generating image: {e}")
        return send_error_image()

    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()

    return send_file(
        io.BytesIO(img_byte_arr),
        mimetype='image/png'
    )

@app.route('/')
def health_check():
    return "OK", 200

def send_error_image():
    error_image_url = "https://raw.githubusercontent.com/Igroshka/-/refs/heads/main/img/nuai/errorimg.png"
    try:
        response = requests.get(error_image_url)
        response.raise_for_status()
        error_image = Image.open(io.BytesIO(response.content))
        img_byte_arr = io.BytesIO()
        error_image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        return send_file(
            io.BytesIO(img_byte_arr),
            mimetype='image/png'
        )
    except Exception as e:
        app.logger.error(f"Error fetching error image: {e}")
        abort(500, description="Error fetching error image")

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=False)