#!/usr/bin/env python3
"""
Image Downscaling Script with S3 and Redis State Management

Usage: ./downscale.py <max-dim> <num-parallel>

Example: ./downscale.py 1920 4
"""

import sys
import os
from io import BytesIO
from urllib.parse import urlparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import redis
import boto3
from PIL import Image
from dotenv import load_dotenv
from tqdm import tqdm

# Load environment variables
load_dotenv()

# Initialize AWS S3 client
s3_client = boto3.client(
    's3',
    aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
    aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
    region_name='ap-south-1'  # Based on your example URL
)

# Initialize Redis client
redis_client = redis.Redis(host='localhost', port=6379, db=0, decode_responses=True)

# Redis keys for state management
COMPLETED_SET = "downscale:completed"
SKIPPED_SET = "downscale:skipped"

# Thread-safe counters
stats_lock = Lock()
stats = {
    'processed': 0,
    'skipped': 0,
    'completed': 0,
    'errors': 0,
    'already_done': 0
}


def parse_s3_url(url):
    """Parse S3 URL to extract bucket and key."""
    parsed = urlparse(url)
    
    # Handle s3://bucket/key format
    if parsed.scheme == 's3':
        bucket = parsed.netloc
        key = parsed.path.lstrip('/')
        return bucket, key
    
    # Handle https://bucket.s3.region.amazonaws.com/key format
    if 's3' in parsed.netloc and 'amazonaws.com' in parsed.netloc:
        # Extract bucket from hostname
        parts = parsed.netloc.split('.')
        if parts[0].endswith('-live') or parts[0].startswith('pixika'):
            bucket = parts[0]
        else:
            bucket = parts[0]
        key = parsed.path.lstrip('/')
        return bucket, key
    
    raise ValueError(f"Unable to parse S3 URL: {url}")


def download_image_from_s3(bucket, key):
    """Download image from S3 into memory."""
    try:
        response = s3_client.get_object(Bucket=bucket, Key=key)
        image_data = response['Body'].read()
        return BytesIO(image_data)
    except Exception as e:
        raise Exception(f"Failed to download from S3: {e}")


def upload_image_to_s3(bucket, key, image_bytes, content_type='image/png'):
    """Upload image to S3."""
    try:
        s3_client.put_object(
            Bucket=bucket,
            Key=key,
            Body=image_bytes,
            ContentType=content_type
        )
        return True
    except Exception as e:
        raise Exception(f"Failed to upload to S3: {e}")


def get_image_content_type(image_format):
    """Get content type from image format."""
    content_types = {
        'PNG': 'image/png',
        'JPEG': 'image/jpeg',
        'JPG': 'image/jpeg',
        'GIF': 'image/gif',
        'WEBP': 'image/webp',
        'BMP': 'image/bmp'
    }
    return content_types.get(image_format.upper(), 'image/png')


def downscale_image(image_bytes, max_dim):
    """
    Downscale image if needed.
    Returns (needs_downscaling, processed_image_bytes, original_size, new_size)
    """
    # Open image
    img = Image.open(image_bytes)
    original_format = img.format
    width, height = img.size
    max_current_dim = max(width, height)
    
    # Check if downscaling is needed
    if max_current_dim <= max_dim:
        return False, None, (width, height), None
    
    # Calculate new dimensions
    if width > height:
        new_width = max_dim
        new_height = int(height * (max_dim / width))
    else:
        new_height = max_dim
        new_width = int(width * (max_dim / height))
    
    # Downscale image
    img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    
    # Save to bytes
    output = BytesIO()
    # Preserve original format
    save_format = original_format if original_format else 'PNG'
    
    # Handle JPEG quality
    if save_format in ['JPEG', 'JPG']:
        img_resized.save(output, format=save_format, quality=90, optimize=True)
    else:
        img_resized.save(output, format=save_format, optimize=True)
    
    output.seek(0)
    
    return True, output, (width, height), (new_width, new_height)


def process_url(url, max_dim, progress_bar):
    """Process a single URL."""
    global stats
    
    try:
        # Check if already processed
        if redis_client.sismember(COMPLETED_SET, url):
            with stats_lock:
                stats['already_done'] += 1
            progress_bar.set_postfix({
                'completed': stats['completed'],
                'skipped': stats['skipped'],
                'errors': stats['errors'],
                'already_done': stats['already_done']
            })
            return {'status': 'already_completed', 'url': url}
        
        if redis_client.sismember(SKIPPED_SET, url):
            with stats_lock:
                stats['already_done'] += 1
            progress_bar.set_postfix({
                'completed': stats['completed'],
                'skipped': stats['skipped'],
                'errors': stats['errors'],
                'already_done': stats['already_done']
            })
            return {'status': 'already_skipped', 'url': url}
        
        # Parse S3 URL
        bucket, key = parse_s3_url(url)
        
        # Download image
        progress_bar.set_description(f"📥 Downloading {key[:50]}...")
        image_bytes = download_image_from_s3(bucket, key)
        
        # Check and downscale if needed
        progress_bar.set_description(f"🔍 Checking dimensions...")
        needs_downscaling, processed_image, original_size, new_size = downscale_image(image_bytes, max_dim)
        
        if not needs_downscaling:
            # Mark as skipped
            redis_client.sadd(SKIPPED_SET, url)
            with stats_lock:
                stats['skipped'] += 1
                stats['processed'] += 1
            progress_bar.set_postfix({
                'completed': stats['completed'],
                'skipped': stats['skipped'],
                'errors': stats['errors'],
                'already_done': stats['already_done']
            })
            return {
                'status': 'skipped',
                'url': url,
                'reason': f'Dimensions {original_size[0]}x{original_size[1]} <= {max_dim}'
            }
        
        # Upload downscaled image
        progress_bar.set_description(f"📤 Uploading downscaled image...")
        
        # Determine content type from original image
        image_bytes.seek(0)
        img = Image.open(image_bytes)
        content_type = get_image_content_type(img.format)
        
        upload_image_to_s3(bucket, key, processed_image.getvalue(), content_type)
        
        # Mark as completed
        redis_client.sadd(COMPLETED_SET, url)
        with stats_lock:
            stats['completed'] += 1
            stats['processed'] += 1
        
        progress_bar.set_postfix({
            'completed': stats['completed'],
            'skipped': stats['skipped'],
            'errors': stats['errors'],
            'already_done': stats['already_done']
        })
        
        return {
            'status': 'completed',
            'url': url,
            'original_size': original_size,
            'new_size': new_size
        }
        
    except Exception as e:
        with stats_lock:
            stats['errors'] += 1
            stats['processed'] += 1
        progress_bar.set_postfix({
            'completed': stats['completed'],
            'skipped': stats['skipped'],
            'errors': stats['errors'],
            'already_done': stats['already_done']
        })
        return {
            'status': 'error',
            'url': url,
            'error': str(e)
        }
    finally:
        progress_bar.update(1)


def load_urls(filename='urls.txt'):
    """Load URLs from file."""
    if not os.path.exists(filename):
        raise FileNotFoundError(f"URLs file not found: {filename}")
    
    with open(filename, 'r') as f:
        urls = [line.strip() for line in f if line.strip()]
    
    return urls


def main():
    """Main execution function."""
    # Parse command line arguments
    if len(sys.argv) != 3:
        print("Usage: ./downscale.py <max-dim> <num-parallel>")
        print("Example: ./downscale.py 1920 4")
        sys.exit(1)
    
    try:
        max_dim = int(sys.argv[1])
        num_parallel = int(sys.argv[2])
    except ValueError:
        print("Error: Both arguments must be integers")
        sys.exit(1)
    
    if max_dim <= 0 or num_parallel <= 0:
        print("Error: Arguments must be positive integers")
        sys.exit(1)
    
    print(f"🚀 Image Downscaling Script")
    print(f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
    print(f"📏 Max dimension: {max_dim}px")
    print(f"⚡ Parallel threads: {num_parallel}")
    print(f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
    
    # Load URLs
    print("📂 Loading URLs from urls.txt...")
    urls = load_urls()
    total_urls = len(urls)
    print(f"✅ Loaded {total_urls} URLs\n")
    
    # Check Redis connection
    try:
        redis_client.ping()
        print("✅ Connected to Redis")
    except Exception as e:
        print(f"❌ Failed to connect to Redis: {e}")
        print("Please ensure Redis is running: redis-server")
        sys.exit(1)
    
    # Check AWS credentials
    if not os.getenv('AWS_ACCESS_KEY_ID') or not os.getenv('AWS_SECRET_ACCESS_KEY'):
        print("❌ AWS credentials not found in .env file")
        sys.exit(1)
    print("✅ AWS credentials loaded")
    
    # Get initial state
    already_completed = len(redis_client.smembers(COMPLETED_SET))
    already_skipped = len(redis_client.smembers(SKIPPED_SET))
    print(f"📊 Previous state: {already_completed} completed, {already_skipped} skipped\n")
    
    # Process URLs in parallel
    print(f"⚙️  Processing images...\n")
    
    results = []
    
    with tqdm(total=total_urls, 
              desc="Overall Progress",
              unit="img",
              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as progress_bar:
        
        with ThreadPoolExecutor(max_workers=num_parallel) as executor:
            # Submit all tasks
            future_to_url = {
                executor.submit(process_url, url, max_dim, progress_bar): url 
                for url in urls
            }
            
            # Collect results as they complete
            for future in as_completed(future_to_url):
                result = future.result()
                results.append(result)
    
    # Print summary
    print("\n" + "━" * 80)
    print("📊 PROCESSING SUMMARY")
    print("━" * 80)
    print(f"Total URLs processed:     {stats['processed']:6d}")
    print(f"✅ Completed (downscaled): {stats['completed']:6d}")
    print(f"⏭️  Skipped (no resize):    {stats['skipped']:6d}")
    print(f"♻️  Already done:          {stats['already_done']:6d}")
    print(f"❌ Errors:                 {stats['errors']:6d}")
    print("━" * 80)
    
    # Print errors if any
    if stats['errors'] > 0:
        print("\n❌ ERRORS:")
        print("━" * 80)
        for result in results:
            if result['status'] == 'error':
                print(f"URL: {result['url']}")
                print(f"Error: {result['error']}\n")
    
    # Print some successful downscaling examples
    completed_results = [r for r in results if r['status'] == 'completed']
    if completed_results:
        print("\n✅ SAMPLE DOWNSCALED IMAGES:")
        print("━" * 80)
        for result in completed_results[:5]:  # Show first 5
            orig = result['original_size']
            new = result['new_size']
            print(f"• {result['url'][-60:]}")
            print(f"  {orig[0]}x{orig[1]} → {new[0]}x{new[1]}")
        if len(completed_results) > 5:
            print(f"  ... and {len(completed_results) - 5} more")
        print("━" * 80)
    
    print("\n✨ Done!")


if __name__ == "__main__":
    main()
