import streamlit as st
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from io import BytesIO

# Title and description
st.title("🎨 Automatic Color Palette Generator")
st.write("Upload an image to extract its dominant colors.")

# File uploader for user to upload an image
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "png", "jpeg"])

if uploaded_file:
    # Read and display the image
    file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
    image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    st.image(image, caption="Uploaded Image", use_container_width=True)

    # User selects the number of colors to extract
    num_colors = st.slider("Select number of colors", min_value=3, max_value=10, value=5)
    
    if st.button("Extract Colors"):
        # Reshape the image into a 2D array of pixels
        image_reshape = image.reshape((-1, 3))
        
        # Apply K-Means clustering to find dominant colors
        kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
        kmeans.fit(image_reshape)
        colors = kmeans.cluster_centers_.astype(int)

        # Display extracted colors as a color bar
        st.subheader("Extracted Colors")
        fig, ax = plt.subplots(figsize=(num_colors, 1))
        ax.imshow([colors / 255])
        ax.set_xticks([])
        ax.set_yticks([])
        st.pyplot(fig)

        # Convert colors to HEX format and display RGB values
        hex_colors = ['#{:02x}{:02x}{:02x}'.format(*color) for color in colors]
        for i, hex_color in enumerate(hex_colors):
            st.markdown(f"**Color {i+1}:** `{hex_color}` (RGB: {tuple(colors[i])})")

        # Allow user to download the color palette as a text file
        if st.button("Download Palette as TXT"):
            palette_text = "\n".join([f"{hex_color} - RGB{tuple(colors[i])}" for i, hex_color in enumerate(hex_colors)])
            b = BytesIO()
            b.write(palette_text.encode())
            b.seek(0)
            st.download_button("Download Palette", b, file_name="color_palette.txt", mime="text/plain")