Spatial Detection Network
The example creates a pipeline to perform YOLOv6-Nano spatial object detection using RGB and stereo depth streams, visualizes results with bounding boxes and spatial coordinates on both colorized depth and RGB frames, and uses a custom visualization node.Setup
This example requires the DepthAI v3 API, see installation instructions.Pipeline
Source code
Python
C++
Python
PythonGitHub
1#!/usr/bin/env python3
2
3from pathlib import Path
4import sys
5import cv2
6import depthai as dai
7import numpy as np
8import time
9
10modelDescription = dai.NNModelDescription("yolov6-nano")
11FPS = 30
12
13class SpatialVisualizer(dai.node.HostNode):
14 def __init__(self):
15 dai.node.HostNode.__init__(self)
16 self.sendProcessingToPipeline(True)
17 def build(self, depth:dai.Node.Output, detections: dai.Node.Output, rgb: dai.Node.Output):
18 self.link_args(depth, detections, rgb) # Must match the inputs to the process method
19
20 def process(self, depthPreview, detections, rgbPreview):
21 depthPreview = depthPreview.getCvFrame()
22 rgbPreview = rgbPreview.getCvFrame()
23 depthFrameColor = self.processDepthFrame(depthPreview)
24 self.displayResults(rgbPreview, depthFrameColor, detections.detections)
25
26 def processDepthFrame(self, depthFrame):
27 depth_downscaled = depthFrame[::4]
28 if np.all(depth_downscaled == 0):
29 min_depth = 0
30 else:
31 min_depth = np.percentile(depth_downscaled[depth_downscaled != 0], 1)
32 max_depth = np.percentile(depth_downscaled, 99)
33 depthFrameColor = np.interp(depthFrame, (min_depth, max_depth), (0, 255)).astype(np.uint8)
34 return cv2.applyColorMap(depthFrameColor, cv2.COLORMAP_HOT)
35
36 def displayResults(self, rgbFrame, depthFrameColor, detections):
37 height, width, _ = rgbFrame.shape
38 for detection in detections:
39 self.drawBoundingBoxes(depthFrameColor, detection)
40 self.drawDetections(rgbFrame, detection, width, height)
41
42 cv2.imshow("depth", depthFrameColor)
43 cv2.imshow("rgb", rgbFrame)
44 if cv2.waitKey(1) == ord('q'):
45 self.stopPipeline()
46
47 def drawBoundingBoxes(self, depthFrameColor, detection):
48 roiData = detection.boundingBoxMapping
49 roi = roiData.roi
50 roi = roi.denormalize(depthFrameColor.shape[1], depthFrameColor.shape[0])
51 topLeft = roi.topLeft()
52 bottomRight = roi.bottomRight()
53 cv2.rectangle(depthFrameColor, (int(topLeft.x), int(topLeft.y)), (int(bottomRight.x), int(bottomRight.y)), (255, 255, 255), 1)
54
55 def drawDetections(self, frame, detection, frameWidth, frameHeight):
56 x1 = int(detection.xmin * frameWidth)
57 x2 = int(detection.xmax * frameWidth)
58 y1 = int(detection.ymin * frameHeight)
59 y2 = int(detection.ymax * frameHeight)
60 try:
61 label = self.labelMap[detection.label] # Ensure labelMap is accessible
62 except IndexError:
63 label = detection.label
64 color = (255, 255, 255)
65 cv2.putText(frame, str(label), (x1 + 10, y1 + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
66 cv2.putText(frame, "{:.2f}".format(detection.confidence * 100), (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
67 cv2.putText(frame, f"X: {int(detection.spatialCoordinates.x)} mm", (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
68 cv2.putText(frame, f"Y: {int(detection.spatialCoordinates.y)} mm", (x1 + 10, y1 + 65), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
69 cv2.putText(frame, f"Z: {int(detection.spatialCoordinates.z)} mm", (x1 + 10, y1 + 80), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
70 cv2.rectangle(frame, (x1, y1), (x2, y2), color, 1)
71
72# Creates the pipeline and a default device implicitly
73with dai.Pipeline() as p:
74 # Define sources and outputs
75 camRgb = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_A)
76 monoLeft = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_B)
77 monoRight = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_C)
78 stereo = p.create(dai.node.StereoDepth)
79 spatialDetectionNetwork = p.create(dai.node.SpatialDetectionNetwork).build(camRgb, stereo, modelDescription, fps=FPS)
80 visualizer = p.create(SpatialVisualizer)
81
82 # setting node configs
83 stereo.setExtendedDisparity(True)
84 platform = p.getDefaultDevice().getPlatform()
85 if platform == dai.Platform.RVC2:
86 # For RVC2, width must be divisible by 16
87 stereo.setOutputSize(640, 400)
88
89 spatialDetectionNetwork.input.setBlocking(False)
90 spatialDetectionNetwork.setBoundingBoxScaleFactor(0.5)
91 spatialDetectionNetwork.setDepthLowerThreshold(100)
92 spatialDetectionNetwork.setDepthUpperThreshold(5000)
93
94 # Linking
95 monoLeft.requestOutput((640, 400)).link(stereo.left)
96 monoRight.requestOutput((640, 400)).link(stereo.right)
97 visualizer.labelMap = spatialDetectionNetwork.getClasses()
98
99 visualizer.build(stereo.depth, spatialDetectionNetwork.out, spatialDetectionNetwork.passthrough)
100
101 p.run()
Need assistance?
Head over to Discussion Forum for technical support or any other questions you might have.