Files
Feeding_control_system/vision/resize_tuili_image_main.py
2025-09-26 13:32:34 +08:00

185 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import cv2
import numpy as np
from ultralytics import YOLO
# ---------------------------
# 类别映射
# ---------------------------
CLASS_NAMES = {
0: "未堆料",
1: "小堆料",
2: "大堆料",
3: "未浇筑满",
4: "浇筑满"
}
# ---------------------------
# 加载 ROI 列表
# ---------------------------
def load_global_rois(txt_path):
rois = []
if not os.path.exists(txt_path):
print(f"ROI 文件不存在: {txt_path}")
return rois
with open(txt_path, 'r') as f:
for line in f:
s = line.strip()
if s:
try:
x, y, w, h = map(int, s.split(','))
rois.append((x, y, w, h))
except Exception as e:
print(f"无法解析 ROI 行 '{s}': {e}")
return rois
# ---------------------------
# 裁剪并 resize ROI
# ---------------------------
def crop_and_resize(img, rois, target_size=640):
crops = []
h_img, w_img = img.shape[:2]
for i, (x, y, w, h) in enumerate(rois):
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
continue
roi = img[y:y + h, x:x + w]
roi_resized = cv2.resize(roi, (target_size, target_size), interpolation=cv2.INTER_AREA)
crops.append((roi_resized, i))
return crops
# ---------------------------
# class1/class2 加权判断
# ---------------------------
def weighted_small_large(pred_probs, threshold=0.4, w1=0.3, w2=0.7):
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
total = p1 + p2
if total > 0:
score = (w1 * p1 + w2 * p2) / total
else:
score = 0.0
final_class = "大堆料" if score >= threshold else "小堆料"
return final_class, score, p1, p2
# ---------------------------
# 单张图片推理函数
# ---------------------------
def classify_image_weighted(image, model, threshold=0.4):
results = model(image)
pred_probs = results[0].probs.data.cpu().numpy().flatten()
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
# class1/class2 使用加权得分
if class_id in [1, 2]:
final_class, score, p1, p2 = weighted_small_large(pred_probs, threshold=threshold)
else:
final_class = class_name
score = confidence
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
return final_class, score, p1, p2
# ---------------------------
# 实时视频流推理函数
# ---------------------------
def real_time_inference(rtsp_url, model_path, roi_file, target_size=640, threshold=0.4):
"""
从RTSP流实时推理
:param rtsp_url: RTSP流URL
:param model_path: 模型路径
:param roi_file: ROI文件路径
:param target_size: 目标尺寸
:param threshold: 分类阈值
"""
# 加载模型
model = YOLO(model_path)
# 加载ROI
rois = load_global_rois(roi_file)
if not rois:
print("❌ 没有有效 ROI退出")
return
# 打开RTSP流
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
print(f"❌ 无法打开视频流: {rtsp_url}")
return
print(f"✅ 成功连接到视频流: {rtsp_url}")
print("'q' 键退出,按 's' 键保存当前帧")
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
print("❌ 无法读取帧,可能连接已断开")
break
frame_count += 1
print(f"\n处理第 {frame_count}")
try:
# 裁剪并调整ROI
crops = crop_and_resize(frame, rois, target_size)
for roi_resized, roi_idx in crops:
final_class, score, p1, p2 = classify_image_weighted(roi_resized, model, threshold=threshold)
print(f"ROI {roi_idx} -> 类别: {final_class}, 加权分数: {score:.2f}, "
f"class1 置信度: {p1:.2f}, class2 置信度: {p2:.2f}")
# 判断是否溢料
if "大堆料" in final_class or "浇筑满" in final_class:
print(f"🚨 检测到溢料: ROI {roi_idx} - {final_class}")
# 可视化(可选)
cv2.imshow(f'ROI {roi_idx}', roi_resized)
# 显示原始帧
cv2.imshow('Original Frame', frame)
except Exception as e:
print(f"处理帧时出错: {e}")
continue
# 键盘控制
key = cv2.waitKey(1) & 0xFF
if key == ord('q'): # 按q退出
break
elif key == ord('s'): # 按s保存当前帧
cv2.imwrite(f"frame_{frame_count}.jpg", frame)
print(f"保存帧到 frame_{frame_count}.jpg")
# 清理资源
cap.release()
cv2.destroyAllWindows()
print("✅ 视频流处理结束")
# ---------------------------
# 主函数 - 实时推理示例
# ---------------------------
# if __name__ == "__main__":
# # RTSP流URL
# rtsp_url = "rtsp://admin:XJ123456@192.168.1.51:554/streaming/channels/101"
#
# # 配置参数
# model_path = r"models/overflow.pt"
# roi_file = r"./roi_coordinates/1_rois.txt"
# target_size = 640
# threshold = 0.4
#
# print("开始实时视频流推理...")
# real_time_inference(rtsp_url, model_path, roi_file, target_size, threshold)