#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
YOLO分类模型工具类 - 支持pt和onnx格式的分类模型调用
"""

import cv2
import numpy as np
from typing import Dict, Any, Optional, List, Union
import time
import sys
import os
from pathlib import Path
import torch
import torchvision.transforms as transforms
from PIL import Image

try:
    from ultralytics import YOLO
    ULTRALYTICS_AVAILABLE = True
except ImportError:
    ULTRALYTICS_AVAILABLE = False

try:
    import onnxruntime as ort
    ONNXRUNTIME_AVAILABLE = True
except ImportError:
    ONNXRUNTIME_AVAILABLE = False


# 分类结果元素信息
# {
#     "类型编号": 0,                    # int - 类别ID (0, 1, 2, 3, 4, 5)
#     "类型字符": "class_name",         # str - 类别名称
#     "分类置信度": 0.8542,            # float - 分类置信度 (保留4位小数)
#     "置信度排名": 1,                 # int - 置信度排名 (1表示最高)
#     "索引": 0                        # int - 在top-k结果中的索引
# }


class yolo_cls_model_loader:
    """YOLO分类模型加载器类，支持pt和onnx格式"""
    
    def __init__(self, device: str = None, verbose: bool = False):
        """
        初始化分类模型加载器
        Args:
            device: 设备选择，可选 'cpu', 'cuda', 'auto'。默认为自动选择
            verbose: 是否显示详细日志信息，默认为True
        """
        self.model = None
        self.model_path = None
        self.device = self._select_device(device)
        self.is_model_loaded = False
        self.model_load_error = None
        self.model_format = None  # 'pt', 'onnx'
        self.input_size = (224, 224)  # 默认输入尺寸
        self.class_names = {}
        self.verbose = verbose
        self.usr_class_names = None
        
        # ONNX会话
        self.ort_session = None
        self.input_name = None
        self.output_name = None
        
        # 数据预处理
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

    def _log(self, message: str):
        """
        内部日志输出方法
        Args:
            message: 日志消息
        """
        if self.verbose:
            print(message)
    
    def _select_device(self, device: str = None) -> str:
        """
        自动选择设备
        Args:
            device: 指定设备，可选 'cpu', 'cuda', 'auto' 或具体的GPU编号如 '0', '1'
        Returns:
            选择的设备字符串
        """
        if device == 'cpu':
            return 'cpu'
        elif device and device.startswith('cuda'):
            return device
        elif device and device.isdigit():
            return f'cuda:{device}'
        else:
            # 自动选择设备
            if torch.cuda.is_available():
                return 'cuda:0'
            else:
                return 'cpu'
    
    def _validate_model_path(self, model_path: str) -> tuple[bool, str]:
        """
        验证模型路径和格式
        Args:
            model_path: 模型文件路径
        Returns:
            (是否有效, 错误信息)
        """
        if not model_path:
            return False, "模型路径不能为空"
            
        model_path = Path(model_path)
        if not model_path.exists():
            return False, f"模型文件不存在: {model_path}"
            
        # 检查支持的格式
        supported_formats = {'.pt', '.onnx'}
        if model_path.suffix.lower() not in supported_formats:
            return False, f"不支持的模型格式: {model_path.suffix}，支持的格式: {supported_formats}"
            
        return True, ""
    
    def load_from_path(self, model_path: str, input_size: tuple = None, **kwargs) -> bool:
        """
        从路径加载YOLO分类模型
        Args:
            model_path: 模型文件路径
            input_size: 输入图像尺寸，格式为(height, width)，默认为(224, 224)
            **kwargs: 其他参数
        Returns:
            是否加载成功
        """
        try:
            # 验证模型路径
            is_valid, error_msg = self._validate_model_path(model_path)
            if not is_valid:
                self.model_load_error = error_msg
                return False
            
            # 检查是否已经加载了相同路径的模型
            if (self.is_model_loaded and 
                self.model_path == str(model_path) and 
                (self.model is not None or self.ort_session is not None)):
                self._log(f"模型已加载，路径相同: {model_path}")
                return True
            
            # 卸载已有模型
            self.unload_model()
            
            # 设置输入尺寸
            if input_size:
                self.input_size = input_size
            
            # 根据文件格式选择加载方式
            model_path_obj = Path(model_path)
            self.model_format = model_path_obj.suffix.lower().replace('.', '')
            
            if self.model_format == 'pt':
                return self._load_pt_model(model_path, **kwargs)
            elif self.model_format == 'onnx':
                return self._load_onnx_model(model_path, **kwargs)
            else:
                self.model_load_error = f"不支持的模型格式: {self.model_format}"
                return False
                
        except Exception as e:
            self.model_load_error = f"模型加载失败: {str(e)}"
            self._log(self.model_load_error)
            return False
    
    def _load_pt_model(self, model_path: str, **kwargs) -> bool:
        """
        加载PyTorch格式的模型
        Args:
            model_path: 模型文件路径
            **kwargs: 其他参数
        Returns:
            是否加载成功
        """
        try:
            # 检查ultralytics是否可用
            if not ULTRALYTICS_AVAILABLE:
                self.model_load_error = "ultralytics库未安装，请安装: pip install ultralytics"
                return False
            
            self._log(f"正在加载PyTorch模型: {model_path}")
            self._log(f"使用设备: {self.device}")
            
            self.model = YOLO(model_path)
            
            # 移动到指定设备
            if hasattr(self.model, 'to'):
                self.model.to(self.device)
            
            # 获取模型信息
            self.model_path = str(model_path)
            self._extract_pt_model_info()
            
            self.is_model_loaded = True
            self.model_load_error = None
            
            self._log(f"PyTorch模型加载成功:")
            self._log(f"  格式: {self.model_format}")
            self._log(f"  输入尺寸: {self.input_size}")
            self._log(f"  类别数: {len(self.class_names) if self.class_names else 'Unknown'}")
            
            return True
            
        except Exception as e:
            self.model_load_error = f"PyTorch模型加载失败: {str(e)}"
            self._log(self.model_load_error)
            return False
    
    def _load_onnx_model(self, model_path: str, **kwargs) -> bool:
        """
        加载ONNX格式的模型
        Args:
            model_path: 模型文件路径
            **kwargs: 其他参数
        Returns:
            是否加载成功
        """
        try:
            # 检查onnxruntime是否可用
            if not ONNXRUNTIME_AVAILABLE:
                self.model_load_error = "onnxruntime库未安装，请安装: pip install onnxruntime 或 pip install onnxruntime-gpu"
                return False
            
            self._log(f"正在加载ONNX模型: {model_path}")
            
            # 选择execution provider
            providers = ['CPUExecutionProvider']
            if self.device.startswith('cuda') and torch.cuda.is_available():
                providers.insert(0, 'CUDAExecutionProvider')
                self._log(f"使用GPU设备: {self.device}")
            else:
                self._log("使用CPU设备")
            
            # 创建ONNX Runtime会话
            self.ort_session = ort.InferenceSession(
                model_path,
                providers=providers
            )
            
            # 获取输入输出信息
            self.input_name = self.ort_session.get_inputs()[0].name
            self.output_name = self.ort_session.get_outputs()[0].name
            
            # 获取输入形状
            input_shape = self.ort_session.get_inputs()[0].shape
            if len(input_shape) == 4:  # NCHW格式
                self.input_size = (input_shape[2], input_shape[3])
            
            self.model_path = str(model_path)
            self._extract_onnx_model_info()
            
            self.is_model_loaded = True
            self.model_load_error = None
            
            self._log(f"ONNX模型加载成功:")
            self._log(f"  格式: {self.model_format}")
            self._log(f"  输入尺寸: {self.input_size}")
            self._log(f"  输入名称: {self.input_name}")
            self._log(f"  输出名称: {self.output_name}")
            
            return True
            
        except Exception as e:
            self.model_load_error = f"ONNX模型加载失败: {str(e)}"
            self._log(self.model_load_error)
            return False
    
    def _extract_pt_model_info(self):
        """提取PyTorch模型信息"""
        try:
            if hasattr(self.model, 'names'):
                self.class_names = self.model.names
            else:
                # 如果没有names属性，使用默认类名
                self.class_names = {i: f"Class{i}" for i in range(1000)}  # 默认1000类
        except Exception as e:
            self._log(f"提取PyTorch模型信息时出错: {e}")
            self.class_names = {0: "Class0"}
    
    def _extract_onnx_model_info(self):
        """提取ONNX模型信息"""
        try:
            # ONNX模型没有直接的类名信息，使用默认类名
            output_shape = self.ort_session.get_outputs()[0].shape
            if len(output_shape) >= 2:
                num_classes = output_shape[-1]
                self.class_names = {i: f"Class{i}" for i in range(num_classes)}
            else:
                self.class_names = {i: f"Class{i}" for i in range(1000)}
        except Exception as e:
            self._log(f"提取ONNX模型信息时出错: {e}")
            self.class_names = {0: "Class0"}
    
    def unload_model(self):
        """卸载模型，释放内存"""
        try:
            if self.model is not None:
                del self.model
                self.model = None
            
            if self.ort_session is not None:
                del self.ort_session
                self.ort_session = None
            
            self.is_model_loaded = False
            self.model_path = None
            self.model_load_error = None
            
            # 清理GPU缓存
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            self._log("模型已卸载")
            
        except Exception as e:
            self._log(f"模型卸载时出错: {e}")
    
    def preprocess_image(self, image: np.ndarray) -> np.ndarray:
        """
        预处理图像用于分类推理 - 与YOLO官方预处理保持一致
        Args:
            image: 输入图像 (BGR格式)
        Returns:
            预处理后的图像数组
        """
        try:
            # 1. 转换为RGB格式（YOLO使用RGB）
            if len(image.shape) == 3 and image.shape[2] == 3:
                image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            elif len(image.shape) == 2:
                # 如果是灰度图，转换为RGB
                image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
            else:
                image_rgb = image
            
            # 2. 调整尺寸，使用与YOLO相同的插值方法
            image_resized = cv2.resize(image_rgb, self.input_size, interpolation=cv2.INTER_LINEAR)
            
            # 3. 归一化到[0,1]，使用float32精度
            image_normalized = image_resized.astype(np.float32) / 255.0
            
            # 4. 标准化 - 使用ImageNet预训练的均值和标准差
            # 注意：确保使用正确的归一化参数
            mean = np.array(self.mean, dtype=np.float32)
            std = np.array(self.std, dtype=np.float32)
            
            image_normalized = (image_normalized - mean) / std
            
            # 5. 转换为NCHW格式 (batch_size=1, channels, height, width)
            image_tensor = np.transpose(image_normalized, (2, 0, 1))
            image_batch = np.expand_dims(image_tensor, axis=0).astype(np.float32)
            
            return image_batch
            
        except Exception as e:
            raise RuntimeError(f"图像预处理失败: {str(e)}")
    
    def preprocess_image_yolo_style(self, image: np.ndarray) -> np.ndarray:
        """
        使用YOLO风格的预处理（与官方导出保持完全一致）
        Args:
            image: 输入图像 (BGR格式)
        Returns:
            预处理后的图像数组
        """
        try:
            # 保存原始图像的副本
            original_image = image.copy()
            
            # 获取输入尺寸
            target_h, target_w = self.input_size
            
            # 1. 图像缩放（保持宽高比或直接缩放）
            # 这里使用直接缩放方式，与导出时保持一致
            resized = cv2.resize(original_image, (target_w, target_h))
            
            # 2. BGR转RGB
            rgb_image = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
            
            # 3. 转换为浮点数并归一化到[0, 1]
            normalized = rgb_image.astype(np.float32) / 255.0
            
            # 4. 转换维度顺序：HWC -> CHW
            transposed = np.transpose(normalized, (2, 0, 1))
            
            # 5. 添加batch维度：CHW -> NCHW
            batched = np.expand_dims(transposed, axis=0)
            
            return batched
            
        except Exception as e:
            raise RuntimeError(f"YOLO风格预处理失败: {str(e)}")
    
    def predict(self, image: np.ndarray, top_k: int = 5, use_yolo_preprocess: bool = True, **kwargs) -> List[Dict]:
        """
        使用加载的分类模型进行预测
        Args:
            image: 输入图像 (BGR格式)
            top_k: 返回top-k个分类结果，默认为5
            use_yolo_preprocess: ONNX模型是否使用YOLO风格预处理，默认True
            **kwargs: 其他预测参数
        Returns:
            分类结果列表
        """
        if not self.is_model_loaded:
            raise RuntimeError("模型未加载")
        
        # 记录开始时间
        start_time = time.perf_counter()
        
        try:
            # 获取模型格式信息
            model_type = self.model_format.upper() + "模型"
            device_info = f"设备: {self.device.upper()}"
            
            #if self.verbose:
            print(f"🚀 推理信息: {model_type} | {device_info} | 全局置信度模式")
            
            # 执行预测
            if self.model_format == 'pt':
                result = self._predict_pt(image, top_k, **kwargs)
            elif self.model_format == 'onnx':
                result = self._predict_onnx(image, top_k, use_yolo_preprocess, **kwargs)
            else:
                raise RuntimeError(f"不支持的模型格式: {self.model_format}")
            
            # 记录结束时间并输出耗时
            end_time = time.perf_counter()
            inference_time = end_time - start_time
            
            #if self.verbose:
            print(f"⚡ 推理耗时: {inference_time:.4f}秒")
            
            return result
                
        except Exception as e:
            error_msg = f"分类预测失败: {str(e)}"
            if self.verbose:
                self._log(error_msg)
            raise RuntimeError(error_msg)
    
    def _predict_pt(self, image: np.ndarray, top_k: int = 5, **kwargs) -> List[Dict]:
        """
        使用PyTorch模型进行预测
        Args:
            image: 输入图像
            top_k: 返回top-k个结果
            **kwargs: 其他参数
        Returns:
            分类结果列表
        """
        try:
            # 使用YOLO模型进行预测
            results = self.model(image, **kwargs)
            
            if not results or len(results) == 0:
                return []
            
            result = results[0]
            
            # 获取分类概率
            if hasattr(result, 'probs') and result.probs is not None:
                probs = result.probs.data.cpu().numpy()
                
                # 获取top-k结果
                top_indices = np.argsort(probs)[::-1][:top_k]
                
                classification_results = []
                for rank, idx in enumerate(top_indices):
                    class_id = int(idx)
                    confidence = float(probs[idx])
                    class_name = self.get_class_name(class_id)
                    
                    result_dict = {
                        "类型编号": class_id,
                        "类型字符": class_name,
                        "分类置信度": round(confidence, 4),
                        "置信度排名": rank + 1,
                        "索引": rank
                    }
                    
                    classification_results.append(result_dict)
                
                return classification_results
            else:
                return []
                
        except Exception as e:
            raise RuntimeError(f"PyTorch模型预测失败: {str(e)}")
    
    def _predict_onnx(self, image: np.ndarray, top_k: int = 5, use_yolo_preprocess: bool = True, **kwargs) -> List[Dict]:
        """
        使用ONNX模型进行预测
        Args:
            image: 输入图像
            top_k: 返回top-k个结果
            use_yolo_preprocess: 是否使用YOLO风格预处理，默认True
            **kwargs: 其他参数
        Returns:
            分类结果列表
        """
        try:
            # 选择预处理方式
            if use_yolo_preprocess:
                input_tensor = self.preprocess_image_yolo_style(image)
            else:
                input_tensor = self.preprocess_image(image)
            
            # 调试信息
            if self.verbose:
                print(f"输入张量形状: {input_tensor.shape}")
                print(f"输入张量数据类型: {input_tensor.dtype}")
                print(f"输入张量值范围: [{input_tensor.min():.4f}, {input_tensor.max():.4f}]")
            
            # 运行推理
            outputs = self.ort_session.run(
                [self.output_name],
                {self.input_name: input_tensor}
            )
            
            # 调试输出信息
            if self.verbose:
                print(f"输出形状: {outputs[0].shape}")
                print(f"原始logits: {outputs[0][0]}")
            
            # 获取预测结果
            logits = outputs[0][0]  # 假设batch_size=1
            
            # 应用softmax得到概率
            exp_logits = np.exp(logits - np.max(logits))  # 数值稳定性
            probs = exp_logits / np.sum(exp_logits)
            
            if self.verbose:
                print(f"Softmax概率: {probs}")
            
            # 获取top-k结果
            top_indices = np.argsort(probs)[::-1][:top_k]
            
            classification_results = []
            for rank, idx in enumerate(top_indices):
                class_id = int(idx)
                confidence = float(probs[idx])
                class_name = self.get_class_name(class_id)
                
                result_dict = {
                    "类型编号": class_id,
                    "类型字符": class_name,
                    "分类置信度": round(confidence, 4),
                    "置信度排名": rank + 1,
                    "索引": rank
                }
                
                classification_results.append(result_dict)
            
            return classification_results
            
        except Exception as e:
            raise RuntimeError(f"ONNX模型预测失败: {str(e)}")
    
    def predict_batch(self, images: List[np.ndarray], top_k: int = 5, **kwargs) -> List[List[Dict]]:
        """
        批量分类预测
        Args:
            images: 图像列表
            top_k: 返回top-k个结果
            **kwargs: 预测参数
        Returns:
            每张图像的分类结果列表
        """
        if not self.is_model_loaded:
            raise RuntimeError("模型未加载")
        
        try:
            results = []
            for i, image in enumerate(images):
                if self.verbose:
                    self._log(f"处理图像 {i+1}/{len(images)}")
                result = self.predict(image, top_k, **kwargs)
                results.append(result)
            return results
        except Exception as e:
            error_msg = f"批量分类预测失败: {str(e)}"
            if self.verbose:
                self._log(error_msg)
            raise RuntimeError(error_msg)
    
    def set_class_names(self, class_names: Dict[int, str]):
        """
        设置自定义类别名称
        Args:
            class_names: 类别名称字典，格式为 {class_id: class_name}
        """
        self.usr_class_names = class_names
        if self.verbose:
            self._log(f"已设置自定义类别名称，共 {len(class_names)} 个类别")
    
    def get_class_name(self, class_id: int) -> str:
        """
        根据类别ID获取类别名称
        Args:
            class_id: 类别ID
        Returns:
            类别名称
        """
        if self.usr_class_names is not None and class_id in self.usr_class_names:
            return self.usr_class_names[class_id]
        
        if class_id in self.class_names:
            return self.class_names[class_id]
        
        return f"Class{class_id}"
    
    def get_top1_result(self, classification_results: List[Dict]) -> Dict:
        """
        获取top-1分类结果
        Args:
            classification_results: 分类结果列表
        Returns:
            top-1结果字典
        """
        if not classification_results:
            return {}
        
        # 返回置信度最高的结果（已经按置信度排序）
        return classification_results[0]
    
    def filter_by_confidence(self, classification_results: List[Dict], min_confidence: float = 0.5) -> List[Dict]:
        """
        根据置信度过滤分类结果
        Args:
            classification_results: 分类结果列表
            min_confidence: 最小置信度阈值
        Returns:
            过滤后的结果列表
        """
        filtered_results = []
        for result in classification_results:
            if result["分类置信度"] >= min_confidence:
                filtered_results.append(result)
        return filtered_results
    
    def get_classification_summary(self, classification_results: List[Dict]) -> Dict:
        """
        获取分类结果的统计摘要
        Args:
            classification_results: 分类结果列表
        Returns:
            统计摘要信息
        """
        if not classification_results:
            return {
                "总数": 0,
                "最高置信度": 0.0,
                "最低置信度": 0.0,
                "平均置信度": 0.0,
                "top1_类别": "无",
                "top1_置信度": 0.0
            }
        
        confidences = [result["分类置信度"] for result in classification_results]
        
        summary = {
            "总数": len(classification_results),
            "最高置信度": round(max(confidences), 4),
            "最低置信度": round(min(confidences), 4),
            "平均置信度": round(np.mean(confidences), 4),
            "top1_类别": classification_results[0]["类型字符"],
            "top1_置信度": classification_results[0]["分类置信度"]
        }
        
        return summary
    
    def set_device(self, device: str) -> bool:
        """
        设置计算设备
        Args:
            device: 设备名称
        Returns:
            是否设置成功
        """
        try:
            new_device = self._select_device(device)
            
            if self.is_model_loaded:
                if self.model_format == 'pt' and hasattr(self.model, 'to'):
                    self.model.to(new_device)
                elif self.model_format == 'onnx':
                    # ONNX模型需要重新创建会话来改变设备
                    if self.model_path:
                        self.device = new_device
                        return self._load_onnx_model(self.model_path)
            
            self.device = new_device
            if self.verbose:
                self._log(f"设备已切换到: {self.device}")
            return True
            
        except Exception as e:
            if self.verbose:
                self._log(f"设备切换失败: {e}")
            return False
    
    def get_model_info(self) -> Dict[str, Any]:
        """
        获取模型信息
        Returns:
            模型信息字典
        """
        return {
            "模型路径": self.model_path,
            "模型格式": self.model_format,
            "是否已加载": self.is_model_loaded,
            "设备": self.device,
            "输入尺寸": self.input_size,
            "类别数量": len(self.class_names) if self.class_names else 0,
            "加载错误": self.model_load_error
        }