增加ailai的旋转检测的推理和部署
This commit is contained in:
105
ailai_obb/angle.py
Normal file
105
ailai_obb/angle.py
Normal file
@ -0,0 +1,105 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
def get_best_obb_angle(image_path, weight_path, return_degree=False):
|
||||
"""
|
||||
输入:
|
||||
image_path: 图像路径
|
||||
weight_path: YOLO权重路径
|
||||
return_degree: 是否返回角度单位为度,默认 False(返回弧度)
|
||||
输出:
|
||||
置信度最高目标的旋转角
|
||||
如果未检测到目标返回 None
|
||||
"""
|
||||
# 读取图像
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"❌ 无法读取图像:{image_path}")
|
||||
return None
|
||||
|
||||
# 加载模型并预测
|
||||
model = YOLO(weight_path)
|
||||
results = model(img, save=False, imgsz=640, conf=0.15, mode='obb')
|
||||
result = results[0]
|
||||
|
||||
boxes = result.obb
|
||||
if not boxes:
|
||||
print("⚠️ 未检测到目标。")
|
||||
return None
|
||||
|
||||
# 取置信度最高框的旋转角
|
||||
best_box = max(boxes, key=lambda x: x.conf.cpu().numpy()[0])
|
||||
r = best_box.xywhr.cpu().numpy()[0][4] # 弧度
|
||||
|
||||
if return_degree:
|
||||
return np.degrees(r)
|
||||
else:
|
||||
return r
|
||||
|
||||
|
||||
def save_obb_visual(image_path, weight_path, save_path):
|
||||
"""
|
||||
输入:
|
||||
image_path: 图像路径
|
||||
weight_path: YOLO权重路径
|
||||
save_path: 保存带角度标注图像路径
|
||||
功能:
|
||||
检测 OBB 并标注置信度最高框旋转角度,保存图片
|
||||
"""
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"❌ 无法读取图像:{image_path}")
|
||||
return
|
||||
|
||||
model = YOLO(weight_path)
|
||||
results = model(img, save=False, imgsz=640, conf=0.15, mode='obb')
|
||||
result = results[0]
|
||||
|
||||
boxes = result.obb
|
||||
if not boxes:
|
||||
print("⚠️ 未检测到目标。")
|
||||
return
|
||||
|
||||
best_box = max(boxes, key=lambda x: x.conf.cpu().numpy()[0])
|
||||
cx, cy, w, h, r = best_box.xywhr.cpu().numpy()[0]
|
||||
angle_deg = np.degrees(r)
|
||||
|
||||
# 绘制 OBB
|
||||
annotated_img = img.copy()
|
||||
rect = ((cx, cy), (w, h), angle_deg)
|
||||
box_pts = cv2.boxPoints(rect).astype(int)
|
||||
cv2.polylines(annotated_img, [box_pts], isClosed=True, color=(0, 255, 0), thickness=2)
|
||||
|
||||
# 标注角度
|
||||
text = f"{angle_deg:.1f}°"
|
||||
font_scale = max(0.5, min(w, h)/100)
|
||||
thickness = 2
|
||||
text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
||||
text_x = int(cx - text_size[0]/2)
|
||||
text_y = int(cy + text_size[1]/2)
|
||||
cv2.putText(annotated_img, text, (text_x, text_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 255), thickness)
|
||||
|
||||
# 保存
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
cv2.imwrite(save_path, annotated_img)
|
||||
print(f"✅ 检测结果已保存至: {save_path}")
|
||||
|
||||
|
||||
# ===============================
|
||||
# 示例调用
|
||||
# ===============================
|
||||
if __name__ == "__main__":
|
||||
weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb3/weights/best.pt"
|
||||
image = r"/home/hx/yolo/output_masks/2.jpg"
|
||||
save_path = "./inference_results/best_detected_2.jpg"
|
||||
|
||||
angle_rad = get_best_obb_angle(image, weight)
|
||||
print(f"旋转角(弧度):{angle_rad:.4f}")
|
||||
|
||||
angle_deg = get_best_obb_angle(image, weight, return_degree=True)
|
||||
print(f"旋转角(度):{angle_deg:.2f}°")
|
||||
|
||||
save_obb_visual(image, weight, save_path)
|
||||
171
ailai_obb/bag_bushu.py
Normal file
171
ailai_obb/bag_bushu.py
Normal file
@ -0,0 +1,171 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from shapely.geometry import Polygon
|
||||
from rknnlite.api import RKNNLite
|
||||
import os
|
||||
|
||||
CLASSES = ['clamp']
|
||||
nmsThresh = 0.4
|
||||
objectThresh = 0.5
|
||||
|
||||
# ------------------- 工具函数 -------------------
|
||||
def letterbox_resize(image, size, bg_color=114):
|
||||
target_width, target_height = size
|
||||
image_height, image_width, _ = image.shape
|
||||
scale = min(target_width / image_width, target_height / image_height)
|
||||
new_width, new_height = int(image_width * scale), int(image_height * scale)
|
||||
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||
canvas = np.ones((target_height, target_width, 3), dtype=np.uint8) * bg_color
|
||||
offset_x, offset_y = (target_width - new_width) // 2, (target_height - new_height) // 2
|
||||
canvas[offset_y:offset_y + new_height, offset_x:offset_x + new_width] = image
|
||||
return canvas, scale, offset_x, offset_y
|
||||
|
||||
class DetectBox:
|
||||
def __init__(self, classId, score, xmin, ymin, xmax, ymax, angle):
|
||||
self.classId = classId
|
||||
self.score = score
|
||||
self.xmin = xmin
|
||||
self.ymin = ymin
|
||||
self.xmax = xmax
|
||||
self.ymax = ymax
|
||||
self.angle = angle
|
||||
|
||||
def rotate_rectangle(x1, y1, x2, y2, a):
|
||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
x1_new = int((x1 - cx) * math.cos(a) - (y1 - cy) * math.sin(a) + cx)
|
||||
y1_new = int((x1 - cx) * math.sin(a) + (y1 - cy) * math.cos(a) + cy)
|
||||
x2_new = int((x2 - cx) * math.cos(a) - (y2 - cy) * math.sin(a) + cx)
|
||||
y2_new = int((x2 - cx) * math.sin(a) + (y2 - cy) * math.cos(a) + cy)
|
||||
x3_new = int((x1 - cx) * math.cos(a) - (y2 - cy) * math.sin(a) + cx)
|
||||
y3_new = int((x1 - cx) * math.sin(a) + (y2 - cy) * math.cos(a) + cy)
|
||||
x4_new = int((x2 - cx) * math.cos(a) - (y1 - cy) * math.sin(a) + cx)
|
||||
y4_new = int((x2 - cx) * math.sin(a) + (y1 - cy) * math.cos(a) + cy)
|
||||
return [(x1_new, y1_new), (x3_new, y3_new), (x2_new, y2_new), (x4_new, y4_new)]
|
||||
|
||||
def intersection(g, p):
|
||||
g = Polygon(np.array(g).reshape(-1,2))
|
||||
p = Polygon(np.array(p).reshape(-1,2))
|
||||
if not g.is_valid or not p.is_valid:
|
||||
return 0
|
||||
inter = g.intersection(p).area
|
||||
union = g.area + p.area - inter
|
||||
return 0 if union == 0 else inter / union
|
||||
|
||||
def NMS(detectResult):
|
||||
predBoxs = []
|
||||
sort_detectboxs = sorted(detectResult, key=lambda x: x.score, reverse=True)
|
||||
for i in range(len(sort_detectboxs)):
|
||||
if sort_detectboxs[i].classId == -1:
|
||||
continue
|
||||
p1 = rotate_rectangle(sort_detectboxs[i].xmin, sort_detectboxs[i].ymin,
|
||||
sort_detectboxs[i].xmax, sort_detectboxs[i].ymax,
|
||||
sort_detectboxs[i].angle)
|
||||
predBoxs.append(sort_detectboxs[i])
|
||||
for j in range(i + 1, len(sort_detectboxs)):
|
||||
if sort_detectboxs[j].classId == sort_detectboxs[i].classId:
|
||||
p2 = rotate_rectangle(sort_detectboxs[j].xmin, sort_detectboxs[j].ymin,
|
||||
sort_detectboxs[j].xmax, sort_detectboxs[j].ymax,
|
||||
sort_detectboxs[j].angle)
|
||||
if intersection(p1, p2) > nmsThresh:
|
||||
sort_detectboxs[j].classId = -1
|
||||
return predBoxs
|
||||
|
||||
def sigmoid(x):
|
||||
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
|
||||
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
|
||||
|
||||
def process(out, model_w, model_h, stride, angle_feature, index, scale_w=1, scale_h=1):
|
||||
class_num = len(CLASSES)
|
||||
angle_feature = angle_feature.reshape(-1)
|
||||
xywh = out[:, :64, :]
|
||||
conf = sigmoid(out[:, 64:, :]).reshape(-1)
|
||||
boxes = []
|
||||
for ik in range(model_h * model_w * class_num):
|
||||
if conf[ik] > objectThresh:
|
||||
w = ik % model_w
|
||||
h = (ik % (model_w * model_h)) // model_w
|
||||
c = ik // (model_w * model_h)
|
||||
# 解析xywh
|
||||
xywh_ = xywh[0, :, (h * model_w) + w].reshape(1, 4, 16, 1)
|
||||
data = np.arange(16).reshape(1, 1, 16, 1)
|
||||
xywh_ = softmax(xywh_, 2)
|
||||
xywh_ = np.sum(xywh_ * data, axis=2).reshape(-1)
|
||||
xywh_add = xywh_[:2] + xywh_[2:]
|
||||
xywh_sub = (xywh_[2:] - xywh_[:2]) / 2
|
||||
# 安全取角度
|
||||
angle_idx = min(index + (h * model_w) + w, len(angle_feature) - 1)
|
||||
angle = (angle_feature[angle_idx] - 0.25) * math.pi
|
||||
cos_a, sin_a = math.cos(angle), math.sin(angle)
|
||||
xy = xywh_sub[0] * cos_a - xywh_sub[1] * sin_a, xywh_sub[0] * sin_a + xywh_sub[1] * cos_a
|
||||
xywh1 = np.array([xy[0] + w + 0.5, xy[1] + h + 0.5, xywh_add[0], xywh_add[1]])
|
||||
xywh1 *= stride
|
||||
xmin = (xywh1[0] - xywh1[2]/2) * scale_w
|
||||
ymin = (xywh1[1] - xywh1[3]/2) * scale_h
|
||||
xmax = (xywh1[0] + xywh1[2]/2) * scale_w
|
||||
ymax = (xywh1[1] + xywh1[3]/2) * scale_h
|
||||
boxes.append(DetectBox(c, conf[ik], xmin, ymin, xmax, ymax, angle))
|
||||
return boxes
|
||||
|
||||
# ------------------- 主函数 -------------------
|
||||
def detect_boxes_angle_rknn(model_path, image_path, save_path=None):
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"❌ 无法读取图像: {image_path}")
|
||||
return None, None
|
||||
|
||||
img_resized, scale, offset_x, offset_y = letterbox_resize(img, (640, 640))
|
||||
infer_img = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
|
||||
infer_img = np.expand_dims(infer_img, 0)
|
||||
|
||||
rknn_lite = RKNNLite(verbose=False)
|
||||
rknn_lite.load_rknn(model_path)
|
||||
rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||
|
||||
results = rknn_lite.inference([infer_img])
|
||||
detect_boxes = []
|
||||
for x in results[:-1]:
|
||||
index, stride = 0, 0
|
||||
if x.shape[2] == 20:
|
||||
stride, index = 32, 20*4*20*4 + 20*2*20*2
|
||||
elif x.shape[2] == 40:
|
||||
stride, index = 16, 20*4*20*4
|
||||
elif x.shape[2] == 80:
|
||||
stride, index = 8, 0
|
||||
feature = x.reshape(1, 65, -1)
|
||||
detect_boxes += process(feature, x.shape[3], x.shape[2], stride, results[-1], index)
|
||||
|
||||
detect_boxes = NMS(detect_boxes)
|
||||
|
||||
# 输出每个检测框角度
|
||||
for i, box in enumerate(detect_boxes):
|
||||
print(f"框 {i+1}: angle = {box.angle:.4f} rad ({np.degrees(box.angle):.2f}°)")
|
||||
if save_path:
|
||||
xmin = int((box.xmin - offset_x)/scale)
|
||||
ymin = int((box.ymin - offset_y)/scale)
|
||||
xmax = int((box.xmax - offset_x)/scale)
|
||||
ymax = int((box.ymax - offset_y)/scale)
|
||||
points = rotate_rectangle(xmin, ymin, xmax, ymax, box.angle)
|
||||
cv2.polylines(img, [np.array(points, np.int32)], True, (0, 255, 0), 1)
|
||||
cv2.putText(img, f"{np.degrees(box.angle):.1f}°", (xmin, ymin-5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 1)
|
||||
|
||||
if save_path:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
cv2.imwrite(save_path, img)
|
||||
print(f"✅ 带角度的检测结果已保存到 {save_path}")
|
||||
|
||||
rknn_lite.release()
|
||||
return detect_boxes, img
|
||||
|
||||
# ------------------- 使用示例 -------------------
|
||||
if __name__ == "__main__":
|
||||
model_path = "obb.rknn"
|
||||
image_path = "2.jpg"
|
||||
save_path = "./inference_results/boxes_with_angle.jpg"
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
detect_boxes_angle_rknn(model_path, image_path, save_path)
|
||||
|
||||
197
ailai_obb/bushu_angle.py
Normal file
197
ailai_obb/bushu_angle.py
Normal file
@ -0,0 +1,197 @@
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from shapely.geometry import Polygon
|
||||
from rknnlite.api import RKNNLite
|
||||
import os
|
||||
|
||||
# ------------------- 配置 -------------------
|
||||
CLASSES = ['clamp']
|
||||
nmsThresh = 0.4
|
||||
objectThresh = 0.5
|
||||
|
||||
# ------------------- 全局原图尺寸 -------------------
|
||||
ORIG_W = 2560 # 原图宽
|
||||
ORIG_H = 1440 # 原图高
|
||||
|
||||
# ------------------- 工具函数 -------------------
|
||||
def letterbox_resize(image, size, bg_color=114):
|
||||
target_width, target_height = size
|
||||
image_height, image_width, _ = image.shape
|
||||
scale = min(target_width / image_width, target_height / image_height)
|
||||
new_width, new_height = int(image_width * scale), int(image_height * scale)
|
||||
image_resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||
canvas = np.ones((target_height, target_width, 3), dtype=np.uint8) * bg_color
|
||||
offset_x, offset_y = (target_width - new_width) // 2, (target_height - new_height) // 2
|
||||
canvas[offset_y:offset_y + new_height, offset_x:offset_x + new_width] = image_resized
|
||||
return canvas, scale, offset_x, offset_y
|
||||
|
||||
class DetectBox:
|
||||
def __init__(self, classId, score, xmin, ymin, xmax, ymax, angle):
|
||||
self.classId = classId
|
||||
self.score = score
|
||||
self.xmin = xmin
|
||||
self.ymin = ymin
|
||||
self.xmax = xmax
|
||||
self.ymax = ymax
|
||||
self.angle = angle
|
||||
|
||||
def rotate_rectangle(x1, y1, x2, y2, a):
|
||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
x1_new = int((x1 - cx) * math.cos(a) - (y1 - cy) * math.sin(a) + cx)
|
||||
y1_new = int((x1 - cx) * math.sin(a) + (y1 - cy) * math.cos(a) + cy)
|
||||
x2_new = int((x2 - cx) * math.cos(a) - (y2 - cy) * math.sin(a) + cx)
|
||||
y2_new = int((x2 - cx) * math.sin(a) + (y2 - cy) * math.cos(a) + cy)
|
||||
x3_new = int((x1 - cx) * math.cos(a) - (y2 - cy) * math.sin(a) + cx)
|
||||
y3_new = int((x1 - cx) * math.sin(a) + (y2 - cy) * math.cos(a) + cy)
|
||||
x4_new = int((x2 - cx) * math.cos(a) - (y1 - cy) * math.sin(a) + cx)
|
||||
y4_new = int((x2 - cx) * math.sin(a) + (y1 - cy) * math.cos(a) + cy)
|
||||
return [(x1_new, y1_new), (x3_new, y3_new), (x2_new, y2_new), (x4_new, y4_new)]
|
||||
|
||||
def intersection(g, p):
|
||||
g = Polygon(np.array(g).reshape(-1,2))
|
||||
p = Polygon(np.array(p).reshape(-1,2))
|
||||
if not g.is_valid or not p.is_valid:
|
||||
return 0
|
||||
inter = g.intersection(p).area
|
||||
union = g.area + p.area - inter
|
||||
return 0 if union == 0 else inter / union
|
||||
|
||||
def NMS(detectResult):
|
||||
predBoxs = []
|
||||
sort_detectboxs = sorted(detectResult, key=lambda x: x.score, reverse=True)
|
||||
for i in range(len(sort_detectboxs)):
|
||||
if sort_detectboxs[i].classId == -1:
|
||||
continue
|
||||
p1 = rotate_rectangle(sort_detectboxs[i].xmin, sort_detectboxs[i].ymin,
|
||||
sort_detectboxs[i].xmax, sort_detectboxs[i].ymax,
|
||||
sort_detectboxs[i].angle)
|
||||
predBoxs.append(sort_detectboxs[i])
|
||||
for j in range(i + 1, len(sort_detectboxs)):
|
||||
if sort_detectboxs[j].classId == sort_detectboxs[i].classId:
|
||||
p2 = rotate_rectangle(sort_detectboxs[j].xmin, sort_detectboxs[j].ymin,
|
||||
sort_detectboxs[j].xmax, sort_detectboxs[j].ymax,
|
||||
sort_detectboxs[j].angle)
|
||||
if intersection(p1, p2) > nmsThresh:
|
||||
sort_detectboxs[j].classId = -1
|
||||
return predBoxs
|
||||
|
||||
def sigmoid(x):
|
||||
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
|
||||
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
|
||||
|
||||
# ------------------- 关键修改:process函数加入scale -------------------
|
||||
def process(out, model_w, model_h, stride, angle_feature, index, scale=1.0, offset_x=0, offset_y=0):
|
||||
class_num = len(CLASSES)
|
||||
angle_feature = angle_feature.reshape(-1)
|
||||
xywh = out[:, :64, :]
|
||||
conf = sigmoid(out[:, 64:, :]).reshape(-1)
|
||||
boxes = []
|
||||
for ik in range(model_h * model_w * class_num):
|
||||
if conf[ik] > objectThresh:
|
||||
w = ik % model_w
|
||||
h = (ik % (model_w * model_h)) // model_w
|
||||
c = ik // (model_w * model_h)
|
||||
# 解析xywh
|
||||
xywh_ = xywh[0, :, (h * model_w) + w].reshape(1, 4, 16, 1)
|
||||
xywh_ = softmax(xywh_, 2)
|
||||
data = np.arange(16).reshape(1, 1, 16, 1)
|
||||
xywh_ = np.sum(xywh_ * data, axis=2).reshape(-1)
|
||||
xywh_add = xywh_[:2] + xywh_[2:]
|
||||
xywh_sub = (xywh_[2:] - xywh_[:2]) / 2
|
||||
# 取角度
|
||||
angle_idx = min(index + (h * model_w) + w, len(angle_feature) - 1)
|
||||
angle = (angle_feature[angle_idx] - 0.25) * math.pi
|
||||
cos_a, sin_a = math.cos(angle), math.sin(angle)
|
||||
xy = xywh_sub[0] * cos_a - xywh_sub[1] * sin_a, xywh_sub[0] * sin_a + xywh_sub[1] * cos_a
|
||||
xywh1 = np.array([xy[0] + w + 0.5, xy[1] + h + 0.5, xywh_add[0], xywh_add[1]])
|
||||
xywh1 *= stride
|
||||
# 映射回原图坐标
|
||||
xmin = (xywh1[0] - xywh1[2]/2 - offset_x) / scale
|
||||
ymin = (xywh1[1] - xywh1[3]/2 - offset_y) / scale
|
||||
xmax = (xywh1[0] + xywh1[2]/2 - offset_x) / scale
|
||||
ymax = (xywh1[1] + xywh1[3]/2 - offset_y) / scale
|
||||
boxes.append(DetectBox(c, conf[ik], xmin, ymin, xmax, ymax, angle))
|
||||
return boxes
|
||||
|
||||
# ------------------- 新可调用函数 -------------------
|
||||
def detect_boxes_rknn(model_path, image_path):
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"❌ 无法读取图像: {image_path}")
|
||||
return None, None
|
||||
|
||||
img_resized, scale, offset_x, offset_y = letterbox_resize(img, (640, 640))
|
||||
infer_img = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
|
||||
infer_img = np.expand_dims(infer_img, 0)
|
||||
|
||||
rknn_lite = RKNNLite(verbose=False)
|
||||
rknn_lite.load_rknn(model_path)
|
||||
rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||
|
||||
results = rknn_lite.inference([infer_img])
|
||||
detect_boxes = []
|
||||
for x in results[:-1]:
|
||||
index, stride = 0, 0
|
||||
if x.shape[2] == 20:
|
||||
stride, index = 32, 20*4*20*4 + 20*2*20*2
|
||||
elif x.shape[2] == 40:
|
||||
stride, index = 16, 20*4*20*4
|
||||
elif x.shape[2] == 80:
|
||||
stride, index = 8, 0
|
||||
feature = x.reshape(1, 65, -1)
|
||||
detect_boxes += process(feature, x.shape[3], x.shape[2], stride, results[-1], index,
|
||||
scale=scale, offset_x=offset_x, offset_y=offset_y)
|
||||
|
||||
detect_boxes = NMS(detect_boxes)
|
||||
rknn_lite.release()
|
||||
return detect_boxes, img
|
||||
|
||||
# ------------------- 绘制与辅助函数 -------------------
|
||||
def get_angles(detect_boxes):
|
||||
return [box.angle for box in detect_boxes]
|
||||
|
||||
def draw_boxes(img, detect_boxes, save_path=None):
|
||||
for box in detect_boxes:
|
||||
points = rotate_rectangle(box.xmin, box.ymin, box.xmax, box.ymax, box.angle)
|
||||
cv2.polylines(img, [np.array(points, np.int32)], True, (0, 255, 0), 1)
|
||||
cv2.putText(img, f"{np.degrees(box.angle):.1f}°", (int(box.xmin), int(box.ymin)-5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 1)
|
||||
if save_path:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
cv2.imwrite(save_path, img)
|
||||
print(f"✅ 带角度的检测结果已保存到 {save_path}")
|
||||
return img
|
||||
|
||||
def visualize_top_box(img, detect_boxes, save_path=None):
|
||||
if not detect_boxes:
|
||||
return img
|
||||
top_box = max(detect_boxes, key=lambda x: x.score)
|
||||
points = rotate_rectangle(top_box.xmin, top_box.ymin, top_box.xmax, top_box.ymax, top_box.angle)
|
||||
cv2.polylines(img, [np.array(points, np.int32)], True, (0, 255, 0), 2)
|
||||
cv2.putText(img, f"{np.degrees(top_box.angle):.1f}°", (int(top_box.xmin), int(top_box.ymin)-5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,255), 2)
|
||||
if save_path:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
cv2.imwrite(save_path, img)
|
||||
return img
|
||||
|
||||
# ------------------- 使用示例 -------------------
|
||||
if __name__ == "__main__":
|
||||
model_path = "obb.rknn"
|
||||
image_path = "2.jpg"
|
||||
|
||||
detect_boxes, img = detect_boxes_rknn(model_path, image_path)
|
||||
angles = get_angles(detect_boxes)
|
||||
for i, angle in enumerate(angles):
|
||||
print(f"框 {i+1}: angle = {angle:.4f} rad ({np.degrees(angle):.2f}°)")
|
||||
|
||||
save_path_all = "./inference_results/boxes_all.jpg"
|
||||
draw_boxes(img.copy(), detect_boxes, save_path_all)
|
||||
|
||||
save_path_top = "./inference_results/top_box.jpg"
|
||||
visualize_top_box(img.copy(), detect_boxes, save_path_top)
|
||||
@ -1,85 +0,0 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# 1. 加载模型
|
||||
model = YOLO(r'/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb2/weights/best.pt')
|
||||
|
||||
# 2. 读取图像
|
||||
img_path = r"/home/hx/yolo/output_masks/2.jpg"
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
if img is None:
|
||||
print(f"❌ 错误:无法读取图像!请检查路径:{img_path}")
|
||||
exit(1)
|
||||
|
||||
# 3. 预测(OBB 模式)
|
||||
results = model(
|
||||
img,
|
||||
save=False,
|
||||
imgsz=640,
|
||||
conf=0.15,
|
||||
mode='obb'
|
||||
)
|
||||
|
||||
# 4. 获取结果并绘制
|
||||
result = results[0]
|
||||
annotated_img = result.plot()
|
||||
|
||||
# 5. 保存结果
|
||||
output_dir = "./inference_results"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
filename = os.path.basename(img_path)
|
||||
save_path = os.path.join(output_dir, "detected_" + filename)
|
||||
cv2.imwrite(save_path, annotated_img)
|
||||
print(f"✅ 推理结果已保存至: {save_path}")
|
||||
|
||||
# 6. 提取旋转框并计算 **两个框之间的夹角**
|
||||
boxes = result.obb
|
||||
if boxes is None or len(boxes) == 0:
|
||||
print("❌ No objects detected.")
|
||||
else:
|
||||
print(f"✅ Detected {len(boxes)} object(s):")
|
||||
directions = [] # 存储每个框的主方向(弧度),归一化到 [0, π)
|
||||
|
||||
for i, box in enumerate(boxes):
|
||||
cls = int(box.cls.cpu().numpy()[0])
|
||||
conf = box.conf.cpu().numpy()[0]
|
||||
xywhr = box.xywhr.cpu().numpy()[0] # [cx, cy, w, h, r]
|
||||
cx, cy, w, h, r_rad = xywhr
|
||||
|
||||
# 确定主方向(长边方向)
|
||||
if w >= h:
|
||||
direction = r_rad # 长边方向就是 r
|
||||
else:
|
||||
direction = r_rad + np.pi / 2 # 长边方向是 r + 90°
|
||||
|
||||
# 归一化到 [0, π)
|
||||
direction = direction % np.pi
|
||||
|
||||
directions.append(direction)
|
||||
angle_deg = np.degrees(direction)
|
||||
print(f" Box {i+1}: Class: {cls}, Confidence: {conf:.3f}, 主方向: {angle_deg:.2f}°")
|
||||
|
||||
# ✅ 计算任意两个框之间的夹角(最小夹角,0° ~ 90°)
|
||||
if len(directions) >= 2:
|
||||
print("\n🔍 计算两个旋转框之间的夹角(主方向夹角):")
|
||||
for i in range(len(directions)):
|
||||
for j in range(i + 1, len(directions)):
|
||||
dir1 = directions[i]
|
||||
dir2 = directions[j]
|
||||
|
||||
# 计算方向差(取最小夹角,考虑周期性)
|
||||
diff = abs(dir1 - dir2)
|
||||
diff = min(diff, np.pi - diff) # 最小夹角(0 ~ π/2)
|
||||
diff_deg = np.degrees(diff)
|
||||
|
||||
print(f" Box {i+1} 与 Box {j+1} 之间的夹角: {diff_deg:.2f}°")
|
||||
else:
|
||||
print("⚠️ 检测到少于两个目标,无法计算夹角。")
|
||||
|
||||
# 7. 显示图像
|
||||
cv2.imshow("YOLO OBB Prediction", annotated_img)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
@ -1,103 +0,0 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# ================== 配置参数 ==================
|
||||
MODEL_PATH = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb4/weights/best.pt"
|
||||
IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb2/test" # 图像文件夹路径
|
||||
OUTPUT_DIR = "./inference_results" # 输出结果保存路径
|
||||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 1. 加载模型
|
||||
print("🔄 加载 YOLO 模型...")
|
||||
model = YOLO(MODEL_PATH)
|
||||
print("✅ 模型加载完成")
|
||||
|
||||
# 获取所有图像文件
|
||||
image_files = [
|
||||
f for f in os.listdir(IMAGE_SOURCE_DIR)
|
||||
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
|
||||
]
|
||||
|
||||
if not image_files:
|
||||
print(f"❌ 错误:在路径中未找到图像文件:{IMAGE_SOURCE_DIR}")
|
||||
exit(1)
|
||||
|
||||
print(f"📁 发现 {len(image_files)} 张图像待处理")
|
||||
|
||||
# ================== 批量处理每张图像 ==================
|
||||
for img_filename in image_files:
|
||||
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
|
||||
print(f"\n🖼️ 正在处理:{img_filename}")
|
||||
|
||||
# 读取图像
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
print(f"❌ 跳过:无法读取图像 {img_path}")
|
||||
continue
|
||||
|
||||
# 推理(OBB 模式)
|
||||
results = model(
|
||||
img,
|
||||
save=False,
|
||||
imgsz=640,
|
||||
conf=0.15,
|
||||
mode='obb'
|
||||
)
|
||||
|
||||
result = results[0]
|
||||
annotated_img = result.plot() # 绘制旋转框
|
||||
|
||||
# 保存结果图像
|
||||
save_path = os.path.join(OUTPUT_DIR, "detected_" + img_filename)
|
||||
cv2.imwrite(save_path, annotated_img)
|
||||
print(f"✅ 推理结果已保存至: {save_path}")
|
||||
|
||||
# 提取旋转框信息
|
||||
boxes = result.obb
|
||||
directions = [] # 存储每个框的主方向(弧度),归一化到 [0, π)
|
||||
|
||||
if boxes is None or len(boxes) == 0:
|
||||
print("❌ 该图像中未检测到任何目标")
|
||||
else:
|
||||
print(f"✅ 检测到 {len(boxes)} 个目标:")
|
||||
for i, box in enumerate(boxes):
|
||||
cls = int(box.cls.cpu().numpy()[0])
|
||||
conf = box.conf.cpu().numpy()[0]
|
||||
xywhr = box.xywhr.cpu().numpy()[0] # [cx, cy, w, h, r]
|
||||
cx, cy, w, h, r_rad = xywhr
|
||||
|
||||
# 确定主方向(长边方向)
|
||||
if w >= h:
|
||||
direction = r_rad # 长边方向
|
||||
else:
|
||||
direction = r_rad + np.pi / 2 # 长边是宽的方向
|
||||
|
||||
# 归一化到 [0, π)
|
||||
direction = direction % np.pi
|
||||
directions.append(direction)
|
||||
|
||||
angle_deg = np.degrees(direction)
|
||||
print(f" Box {i+1}: Class: {cls}, Confidence: {conf:.3f}, 主方向: {angle_deg:.2f}°")
|
||||
|
||||
# 计算两两之间的夹角(最小夹角,0°~90°)
|
||||
if len(directions) >= 2:
|
||||
print("\n🔍 计算各框之间的夹角(主方向最小夹角):")
|
||||
for i in range(len(directions)):
|
||||
for j in range(i + 1, len(directions)):
|
||||
dir1 = directions[i]
|
||||
dir2 = directions[j]
|
||||
|
||||
diff = abs(dir1 - dir2)
|
||||
min_diff_rad = min(diff, np.pi - diff) # 最小夹角(考虑周期性)
|
||||
min_diff_deg = np.degrees(min_diff_rad)
|
||||
|
||||
print(f" Box {i+1} 与 Box {j+1} 之间夹角: {min_diff_deg:.2f}°")
|
||||
else:
|
||||
print("⚠️ 检测到少于两个目标,无法计算夹角。")
|
||||
|
||||
print("\n🎉 所有图像处理完成!")
|
||||
76
angle_base_obb/angle_caculate.py
Normal file
76
angle_base_obb/angle_caculate.py
Normal file
@ -0,0 +1,76 @@
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
|
||||
def predict_obb_best_angle(model_path, image_path, save_path=None):
|
||||
"""
|
||||
输入:
|
||||
model_path: YOLO 权重路径
|
||||
image_path: 图片路径
|
||||
save_path: 可选,保存带标注图像
|
||||
输出:
|
||||
angle_deg: 置信度最高两个框的主方向夹角(度),如果检测少于两个目标返回 None
|
||||
annotated_img: 可视化图像
|
||||
"""
|
||||
# 1. 加载模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 2. 读取图像
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"无法读取图像: {image_path}")
|
||||
return None, None
|
||||
|
||||
# 3. 推理 OBB
|
||||
results = model(img, save=False, imgsz=640, conf=0.5, mode='obb')
|
||||
result = results[0]
|
||||
|
||||
# 4. 可视化
|
||||
annotated_img = result.plot()
|
||||
if save_path:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
cv2.imwrite(save_path, annotated_img)
|
||||
print(f"推理结果已保存至: {save_path}")
|
||||
|
||||
# 5. 提取旋转角度和置信度
|
||||
boxes = result.obb
|
||||
if boxes is None or len(boxes) < 2:
|
||||
print("检测到少于两个目标,无法计算夹角。")
|
||||
return None, annotated_img
|
||||
|
||||
box_info = []
|
||||
for box in boxes:
|
||||
conf = box.conf.cpu().numpy()[0]
|
||||
cx, cy, w, h, r_rad = box.xywhr.cpu().numpy()[0]
|
||||
direction = r_rad if w >= h else r_rad + np.pi/2
|
||||
direction = direction % np.pi
|
||||
box_info.append((conf, direction))
|
||||
|
||||
# 6. 取置信度最高两个框
|
||||
box_info = sorted(box_info, key=lambda x: x[0], reverse=True)
|
||||
dir1, dir2 = box_info[0][1], box_info[1][1]
|
||||
|
||||
# 7. 计算夹角(最小夹角,0~90°)
|
||||
diff = abs(dir1 - dir2)
|
||||
diff = min(diff, np.pi - diff)
|
||||
angle_deg = np.degrees(diff)
|
||||
|
||||
print(f"置信度最高两个框主方向夹角: {angle_deg:.2f}°")
|
||||
return angle_deg, annotated_img
|
||||
|
||||
|
||||
# ------------------- 测试 -------------------
|
||||
if __name__ == "__main__":
|
||||
weight_path = r'best.pt'
|
||||
image_path = r"./test_image/3.jpg"
|
||||
save_path = "./inference_results/detected_3.jpg"
|
||||
|
||||
#angle_deg, annotated_img = predict_obb_best_angle(weight_path, image_path, save_path)
|
||||
angle_deg,_ = predict_obb_best_angle(weight_path, image_path, save_path)
|
||||
annotated_img = None
|
||||
print(angle_deg)
|
||||
if annotated_img is not None:
|
||||
cv2.imshow("YOLO OBB Prediction", annotated_img)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
102
angle_base_obb/angle_caculate_file.py
Normal file
102
angle_base_obb/angle_caculate_file.py
Normal file
@ -0,0 +1,102 @@
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
|
||||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
|
||||
|
||||
|
||||
def process_obb_images(model_path, image_dir, output_dir="./inference_results", conf_thresh=0.15, imgsz=640):
|
||||
"""
|
||||
批量处理图像的 OBB 推理,计算每张图像检测目标的主方向和夹角。
|
||||
|
||||
输入:
|
||||
model_path: YOLO 权重路径
|
||||
image_dir: 图像文件夹路径
|
||||
output_dir: 输出结果保存路径
|
||||
conf_thresh: 置信度阈值
|
||||
imgsz: 输入图像大小
|
||||
输出:
|
||||
results_dict: {image_filename: {'angles_deg': [...], 'pairwise_angles_deg': [...]}}
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_dict = {}
|
||||
|
||||
print("加载 YOLO 模型...")
|
||||
model = YOLO(model_path)
|
||||
print("✅ 模型加载完成")
|
||||
|
||||
# 获取图像文件
|
||||
image_files = [f for f in os.listdir(image_dir) if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS]
|
||||
if not image_files:
|
||||
print(f"❌ 未找到图像文件:{image_dir}")
|
||||
return results_dict
|
||||
|
||||
print(f"发现 {len(image_files)} 张图像待处理")
|
||||
|
||||
for img_filename in image_files:
|
||||
img_path = os.path.join(image_dir, img_filename)
|
||||
print(f"\n正在处理:{img_filename}")
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
print(f"❌ 跳过:无法读取图像 {img_path}")
|
||||
continue
|
||||
|
||||
# 推理 OBB
|
||||
results = model(img, save=False, imgsz=imgsz, conf=conf_thresh, mode='obb')
|
||||
result = results[0]
|
||||
annotated_img = result.plot()
|
||||
|
||||
# 保存可视化
|
||||
save_path = os.path.join(output_dir, "detected_" + img_filename)
|
||||
cv2.imwrite(save_path, annotated_img)
|
||||
print(f"✅ 推理结果已保存至: {save_path}")
|
||||
|
||||
# 提取旋转角
|
||||
boxes = result.obb
|
||||
angles_deg = []
|
||||
if boxes is None or len(boxes) == 0:
|
||||
print("❌ 该图像中未检测到任何目标")
|
||||
else:
|
||||
for i, box in enumerate(boxes):
|
||||
cls = int(box.cls.cpu().numpy()[0])
|
||||
conf = box.conf.cpu().numpy()[0]
|
||||
cx, cy, w, h, r_rad = box.xywhr.cpu().numpy()[0]
|
||||
direction = r_rad if w >= h else r_rad + np.pi / 2
|
||||
direction = direction % np.pi
|
||||
angle_deg = np.degrees(direction)
|
||||
angles_deg.append(angle_deg)
|
||||
print(f" Box {i + 1}: Class={cls}, Conf={conf:.3f}, 主方向={angle_deg:.2f}°")
|
||||
|
||||
# 两两夹角
|
||||
pairwise_angles_deg = []
|
||||
if len(angles_deg) >= 2:
|
||||
for i in range(len(angles_deg)):
|
||||
for j in range(i + 1, len(angles_deg)):
|
||||
diff_rad = abs(np.radians(angles_deg[i]) - np.radians(angles_deg[j]))
|
||||
min_diff_rad = min(diff_rad, np.pi - diff_rad)
|
||||
pairwise_angles_deg.append(np.degrees(min_diff_rad))
|
||||
print(f" Box {i + 1} 与 Box {j + 1} 夹角: {np.degrees(min_diff_rad):.2f}°")
|
||||
|
||||
# 保存每张图像结果
|
||||
results_dict[img_filename] = {
|
||||
"angles_deg": angles_deg,
|
||||
"pairwise_angles_deg": pairwise_angles_deg
|
||||
}
|
||||
|
||||
print("\n所有图像处理完成!")
|
||||
return results_dict
|
||||
|
||||
|
||||
# ------------------- 测试调用 -------------------
|
||||
if __name__ == "__main__":
|
||||
MODEL_PATH = r'best.pt'
|
||||
IMAGE_SOURCE_DIR = r"./test_image"
|
||||
OUTPUT_DIR = "./inference_results"
|
||||
|
||||
results = process_obb_images(MODEL_PATH, IMAGE_SOURCE_DIR, OUTPUT_DIR)
|
||||
for img_name, info in results.items():
|
||||
print(f"\n {img_name}:")
|
||||
print(f"主方向角度列表: {info['angles_deg']}")
|
||||
print(f"两两夹角列表: {info['pairwise_angles_deg']}")
|
||||
BIN
angle_base_obb/best.pt
Normal file
BIN
angle_base_obb/best.pt
Normal file
Binary file not shown.
BIN
angle_base_obb/test_image/1.jpg
Normal file
BIN
angle_base_obb/test_image/1.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 587 KiB |
BIN
angle_base_obb/test_image/2.jpg
Normal file
BIN
angle_base_obb/test_image/2.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 513 KiB |
BIN
angle_base_obb/test_image/3.jpg
Normal file
BIN
angle_base_obb/test_image/3.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.9 MiB |
@ -7,7 +7,7 @@ from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
|
||||
# ====================== 配置参数 ======================
|
||||
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/seg_r/exp/weights/best.pt"
|
||||
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/seg_r/exp2/weights/best.pt"
|
||||
#SOURCE_IMG_DIR = "/home/hx/yolo/output_masks" # 原始输入图像目录
|
||||
SOURCE_IMG_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/f6" # 原始输入图像目录
|
||||
OUTPUT_DIR = "/home/hx/yolo/output_masks2" # 推理输出根目录
|
||||
|
||||
143
zhuangtai_class_cls/remain_tuili.py
Normal file
143
zhuangtai_class_cls/remain_tuili.py
Normal file
@ -0,0 +1,143 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
|
||||
# ---------------------------
|
||||
# 类别映射
|
||||
# ---------------------------
|
||||
CLASS_NAMES = {
|
||||
0: "未堆料",
|
||||
1: "小堆料",
|
||||
2: "大堆料",
|
||||
3: "未浇筑满",
|
||||
4: "浇筑满"
|
||||
}
|
||||
|
||||
# ---------------------------
|
||||
# 加载 ROI 列表
|
||||
# ---------------------------
|
||||
def load_global_rois(txt_path):
|
||||
rois = []
|
||||
if not os.path.exists(txt_path):
|
||||
print(f"❌ ROI 文件不存在: {txt_path}")
|
||||
return rois
|
||||
with open(txt_path, 'r') as f:
|
||||
for line in f:
|
||||
s = line.strip()
|
||||
if s:
|
||||
try:
|
||||
x, y, w, h = map(int, s.split(','))
|
||||
rois.append((x, y, w, h))
|
||||
except Exception as e:
|
||||
print(f"⚠️ 无法解析 ROI 行 '{s}': {e}")
|
||||
return rois
|
||||
|
||||
# ---------------------------
|
||||
# 裁剪并 resize ROI
|
||||
# ---------------------------
|
||||
def crop_and_resize(img, rois, target_size=640):
|
||||
crops = []
|
||||
h_img, w_img = img.shape[:2]
|
||||
for i, (x, y, w, h) in enumerate(rois):
|
||||
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
|
||||
continue
|
||||
roi = img[y:y+h, x:x+w]
|
||||
roi_resized = cv2.resize(roi, (target_size, target_size), interpolation=cv2.INTER_AREA)
|
||||
crops.append((roi_resized, i))
|
||||
return crops
|
||||
|
||||
# ---------------------------
|
||||
# class1/class2 加权判断
|
||||
# ---------------------------
|
||||
def weighted_small_large(pred_probs, threshold=0.4, w1=0.3, w2=0.7):
|
||||
p1 = float(pred_probs[1])
|
||||
p2 = float(pred_probs[2])
|
||||
total = p1 + p2
|
||||
if total > 0:
|
||||
score = (w1 * p1 + w2 * p2) / total
|
||||
else:
|
||||
score = 0.0
|
||||
final_class = "大堆料" if score >= threshold else "小堆料"
|
||||
return final_class, score, p1, p2
|
||||
|
||||
# ---------------------------
|
||||
# 单张图片推理函数
|
||||
# ---------------------------
|
||||
def classify_image_weighted(image, model, threshold=0.5):
|
||||
results = model(image)
|
||||
pred_probs = results[0].probs.data.cpu().numpy().flatten()
|
||||
class_id = int(pred_probs.argmax())
|
||||
confidence = float(pred_probs[class_id])
|
||||
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
|
||||
|
||||
# class1/class2 使用加权得分
|
||||
if class_id in [1, 2]:
|
||||
final_class, score, p1, p2 = weighted_small_large(pred_probs, threshold=threshold)
|
||||
else:
|
||||
final_class = class_name
|
||||
score = confidence
|
||||
p1 = float(pred_probs[1])
|
||||
p2 = float(pred_probs[2])
|
||||
|
||||
return final_class, score, p1, p2
|
||||
|
||||
# ---------------------------
|
||||
# 批量推理主函数
|
||||
# ---------------------------
|
||||
def batch_classify_images(model_path, input_folder, output_root, roi_file, target_size=640, threshold=0.5):
|
||||
# 加载模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 确保输出根目录存在
|
||||
output_root = Path(output_root)
|
||||
output_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 为所有类别创建目录
|
||||
class_dirs = {}
|
||||
for name in CLASS_NAMES.values():
|
||||
d = output_root / name
|
||||
d.mkdir(exist_ok=True)
|
||||
class_dirs[name] = d
|
||||
|
||||
rois = load_global_rois(roi_file)
|
||||
if not rois:
|
||||
print("❌ 没有有效 ROI,退出")
|
||||
return
|
||||
|
||||
# 遍历图片
|
||||
for img_path in Path(input_folder).glob("*.*"):
|
||||
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp', '.tif']:
|
||||
continue
|
||||
try:
|
||||
img = cv2.imread(str(img_path))
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
crops = crop_and_resize(img, rois, target_size)
|
||||
|
||||
for roi_resized, roi_idx in crops:
|
||||
final_class, score, p1, p2 = classify_image_weighted(roi_resized, model, threshold=threshold)
|
||||
|
||||
# 文件名中保存 ROI、类别、加权分数、class1/class2 置信度
|
||||
suffix = f"_roi{roi_idx}_{final_class}_score{score:.2f}_p1{p1:.2f}_p2{p2:.2f}"
|
||||
dst_path = class_dirs[final_class] / f"{img_path.stem}{suffix}{img_path.suffix}"
|
||||
cv2.imwrite(dst_path, roi_resized)
|
||||
print(f"{img_path.name}{suffix} -> {final_class} (score={score:.2f}, p1={p1:.2f}, p2={p2:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 处理失败 {img_path.name}: {e}")
|
||||
|
||||
# ---------------------------
|
||||
# 使用示例
|
||||
# ---------------------------
|
||||
if __name__ == "__main__":
|
||||
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize/exp_cls2/weights/best.pt"
|
||||
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/f6"
|
||||
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classified"
|
||||
roi_file = "./roi_coordinates/1_rois.txt"
|
||||
target_size = 640
|
||||
threshold = 0.4 # 可调节的比例系数
|
||||
|
||||
batch_classify_images(model_path, input_folder, output_root, roi_file, target_size, threshold)
|
||||
Reference in New Issue
Block a user