Non-Maximum Suppression (NMS) Example
A pure-NumPy implementation of single-class and multi-class NMS, demonstrating the greedy suppression algorithm used in object detection pipelines.Copy
import numpy as np
def compute_iou(box, boxes):
"""
box: (4,) array -> [x1, y1, x2, y2]
boxes: (N,4) array
Returns:
iou: (N,) array
"""
x1 = np.maximum(box[0], boxes[:, 0])
y1 = np.maximum(box[1], boxes[:, 1])
x2 = np.minimum(box[2], boxes[:, 2])
y2 = np.minimum(box[3], boxes[:, 3])
inter_w = np.maximum(0.0, x2 - x1)
inter_h = np.maximum(0.0, y2 - y1)
intersection = inter_w * inter_h
box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union = box_area + boxes_area - intersection
return intersection / (union + 1e-8)
def nms(boxes, scores, iou_threshold):
"""
boxes: (N,4) array
scores: (N,) array
iou_threshold: float
Returns:
keep_indices: list of selected indices
"""
indices = np.argsort(scores)[::-1] # descending order
keep = []
while len(indices) > 0:
i = indices[0] # highest score
keep.append(i)
if len(indices) == 1:
break
remaining = indices[1:]
ious = compute_iou(boxes[i], boxes[remaining])
# Keep boxes whose IoU <= threshold
indices = remaining[ious <= iou_threshold]
return keep
def multiclass_nms(boxes, scores, labels, iou_threshold, score_threshold=0.0):
"""
boxes: (N,4)
scores: (N,)
labels: (N,) integer class IDs
"""
keep = []
unique_classes = np.unique(labels)
for c in unique_classes:
class_mask = (labels == c) & (scores > score_threshold)
class_boxes = boxes[class_mask]
class_scores = scores[class_mask]
class_indices = np.where(class_mask)[0]
selected = nms(class_boxes, class_scores, iou_threshold)
keep.extend(class_indices[selected])
return keep
Copy
# Example usage
boxes = np.array([[10, 20, 40, 60], [12, 22, 42, 62], [100, 200, 150, 250]], dtype=np.float32)
scores = np.array([0.9, 0.8, 0.85], dtype=np.float32)
labels = np.array([1, 1, 2], dtype=np.int64)
keep = multiclass_nms(boxes, scores, labels, iou_threshold=0.5, score_threshold=0.0)
print("Kept indices:", keep)
print("Filtered boxes:\n", boxes[keep])
print("Filtered scores:", scores[keep])
print("Filtered labels:", labels[keep])
Copy
Kept indices: [np.int64(0), np.int64(2)]
Filtered boxes:
[[ 10. 20. 40. 60.]
[100. 200. 150. 250.]]
Filtered scores: [0.9 0.85]
Filtered labels: [1 2]

