#!/usr/bin/env -S uv run
# /// script
# dependencies = [
#   "google-genai>=1.67.0",
#   "pillow>=11.0.0",
# ]
# ///
"""Benchmark 2K versus 4K packshot generation and build a comparison image.

Usage, run from project root:
  uv run ~/.pi/agent/skills/packshot/scripts/benchmark-packshot-size.py brands/[brand]/packshots/[output-name]

Outputs:
  [output-name]_2k_vN.png
  [output-name]_4k_vN.png
  [output-name]_2k-vs-4k_vN.png
  size-benchmark.json
"""
from __future__ import annotations

import importlib.util
import json
import sys
from pathlib import Path
from typing import Any

from PIL import Image, ImageDraw, ImageFont


FONT_BOLD = "/System/Library/Fonts/Supplemental/Arial Bold.ttf"
FONT_REGULAR = "/System/Library/Fonts/Supplemental/Arial.ttf"


def load_generate_module():
    script_path = Path(__file__).with_name("generate-packshot.py")
    spec = importlib.util.spec_from_file_location("generate_packshot_module", script_path)
    if spec is None or spec.loader is None:
        raise SystemExit(f"Could not import {script_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def font(size: int, bold: bool = False) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
    path = FONT_BOLD if bold else FONT_REGULAR
    try:
        return ImageFont.truetype(path, size=size)
    except Exception:
        return ImageFont.load_default()


def draw_centered(draw: ImageDraw.ImageDraw, xy: tuple[int, int], text: str, fnt, fill=(20, 20, 20)) -> None:
    x, y = xy
    bbox = draw.textbbox((0, 0), text, font=fnt)
    draw.text((x - (bbox[2] - bbox[0]) // 2, y), text, font=fnt, fill=fill)


def fit_image(image: Image.Image, size: tuple[int, int]) -> Image.Image:
    panel_w, panel_h = size
    src_w, src_h = image.size
    scale = min(panel_w / src_w, panel_h / src_h)
    new_size = (max(1, round(src_w * scale)), max(1, round(src_h * scale)))
    resample = Image.Resampling.LANCZOS if scale < 1 else Image.Resampling.BICUBIC
    return image.resize(new_size, resample)


def next_comparison_path(output_dir: Path, output_name: str) -> Path:
    version = 1
    while (output_dir / f"{output_name}_2k-vs-4k_v{version}.png").exists():
        version += 1
    return output_dir / f"{output_name}_2k-vs-4k_v{version}.png"


def make_comparison(output_dir: Path, output_name: str, results: list[dict[str, Any]], original_path: Path) -> Path:
    by_size = {r["image_size"].upper(): r for r in results}
    left_result = by_size["2K"]
    right_result = by_size["4K"]
    left = Image.open(left_result["output_path"]).convert("RGB")
    right = Image.open(right_result["output_path"]).convert("RGB")
    original = Image.open(original_path).convert("RGB")

    panel_w = max(left.width, right.width)
    panel_h = max(left.height, right.height)
    pad = max(56, panel_w // 48)
    label_h = max(260, panel_h // 13)
    canvas_w = panel_w * 2 + pad * 3
    canvas_h = panel_h + label_h + pad * 2
    canvas = Image.new("RGB", (canvas_w, canvas_h), (248, 248, 246))
    draw = ImageDraw.Draw(canvas)

    title_font = font(max(40, panel_w // 42), bold=True)
    meta_font = font(max(28, panel_w // 64), bold=False)
    small_font = font(max(22, panel_w // 82), bold=True)

    panel_y = label_h + pad
    positions = [(pad, panel_y), (pad * 2 + panel_w, panel_y)]
    for x, y in positions:
        draw.rectangle((x - 2, y - 2, x + panel_w + 2, y + panel_h + 2), outline=(210, 210, 205), width=4)
        draw.rectangle((x, y, x + panel_w, y + panel_h), fill=(255, 255, 255))

    left_fit = fit_image(left, (panel_w, panel_h))
    right_fit = fit_image(right, (panel_w, panel_h))
    canvas.paste(left_fit, (positions[0][0] + (panel_w - left_fit.width) // 2, panel_y + (panel_h - left_fit.height) // 2))
    canvas.paste(right_fit, (positions[1][0] + (panel_w - right_fit.width) // 2, panel_y + (panel_h - right_fit.height) // 2))

    def dims_text(result: dict[str, Any]) -> str:
        dims = result.get("dimensions") or []
        dims_s = f"{dims[0]}x{dims[1]}" if len(dims) == 2 else "unknown dims"
        return f"{dims_s} • {result.get('duration_seconds', 0):.1f}s"

    left_center = positions[0][0] + panel_w // 2
    right_center = positions[1][0] + panel_w // 2
    draw_centered(draw, (left_center, pad), "2K", title_font)
    draw_centered(draw, (left_center, pad + max(58, panel_w // 38)), dims_text(left_result), meta_font, fill=(70, 70, 70))
    draw_centered(draw, (right_center, pad), "4K", title_font)
    draw_centered(draw, (right_center, pad + max(58, panel_w // 38)), dims_text(right_result), meta_font, fill=(70, 70, 70))

    # Keep the original reference large enough to be useful when the comparison canvas is 4K-wide.
    thumb_max = (max(420, panel_w // 5), max(640, panel_h // 5))
    thumb = original.copy()
    thumb.thumbnail(thumb_max, Image.Resampling.LANCZOS)
    thumb_x = canvas_w - pad - thumb.width
    thumb_y = pad // 2
    draw.rectangle((thumb_x - 8, thumb_y - 8, thumb_x + thumb.width + 8, thumb_y + thumb.height + 42), fill=(255, 255, 255), outline=(180, 180, 176), width=3)
    canvas.paste(thumb, (thumb_x, thumb_y))
    draw_centered(draw, (thumb_x + thumb.width // 2, thumb_y + thumb.height + 8), "original", small_font, fill=(35, 35, 35))

    comparison_path = next_comparison_path(output_dir, output_name)
    canvas.save(comparison_path, format="PNG", optimize=True)
    return comparison_path


def write_metrics(output_dir: Path, output_name: str, results: list[dict[str, Any]], comparison_path: Path, original_path: Path) -> None:
    metrics_path = output_dir / "size-benchmark.json"
    record = {
        "output_name": output_name,
        "original_path": str(original_path),
        "comparison_path": str(comparison_path),
        "results": results,
    }
    if metrics_path.exists():
        try:
            data = json.loads(metrics_path.read_text())
            if not isinstance(data, list):
                data = [data]
        except Exception:
            data = []
    else:
        data = []
    data.append(record)
    metrics_path.write_text(json.dumps(data, indent=2))


def main() -> None:
    args = [arg for arg in sys.argv[1:] if arg != "--"]
    if "-h" in args or "--help" in args or len(args) != 1:
        print(__doc__.strip())
        raise SystemExit(0 if ("-h" in args or "--help" in args) else 2)

    output_dir = Path(args[0]).expanduser().resolve()
    gen = load_generate_module()
    spec_data = gen.load_spec(output_dir)
    _, _, product_images, output_name, _ = gen.validate_spec(spec_data, output_dir)
    original_path = product_images[0]

    results: list[dict[str, Any]] = []
    for size in ["2K", "4K"]:
        result = gen.generate_packshot(output_dir, image_size=size, tag=size.lower(), quiet=False)
        results.append(result)

    comparison_path = make_comparison(output_dir, output_name, results, original_path)
    write_metrics(output_dir, output_name, results, comparison_path, original_path)

    print()
    print(f"Comparison: {comparison_path}")
    print(f"Metrics: {output_dir / 'size-benchmark.json'}")


if __name__ == "__main__":
    main()
