DepthAI
  • DepthAI Components
    • AprilTags
    • Benchmark
    • Camera
    • Calibration
    • DetectionNetwork
    • Events
    • FeatureTracker
    • Gate
    • HostNodes
    • ImageAlign
    • ImageManip
    • IMU
    • Misc
    • Model Zoo
    • NeuralDepth
    • NeuralNetwork
    • ObjectTracker
    • RecordReplay
    • RGBD
    • Script
    • SpatialDetectionNetwork
    • SpatialLocationCalculator
    • StereoDepth
    • Sync
    • VideoEncoder
    • Visualizer
    • Warp
    • RVC2-specific
  • Advanced Tutorials
  • API Reference
  • Tools
Software Stack

ON THIS PAGE

  • Spatial Detection Network
  • Pipeline
  • Source code

Spatial Detection Network

Supported on:RVC2RVC4
The example creates a pipeline to perform YOLOv6-Nano spatial object detection using RGB and stereo depth streams, visualizes results with bounding boxes and spatial coordinates on both colorized depth and RGB frames, and uses a custom visualization node.This example requires the DepthAI v3 API, see installation instructions.

Pipeline

Source code

Python
C++

Python

Python
GitHub
1#!/usr/bin/env python3
2
3import argparse
4from pathlib import Path
5import cv2
6import depthai as dai
7import numpy as np
8
9NEURAL_FPS = 8
10STEREO_DEFAULT_FPS = 20
11
12parser = argparse.ArgumentParser()
13parser.add_argument(
14    "--depthSource", type=str, default="stereo", choices=["stereo", "neural"]
15)
16args = parser.parse_args()
17# For better results on OAK4, use a segmentation model like "luxonis/yolov8-instance-segmentation-large:coco-640x480"
18# for depth estimation over the objects mask instead of the full bounding box.
19modelDescription = dai.NNModelDescription("yolov6-nano")
20size = (640, 400)
21
22if args.depthSource == "stereo":
23    fps = STEREO_DEFAULT_FPS
24else:
25    fps = NEURAL_FPS
26
27class SpatialVisualizer(dai.node.HostNode):
28    def __init__(self):
29        dai.node.HostNode.__init__(self)
30        self.sendProcessingToPipeline(True)
31    def build(self, depth:dai.Node.Output, detections: dai.Node.Output, rgb: dai.Node.Output):
32        self.link_args(depth, detections, rgb) # Must match the inputs to the process method
33
34    def process(self, depthPreview, detections, rgbPreview):
35        depthPreview = depthPreview.getCvFrame()
36        rgbPreview = rgbPreview.getCvFrame()
37        depthFrameColor = self.processDepthFrame(depthPreview)
38        self.displayResults(rgbPreview, depthFrameColor, detections.detections)
39
40    def processDepthFrame(self, depthFrame):
41        depthDownscaled = depthFrame[::4]
42        if np.all(depthDownscaled == 0):
43            minDepth = 0
44        else:
45            minDepth = np.percentile(depthDownscaled[depthDownscaled != 0], 1)
46        maxDepth = np.percentile(depthDownscaled, 99)
47        depthFrameColor = np.interp(depthFrame, (minDepth, maxDepth), (0, 255)).astype(np.uint8)
48        return cv2.applyColorMap(depthFrameColor, cv2.COLORMAP_HOT)
49
50    def displayResults(self, rgbFrame, depthFrameColor, detections):
51        height, width, _ = rgbFrame.shape
52        for detection in detections:
53            self.drawBoundingBoxes(depthFrameColor, detection)
54            self.drawDetections(rgbFrame, detection, width, height)
55
56        cv2.imshow("Depth frame", depthFrameColor)
57        cv2.imshow("Color frame", rgbFrame)
58        if cv2.waitKey(1) == ord('q'):
59            self.stopPipeline()
60
61    def drawBoundingBoxes(self, depthFrameColor, detection):
62        roiData = detection.boundingBoxMapping
63        roi = roiData.roi
64        roi = roi.denormalize(depthFrameColor.shape[1], depthFrameColor.shape[0])
65        topLeft = roi.topLeft()
66        bottomRight = roi.bottomRight()
67        cv2.rectangle(depthFrameColor, (int(topLeft.x), int(topLeft.y)), (int(bottomRight.x), int(bottomRight.y)), (255, 255, 255), 1)
68
69    def drawDetections(self, frame, detection, frameWidth, frameHeight):
70        x1 = int(detection.xmin * frameWidth)
71        x2 = int(detection.xmax * frameWidth)
72        y1 = int(detection.ymin * frameHeight)
73        y2 = int(detection.ymax * frameHeight)
74        label = detection.labelName
75        color = (255, 255, 255)
76        cv2.putText(frame, str(label), (x1 + 10, y1 + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
77        cv2.putText(frame, "{:.2f}".format(detection.confidence * 100), (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
78        cv2.putText(frame, f"X: {int(detection.spatialCoordinates.x)} mm", (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
79        cv2.putText(frame, f"Y: {int(detection.spatialCoordinates.y)} mm", (x1 + 10, y1 + 65), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
80        cv2.putText(frame, f"Z: {int(detection.spatialCoordinates.z)} mm", (x1 + 10, y1 + 80), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
81        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 1)
82
83# Creates the pipeline and a default device implicitly
84with dai.Pipeline() as p:
85    # Define sources and outputs
86    platform = p.getDefaultDevice().getPlatform()
87
88    camRgb = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_A, sensorFps=fps)
89    monoLeft = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_B, sensorFps=fps)
90    monoRight = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_C, sensorFps=fps)
91    if args.depthSource == "stereo":
92        depthSource = p.create(dai.node.StereoDepth)
93        depthSource.setExtendedDisparity(True)
94        monoLeft.requestOutput(size).link(depthSource.left)
95        monoRight.requestOutput(size).link(depthSource.right)
96    elif args.depthSource == "neural":
97        depthSource = p.create(dai.node.NeuralDepth).build(
98            monoLeft.requestFullResolutionOutput(),
99            monoRight.requestFullResolutionOutput(),
100            dai.DeviceModelZoo.NEURAL_DEPTH_LARGE,
101        )
102    else:
103        raise ValueError(f"Invalid depth source: {args.depthSource}")
104
105    spatialDetectionNetwork = p.create(dai.node.SpatialDetectionNetwork).build(
106        camRgb, depthSource, modelDescription
107    )
108    visualizer = p.create(SpatialVisualizer)
109
110    spatialDetectionNetwork.spatialLocationCalculator.initialConfig.setSegmentationPassthrough(False)
111    spatialDetectionNetwork.input.setBlocking(False)
112    spatialDetectionNetwork.setDepthLowerThreshold(100)
113    spatialDetectionNetwork.setDepthUpperThreshold(5000)
114
115    visualizer.build(
116        spatialDetectionNetwork.passthroughDepth,
117        spatialDetectionNetwork.out,
118        spatialDetectionNetwork.passthrough,
119    )
120
121    print("Starting pipeline with depth source: ", args.depthSource)
122
123    p.run()

Need assistance?

Head over to Discussion Forum for technical support or any other questions you might have.