#!/usr/bin/env python3
"""
Image Consistency Checker for Broken Spire
Uses perceptual hashing to compare generated images with reference images
"""

import os
import sys
from pathlib import Path
from PIL import Image
import numpy as np

# Configuration
GENERATED_DIR = "/mnt/c/Users/fbmor/broken-spire-comparison/generated"
REFERENCES_DIR = "/mnt/c/Users/fbmor/broken-spire-comparison/references"
OUTPUT_DIR = "/mnt/c/Users/fbmor/broken-spire-comparison/results"

# Character mapping: generated frame -> reference(s)
CHARACTER_MAP = {
    "02_ash_birth_00002_": ["Ash.png", "Ash 2.png"],
    "03_far_future_ash_evil_00002_": ["Far-Future Ash.png"],
    "05_everly_soldier_00002_": ["Everly.png"],
    "06_eva_doctor_00002_": ["Éva Moreau.png"],
    "07_nova_warrior_00002_": ["Nova Human.png", "Nova devil.png"],
    "08_violet_devil_00002_": ["Violet humaine.png", "Violet Devil.png"],
    "09_lin_weishan_00002_": ["Lin Weishan.png"],
    "10_tc23_esper_00002_": ["TC-23.png", "Esper.png"],
    "11_jonas_ghost_00002_": ["Jonas.png"],
    "13_ash_everly_tension_00002_": ["Everly.png"],
    "14_ash_eva_medical_00002_": ["Éva Moreau.png"],
    "15_ash_nova_war_00002_": ["Nova Human.png"],
    "17_all_characters_triangle_00002_": ["Ash.png", "Everly.png", "Éva Moreau.png", "Nova Human.png"],
}

def load_image(path):
    """Load and resize image for comparison"""
    try:
        img = Image.open(path).convert('RGB')
        img = img.resize((256, 256))  # Standardize size
        return np.array(img)
    except Exception as e:
        print(f"Error loading {path}: {e}")
        return None

def calculate_features(img_array):
    """Calculate simple visual features for comparison"""
    # Color histogram
    hist_r = np.histogram(img_array[:,:,0], bins=16, range=(0,256))[0]
    hist_g = np.histogram(img_array[:,:,1], bins=16, range=(0,256))[0]
    hist_b = np.histogram(img_array[:,:,2], bins=16, range=(0,256))[0]
    
    # Normalize
    hist_r = hist_r / hist_r.sum()
    hist_g = hist_g / hist_g.sum()
    hist_b = hist_b / hist_b.sum()
    
    # Average color
    avg_color = img_array.mean(axis=(0,1)) / 255.0
    
    # Brightness
    brightness = img_array.mean() / 255.0
    
    return np.concatenate([hist_r, hist_g, hist_b, avg_color, [brightness]])

def compare_images(img1_path, img2_path):
    """Compare two images and return similarity score"""
    img1 = load_image(img1_path)
    img2 = load_image(img2_path)
    
    if img1 is None or img2 is None:
        return 0.0
    
    features1 = calculate_features(img1)
    features2 = calculate_features(img2)
    
    # Cosine similarity
    dot = np.dot(features1, features2)
    norm1 = np.linalg.norm(features1)
    norm2 = np.linalg.norm(features2)
    
    if norm1 == 0 or norm2 == 0:
        return 0.0
    
    similarity = dot / (norm1 * norm2)
    
    # Also check color distribution similarity
    color_sim = 1 - np.abs(features1[:48] - features2[:48]).mean()
    
    return (similarity + color_sim) / 2

def analyze_character_consistency():
    """Analyze all characters and report consistency"""
    print("=" * 70)
    print("BROKEN SPIRE - CHARACTER CONSISTENCY ANALYSIS")
    print("=" * 70)
    
    results = []
    
    for frame_name, ref_names in CHARACTER_MAP.items():
        gen_path = os.path.join(GENERATED_DIR, f"{frame_name}.png")
        
        print(f"\n🔍 {frame_name}")
        print(f"   Generated: {os.path.basename(gen_path)}")
        
        if not os.path.exists(gen_path):
            print(f"   ❌ GENERATED FILE MISSING")
            results.append((frame_name, "MISSING", 0))
            continue
        
        best_match = 0
        best_ref = None
        
        for ref_name in ref_names:
            ref_path = os.path.join(REFERENCES_DIR, ref_name)
            if os.path.exists(ref_path):
                similarity = compare_images(gen_path, ref_path)
                print(f"   vs {ref_name}: {similarity:.2%}")
                
                if similarity > best_match:
                    best_match = similarity
                    best_ref = ref_name
        
        # Determine pass/fail
        if best_match >= 0.70:
            status = "✅ PASS"
            results.append((frame_name, "PASS", best_match))
        elif best_match >= 0.50:
            status = "⚠️  PARTIAL"
            results.append((frame_name, "PARTIAL", best_match))
        else:
            status = "❌ FAIL"
            results.append((frame_name, "FAIL", best_match))
        
        print(f"   Best match: {best_ref} ({best_match:.2%}) {status}")
    
    return results

def generate_report(results):
    """Generate detailed report"""
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    
    passed = sum(1 for r in results if r[1] == "PASS")
    partial = sum(1 for r in results if r[1] == "PARTIAL")
    failed = sum(1 for r in results if r[1] == "FAIL")
    missing = sum(1 for r in results if r[1] == "MISSING")
    
    print(f"\nTotal frames analyzed: {len(results)}")
    print(f"✅ PASS (>=70% match): {passed}")
    print(f"⚠️  PARTIAL (50-70%):   {partial}")
    print(f"❌ FAIL (<50%):        {failed}")
    print(f"❌ MISSING:            {missing}")
    
    # List failed frames for regeneration
    print("\n" + "=" * 70)
    print("FRAMES NEEDING REGENERATION (with IP-Adapter)")
    print("=" * 70)
    
    for frame, status, score in results:
        if status in ["FAIL", "PARTIAL", "MISSING"]:
            print(f"  - {frame} ({score:.1%})")
    
    print("\n" + "=" * 70)
    print("RECOMMENDATIONS")
    print("=" * 70)
    print("""
For frames that failed the consistency check:

1. Open ComfyUI: https://mu1aafo3zo4da1-8188.proxy.runpod.net
2. Load the reference image using IP-Adapter:
   - Add IPAdapter Unified Loader
   - Add IPAdapter Advanced + Load Image
   - Connect reference image
3. Use the same prompt from broken-spire-frames.md
4. Regenerate and replace the failing frame

This will ensure the generated character matches the reference exactly.
""")
    
    # Save to file
    with open(os.path.join(OUTPUT_DIR, "consistency_report.txt"), "w") as f:
        f.write("BROKEN SPIRE - CHARACTER CONSISTENCY REPORT\n")
        f.write("=" * 50 + "\n\n")
        for frame, status, score in results:
            f.write(f"{frame}: {status} ({score:.1%})\n")

if __name__ == "__main__":
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    results = analyze_character_consistency()
    generate_report(results)