Xây dựng hệ thống nhận diện biển báo giao thông bằng Retinanet

Hello toàn thể anh em Mì AI, hôm nay chúng ta sẽ tìm hiểu về Focal Loss, Retinanet và thực hành Xây dựng hệ thống nhận diện biển báo giao thông bằng Retinanet xem nhé!

Let’s go!

Phần 1 – Giới thiệu qua về Focal Loss, Retinanet

Trước giờ trên Blog có nhiều bài về YOLO, SSD…. và chúng ta thấy rằng các net đó có một điểm yếu là rất dễ miss các vật thể nhỏ trong ảnh. Và theo yêu của các mem (cả về việc tìm hiểu Retinanet và cả về việc đề nghị làm một ví dụ về nhận diện biển báo giao thông, bài phân loại biển báo thì có rồi : Tại đây) nên hôm nay mình triển món này. Retinanet được đánh giá là khá ổn trong việc phát hiện các object nhỏ và lý do thì mời anh em đọc hết bài này nhé.

Tuy nhiên, với phong cách Mì ăn liền nên mình không đi sâu tìm hiểu kiến trúc với toán học nhé. Các bạn có nhu cầu cứ đọc trực tiếp paper gốc tại đây nha.

Okie, và bây giờ chúng ta bắt đầu với Focal Loss trước.

Focal Loss là gì?

Trước giờ khi làm các bài toán về classification chúng ta hay dùng làm loss là Cross Entropy. Hàm loss này có một điểm yếu là đối xử với các class như nhau. Do đó nếu như dữ liệu của chúng ta imbalance, mất cân bằng giữa các class thì hàm loss này tỏ ra kém hiệu quả. Model sẽ có xu hướng nghiêng về các class có nhiều dữ liệu hơn và bỏ sót class thiểu số. Điều này càng trở nên nghiêm trọng khi các class thiểu số là các class quan trọng như: tế bào ung thư, giao dịch giả mạo….

cross entropy
Hàm loss Cross Entropy

Trước Focal Loss cũng đã có Balance Cross Entropy thực hiện bằng cách gán trọng số cao hơn cho các class thiểu số để phạt model mạnh hơn khi dự đoán sai các class này. Do đó các class thiểu số sẽ tác động tới hàm loss nhiều hơn. Tuy vậy theo các tài liệu trên mạng thì nó vẫn chưa thực sự can thiệp vào quá trình Gradient Descent nên GD vẫn sẽ bị điều chỉnh theo các class đa số. Và lý do đó nên Focal Loss (FL) ra đời để xử lý triệt để vấn đề này.

Ở đây mình cũng không có ý định tìm hiểu sâu nên anh em chỉ biết sơ vậy để sử dụng thôi nhé. Các cao thủ đi qua bỏ quá.

focal loss
Hàm Focal Loss
Ơ thế sao đang Retinanet lại nói về Focal Loss?

Đơn giản thôi, vì cái thằng Retinanet nó dùng FL và cũng giới thiệu luôn về FL lần đầu tiên luôn? Thế vì đâu mà ông Retinanet lại phải sử dụng FL? Vì tác giá phát hiện ra rằng trong các bài toán Object Detection (OD) thì ta nhận thấy như sau:

  • Có 2 class là Foreground (chứa vật thể) và Background (nền và không chứa vật thể)
  • Họ nhận thấy “We discover that the extreme foreground-background class imbalance encountered during training of dense detectors is the central cause”. Tạm dịch là do sự imbalance giữa 2 class đó là nguyên nhân chính dẫn đến sự kém hiểu quả trong OD.

Đó, thế là họ quyết định dùng Focal Loss. Link về paper đã để ở trên cho các bạn học chuyên sâu. Vậy đọng lại ở đây là gì? Đó là: Retinanet dùng FL và ta có thể dùng FL với các bài imbalance data.

Còn bây giờ đi tiếp nào!

Retinanet có cấu trúc như sau:

retinanet architect
Nguồn: Tại đây

Okie, nhìn vào sơ đồ trên anh em thấy rõ các việc cần làm như sau:

Từ ảnh đầu vào sẽ đưa qua một mạng CNN đóng vai trò trích xuất đặc trưng. Tuy nhiên ở đây kết hợp thêm một mạng FPN (Feature Pyramid Network) để tạo ra một loạt các output dạng kim tự tháp như trên ảnh Mục đích của việc này là “…for detecting objects at different scales” nghĩa là nhận diện đối tượng ở các tỷ lệ scale khác nhau. Paper đây các bạn nhé!

FPN sẽ gồm một nhánh đi lên (a) và nhánh đi xuống (b). Nhánh đi lên đóng vai trò trích xuất đặc trưng, sẽ giảm về kích thước nhưng tăng về ý nghĩa (giảm 2 lần sau mỗi lần downsampling) . Một số mạng sau vài layer Conv thì dùng feature map (FM) cuối cùng để detection nên dễ bỏ qua các object nhỏ.

Do đó, ở đây FPN bổ sung thêm nhánh bên phải đi xuống để up sampling từ (FM) nhỏ nhất. Mỗi lần upsampling kích thước sẽ tăng lên gấp đôi. Và lúc này tuy là layer được tạo ra vẫn giữ ý nghĩa nhưng vị trí của object đã bị mất qua các bước down sampling và up sampling. Tác giả khắc phục bằng cách nối với một layer ở nhánh đi lên (layer cùng level) bằng phép cộng element-wise additional với hi vọng giữ lại được thông tin đó.

Sau khi cộng vào thì sẽ sinh ra một loạt các FM mới và Retinanet dùng các FM đó để đưa vào một 2 mạng phụ (subnet ) là class subnet và box subnet để predict các thông tin chúng ta cần là:

  • Class subnet thì predict probality của các class trên từng điểm ảnh trên ảnh. Output của mạng này như trên hình các bạn xem là WxHxKA. Trong đó W, H là width và height của ảnh. KA là độ sâu = K * A, K là cố class (ví dụ train 80 class thì K = 80) và A là số anchorbox (trong paper là 9, bằng thực nghiệm). Ý nghĩa là với từng điểm ảnh, với anchorbox này, thì mỗi class có probality là bao nhiêu….
  • Box subnet thì predict ra cái box của vật thể tại từng điểm ảnh và output có độ sâu là 4xA. Vì sao? Vì mỗi anchorbox ta sẽ dự đoán 4 giá trị là : (x, y, w, h) (tâm và rộng dài) của box đó.

Chú ý ở đây là các FM mới sinh ra ( > 1) đều được đưa vào các subnet nhằm detect ra cả các vật thể với kích thước to nhỏ khác nhau nhé!

Rồi, đại khái về ông Retinanet là vậy, với suy nghĩ và cách làm như vậy, tác giả mong muốn có một model OD tốt, detect ngon được các vật với nhiều kích thước khác nhau. Chúng ta sẽ thử xem nhé!

Phần 2 – Train và thử nghiệm với bài toán nhận diện biển báo giao thông

Bài toán này mình sẽ sử dụng dữ liệu của Zalo AI nhé (xin cảm ơn Zalo). Các bạn có thể tải dữ liệu tại Thư viện Mì AI: https://www.miai.vn/thu-vien-mi-ai/ nhé!

Let’s do it! Đầu tiên các bạn clone cái github của mình về nhé:

git clone https://github.com/thangnch/MiAI_Trafficsign_RetinanetCode language: PHP (php)
Bước 1. Convert data

Dữ liệu tải về từ Zalo AI có dạng json theo cấu trúc trên website Zalo AI Challenge này. Để có thể sử dụng để train Retinanet, chúng ta phải thực hiện viết code convert về csv file.

Chú ý một chút là cấu trúc file train ở bên này khác với Yolo nhé, các bạn đừng nhầm lấy dữ liệu YOLO sang train nha!

Anh em lưu ảnh vào một folder nào đó sau đó tạo ra 1 file csv gồm N dòng, mỗi dòng là một bbox (hình chữ nhật) của object trong ảnh. Nếu một ảnh có nhiều bbox thì file ảnh đó sẽ có nhiều dòng:

path/to/image.jpg,x1,y1,x2,y2,class_name

Ví dụ theo như nội dung trang chủ:

/data/imgs/img_001.jpg,837,346,981,456,cow
/data/imgs/img_002.jpg,215,312,279,391,cat
/data/imgs/img_002.jpg,22,5,89,84,bird

Do đó chúng ta sẽ thực hiện convert bằng đoạn code sau:

import json

# Class descriptions:
sign_dict = {
    1: "No entry",
    2: "No parking / waiting",
    3: "No turning",
    4: "Max Speed",
    5: "Other prohibition signs",
    6: "Warning",
    7: "Mandatory"
}

json_path = "data/za_traffic_2020/traffic_train/train_traffic_sign_dataset.json"
csv_path = "data/za_traffic_2020/traffic_train/train_traffic_sign_dataset.csv"
with open(csv_path, "w") as csv_file:
    with open(json_path) as json_file:
        data = json.load(json_file)
        annotations = data['annotations']
        for p in annotations:
            print('Bbox: ' + str(p['bbox']))
            print('Image: ' + str(p['image_id']))
            print('category_id: ' + str(sign_dict[p['category_id']]))
            csv_file.write(
                "images/{}.png,{},{},{},{},{}\n".format(p['image_id'], p['bbox'][0], p['bbox'][1], p['bbox'][0] + p['bbox'][2],
                                                        p['bbox'][1] + p['bbox'][3], sign_dict[p['category_id']]))
Code language: PHP (php)

Chú ý ở đây do trong file json các class được đánh số từ 1 nên mình đang để cái dict như kia nhé. Sau bước này ta sẽ có một file csv (tên là train_traffic_sign_dataset.csv) đúng theo yêu cầu của Retinanet.

Bước 2 – Tạo file classes.csv

Bước này ta tạo ra một file danh mục class và id theo yêu cầu của Retinanet. File này nếu nhiều class có thể code để tạo ra, còn mình ít nên tạo tay luôn.

No entry,0
No parking / waiting,1
No turning,2
Max Speed,3
Other prohibition signs,4
Warning,5
Mandatory,6
Bước 3 – Chạy lệnh compile Cython code

Bạn chuyển về thư mục gốc và chạy lệnh:

python setup.py build_ext --inplaceCode language: CSS (css)
Bước 4 – Train model

Mọi thứ sẵn sàng rồi thì train model thôi các mem. Chạy lệnh như sau:

python keras_retinanet/bin/train.py --epochs 50 --steps 562 --batch-size 8 csv data/za_traffic_2020/traffic_train/train_traffic_sign_dataset.csv data/za_traffic_2020/traffic_train/classes.csv

Các bạn chú ý ở đây nhé:

  • –epochs: mình dàng để là 50
  • –batch-size: mình để là 8 cho phù hợp với VRAM của GPU của mình, các bạn có thể tăng giảm phù hợp để vừa có hiệu năng train cao vừa đỡ bị hết bộ nhớ GPU nha
  • –steps : số bước trong mỗi epoch. Các bạn cứ lấy tổng số ảnh chia lấy phần nguyên cho –batch-size là ra nhé. Ví dụ có 4500 ảnh, batch size = 8 -> steps = 562 😉

Rồi sau khi train tầm 36 epochs thì mình có loss tầm 0.2 nên dừng lại. Mình không đặt nặng vấn đề tối ưu cho bài này nên:

  • Dừng để tiết kiệm thời gian và chi phí thuê GPU.
  • Mình không sử dụng validation khi train, để test thôi mà 😀

Sau mỗi epoch thì retinanet sẽ lưu ra 1 model h5 trong thư mục snapshots, anh em cứ lấy model cuối cùng cũng được. Ví dụ như mình là lấy resnet50_csv_36.h5.

Bước 5 – Convert model để inference

Model train ra của Retinanet cần được convert trước khi inference nhé. Lý do là để tăng tốc độ train nên khi train thì Retinanet sẽ sử dụng một phiên bản stripped down (giản lược) đi nhiều layers và chỉ giữ lại các layer phục vụ training thôi. Do vậy train xong thì cần convert lại.

Ta chạy lệnh:

python keras_retinanet/bin/convert_model.py train_model/resnet50_csv_36.h5 infer_model/resnet50_csv_36.h5

Và đợi chút là okie!

Bước 6 – Test trên ảnh xem nào!

Bây giờ mình code 1 file test_model.py với nội dung sau để test, code khá đơn giản , các bạn đọc là hiểu!


from keras_retinanet import models
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
from keras_retinanet.utils.visualization import draw_box, draw_caption
from keras_retinanet.utils.colors import label_color

# import miscellaneous modules
import matplotlib.pyplot as plt
import cv2
import os
import numpy as np

model_path = os.path.join('infer_model', 'resnet50_csv_36.h5')

# load retinanet model
model = models.load_model(model_path, backbone_name='resnet50')

# load label to names mapping for visualization purposes
labels_to_names = {
    0: "No entry",
    1: "No parking / waiting",
    2: "No turning",
    3: "Max Speed",
    4: "Other prohibition signs",
    5: "Warning",
    6: "Mandatory"
}

# load image
image = read_image_bgr('data/za_traffic_2020/traffic_train/images/108.png')

# copy to draw on
draw = image.copy()
draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB)

# preprocess image for network
image = preprocess_image(image)
image, scale = resize_image(image)

boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))

# correct for image scale
boxes /= scale

# visualize detections
for box, score, label in zip(boxes[0], scores[0], labels[0]):
    # scores are sorted so we can break
    if score < 0.5:
        break

    color = label_color(label)

    b = box.astype(int)
    draw_box(draw, b, color=color)

    caption = "{} {:.3f}".format(labels_to_names[label], score)
    draw_caption(draw, b, caption)

plt.figure(figsize=(10, 5))
plt.axis('off')
plt.imshow(draw)
plt.tight_layout()
plt.show()Code language: PHP (php)

Và kết quả với 1 số ảnh:

retinanet
retinanet
retinanet

Nhận xét qua thì ta thấy các object trong hình trên khá nhỏ nhưng vẫn không làm khó được Retinanet.

Mình cũng tặng bạn nào đọc đến đây file weights mình train sẵn với biển báo giao thông nhé. Tải tại đây!

Well, như vậy mình cũng đã hướng dẫn các bạn cách train và test model Retina Net. Hi vọng sẽ giúp được các bạn.

Chúc các bạn thành công!

#MìAI

Fanpage: http://facebook.com/miaiblog
Group trao đổi, chia sẻ: https://www.facebook.com/groups/miaigroup
Website: https://miai.vn
Youtube: http://bit.ly/miaiyoutube

Tài liệu tham khảo: https://github.com/fizyr/keras-retinanet

Related Post

22 Replies to “Xây dựng hệ thống nhận diện biển báo giao thông bằng Retinanet”

  1. tạo ra 1 file csv gồm N dòng, mỗi dòng là một bbox (hình chữ nhật) của object trong ảnh ???
    Bước này làm như nào vậy anh ???

  2. Em train bằng laptop không có CUDA thì trong quá trình train toàn sử dụng 100% CPU và gần 100% (8GB ram thật) và gần 100% (8GB ram ảo). Có phải lỗi không anh?

  3. Anh ơi. Em chạy “python setup.py build_ext –inplace” nó tạo ra keras_retinanet/utils nhưng bên trong không có gì hết thì phải làm như nào ạ?

    1. error: D:\python\mi_ai\MiAI_Trafficsign_Retinanet\keras_retinanet\utils\compute_overlap.pyx

  4. (MiAI_Trafficsign_Retinanet) D:\python\mi\MiAI_Trafficsign_Retinanet>python setup.py build_ext –inplace

    running build_ext
    cythoning keras_retinanet/utils/compute_overlap.pyx to keras_retinanet/utils\compute_overlap.c
    creating keras_retinanet
    creating keras_retinanet\utils
    error: D:\python\mi\MiAI_Trafficsign_Retinanet\keras_retinanet\utils\compute_overlap.pyx

    em chạy “python setup.py build_ext –inplace” thì nó báo lên như trên ạ. Cái này phải khắc phục như nào anh?

    1. Bạn ơi, mình bị lỗi giống như bạn vậy. Không biết bạn đã fix được chưa thì giúp mình với. Mình cảm ơn bạn.

  5. Bạn ơi, mình bị lỗi giống như bạn vậy. Không biết bạn đã fix được chưa thì giúp mình với. Mình cảm ơn bạn.

  6. Em chào anh. Em khi chạy đến bước build python setup.py build_ext –inplace cũng bị báo lỗi
    error: D:\python\mi\MiAI_Trafficsign_Retinanet\keras_retinanet\utils\compute_overlap.pyx
    như hai bạn đang gặp trên.
    Sau khi em clone qua project (https://github.com/fizyr/keras-retinanet) thì câu build trên chạy được. Nhưng đến phần train model. Phần outputShape của em có dạng [None,None,None, x] (x là 1 số 1->6). Và lost rất lớn. (>200k). Data trong file csv sau khi convert (csv_file.write(“images/{}.png,{},{},{},{},{}\n”.format(p[‘image_id’],p[‘bbox'[0],p[‘bbox'[1],p[‘bbox’][0] + p[‘bbox’][2],p[‘bbox’][1] + p[‘bbox’][3], sign_dict[p[‘category_id’]])) ) của em có dạng 6 cột với thứ tự image_path x1 y1 x2 y2 class. Không biết có phải bước convert data về csv file bị trái với định dạng của retinaNet ko ạ? Nếu không phiền, mong được anh chia sẻ kinh nghiệm nếu như a đã từng gặp như em ạ. Em cảm ơn! Chúc anh và MiAI ngày càng phát triển!

  7. cho em là project này nó chị áp dụng cho ảnh thôi hay là cả trên ảnh thời gian thực vậy ạ?

  8. anh ơi nhận diện biển báo giao thông này thì kích thước ảnh mẫu lớn nhỏ tùy ý, miễn là các ảnh mẫu kích thước giống nhau là được phải không anh?

  9. Anh có thể hướng dẫn sử dụng RetinaNet cho tập X-RAY của Vinbigdata không ạ ? Em làm theo như trên thay dataset thôi nhưng vẫn gặp nhiều lỗi ạ. Em cảm ơn anh !

  10. github của anh thắng bị lỗi hay sao ấy ạ!
    cythoning keras_retinanet/utils/compute_overlap.pyx to keras_retinanet/utils/compute_overlap.c
    error: /content/gdrive/Shareddrives/Train_NVHS/pytorch-retinanet-master/keras_retinanet/utils/compute_overlap.pyx

Leave a Reply

Your email address will not be published. Required fields are marked *