DepthAI Tutorials
DepthAI API References

ON THIS PAGE

  • Object tracker on video
  • Similar samples:
  • Demo
  • Setup
  • Source code

Object tracker on video

This example shows how to run MobileNetv2SSD on video input frame, and perform object tracking on persons.

Similar samples:

Demo

Setup

Please run the install script to download all required dependencies. Please note that this script must be ran from git context, so you have to download the depthai-python repository first and then run the script
Command Line
1git clone https://github.com/luxonis/depthai-python.git
2cd depthai-python/examples
3python3 install_requirements.py
For additional information, please follow the installation guide.

Source code

Python
C++
Python
GitHub
1#!/usr/bin/env python3
2
3from pathlib import Path
4import cv2
5import depthai as dai
6import numpy as np
7import time
8import argparse
9
10labelMap = ["person", ""]
11
12nnPathDefault = str((Path(__file__).parent / Path('../models/person-detection-retail-0013_openvino_2021.4_7shave.blob')).resolve().absolute())
13videoPathDefault = str((Path(__file__).parent / Path('../models/construction_vest.mp4')).resolve().absolute())
14parser = argparse.ArgumentParser()
15parser.add_argument('-nnPath', help="Path to mobilenet detection network blob", default=nnPathDefault)
16parser.add_argument('-v', '--videoPath', help="Path to video frame", default=videoPathDefault)
17
18args = parser.parse_args()
19
20# Create pipeline
21pipeline = dai.Pipeline()
22
23# Define sources and outputs
24manip = pipeline.create(dai.node.ImageManip)
25objectTracker = pipeline.create(dai.node.ObjectTracker)
26detectionNetwork = pipeline.create(dai.node.MobileNetDetectionNetwork)
27
28manipOut = pipeline.create(dai.node.XLinkOut)
29xinFrame = pipeline.create(dai.node.XLinkIn)
30trackerOut = pipeline.create(dai.node.XLinkOut)
31xlinkOut = pipeline.create(dai.node.XLinkOut)
32nnOut = pipeline.create(dai.node.XLinkOut)
33
34manipOut.setStreamName("manip")
35xinFrame.setStreamName("inFrame")
36xlinkOut.setStreamName("trackerFrame")
37trackerOut.setStreamName("tracklets")
38nnOut.setStreamName("nn")
39
40# Properties
41xinFrame.setMaxDataSize(1920*1080*3)
42
43manip.initialConfig.setResizeThumbnail(544, 320)
44# manip.initialConfig.setResize(384, 384)
45# manip.initialConfig.setKeepAspectRatio(False) #squash the image to not lose FOV
46# The NN model expects BGR input. By default ImageManip output type would be same as input (gray in this case)
47manip.initialConfig.setFrameType(dai.ImgFrame.Type.BGR888p)
48manip.inputImage.setBlocking(True)
49
50# setting node configs
51detectionNetwork.setBlobPath(args.nnPath)
52detectionNetwork.setConfidenceThreshold(0.5)
53detectionNetwork.input.setBlocking(True)
54
55objectTracker.inputTrackerFrame.setBlocking(True)
56objectTracker.inputDetectionFrame.setBlocking(True)
57objectTracker.inputDetections.setBlocking(True)
58objectTracker.setDetectionLabelsToTrack([1])  # track only person
59# possible tracking types: ZERO_TERM_COLOR_HISTOGRAM, ZERO_TERM_IMAGELESS, SHORT_TERM_IMAGELESS, SHORT_TERM_KCF
60objectTracker.setTrackerType(dai.TrackerType.ZERO_TERM_COLOR_HISTOGRAM)
61# take the smallest ID when new object is tracked, possible options: SMALLEST_ID, UNIQUE_ID
62objectTracker.setTrackerIdAssignmentPolicy(dai.TrackerIdAssignmentPolicy.SMALLEST_ID)
63
64# Linking
65manip.out.link(manipOut.input)
66manip.out.link(detectionNetwork.input)
67xinFrame.out.link(manip.inputImage)
68xinFrame.out.link(objectTracker.inputTrackerFrame)
69detectionNetwork.out.link(nnOut.input)
70detectionNetwork.out.link(objectTracker.inputDetections)
71detectionNetwork.passthrough.link(objectTracker.inputDetectionFrame)
72objectTracker.out.link(trackerOut.input)
73objectTracker.passthroughTrackerFrame.link(xlinkOut.input)
74
75# Connect and start the pipeline
76with dai.Device(pipeline) as device:
77
78    qIn = device.getInputQueue(name="inFrame")
79    trackerFrameQ = device.getOutputQueue(name="trackerFrame", maxSize=4)
80    tracklets = device.getOutputQueue(name="tracklets", maxSize=4)
81    qManip = device.getOutputQueue(name="manip", maxSize=4)
82    qDet = device.getOutputQueue(name="nn", maxSize=4)
83
84    startTime = time.monotonic()
85    counter = 0
86    fps = 0
87    detections = []
88    frame = None
89
90    def to_planar(arr: np.ndarray, shape: tuple) -> np.ndarray:
91        return cv2.resize(arr, shape).transpose(2, 0, 1).flatten()
92
93    # nn data, being the bounding box locations, are in <0..1> range - they need to be normalized with frame width/height
94    def frameNorm(frame, bbox):
95        normVals = np.full(len(bbox), frame.shape[0])
96        normVals[::2] = frame.shape[1]
97        return (np.clip(np.array(bbox), 0, 1) * normVals).astype(int)
98
99    def displayFrame(name, frame):
100        for detection in detections:
101            bbox = frameNorm(frame, (detection.xmin, detection.ymin, detection.xmax, detection.ymax))
102            cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2)
103            cv2.putText(frame, labelMap[detection.label], (bbox[0] + 10, bbox[1] + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
104            cv2.putText(frame, f"{int(detection.confidence * 100)}%", (bbox[0] + 10, bbox[1] + 40), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
105        cv2.imshow(name, frame)
106
107    cap = cv2.VideoCapture(args.videoPath)
108    baseTs = time.monotonic()
109    simulatedFps = 30
110    inputFrameShape = (1920, 1080)
111
112    while cap.isOpened():
113        read_correctly, frame = cap.read()
114        if not read_correctly:
115            break
116
117        img = dai.ImgFrame()
118        img.setType(dai.ImgFrame.Type.BGR888p)
119        img.setData(to_planar(frame, inputFrameShape))
120        img.setTimestamp(baseTs)
121        baseTs += 1/simulatedFps
122
123        img.setWidth(inputFrameShape[0])
124        img.setHeight(inputFrameShape[1])
125        qIn.send(img)
126
127        trackFrame = trackerFrameQ.tryGet()
128        if trackFrame is None:
129            continue
130
131        track = tracklets.get()
132        manip = qManip.get()
133        inDet = qDet.get()
134
135        counter+=1
136        current_time = time.monotonic()
137        if (current_time - startTime) > 1 :
138            fps = counter / (current_time - startTime)
139            counter = 0
140            startTime = current_time
141
142        detections = inDet.detections
143        manipFrame = manip.getCvFrame()
144        displayFrame("nn", manipFrame)
145
146        color = (255, 0, 0)
147        trackerFrame = trackFrame.getCvFrame()
148        trackletsData = track.tracklets
149        for t in trackletsData:
150            roi = t.roi.denormalize(trackerFrame.shape[1], trackerFrame.shape[0])
151            x1 = int(roi.topLeft().x)
152            y1 = int(roi.topLeft().y)
153            x2 = int(roi.bottomRight().x)
154            y2 = int(roi.bottomRight().y)
155
156            try:
157                label = labelMap[t.label]
158            except:
159                label = t.label
160
161            cv2.putText(trackerFrame, str(label), (x1 + 10, y1 + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
162            cv2.putText(trackerFrame, f"ID: {[t.id]}", (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
163            cv2.putText(trackerFrame, t.status.name, (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
164            cv2.rectangle(trackerFrame, (x1, y1), (x2, y2), color, cv2.FONT_HERSHEY_SIMPLEX)
165
166        cv2.putText(trackerFrame, "Fps: {:.2f}".format(fps), (2, trackerFrame.shape[0] - 4), cv2.FONT_HERSHEY_TRIPLEX, 0.4, color)
167
168        cv2.imshow("tracker", trackerFrame)
169
170        if cv2.waitKey(1) == ord('q'):
171            break

Need assistance?

Head over to Discussion Forum for technical support or any other questions you might have.