import gradio as gr
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans
def extract_colors(img, num_colors):
if img is None:
return "
No image uploaded.
"
# Resize image for faster processing
img_resized = img.resize((150, 150))
data = np.array(img_resized)
data = data.reshape(-1, 3)
# Remove any grayscale or alpha if present
if data.shape[1] == 4:
data = data[:, :3]
# Fit KMeans
kmeans = KMeans(n_clusters=num_colors, random_state=42)
kmeans.fit(data)
# Get cluster centers
colors = kmeans.cluster_centers_.round().astype(int)
# Convert to hex
hex_colors = ['#' + ''.join(f'{c:02x}' for c in color) for color in colors]
# Generate HTML for palette
html = 'Color Palette
'
for hex_color in hex_colors:
html += f'''
{hex_color}
'''
html += '
'
return html
with gr.Blocks(title="Image Color Palette Extractor") as demo:
gr.Markdown("# Image Color Palette Extractor")
gr.Markdown("Upload an image to extract the main colors and generate a palette.")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Upload Image")
num_colors = gr.Slider(2, 12, value=6, step=1, label="Number of Colors")
extract_btn = gr.Button("Extract Colors", variant="primary")
with gr.Column(scale=2):
palette_output = gr.HTML()
extract_btn.click(
extract_colors,
inputs=[input_img, num_colors],
outputs=palette_output
)
input_img.change(
extract_colors,
inputs=[input_img, num_colors],
outputs=palette_output
)
demo.queue()
demo.launch()