#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import numpy as np
import bpu_infer_lib
import os
import time
import sys

# ----------------------------- 配置参数 -----------------------------------
MODEL_PATH = "yolov8n.bin"      # 模型文件
IMAGE_PATH = "test.png"          # 测试图片
CONF_THRESH = 0.25               # 置信度阈值
IOU_THRESH = 0.45                # NMS 的 IoU 阈值
INPUT_SIZE = (640, 640)          # 模型输入尺寸（宽, 高）
# -------------------------------------------------------------------------

# 自定义类别名称（共7类）
CLASS_NAMES = ['write', 'read', 'lookup', 'turn_head', 'raise_hand', 'stand', 'discuss']
NUM_CLASSES = len(CLASS_NAMES)   # 7

def letterbox(image, target_size=(640, 640), color=(114, 114, 114)):
    """
    保持宽高比的 letterbox 预处理
    返回: 填充后的图像 (RGB, uint8), 缩放因子, 填充尺寸
    """
    h, w = image.shape[:2]
    target_w, target_h = target_size
    # 计算缩放比例
    scale = min(target_w / w, target_h / h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    resized = cv2.resize(image, (new_w, new_h))
    # 创建画布并填充
    canvas = np.full((target_h, target_w, 3), color, dtype=np.uint8)
    pad_w = (target_w - new_w) // 2
    pad_h = (target_h - new_h) // 2
    canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
    # 转换为 RGB 并转 CHW
    rgb_img = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
    # 归一化：模型输入为 uint8，范围 0-255，无需额外归一化
    input_tensor = np.transpose(rgb_img, (2, 0, 1))  # CHW
    input_tensor = np.expand_dims(input_tensor, axis=0)  # 1CHW
    return input_tensor, (scale, scale), (pad_w, pad_h)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def decode_yolov8_output(pred, conf_thresh=0.25):
    """
    解码 YOLOv8 输出（单张图片，形状 [1, 11, 8400] 或 [1, 8400, 11]）
    返回: 列表，每个元素为 [x1, y1, x2, y2, conf, class_id]
    """
    # 自动判断形状
    if pred.shape[1] == 4 + NUM_CLASSES and pred.shape[2] > 1:
        # 形状 [1, 11, 8400] -> 转置为 [1, 8400, 11]
        pred = pred.transpose(0, 2, 1)
    # 此时 pred 形状应为 [1, num_anchors, 11]
    anchors = pred[0]  # [num_anchors, 11]
    # 分离坐标、类别分数
    xywh = anchors[:, :4]           # [num_anchors, 4]  (cx, cy, w, h)
    cls_scores = anchors[:, 4:4+NUM_CLASSES]  # [num_anchors, 7]
    # 应用 sigmoid 获取置信度（类别最大概率）
    cls_scores = sigmoid(cls_scores)
    max_scores = np.max(cls_scores, axis=1)      # 每个 anchor 的最大类别概率
    class_ids = np.argmax(cls_scores, axis=1)
    # 置信度阈值过滤
    mask = max_scores >= conf_thresh
    if not np.any(mask):
        return []
    xywh = xywh[mask]
    max_scores = max_scores[mask]
    class_ids = class_ids[mask]
    # 将 cx, cy, w, h 转换为 xyxy（绝对坐标，范围 0~640）
    x1 = xywh[:, 0] - xywh[:, 2] / 2
    y1 = xywh[:, 1] - xywh[:, 3] / 2
    x2 = xywh[:, 0] + xywh[:, 2] / 2
    y2 = xywh[:, 1] + xywh[:, 3] / 2
    # 裁剪到 [0, INPUT_SIZE] 范围内
    x1 = np.clip(x1, 0, INPUT_SIZE[0])
    y1 = np.clip(y1, 0, INPUT_SIZE[1])
    x2 = np.clip(x2, 0, INPUT_SIZE[0])
    y2 = np.clip(y2, 0, INPUT_SIZE[1])
    detections = []
    for i in range(len(x1)):
        detections.append([x1[i], y1[i], x2[i], y2[i], max_scores[i], class_ids[i]])
    return detections

def nms(detections, iou_thresh=0.45):
    """非极大值抑制（基于 OpenCV 的 NMSBoxes）"""
    if len(detections) == 0:
        return []
    boxes = [d[:4] for d in detections]
    scores = [d[4] for d in detections]
    # OpenCV NMS 要求输入为 (x1, y1, x2, y2)
    indices = cv2.dnn.NMSBoxes(boxes, scores, CONF_THRESH, iou_thresh)
    if len(indices) == 0:
        return []
    # indices 可能是个 tuple
    if isinstance(indices, tuple):
        indices = indices[0]
    keep = [detections[i] for i in indices.flatten()]
    return keep

def scale_boxes(boxes, scale_ratio, pad, original_shape):
    """
    将 letterbox 后的检测框映射回原图坐标
    boxes: 列表，每个元素 [x1,y1,x2,y2,conf,class_id] (坐标在 0~INPUT_SIZE 内)
    scale_ratio: (scale_x, scale_y) 缩放因子
    pad: (pad_w, pad_h) 填充大小
    original_shape: (orig_w, orig_h) 原始图像宽高
    """
    scale_x, scale_y = scale_ratio
    pad_w, pad_h = pad
    orig_w, orig_h = original_shape
    mapped_boxes = []
    for box in boxes:
        x1, y1, x2, y2, conf, cls_id = box
        # 去掉填充
        x1 = (x1 - pad_w) / scale_x
        y1 = (y1 - pad_h) / scale_y
        x2 = (x2 - pad_w) / scale_x
        y2 = (y2 - pad_h) / scale_y
        # 裁剪到原始图像范围
        x1 = np.clip(x1, 0, orig_w)
        y1 = np.clip(y1, 0, orig_h)
        x2 = np.clip(x2, 0, orig_w)
        y2 = np.clip(y2, 0, orig_h)
        mapped_boxes.append([x1, y1, x2, y2, conf, cls_id])
    return mapped_boxes

def draw_boxes(image_path, detections, output_path="result.jpg"):
    """在原图上绘制检测框并保存"""
    img = cv2.imread(image_path)
    if img is None:
        print(f"无法读取图像 {image_path}")
        return
    for det in detections:
        x1, y1, x2, y2, conf, cls_id = det
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        label = f"{CLASS_NAMES[cls_id]}: {conf:.2f}"
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    cv2.imwrite(output_path, img)
    print(f"结果已保存至: {output_path}")

def main():
    # 1. 检查文件是否存在
    if not os.path.exists(MODEL_PATH):
        print(f"错误：模型文件 {MODEL_PATH} 不存在！")
        sys.exit(1)
    if not os.path.exists(IMAGE_PATH):
        print(f"错误：图片文件 {IMAGE_PATH} 不存在！")
        sys.exit(1)

    # 2. 初始化 BPU 推理库
    print("初始化推理库...")
    infer = bpu_infer_lib.Infer(False)  # debug=False

    # 3. 加载模型
    print(f"加载模型: {MODEL_PATH}")
    infer.load_model(MODEL_PATH)

    # 可选：打印模型输入输出信息（调试）
    # print(f"模型输入数量: {len(infer.inputs)}")
    # for i, inp in enumerate(infer.inputs):
    #     print(f"Input {i}: {inp.shape}, layout={inp.properties.tensorLayout}")
    # print(f"模型输出数量: {len(infer.outputs)}")
    # for i, out in enumerate(infer.outputs):
    #     print(f"Output {i}: {out.shape}, layout={out.properties.tensorLayout}")

    # 4. 读取原图并预处理
    original_img = cv2.imread(IMAGE_PATH)
    orig_h, orig_w = original_img.shape[:2]
    input_tensor, scale_ratio, pad = letterbox(original_img, INPUT_SIZE)
    # input_tensor 已经是 (1,3,640,640) uint8

    # 5. 推理
    print("执行推理...")
    start = time.time()
    outputs = infer.forward(input_tensor)  # outputs 是 list of np.ndarray
    elapsed_ms = (time.time() - start) * 1000
    print(f"推理耗时: {elapsed_ms:.2f} ms")

    # 6. 后处理
    # 假设模型输出只有一个张量（YOLOv8 通常如此），取 outputs[0]
    if len(outputs) == 0:
        print("模型没有输出！")
        return
    pred = outputs[0]  # 形状可能是 (1,11,8400) 或 (1,8400,11)
    print(f"输出张量形状: {pred.shape}")

    # 解码得到候选框
    raw_dets = decode_yolov8_output(pred, CONF_THRESH)
    if len(raw_dets) == 0:
        print("没有检测到目标。")
        return

    # NMS 过滤
    final_dets = nms(raw_dets, IOU_THRESH)
    if len(final_dets) == 0:
        print("NMS 后无剩余目标。")
        return

    # 坐标映射回原图
    final_dets = scale_boxes(final_dets, scale_ratio, pad, (orig_w, orig_h))

    # 7. 绘制并保存结果
    draw_boxes(IMAGE_PATH, final_dets, "result_with_boxes.jpg")

    # 打印检测结果
    print(f"检测到 {len(final_dets)} 个目标：")
    for det in final_dets:
        x1, y1, x2, y2, conf, cls_id = det
        print(f"  {CLASS_NAMES[cls_id]}: 置信度 {conf:.3f}, 框 [{x1:.1f},{y1:.1f},{x2:.1f},{y2:.1f}]")

if __name__ == "__main__":
    main()