#!/usr/bin/env -S uv run
# /// script
# dependencies = [
#   "google-genai>=1.67.0",
#   "pillow>=11.0.0",
# ]
# ///
"""Generate original packshot staging variants and a labeled comparison sheet.

Variants are the three original packshot workflow archetypes:
  - ghost: ghost-mannequin packshot
  - flatlay: top-down flat-lay packshot
  - folded: stylized folded/streetwear packshot

Usage, run from project root:
  uv run skills/references/generate-packshot-variants.py brands/[brand]/packshots/[output-name]
  uv run skills/references/generate-packshot-variants.py brands/[brand]/packshots/[output-name] --size 4K

Outputs:
  [output-name]_ghost_2k_vN.png
  [output-name]_flatlay_2k_vN.png
  [output-name]_folded_2k_vN.png
  [output-name]_staging-variants_2k_vN.png
  staging-variants.json
"""
from __future__ import annotations

import importlib.util
import json
import sys
import time
from pathlib import Path
from typing import Any

from google import genai
from google.genai import types
from PIL import Image, ImageDraw, ImageFont


FONT_BOLD = "/System/Library/Fonts/Supplemental/Arial Bold.ttf"
FONT_REGULAR = "/System/Library/Fonts/Supplemental/Arial.ttf"

VARIANTS = {
    "ghost": {
        "label": "Ghost mannequin",
        "staging": "Produce a high-end commercial ghost-mannequin packshot, showing the item in a three-dimensional, hollow-body floating state.",
    },
    "flatlay": {
        "label": "Top-down flat-lay",
        "staging": "Produce a clean, professional top-down flat-lay packshot, oriented with surgical precision for e-commerce.",
    },
    "folded": {
        "label": "Folded/styled",
        "staging": "Produce a stylized streetwear-inspired folded packshot, emphasizing the item's texture and silhouette in a relaxed studio setting.",
    },
}
DEFAULT_VARIANTS = ["ghost", "flatlay", "folded"]
VALID_IMAGE_SIZES = {"1K", "2K", "4K"}


def load_generate_module():
    script_path = Path(__file__).with_name("generate-packshot.py")
    spec = importlib.util.spec_from_file_location("generate_packshot_module", script_path)
    if spec is None or spec.loader is None:
        raise SystemExit(f"Could not import {script_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def font(size: int, bold: bool = False) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
    path = FONT_BOLD if bold else FONT_REGULAR
    try:
        return ImageFont.truetype(path, size=size)
    except Exception:
        return ImageFont.load_default()


def draw_centered(draw: ImageDraw.ImageDraw, xy: tuple[int, int], text: str, fnt, fill=(20, 20, 20)) -> None:
    x, y = xy
    bbox = draw.textbbox((0, 0), text, font=fnt)
    draw.text((x - (bbox[2] - bbox[0]) // 2, y), text, font=fnt, fill=fill)


def fit_image(image: Image.Image, size: tuple[int, int]) -> Image.Image:
    panel_w, panel_h = size
    src_w, src_h = image.size
    scale = min(panel_w / src_w, panel_h / src_h)
    new_size = (max(1, round(src_w * scale)), max(1, round(src_h * scale)))
    resample = Image.Resampling.LANCZOS if scale < 1 else Image.Resampling.BICUBIC
    return image.resize(new_size, resample)


def prompt_body(prompt: str) -> str:
    stripped = prompt.strip()
    if "\n\n" in stripped:
        return stripped.split("\n\n", 1)[1]
    if "\n" in stripped:
        return stripped.split("\n", 1)[1]
    return stripped


def variant_prompt(base_prompt: str, variant_key: str) -> str:
    body = prompt_body(base_prompt)
    return VARIANTS[variant_key]["staging"] + "\n\n" + body


def next_path(output_dir: Path, output_name: str, suffix: str) -> Path:
    version = 1
    while (output_dir / f"{output_name}_{suffix}_v{version}.png").exists():
        version += 1
    return output_dir / f"{output_name}_{suffix}_v{version}.png"


def generate_variant(gen: Any, output_dir: Path, spec_data: dict[str, Any], variant_key: str, image_size: str) -> dict[str, Any]:
    base_prompt, aspect_ratio, product_images, output_name, _ = gen.validate_spec(spec_data, output_dir)
    prompt = variant_prompt(base_prompt, variant_key)
    out_path = next_path(output_dir, output_name, f"{variant_key}_{image_size.lower()}")

    print()
    print(f"{output_name} - {VARIANTS[variant_key]['label']} ({image_size})")
    print(f"  Output: {out_path}")
    print(f"  Refs: {len(product_images)}")

    client = genai.Client(api_key=gen.get_api_key())
    contents: list[Any] = [prompt]
    contents.extend(gen.image_part(p) for p in product_images)

    start = time.perf_counter()
    response = client.models.generate_content(
        model=gen.MODEL,
        contents=contents,
        config=types.GenerateContentConfig(
            response_modalities=["IMAGE", "TEXT"],
            image_config=types.ImageConfig(aspect_ratio=aspect_ratio, image_size=image_size),
        ),
    )
    duration = time.perf_counter() - start
    gen.save_response_image(response, out_path)
    dims = gen.image_dimensions(out_path)

    print(f"  Saved: {out_path.name}")
    print(f"  Dims: {dims[0]}x{dims[1]}" if dims else "  Dims: unknown")
    print(f"  Time: {duration:.1f}s")

    return {
        "variant": variant_key,
        "label": VARIANTS[variant_key]["label"],
        "image_size": image_size,
        "output_path": str(out_path),
        "dimensions": dims,
        "duration_seconds": duration,
        "aspect_ratio": aspect_ratio,
        "product_images": [str(p) for p in product_images],
        "model": gen.MODEL,
    }


def make_comparison(output_dir: Path, output_name: str, results: list[dict[str, Any]], original_path: Path, image_size: str) -> Path:
    images = [Image.open(r["output_path"]).convert("RGB") for r in results]
    original = Image.open(original_path).convert("RGB")
    panel_w = max(i.width for i in images)
    panel_h = max(i.height for i in images)
    pad = max(54, panel_w // 52)
    label_h = max(250, panel_h // 13)
    canvas_w = panel_w * len(images) + pad * (len(images) + 1)
    canvas_h = panel_h + label_h + pad * 2
    canvas = Image.new("RGB", (canvas_w, canvas_h), (248, 248, 246))
    draw = ImageDraw.Draw(canvas)

    title_font = font(max(40, panel_w // 42), bold=True)
    meta_font = font(max(26, panel_w // 68), bold=False)
    small_font = font(max(22, panel_w // 82), bold=True)

    panel_y = label_h + pad
    for i, (img, result) in enumerate(zip(images, results)):
        x = pad + i * (panel_w + pad)
        draw.rectangle((x - 2, panel_y - 2, x + panel_w + 2, panel_y + panel_h + 2), outline=(210, 210, 205), width=4)
        draw.rectangle((x, panel_y, x + panel_w, panel_y + panel_h), fill=(255, 255, 255))
        fitted = fit_image(img, (panel_w, panel_h))
        canvas.paste(fitted, (x + (panel_w - fitted.width) // 2, panel_y + (panel_h - fitted.height) // 2))

        center = x + panel_w // 2
        dims = result.get("dimensions") or []
        dims_s = f"{dims[0]}x{dims[1]}" if len(dims) == 2 else "unknown dims"
        draw_centered(draw, (center, pad), result["label"], title_font)
        draw_centered(draw, (center, pad + max(58, panel_w // 38)), f"{image_size} • {dims_s} • {result['duration_seconds']:.1f}s", meta_font, fill=(70, 70, 70))

    thumb_max = (max(420, panel_w // 5), max(640, panel_h // 5))
    thumb = original.copy()
    thumb.thumbnail(thumb_max, Image.Resampling.LANCZOS)
    thumb_x = canvas_w - pad - thumb.width
    thumb_y = pad // 2
    draw.rectangle((thumb_x - 8, thumb_y - 8, thumb_x + thumb.width + 8, thumb_y + thumb.height + 42), fill=(255, 255, 255), outline=(180, 180, 176), width=3)
    canvas.paste(thumb, (thumb_x, thumb_y))
    draw_centered(draw, (thumb_x + thumb.width // 2, thumb_y + thumb.height + 8), "original", small_font, fill=(35, 35, 35))

    comparison_path = next_path(output_dir, output_name, f"staging-variants_{image_size.lower()}")
    canvas.save(comparison_path, format="PNG", optimize=True)
    return comparison_path


def append_metrics(output_dir: Path, output_name: str, results: list[dict[str, Any]], comparison_path: Path, original_path: Path) -> None:
    metrics_path = output_dir / "staging-variants.json"
    record = {
        "output_name": output_name,
        "original_path": str(original_path),
        "comparison_path": str(comparison_path),
        "results": results,
    }
    if metrics_path.exists():
        try:
            data = json.loads(metrics_path.read_text())
            if not isinstance(data, list):
                data = [data]
        except Exception:
            data = []
    else:
        data = []
    data.append(record)
    metrics_path.write_text(json.dumps(data, indent=2))


def main() -> None:
    args = [arg for arg in sys.argv[1:] if arg != "--"]
    image_size = "2K"
    variant_keys = DEFAULT_VARIANTS

    def pop_value(flag: str) -> str | None:
        nonlocal args
        if flag not in args:
            return None
        i = args.index(flag)
        if i + 1 >= len(args):
            raise SystemExit(f"Error: {flag} requires a value")
        value = args[i + 1]
        del args[i : i + 2]
        return value

    size_arg = pop_value("--size")
    if size_arg:
        image_size = size_arg.upper()
        if image_size not in VALID_IMAGE_SIZES:
            raise SystemExit(f"Error: --size must be one of {sorted(VALID_IMAGE_SIZES)}")

    variants_arg = pop_value("--variants")
    if variants_arg:
        variant_keys = [x.strip() for x in variants_arg.split(",") if x.strip()]
        invalid = [x for x in variant_keys if x not in VARIANTS]
        if invalid:
            raise SystemExit(f"Error: unknown variants {invalid}. Valid: {sorted(VARIANTS)}")

    if "-h" in args or "--help" in args or len(args) != 1:
        print(__doc__.strip())
        raise SystemExit(0 if ("-h" in args or "--help" in args) else 2)

    output_dir = Path(args[0]).expanduser().resolve()
    gen = load_generate_module()
    spec_data = gen.load_spec(output_dir)
    _, _, product_images, output_name, _ = gen.validate_spec(spec_data, output_dir)
    original_path = product_images[0]

    results = [generate_variant(gen, output_dir, spec_data, key, image_size) for key in variant_keys]
    comparison_path = make_comparison(output_dir, output_name, results, original_path, image_size)
    append_metrics(output_dir, output_name, results, comparison_path, original_path)

    print()
    print(f"Comparison: {comparison_path}")
    print(f"Metrics: {output_dir / 'staging-variants.json'}")


if __name__ == "__main__":
    main()
