#!/usr/bin/env -S uv run
# /// script
# dependencies = [
#   "google-genai>=1.67.0",
#   "pillow>=11.0.0",
# ]
# ///
"""generate-packshot.py - Packshot generation pipeline using Gemini Nano Banana.

Usage, run from project root:
  uv run ~/.pi/agent/skills/packshot/scripts/generate-packshot.py brands/[brand]/packshots/[output-name]
  uv run ~/.pi/agent/skills/packshot/scripts/generate-packshot.py brands/[brand]/packshots/[output-name] --size 4K --tag 4k
  uv run ~/.pi/agent/skills/packshot/scripts/generate-packshot.py brands/[brand]/packshots/[output-name] --dry-run

Reads packshot-spec.json from the output folder.
Output: [output-name]_v1.png, or [output-name]_[tag]_v1.png when --tag is provided.

This follows the local `nb` Nano Banana script pattern:
- Gemini API key from macOS Keychain service `gemini-api-key`
- model `gemini-3-pro-image-preview`
- image references attached with google-genai Parts
"""
from __future__ import annotations

import io
import json
import mimetypes
import os
import subprocess
import sys
import time
from pathlib import Path
from typing import Any

from google import genai
from google.genai import types
from PIL import Image


MODEL = "gemini-3-pro-image-preview"
KEYCHAIN_SERVICE = "gemini-api-key"
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp"}
HEIC_EXTENSIONS = {".heic", ".heif"}
VALID_IMAGE_SIZES = {"1K", "2K", "4K"}


def c(color: str, text: str, enabled: bool = True) -> str:
    if not enabled:
        return text
    colors = {
        "green": "\033[32m",
        "yellow": "\033[33m",
        "red": "\033[31m",
        "cyan": "\033[36m",
        "bold": "\033[1m",
        "reset": "\033[0m",
    }
    return f"{colors.get(color, '')}{text}{colors['reset']}"


def get_api_key() -> str:
    """Fetch Gemini API key from macOS Keychain, same as /opt/homebrew/bin/nb."""
    result = subprocess.run(
        ["security", "find-generic-password", "-s", KEYCHAIN_SERVICE, "-w"],
        capture_output=True,
        text=True,
        check=False,
    )
    if result.returncode == 0 and result.stdout.strip():
        return result.stdout.strip()

    key = os.environ.get("GEMINI_API_KEY", "").strip()
    if key:
        return key

    raise SystemExit(
        f"Error: no Gemini API key found in Keychain service '{KEYCHAIN_SERVICE}' or $GEMINI_API_KEY."
    )


def load_spec(output_dir: Path) -> dict[str, Any]:
    spec_path = output_dir / "packshot-spec.json"
    if not spec_path.exists():
        raise SystemExit(f"Error: packshot-spec.json not found at {spec_path}")
    try:
        return json.loads(spec_path.read_text())
    except json.JSONDecodeError as exc:
        raise SystemExit(f"Error: invalid JSON in {spec_path}: {exc}") from exc


def resolve_image_path(raw: str, output_dir: Path, project_root: Path) -> Path:
    p = Path(raw).expanduser()
    candidates = [p] if p.is_absolute() else [project_root / p, output_dir / p, output_dir.parent / p]
    for candidate in candidates:
        if candidate.exists():
            return candidate.resolve()
    raise SystemExit(f"Error: product image not found: {raw}")


def convert_heic_to_jpeg(path: Path) -> Path:
    converted_dir = path.parent / ".packshot-converted"
    converted_dir.mkdir(exist_ok=True)
    out = converted_dir / f"{path.stem}.jpg"
    if out.exists() and out.stat().st_mtime >= path.stat().st_mtime:
        return out
    subprocess.run(
        ["sips", "-s", "format", "jpeg", str(path), "--out", str(out)],
        check=True,
        capture_output=True,
        text=True,
    )
    return out


def image_part(path: Path) -> types.Part:
    src = convert_heic_to_jpeg(path) if path.suffix.lower() in HEIC_EXTENSIONS else path
    mime = mimetypes.guess_type(str(src))[0] or "image/jpeg"
    return types.Part.from_bytes(data=src.read_bytes(), mime_type=mime)


def validate_spec(spec: dict[str, Any], output_dir: Path) -> tuple[str, str, list[Path], str, str]:
    prompt = str(spec.get("prompt", "")).strip()
    if not prompt:
        raise SystemExit("Error: packshot-spec.json has no prompt. Run the packshot skill first.")

    raw_images = spec.get("product_images", [])
    if not isinstance(raw_images, list) or not raw_images:
        raise SystemExit("Error: packshot-spec.json has no product_images.")

    project_root = Path.cwd().resolve()
    product_images = [resolve_image_path(str(p), output_dir, project_root) for p in raw_images]
    for p in product_images:
        suffix = p.suffix.lower()
        if suffix not in IMAGE_EXTENSIONS and suffix not in HEIC_EXTENSIONS:
            raise SystemExit(f"Error: unsupported product image type: {p}")

    aspect_ratio = str(spec.get("aspect_ratio", "3:4")).strip() or "3:4"
    output_name = str(spec.get("output_name", output_dir.name)).strip() or output_dir.name
    spec_size = str(spec.get("image_size", "2K")).strip().upper() or "2K"
    if spec_size not in VALID_IMAGE_SIZES:
        spec_size = "2K"
    return prompt, aspect_ratio, product_images, output_name, spec_size


def next_output_path(output_dir: Path, base: str, tag: str | None = None) -> Path:
    clean_tag = tag.strip().replace(" ", "-") if tag else ""
    prefix = f"{base}_{clean_tag}" if clean_tag else base
    version = 1
    while (output_dir / f"{prefix}_v{version}.png").exists():
        version += 1
    return output_dir / f"{prefix}_v{version}.png"


def save_response_image(response: Any, out_path: Path) -> None:
    if not response.candidates:
        raise SystemExit("Error: Gemini returned no candidates.")
    for part in response.candidates[0].content.parts:
        if part.inline_data is not None:
            data = part.inline_data.data
            mime = part.inline_data.mime_type or ""
            if mime == "image/png":
                out_path.write_bytes(data)
                return
            image = Image.open(io.BytesIO(data))
            image.save(out_path, format="PNG")
            return
    raise SystemExit("Error: Gemini returned no image data.")


def image_dimensions(path: Path) -> tuple[int, int] | None:
    try:
        image = Image.open(path)
        return image.width, image.height
    except Exception:
        return None


def generate_packshot(
    output_dir: Path,
    image_size: str | None = None,
    tag: str | None = None,
    dry_run: bool = False,
    quiet: bool = False,
    color: bool = True,
) -> dict[str, Any]:
    output_dir = output_dir.expanduser().resolve()
    spec = load_spec(output_dir)
    prompt, aspect_ratio, product_images, output_name, spec_size = validate_spec(spec, output_dir)
    requested_size = (image_size or spec_size).upper()
    if requested_size not in VALID_IMAGE_SIZES:
        raise SystemExit(f"Error: image size must be one of {sorted(VALID_IMAGE_SIZES)}")

    out_path = next_output_path(output_dir, output_name, tag=tag)

    if not quiet:
        print()
        print(c("bold", f"{output_name} packshot", color))
        print(c("cyan", "  Model  : ", color) + MODEL)
        print(c("cyan", "  Ratio  : ", color) + aspect_ratio)
        print(c("cyan", "  Size   : ", color) + requested_size)
        print(c("cyan", "  Output : ", color) + str(out_path))
        print(c("cyan", "  Refs   : ", color) + str(len(product_images)))
        for p in product_images:
            print(f"    - {p}")

    if dry_run:
        if not quiet:
            print()
            print(c("green", "Dry run ok. No generation performed.", color))
        return {
            "status": "dry_run",
            "output_path": str(out_path),
            "image_size": requested_size,
            "aspect_ratio": aspect_ratio,
            "product_images": [str(p) for p in product_images],
            "duration_seconds": 0,
        }

    api_key = get_api_key()
    client = genai.Client(api_key=api_key)

    contents: list[Any] = [prompt]
    contents.extend(image_part(p) for p in product_images)

    if not quiet:
        print()
        print(c("yellow", f"Generating {requested_size} packshot with Gemini Nano Banana...", color))
    start = time.perf_counter()
    response = client.models.generate_content(
        model=MODEL,
        contents=contents,
        config=types.GenerateContentConfig(
            response_modalities=["IMAGE", "TEXT"],
            image_config=types.ImageConfig(
                aspect_ratio=aspect_ratio,
                image_size=requested_size,
            ),
        ),
    )
    duration = time.perf_counter() - start

    save_response_image(response, out_path)
    dims = image_dimensions(out_path)

    if not quiet:
        print(c("green", f"Saved: {out_path}", color))
        if dims:
            print(c("cyan", "  Dims   : ", color) + f"{dims[0]}x{dims[1]}")
        print(c("cyan", "  Time   : ", color) + f"{duration:.1f}s")

    return {
        "status": "ok",
        "output_path": str(out_path),
        "image_size": requested_size,
        "aspect_ratio": aspect_ratio,
        "dimensions": dims,
        "product_images": [str(p) for p in product_images],
        "duration_seconds": duration,
        "model": MODEL,
    }


def main() -> None:
    args = [arg for arg in sys.argv[1:] if arg != "--"]
    dry_run = False
    as_json = False
    quiet = False
    tag: str | None = None
    image_size: str | None = None

    def pop_value(flag: str) -> str | None:
        nonlocal args
        if flag not in args:
            return None
        i = args.index(flag)
        if i + 1 >= len(args):
            raise SystemExit(f"Error: {flag} requires a value")
        value = args[i + 1]
        del args[i : i + 2]
        return value

    if "--dry-run" in args:
        dry_run = True
        args.remove("--dry-run")
    if "--json" in args:
        as_json = True
        args.remove("--json")
    if "--quiet" in args:
        quiet = True
        args.remove("--quiet")
    tag = pop_value("--tag")
    image_size = pop_value("--size")

    if "-h" in args or "--help" in args or len(args) != 1:
        print(__doc__.strip())
        raise SystemExit(0 if ("-h" in args or "--help" in args) else 2)

    result = generate_packshot(
        Path(args[0]),
        image_size=image_size,
        tag=tag,
        dry_run=dry_run,
        quiet=quiet or as_json,
        color=not as_json,
    )
    if as_json:
        print(json.dumps(result, indent=2))


if __name__ == "__main__":
    main()
