Skip to content
Mì AI Mì AI Mì AI

Học AI theo cách Mì ăn liền!

Mì AI Mì AI Mì AI

Học AI theo cách Mì ăn liền!

  • Trang chủ
  • Kênh Youtube
  • Facebook Group
  • Nói về chủ tiệm Mì
  • Trang chủ
  • Kênh Youtube
  • Facebook Group
  • Nói về chủ tiệm Mì
Close

Search

  • Trang chủ
  • Kênh Youtube
  • Facebook Group
  • Nói về chủ tiệm Mì
Mì AI Mì AI Mì AI

Học AI theo cách Mì ăn liền!

Mì AI Mì AI Mì AI

Học AI theo cách Mì ăn liền!

  • Trang chủ
  • Kênh Youtube
  • Facebook Group
  • Nói về chủ tiệm Mì
  • Trang chủ
  • Kênh Youtube
  • Facebook Group
  • Nói về chủ tiệm Mì
Close

Search

  • Trang chủ
  • Kênh Youtube
  • Facebook Group
  • Nói về chủ tiệm Mì
Computer Vision

Train RepVGG từ đầu và transfer learning (Phần 2/2)

By Chủ tiệm Mì
May 12, 2021 8 Min Read
1

Xin chào anh em Mì AI, như vậy trong bài trước chúng ta đã tìm hiểu qua về RepVGG và thử convert và inference RepVGG. Hôm nay chúng ta sẽ cùng nhau tìm cách train RepVGG với dữ liệu của chúng ta theo cả 2 hình thức là train từ đầu và transfer learning.

Link bài trước cho bạn nào chưa đọc: RepVGG – hậu thế của VGG tái xuất giang hồ (Phần 1/2).

Mình có tìm kiếm trên mạng thì nhiều bài giới thiệu về RepVGG nhưng các bài hướng dẫn train thì còn ít nên mạnh dạn viết bài này để guide các bạn. Nếu có sai sót mong các cao thủ đi qua bỏ quá!

Còn bây giờ thì let’s go!

Phần 1 – Train RepVGG từ đầu với dữ liệu của chúng ta

Như chúng ta đã biết thì RepVGG đã được train với bộ dữ liệu ImageNet và bài trước chúng ta đã sử dụng pretrain đó để inference thử và cho kết quả khá tốt.

RepVGG

Còn bây giờ chúng ta sẽ thử train với data của chúng ta. Bài hôm nay mình sử dụng Fish Dataset trên Kaggle tại địa chỉ https://www.kaggle.com/crowww/a-large-scale-fish-dataset. Đây là bộ dữ liệu gồm 9 loài cá, mỗi loại 1000 ảnh có độ lớn 3GB, khá ổn để có thể train from scratch bài này. Các bạn chú ý mà dữ liệu ít quá thì chớ dại nhé, dễ bị underfit lắm.

Bắt đầu thôi!

Bước 1 – Tiền xử lý dữ liệu

Dữ liệu Fish Dataset này sử dụng cho Image Segmentation nên tổ chức dữ liệu sẽ có thêm folder masking nữa. Trong bài này chúng ta tạm thời sẽ không sử dụng đến dữ liệu masking.

Các bạn tải file dữ liệu cá kia về giải nén. Sau đó tạo thư mục data và 3 thư mục con là raw, train và val. Sau đó copy dữ liệu tải về vào thư mục raw để đảm bảo có cấu trúc như sau:

train RepVGG

Okie rồi, bây giờ ta cần một đoạn script để convert dữ liệu trong thư mục raw vào 2 thư mục train và val. Ở đây mình chọn tỷ lệ chia dữ liệu là 80/20, nghĩa là 80% dữ liệu train và 20% để làm validation.

import os
import random
import shutil

data_root = "data"
data_raw = os.path.join(data_root, "raw")
data_train = os.path.join(data_root, "train")
data_val  = os.path.join(data_root, "val")

for folder in os.listdir(data_raw):
    if folder[0]!=".":
        file_list = []
        full_folder = os.path.join(data_raw, folder, folder)
        print("Folder ", full_folder)
        for file in os.listdir(full_folder):
            if file[0] != ".":
                full_file = os.path.join(full_folder, file)
                file_list.append(full_file)

        total_files = len(file_list)
        print("Total = ",total_files)
        random.shuffle(file_list)
        train_files = file_list[0:800]
        val_files = file_list[800:]
        print("Số file = train ",len(train_files))
        print("Số file = val ", len(val_files))


        train_folder  = os.path.join(data_train, folder)
        if not os.path.exists(train_folder):
            os.makedirs(train_folder)
        i=0
        for train_file in train_files:
            i+=1
            dest_file = os.path.join(train_folder, os.path.basename(train_file))
            print("Copy to train ", train_file, " to ", dest_file)

            shutil.copyfile(train_file, dest_file)

        print("Copy train ", i)
        val_folder = os.path.join(data_val, folder)
        if not os.path.exists(val_folder):
            os.makedirs(val_folder)
        i=0
        for val_file in val_files:
            i+=1
            # print("Copy to val ", val_file)
            shutil.copyfile(val_file, os.path.join(val_folder, os.path.basename(val_file)))
        print("Copy val ", i)

Sau khi chạy đoạn script trên (trong github mình để ở file make_dataset.py) thì trong hai thư mục train và val sẽ có các folder con để train.

train RepVGG
Bước 2 – Tiến hành train với kiến trúc mạng RepVGG-A0

Okie, để train thì mình sẽ chạy đúng script train mà github gốc của tác giả cung cấp:

python train.py -a RepVGG-A0 --epochs 50  data 

Trong câu lệnh trên thì:

  • Cái đoạn -a RepVGG-A0 là kiến trúc mạng sẽ sử dụng.
  • –epochs 50: Train 50 epochs
  • data : là tên folder chữ dữ liệu, chính là folder data mà chúng ta đã tạo ra ở bên trên.

Khi train mà bạn thấy màn hình hiển thị như sau là ổn rồi:

Epoch: [0][ 0/29]	Time 27.423 (27.423)	Data 21.897 (21.897)	Loss 6.9711e+00 (6.9711e+00)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)

Epoch: [0][10/29]	Time  0.950 ( 4.026)	Data  0.002 ( 2.848)	Loss 2.0947e+00 (3.3592e+00)	Acc@1  22.27 ( 18.04)	Acc@5  64.84 ( 61.29)
Epoch: [0][20/29]	Time  0.872 ( 3.216)	Data  0.001 ( 2.248)	Loss 1.5741e+00 (2.6636e+00)	Acc@1  38.28 ( 24.78)	Acc@5  92.19 ( 71.06)
Test: [0/8]	Time 19.879 (19.879)	Loss 8.3546e+03 (8.3546e+03)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)
 * Acc@1 11.111 Acc@5 55.556

Sau 50 epochs mình đạt được Top 1 Acc (thể hiện ở mục Acc@1) là khoảng 96% và Top 5 Acc (thể hiện Acc@5) là 99%. Khá ổn, có lẽ do dữ liệu bài này khá dễ phân biệt.

Bước 3 – Convert model và inference

Sau quá trình train, các bạn sẽ thấy xuất hiện 02 file trong folder chứ file train.py đó là:

  • checkpoint.pth.tar: Là file weights cuối cùng tạo ra từ model.
  • model_best.pth.tar: Là file weights tốt nhất mà model lưu lại được.

Tùy nhu cầu của các bạn mà các bạn lấy file weights tương ứng để sử dụng nhé.

Sau đó các bạn convert tương tự như Bài số 1:

python convert.py model_best.pth.tar model_best_deploy.pth -a RepVGG-A0

Và chúng ta sử dụng lại code inference của bài số 1 nhưng thay đoạn load_checkpoint(model, “RepVGG-A0-deploy.pth”) bằng load_checkpoint(model, “model_best_deploy.pth”) để load model của chúng ta.

Load ngon rồi thì bạn đưa các ảnh cá vào để predict thử xem có chuẩn không nhé! Có vẻ dữ liệu này khá dễ nên model nhận chuẩn vãi lúa 😀

Các điểm chú ý: Bài này sử dụng kiến trúc RepVGG-A0 đẻ train, mà cái thằng RepVGG này nó lại train trên ImageNet với số class vẫn là 1000 nhé trong khi chúng ta chỉ có 9 class. Tuy nhiên vẫn không ảnh hưởng lắm, model vẫn chạy ầm ầm. Bạn nào muốn chuẩn có thể mở file repvgg.py để sửa lại num_classes trong các block khai báo như bên dưới nhé. Bài này mình tạm để nguyên.

def create_RepVGG_A0(deploy=False):
    return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
                  width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy)

Phần 2 – Transfer Learning, sử dụng VGG với vai trò Feature Extractor

Trong phần transfer learning này, để đơn giản mình chỉ tập trung vào việc sử dụng RepVGG (với pretrained ImageNet) như một Feature Extractor (FE) sau đó customize lại các layer Fully Connected cuối cùng để phục vụ cho bài toán customize của mình.

Dữ liệu mình vẫn sử dụng bộ dữ liệu trong thư mục data đã tạo ra ở Phần 1 và load bằng data_loader của Pytorch

traindir = os.path.join("data", 'train')
    valdir = os.path.join("data", 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))


    train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=256, shuffle=(train_sampler is None),
        num_workers=8, pin_memory=True, sampler=train_sampler)

    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

Sau đó mình tiến hành Transfer Learning. Cụ thể là đóng băng toàn bộ các layer của mạng RepVGG gốc để giữ lại các weights đã train. Ta chỉ thay đổi lại layer FC cuối cùng thành 9 class và tiến hành train để layer này học mà thôi:

    # Load model RepVGG với weights gốc ImageNet
    from repvgg import create_RepVGG_A0
    from utils import load_checkpoint
    model_ft = create_RepVGG_A0(deploy=True)
    load_checkpoint(model_ft, "RepVGG-A0-deploy.pth")

    # Đóng băng toàn bộ các layer của RepVGG   
    for param in model_ft.parameters():
        param.requires_grad = False
    
    # Chỉnh sửa lại layer cuối, thằng này ko đóng băng
    num_ftrs = model_ft.linear.in_features
    model_ft.linear = nn.Linear(num_ftrs, 9)

    model_ft = model_ft.to(device)
    
    # Định nghĩa loss, optimizer, learning rate schedule
    criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

Bây giờ định nghĩa hàm train để train model, phần này sẽ khá quen thuộc với bạn nào đã học qua Pytorch:

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Chuyển giữa các phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

Và tiến hành train với 25 epochs:

# Tiến hành train
    model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                           num_epochs=25)

Sau khoảng vài epochs thì model đã đạt Acc rất tốt trên tập val:

Epoch 0/24
----------
train Loss: 1.3389 Acc: 0.5994
val Loss: 0.4116 Acc: 0.9672

Epoch 1/24
----------
train Loss: 0.4240 Acc: 0.9213
val Loss: 0.1852 Acc: 0.9878

Epoch 2/24
----------
train Loss: 0.3064 Acc: 0.9385
val Loss: 0.1319 Acc: 0.9956

Epoch 3/24
----------
train Loss: 0.2568 Acc: 0.9468
val Loss: 0.1045 Acc: 0.9972

Epoch 4/24
----------
train Loss: 0.2274 Acc: 0.9546
val Loss: 0.0870 Acc: 0.9978

Epoch 5/24
----------
train Loss: 0.2144 Acc: 0.9513
val Loss: 0.0764 Acc: 0.9983

Sau khi train xong, bạn thấy okie thì có thể save model ra file hoặc sử dụng để inference luôn được rồi. Phần này là pytorch cơ bản như:

# Save model
torch.save(model_ft,'mymodel.pth')
# Load model
model = torch.load('mymodel.pth')

Ở đây mình nhấn mạnh là mình không tập trung vào loss, acc của model mà mình demo training transfer learning là chính thôi nhé!

Rồi, vậy là mình đã guide các bạn cách train từ đầu cũng như transfer learning với RepVGG. Hi vọng giúp được các bạn. Các script này mình sẽ copy lên github để các bạn tham khảo: https://github.com/thangnch/MIAI_RepVGG

Chúc các bạn thành công!

#MìAI

Fanpage: http://facebook.com/miaiblog
Group trao đổi, chia sẻ: https://www.facebook.com/groups/miaigroup
Website: https://miai.vn/
Youtube: http://bit.ly/miaiyoutube

Tags:

repVGGtrain RepVGGtransfer learningtransfer learning repVGG
Author

Chủ tiệm Mì

Follow Me
Other Articles
Previous

RepVGG – hậu thế của VGG tái xuất giang hồ (Phần 1/2)

Next

Thử xây dựng model chống giả mạo giao dịch thẻ Ngân hàng – Mì AI

One Comment
  1. RepVGG - hậu thế của VGG tái xuất giang hồ (Phần 1/2) - Mì AI says:
    May 12, 2021 at 10:30 am

    […] Update: Phần 2 đã có tại đây: https://www.miai.vn/2021/05/12/train-repvgg-tu-dau-va-transfer-learning-phan-2-2/ […]

    Reply

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

Recent Posts

  • Tìm hiểu và cài đặt OpenClaw – trợ lý ảo 24/7 thông minh đa chức năng – Mì Ai
  • Dùng thử Pika – robot học Tiếng Anh cho trẻ cực đỉnh – Mì AI
  • TopView.AI 4.0 – nền tảng tạo AI video cộng tác bá đạo – Mì AI
  • Storm MCP – giải pháp nhanh gọn nhẹ để có MCP Server trong 5 phút – Mì AI
  • VoxCPM thử voice cloning với checkpoint finetune Tiếng Việt – Mì AI

Recent Comments

  1. Chủ tiệm Mì on Thử xây dựng hệ thống Agentic AI với LangGraph – Mì AI
  2. Nguyễn Chiến Thắng on [Nhận diện biển số xe] Chương 3 – Phát hiện biển số bằng OpenCV thuần
  3. Trần Sơn Dương on [Nhận diện biển số xe] Chương 3 – Phát hiện biển số bằng OpenCV thuần
  4. Salomon on [CV] Thử làm model cảnh báo ngủ gật cho tài xế oto bằng Dlib và Resnet
  5. khang on Xây dựng hệ thống nhận diện thủ ngữ – ngôn ngữ ký hiệu tay – để giao tiếp với người khuyết tật

Categories

  • Basic
  • Computer Vision
  • Data Science – Data Analysis
  • Generative AI
  • MÌ ÚP
  • Natural Language Processing
  • RNN-LSTM-GRU
  • Share Data

Là người đi trước, hãy biết đưa tay lại phía sau.

Nguyễn Chiến Thắng
Cảm ơn các bạn đã ủng hộ Mì AI!