Train RepVGG từ đầu và transfer learning (Phần 2/2)

Xin chào anh em Mì AI, như vậy trong bài trước chúng ta đã tìm hiểu qua về RepVGG và thử convert và inference RepVGG. Hôm nay chúng ta sẽ cùng nhau tìm cách train RepVGG với dữ liệu của chúng ta theo cả 2 hình thức là train từ đầu và transfer learning.

Link bài trước cho bạn nào chưa đọc: RepVGG – hậu thế của VGG tái xuất giang hồ (Phần 1/2).

Mình có tìm kiếm trên mạng thì nhiều bài giới thiệu về RepVGG nhưng các bài hướng dẫn train thì còn ít nên mạnh dạn viết bài này để guide các bạn. Nếu có sai sót mong các cao thủ đi qua bỏ quá!

Còn bây giờ thì let’s go!

Phần 1 – Train RepVGG từ đầu với dữ liệu của chúng ta

Như chúng ta đã biết thì RepVGG đã được train với bộ dữ liệu ImageNet và bài trước chúng ta đã sử dụng pretrain đó để inference thử và cho kết quả khá tốt.

RepVGG

Còn bây giờ chúng ta sẽ thử train với data của chúng ta. Bài hôm nay mình sử dụng Fish Dataset trên Kaggle tại địa chỉ https://www.kaggle.com/crowww/a-large-scale-fish-dataset. Đây là bộ dữ liệu gồm 9 loài cá, mỗi loại 1000 ảnh có độ lớn 3GB, khá ổn để có thể train from scratch bài này. Các bạn chú ý mà dữ liệu ít quá thì chớ dại nhé, dễ bị underfit lắm.

Bắt đầu thôi!

Bước 1 – Tiền xử lý dữ liệu

Dữ liệu Fish Dataset này sử dụng cho Image Segmentation nên tổ chức dữ liệu sẽ có thêm folder masking nữa. Trong bài này chúng ta tạm thời sẽ không sử dụng đến dữ liệu masking.

Các bạn tải file dữ liệu cá kia về giải nén. Sau đó tạo thư mục data và 3 thư mục con là raw, train và val. Sau đó copy dữ liệu tải về vào thư mục raw để đảm bảo có cấu trúc như sau:

train RepVGG

Okie rồi, bây giờ ta cần một đoạn script để convert dữ liệu trong thư mục raw vào 2 thư mục train và val. Ở đây mình chọn tỷ lệ chia dữ liệu là 80/20, nghĩa là 80% dữ liệu train và 20% để làm validation.

import os
import random
import shutil

data_root = "data"
data_raw = os.path.join(data_root, "raw")
data_train = os.path.join(data_root, "train")
data_val  = os.path.join(data_root, "val")

for folder in os.listdir(data_raw):
    if folder[0]!=".":
        file_list = []
        full_folder = os.path.join(data_raw, folder, folder)
        print("Folder ", full_folder)
        for file in os.listdir(full_folder):
            if file[0] != ".":
                full_file = os.path.join(full_folder, file)
                file_list.append(full_file)

        total_files = len(file_list)
        print("Total = ",total_files)
        random.shuffle(file_list)
        train_files = file_list[0:800]
        val_files = file_list[800:]
        print("Số file = train ",len(train_files))
        print("Số file = val ", len(val_files))


        train_folder  = os.path.join(data_train, folder)
        if not os.path.exists(train_folder):
            os.makedirs(train_folder)
        i=0
        for train_file in train_files:
            i+=1
            dest_file = os.path.join(train_folder, os.path.basename(train_file))
            print("Copy to train ", train_file, " to ", dest_file)

            shutil.copyfile(train_file, dest_file)

        print("Copy train ", i)
        val_folder = os.path.join(data_val, folder)
        if not os.path.exists(val_folder):
            os.makedirs(val_folder)
        i=0
        for val_file in val_files:
            i+=1
            # print("Copy to val ", val_file)
            shutil.copyfile(val_file, os.path.join(val_folder, os.path.basename(val_file)))
        print("Copy val ", i)

Code language: PHP (php)

Sau khi chạy đoạn script trên (trong github mình để ở file make_dataset.py) thì trong hai thư mục train và val sẽ có các folder con để train.

train RepVGG
Bước 2 – Tiến hành train với kiến trúc mạng RepVGG-A0

Okie, để train thì mình sẽ chạy đúng script train mà github gốc của tác giả cung cấp:

python train.py -a RepVGG-A0 --epochs 50  data Code language: CSS (css)

Trong câu lệnh trên thì:

  • Cái đoạn -a RepVGG-A0 là kiến trúc mạng sẽ sử dụng.
  • –epochs 50: Train 50 epochs
  • data : là tên folder chữ dữ liệu, chính là folder data mà chúng ta đã tạo ra ở bên trên.

Khi train mà bạn thấy màn hình hiển thị như sau là ổn rồi:

Epoch: [0][ 0/29]	Time 27.423 (27.423)	Data 21.897 (21.897)	Loss 6.9711e+00 (6.9711e+00)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)

Epoch: [0][10/29]	Time  0.950 ( 4.026)	Data  0.002 ( 2.848)	Loss 2.0947e+00 (3.3592e+00)	Acc@1  22.27 ( 18.04)	Acc@5  64.84 ( 61.29)
Epoch: [0][20/29]	Time  0.872 ( 3.216)	Data  0.001 ( 2.248)	Loss 1.5741e+00 (2.6636e+00)	Acc@1  38.28 ( 24.78)	Acc@5  92.19 ( 71.06)
Test: [0/8]	Time 19.879 (19.879)	Loss 8.3546e+03 (8.3546e+03)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)
 * Acc@1 11.111 Acc@5 55.556Code language: CSS (css)

Sau 50 epochs mình đạt được Top 1 Acc (thể hiện ở mục Acc@1) là khoảng 96% và Top 5 Acc (thể hiện Acc@5) là 99%. Khá ổn, có lẽ do dữ liệu bài này khá dễ phân biệt.

Bước 3 – Convert model và inference

Sau quá trình train, các bạn sẽ thấy xuất hiện 02 file trong folder chứ file train.py đó là:

  • checkpoint.pth.tar: Là file weights cuối cùng tạo ra từ model.
  • model_best.pth.tar: Là file weights tốt nhất mà model lưu lại được.

Tùy nhu cầu của các bạn mà các bạn lấy file weights tương ứng để sử dụng nhé.

Sau đó các bạn convert tương tự như Bài số 1:

python convert.py model_best.pth.tar model_best_deploy.pth -a RepVGG-A0Code language: CSS (css)

Và chúng ta sử dụng lại code inference của bài số 1 nhưng thay đoạn load_checkpoint(model, “RepVGG-A0-deploy.pth”) bằng load_checkpoint(model, “model_best_deploy.pth”) để load model của chúng ta.

Load ngon rồi thì bạn đưa các ảnh cá vào để predict thử xem có chuẩn không nhé! Có vẻ dữ liệu này khá dễ nên model nhận chuẩn vãi lúa 😀

Các điểm chú ý: Bài này sử dụng kiến trúc RepVGG-A0 đẻ train, mà cái thằng RepVGG này nó lại train trên ImageNet với số class vẫn là 1000 nhé trong khi chúng ta chỉ có 9 class. Tuy nhiên vẫn không ảnh hưởng lắm, model vẫn chạy ầm ầm. Bạn nào muốn chuẩn có thể mở file repvgg.py để sửa lại num_classes trong các block khai báo như bên dưới nhé. Bài này mình tạm để nguyên.

def create_RepVGG_A0(deploy=False):
    return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
                  width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy)Code language: PHP (php)

Phần 2 – Transfer Learning, sử dụng VGG với vai trò Feature Extractor

Trong phần transfer learning này, để đơn giản mình chỉ tập trung vào việc sử dụng RepVGG (với pretrained ImageNet) như một Feature Extractor (FE) sau đó customize lại các layer Fully Connected cuối cùng để phục vụ cho bài toán customize của mình.

Dữ liệu mình vẫn sử dụng bộ dữ liệu trong thư mục data đã tạo ra ở Phần 1 và load bằng data_loader của Pytorch

traindir = os.path.join("data", 'train')
    valdir = os.path.join("data", 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))


    train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=256, shuffle=(train_sampler is None),
        num_workers=8, pin_memory=True, sampler=train_sampler)

    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)Code language: PHP (php)

Sau đó mình tiến hành Transfer Learning. Cụ thể là đóng băng toàn bộ các layer của mạng RepVGG gốc để giữ lại các weights đã train. Ta chỉ thay đổi lại layer FC cuối cùng thành 9 class và tiến hành train để layer này học mà thôi:

    # Load model RepVGG với weights gốc ImageNet
    from repvgg import create_RepVGG_A0
    from utils import load_checkpoint
    model_ft = create_RepVGG_A0(deploy=True)
    load_checkpoint(model_ft, "RepVGG-A0-deploy.pth")

    # Đóng băng toàn bộ các layer của RepVGG   
    for param in model_ft.parameters():
        param.requires_grad = False
    
    # Chỉnh sửa lại layer cuối, thằng này ko đóng băng
    num_ftrs = model_ft.linear.in_features
    model_ft.linear = nn.Linear(num_ftrs, 9)

    model_ft = model_ft.to(device)
    
    # Định nghĩa loss, optimizer, learning rate schedule
    criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)Code language: PHP (php)

Bây giờ định nghĩa hàm train để train model, phần này sẽ khá quen thuộc với bạn nào đã học qua Pytorch:

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Chuyển giữa các phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return modelCode language: PHP (php)

Và tiến hành train với 25 epochs:

# Tiến hành train
    model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                           num_epochs=25)Code language: PHP (php)

Sau khoảng vài epochs thì model đã đạt Acc rất tốt trên tập val:

Epoch 0/24
----------
train Loss: 1.3389 Acc: 0.5994
val Loss: 0.4116 Acc: 0.9672

Epoch 1/24
----------
train Loss: 0.4240 Acc: 0.9213
val Loss: 0.1852 Acc: 0.9878

Epoch 2/24
----------
train Loss: 0.3064 Acc: 0.9385
val Loss: 0.1319 Acc: 0.9956

Epoch 3/24
----------
train Loss: 0.2568 Acc: 0.9468
val Loss: 0.1045 Acc: 0.9972

Epoch 4/24
----------
train Loss: 0.2274 Acc: 0.9546
val Loss: 0.0870 Acc: 0.9978

Epoch 5/24
----------
train Loss: 0.2144 Acc: 0.9513
val Loss: 0.0764 Acc: 0.9983

Sau khi train xong, bạn thấy okie thì có thể save model ra file hoặc sử dụng để inference luôn được rồi. Phần này là pytorch cơ bản như:

# Save model
torch.save(model_ft,'mymodel.pth')
# Load model
model = torch.load('mymodel.pth')Code language: PHP (php)

Ở đây mình nhấn mạnh là mình không tập trung vào loss, acc của model mà mình demo training transfer learning là chính thôi nhé!

Rồi, vậy là mình đã guide các bạn cách train từ đầu cũng như transfer learning với RepVGG. Hi vọng giúp được các bạn. Các script này mình sẽ copy lên github để các bạn tham khảo: https://github.com/thangnch/MIAI_RepVGG

Chúc các bạn thành công!

#MìAI

Fanpage: http://facebook.com/miaiblog
Group trao đổi, chia sẻ: https://www.facebook.com/groups/miaigroup
Website: https://miai.vn
Youtube: http://bit.ly/miaiyoutube

Related Post

One Reply to “Train RepVGG từ đầu và transfer learning (Phần 2/2)”

Leave a Reply

Your email address will not be published. Required fields are marked *