Hello anh em Mì AI, hôm nay chúng ta sẽ cùng nhau tìm hiểu và triển khai RepVGG, một hậu thể của model VGG lừng lẫy một thời.
Trước khi bắt đầu mình xin nói đôi lời:
- Đây là một model mới ra mắt, theo paper là 1/2021 nên chưa có nhiều bài tham khảo. Với kiến thức Mì ăn liền, mình tự đọc Paper để hiểu và viết, nếu sai mong các bác lượng thứ và chỉnh sửa giúp.
- Với phương châm Mì nên mình cũng chỉ tìm hiểu cho biết nó là cái chi chi chứ không đi sâu vào mấy cái công thức toán mà mình đọc nghìn năm vẫn không hiểu ah =)).
- Bác nào cần nghiên cứu sâu thì đây là link paper ạ: https://arxiv.org/pdf/2101.03697.pdf
Let’s go!
Phần 1 – Uống nước nhớ nguồn, hồi tưởng về tiền bối VGG
Chắc hẳn anh em làm Deep Learning và đặc biệt là các anh em theo dõi các bài trên Mì AI đã quá quen thuộc với VGG. VGG16 là một convolutional neural network được giới thiệu bởi K. Simonyan and A. Zisserman, Đại học Oxford trong paper “Very Deep Convolutional Networks for Large-Scale Image Recognition” năm 2014. Model đã có nhiều thành công vượt trội so với các model thời bấy giờ (như AlexNet, …). VGG đạt 92.7% top-5 test accuracy trên bộ data ImageNet.
Ngoài độ chính xác, VGG còn có kiến trúc đơn giản nên tốc độ khá nhanh và dễ dàng trong triển khai, customize và transfer learning.
Chính vì thế VGG trở nên thông dụng và được ông Keras tích hợp luôn vào trong framework và chúng ta chẳng cần làm gì ngoài khai báo và sử dụng. Ví dụ về việc khai báo là dùng được của VGG16: Nhận dạng tiền Việt Nam bằng VGG16 CNN Classify – Mì AI.
Phần 2 – Các thế lực cạnh tranh mang tên Multi Branches
Vâng, không ít lâu sau khi VGG trở thành bá chủ thiên hạ thì có một môn phái mới xuất hiện có tên Multi Branches. Điều đặc biệt của môn phái này là kiến trúc mạng đa nhánh, đúng với cái tên của nó.
Để các bạn hình dung sự khác nhau giữa đơn nhánh và đa nhánh ta ghép chúng vào một hình như sau:
Nhìn vào hình các bạn thấy rõ rằng VGG thẳng một trục, còn ông Resnet và GoogLenet thì tỏa như cây.
Đến đây có bạn sẽ thắc mắc thế cái môn phái Multi Branches nó có độc chiêu gì mà có thể khuất phục VGG đại ca được chứ?
Kiến trúc đa nhánh sẽ có nhiều ưu điểm:
- Các model dạng một trục thẳng như VGG thì khi kiến trúc càng sâu sẽ dễ gặp hiện tượng triệt tiêu đạo hàm (Gradient Vanishing) khi thực hiện lan truyền ngược backpropagation. Điều đó dẫn đến các lớp cuối gần như không được update.
- Các model đa nhánh như là sự ghép nối của nhiều model nông (shallow model) với nhau để tạo thành model cuối cùng. Điều đó sẽ tránh việc Gradient Vanishing và cũng góp phần giảm sự phụ thuộc của model chính vào một nhánh nhất định.
- Thêm một điểm nữa là các model đa nhánh sẽ cho phép các đặc trưng của lớp trước có thể ghép nối, truyền thẳng đến lớp sau thông qua các kết nối để tránh mất thông tin hoặc có thể thực hiện nhiều ý đồ khác theo thiết kế của tác giả.
Đấy đấy, do có nhiều ưu điểm thế nên môn phái đa nhánh này đạt nhiều thành công trong giới võ lâm. Ví như thằng Inception Resnet V3 đạt 80% top 1 Accuracy trên ImageNet (theo https://paperswithcode.com/sota/image-classification-on-imagenet)
Vậy là VGG bị khuất phục, bỏ lên núi tu luyện, trước khi đi không quên quăng lại một câu “hẹn ngày tái ngộ”!
Phần 3 – 7 năm sau – ngày tái xuất của RepVGG
Đúng 7 năm sau, từ trên núi cao có một cao thủ xuất hiện, thân thủ phi phàm lao xuống núi rầm rầm. Thấy vậy giang hồ mới xin hỏi quý tính đại danh. Cao thủ bỏ mũ ra thì các anh hùng hào kiệt đều mắt chữ O, mồm chữ A. Rõ ràng có chút gì đó quen quen nhưng lại lạ lạ. Same same but different.
Hóa ra đó chính là hậu thế của VGG, có tên RepVGG. RepVGG đã luyện thành công một model vừa có tốc độ inference bá đạo kèm theo accuracy bá cháy 😀
Thế là các anh hùng ngồi xuống trà đạo rồi RepVGG từ từ kể.
Ngày đó bị khuất phục, tiền bối VGG đã ngày đêm nghiên cứu môn phái Multi Branches và nhận ra rằng ngoài các ưu điểm nói trên thì cũng có nhiều điểm yếu như: sự đa nhánh làm cho việc customize khó, chạy inference chậm, tiêu tốn bộ nhớ RAM gấp đôi kiến trúc đơn giản một nhánh (như hình dưới đây).
Thế là tiền bối VGG mới nghĩ ra độc chiêu biến hóa cho model là “train bằng đa nhánh, nhưng inferce một nhánh” với tuyệt chiêu “Rì pa ra mét tờ” để chuyển đổi tham số mạng từ đa nhánh về 1 nhánh. Cụ thể như sau:
- Đầu tiên, model được train đa nhánh để tận dụng các ưu điểm của môn phái Multi Branches.
- Sau đó, trước khi thực hiện Inference, model sẽ được convert về đơn nhánh. Chỗ này nói thêm 1 chút, trước giờ anh em vẫn biết là model train xong thì weights sẽ được lưu lại và weights này không thể load cho một model có kiến trúc khác. Nhưng với RepVGG tác giả có thêm món Model Re-parameterization để convert được weights từ đa nhánh về 1 nhánh. Nói một cách nôm na là weights mới của mạng đơn nhánh sẽ được tính toán từ weights cũ của mạng đa nhánh một cách toán học và học thuật nào đó để đảm bảo rằng hai model là tương đương (có sai số tầm 1e-6 thì phải). Điều này sẽ làm cho model inference nhanh hơn nhiều với các model Multi Branches khác.
Đó, có vậy thôi, vậy là RepVGG mới được submit hội nghị, có paper ngon lành :D. À quên, nói thêm là RepVGG có nhiều anh em biến thể khác nhau chút xíu về kiến trúc là : RepVGG-A0/1/2 và RepVGG-B0/1/2.
Phần 4 – Thử inference với model RepVGG
Đầu tiên chúng ta tải pretrain weights của RepVGG tại link sau: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq?usp=sharing . Ở đây mình chọn model RepVGG-A0.
Sau khi tải xong ta sẽ có file RepVGG-A0-train.pth. Nhìn cái đuôi là biết Pytorch rồi.
Tiếp theo chúng ta sẽ git clone từ github của mình về (hoặc github gốc của tác giả tại https://github.com/DingXiaoH/RepVGG, nhưng sẽ thiếu file inference mình viết thêm) :
git clone https://github.com/thangnch/MIAI_RepVGG
Code language: PHP (php)
Cài đặt các thư viện cần thiết
pip install torch torchvision opencv-python
Tiếp theo, để sử dụng model để inference ta sẽ thực hiện lệnh convert
python convert.py RepVGG-A0-train.pth RepVGG-A0-deploy.pth -a RepVGG-A0
Code language: CSS (css)
Convert xong ta sẽ có file RepVGG-A0-deploy.pth và bây giờ thử inference xem nào. Đây là đoạn source mình viết thêm, trong source code gốc không có nhé. Bạn nào tải github gốc thì tạo một file python và viết vào file đó nhé. Còn bạn nào dùng github của mình thì nó ở trong file main.py.
Để có thể load được ảnh, thực hiện các bước Transform theo như tác giả đã làm, ta tạo một hàm load ảnh như sau:
def image_loader(image_name):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
trans = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
image = Image.open(image_name)
image = trans(image).float()
image = Variable(image, requires_grad=False)
image = image.unsqueeze(0)
return image
Code language: PHP (php)
Chỗ này mình nói thêm cho các bạn chưa thạo. Đó là ví dụ khi train họ chế cháo ảnh, thêm normalize, resize…. thì khi mình test mình cũng phải process ảnh của chúng ta tương tự như vậy trước khi đưa vào model. Ngoài ra bạn để ý trong các bước transform có một bước .ToTensor() để chuyển ảnh thành tensor nhé.
Ngoài ra, để cho model có thể in ra được nhãn ảnh (ví dụ: cá mập, rùa….) thay vì các con số khô khan thì mình cũng làm thêm một hàm load tên class từ file txt (mình có thể trong github)
def load_imagenet_class_labels():
file = open("imagenet1000_clsidx_to_labels.txt", "r")
contents = file. read()
image_labels = ast.literal_eval(contents)
file.close()
return image_labels
Code language: JavaScript (javascript)
Giờ đến lượt chúng ta load cái model đã convert và tiến hành predict thôi:
# Load labels image net
image_labels = load_imagenet_class_labels()
repvgg_build_func = get_RepVGG_func_by_name('RepVGG-A0')
model = repvgg_build_func(deploy=True)
load_checkpoint(model, "RepVGG-A0-deploy.pth")
model.eval()
image = image_loader("test_images\\n01665541_leatherback_turtle.jpg")
with torch.no_grad():
# compute output
output = model(image)
class_id = numpy.argmax(output.numpy())
print("Class = ", image_labels[class_id])
Code language: PHP (php)
Và đây là kết quả chuẩn rồi:
Như vậu bước đầu mình đã giới thiệu cho các bạn cơ bản về RepVGG và cách thức sử dụng pretrain RepVGG để dự đoán hình ảnh. Minh chạy tốc độ dưới 1s trên máy tính chỉ có CPU Core i5 2400. Khá ổn!
Bài này đã dài mình xin tạm dừng, trong bài sau mình sẽ cùng tìm hiểu cách train từ đầu, transfer learning trên RepVGG nhé! Hẹn gặp lại các bạn
Update: Phần 2 đã có tại đây: https://www.miai.vn/2021/05/12/train-repvgg-tu-dau-va-transfer-learning-phan-2-2/
Chúc các bạn thành công!
#MìAI
Fanpage: http://facebook.com/miaiblog
Group trao đổi, chia sẻ: https://www.facebook.com/groups/miaigroup
Website: https://miai.vn
Youtube: http://bit.ly/miaiyoutube
comment first :>