import torch
import numpy as np
import mmcv, cv2
from PIL import Image, ImageDraw
from facenet_pytorch import MTCNN
from IPython import display
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
mtcnn = MTCNN(keep_all=True, device=device)
# draw the bounding boxes for face detection
def draw_bbox(bounding_boxes, image):
faces = 0
if bounding_boxes is None:
return image, faces
else:
faces = len(bounding_boxes)
for i in range(len(bounding_boxes)):
x1, y1, x2, y2 = bounding_boxes[i]
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)),
(0, 0, 255), 2)
return image, faces
def image_tracker(image_path, data):
info = {}
image = Image.open(image_path).convert('RGB')
size = image.size
image_array = np.array(image, dtype=np.float32)
image_array = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
bounding_boxes, conf = mtcnn.detect(image)
image_array, faces = draw_bbox(bounding_boxes, image_array)
if data.height > 0 and data.width > 0:
image_array = cv2.resize(image_array, (data.width, data.height), interpolation=cv2.INTER_AREA)
cv2.imwrite(image_path, image_array)
info['resolution'] = size
info['faces_detected'] = faces
return info
def video_tracker(video_path, data):
video = mmcv.VideoReader(video_path)
#print(video.width, video.height, video.resolution, video.fps)
info = {
"resolution": video.resolution,
"frame_rate": video.fps
}
frame_info = {}
frames = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in video]
frames_tracked = []
for i, frame in enumerate(frames):
print('\rTracking frame: {}'.format(i + 1), end='')
# Detect faces
boxes, _ = mtcnn.detect(frame)
if boxes is not None:
frame_info['frame_'+str(i+1)] = len(boxes)
# Draw faces
frame_draw = frame.copy()
draw = ImageDraw.Draw(frame_draw)
for box in boxes:
draw.rectangle(box.tolist(), outline=(255, 0, 0), width=6)
if data.height > 0 and data.width > 0:
# Add to frame list
frames_tracked.append(frame_draw.resize((data.width, data.height), Image.BILINEAR))
else:
frames_tracked.append(frame_draw)
else:
frames_tracked.append(frame)
print('\nDone')
dim = frames_tracked[0].size
fourcc = cv2.VideoWriter_fourcc(*'H264')
video_tracked = cv2.VideoWriter(video_path, fourcc, 25.0, dim)
for frame in frames_tracked:
video_tracked.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
video_tracked.release()
info['faces_detected'] = frame_info
return info