Nhìn đời theo cách của mạng CNN (visualize feature maps- heatmap)

Hello anh em, chúng ta đã làm việc với mạng CNN đã lâu lắm rồi, hôm nay chúng ta sẽ đi đến một món mới hơn là nhìn đời theo cách của mạng CNN nhé (visualize feature maps, heatmap).

Chúng ta sẽ cùng tìm hiểu xem khi ta đưa một vào mạng một hình ảnh (một cô gái chẳng hạn) thì ông CNN ông ấy nhìn thấy cô gái đó như nào? Có như ta nhìn không? Và tập trung vào phần nào của cô gái? Có giống chúng ta khi nhìn vào các cô gái thì hay tập trung vào “v..” ……………ý mình là vẻ đẹp không =))

Phần 1 – Nhắc lại chút về mạng CNN và cách CNN nhìn đời

Nguồn: Tại đây

Như các bài trước mình đã viết trên Mì AI Blog cũng như kênh youtube Mì AI, CNN sử dụng các khái niệm các filter với các kernel size khác nhau để quét qua tấm ảnh từ trên xuống dưới, từ trái sang phải (lý thuyết thôi nhé, thực tế train qua GPU nó chơi đồng thời luôn á). Sau khi quét xong, nó sẽ tạo ra (bằng cách nhân lần lượt ma trận filter và một phần của tấm ảnh) các feature maps – các đặc trưng trích ra được từ tấm ảnh.

Đó, và features maps chính là những gì mà CNN nhìn thấy từ tấm ảnh.

Rồi, nhưng ta sẽ thắc mắc trong những thứ mà CNN nó nhìn thấy thì nó tập trung vào phần nào nhất để có thể predict ra đúng class của tấm ảnh. Ví dụ: đưa vào một tấm ảnh, CNN predict ra là GIRL, vậy chúng ta cần biết được model đã dựa vào phần nào trong những phần nó nhìn thấy để dự đoán đây là girl.

Chúng ta sẽ cùng nhau tìm hiểu các vấn đề trên nhé! Trong bài này ta sẽ cùng mổ xẻ và tìm hiểu (visualize feature maps, heatmap) với một model cụ thể là VGG16.

Phần 2 – Tìm hiểu xem CNN nhìn thấy những gì? (visualize feature maps)

Thực ra là chúng ta sẽ in ra các Feature Maps ấy mà kaka!

Rồi, VGG16 thì mọi người rõ kiến trúc của nó rồi, cụ thể như sau:

Nguồn: Tại đây

Với model này, khi ta đưa ảnh 224x224x3 vào , nó sẽ trả ra 1 vector softmax để từ đó ta in ra class của tấm ảnh.

Ngoài ra, nhìn vào đó các bạn sẽ tháy có tất cả 5 convolution block và trong đó có 13 convolution layers (màu xanh blue). Kích thước output của các conv block như sau:

Nguồn: Tại đây

Và bây giờ muốn biết CNN nhìn thấy gì, ta sẽ in ra output sau các layer Conv là okie (chính là các Feature map). Để làm được điều này ta phải tiến hành load weights VGG16 pretrain trên Imagenet sau đó tạo ra một model mới vẫn dùng chung weights đó nhưng kiến trúc sẽ là

  • Input: Ảnh đầu vào
  • Output: Là các output tại các layer ta cần in feature maps

Ví dụ ta sẽ sửa như sau:

vgg_model = VGG16()
# Ta lấy ra 3 ouput sau layer 2, 9 và 17
output_layer_list = [2, 9, 17]
outputs = [vgg_model.layers[idx].output for idx in output_layer_list]
model = Model(inputs=vgg_model.inputs, outputs=outputs)Code language: PHP (php)

Sau khi chạy đoạn code trên thì model sẽ là biến lưu model mà chúng ta cần sử dụng để in ra các feature maps (chú ý ở đây là model 1 input và nhiều output nhé).

Rồi, bây giờ mình sẽ load 1 tấm ảnh con bò và đưa vào trong mạng để xem model nhìn con bò này như nào nào:

Nguồn: Tại đây

Tất cả đều dùng lệnh của Keras nhé:

from keras.applications.vgg16 import preprocess_input
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array

# Load ảnh
net_input_size = (224,244)
frame = load_img('cow.jpg', target_size=net_input_size)
frame = img_to_array(frame)
frame = expand_dims(frame, axis=0)
frame = preprocess_input(frame)

# Đưa vào model (mới tạo ấy nhé) để lấy output
feature_maps = model.predict(frame)Code language: PHP (php)

Sau bước này, ta sẽ có được 3 output sau 3 conv layer, mỗi output sẽ có size khác nhau nhưng đều có 64 filter trong mỗi ouput. Do đó, với mỗi output ta sẽ vẽ ra 64 filter này theo grid 8×8 (cho đẹp thôi, vẽ cách nào cũng được nhé).

import math
item_per_col = int(math.sqrt(feature_maps[0].shape[3])) # Căn bậc 2 của 64 là 8
for fm in feature_maps:
	idx = 1
	for _ in range(item_per_col):
		for _ in range(item_per_col):
			ax = pyplot.subplot(item_per_col, item_per_col, idx)
			ax.set_xticks([])
			ax.set_yticks([])
			pyplot.imshow(fm[0, :, :, idx-1])
			idx += 1
	pyplot.show()Code language: PHP (php)

Và kết quả sẽ có 3 ảnh là 3 output, đó chính là những gì CNN nhìn thấy đó. Các bạn có thấy các layer về sau sẽ nhìn bức ảnh một cách tổng quát, còn các layer đầu sẽ nhìn chi tiết bức ảnh không?

visualize feature maps
Output của Conv layer số 2
Output của Conv layer số 9
visualize feature maps
Output của Conv layer số 17

Đấy. model nhìn thấy thế đấy, và sau đó nó sẽ dựa vào các input nó nhìn thấy để predict ra đúng class của bức ảnh. Nó không nhìn thấy như chúng ta và xử lý như chúng ta nhỉ 😀

Tiếp theo chúng ta sẽ cùng tìm hiểu xem vậy với một bức ảnh con bò như này thì model sẽ tập trung vào phần nào nhất để có thể dự đoán đúng là con bò chứ ko ra con chó nhé!

Phần 3 – CNN chú ý vào phần nào của bức ảnh con bò? (visualize feature heatmap)

Đầu tiên ta cũng khởi tạo VGG16 và load ảnh như trên:

model = VGG16()
model.summary()
#
 load ảnh
net_input_size = (224,224)
frame = load_img('cow.jpg', target_size=net_input_size)
frame = img_to_array(frame)
frame = expand_dims(frame, axis=0)
frame = preprocess_input(frame)Code language: PHP (php)

Các bạn chú ý ở đây mình có sử dụng model.summary() mục đích là lấy ra tên conv layer cuối cùng, cụ thể là ‘block5_conv3’.

Sau đó chúng ta sử dụng GradientTape của tensorflow. Dịch thô là băng ghi lại gradient. Món này sẽ giúp ta lưu lại được gradient để tính heatmap:

with tf.GradientTape() as tape:
	# Tạo ra một model mới có 1 input và 2 output là output của model và output của conv layer cuối cùng
	last_conv_layer = model.get_layer('block5_conv3')
	new_model = tf.keras.models.Model([model.inputs], [model.output, last_conv_layer.output])

	# Đưa ảnh vào model mới để lấy output
	model_out, last_conv_layer = new_model(frame)

	# Lấy output có prob lớn nhất
	class_out = model_out[:, np.argmax(model_out[0])]

	# Tính gradient của class output đối với output của last_conv_layer
	grads = tape.gradient(class_out, last_conv_layer)


	# Tính giá trị trung bình của gradient, kết quả là 1 vector 512
	pooled_grads = K.mean(grads, axis=(0, 1, 2))Code language: PHP (php)

Tiếp đó là tính heatmap và xử lý qua 1 chút (để vẽ ColorMap cho tiện)


# Nhân pooled_grads với output của last_conv_layer và lấy mean để có heatmap.
# Chú ý last_conv_layer có size (1, 14,14,512)
# Output là heatmap size (1,14,14)
heatmap = tf.reduce_mean(tf.multiply(pooled_grads, last_conv_layer), axis=-1)

# Xử lý heat map, bỏ giá trị âm, scale lại giá trị về 0,1
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
heatmap = heatmap.reshape((14, 14))Code language: PHP (php)

Và cuối cùng là ta load ảnh và show hàng thôi:


# Vẽ heatmap lên ảnh
# Đọc ảnh con bò
img = cv2.imread('cow.jpg')

# Chỉnh lại heatmap
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)

# Vẽ heatmap lên ảnh
overlay_img = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
cv2.imshow("B",overlay_img)
cv2.waitKey()Code language: PHP (php)

Và đây, model của chúng ta tập trung vào phần đầu con bò để predict, khôn ra phết:

visualize feature maps heatmap

Bây giờ để thêm phần funny, mình thử 1 cái ảnh khác xem sao 😀

visualize heatmap

Bố ông model CNN tập trung toàn điểm “khôn” thế không biết =)). Khôn như ông quê tôi đầy 😀

Ok, vậy là mình đã guide các bạn visualize feature maps, heatmap và tìm hiểu xem model CNN nhìn đời qua lăng kính Convolution ra sao rồi. Hi vọng các bạn đã hiểu vấn đề. Nếu còn thắc mắc gì các bạn cứ post lên https://www.facebook.com/groups/miaigroup để cùng trao đổi nhé.

Hẹn gặp lại các bạn! Mình xin tặng bạn nào đọc đến đây link github thay lời cảm ơn nha: Tại đây.

Chúc các bạn thành công!

#MìAI

Fanpage: http://facebook.com/miaiblog
Group trao đổi, chia sẻ: https://www.facebook.com/groups/miaigroup
Website: https://miai.vn
Youtube: http://bit.ly/miaiyoutube

Tài liệu tham khảo:

  • https://github.com/Anil-matcha/EIP3-Assignments/blob/master/Gradcam.ipynb
  • https://machinelearningmastery.com/how-to-visualize-filters-and-feature-maps-in-convolutional-neural-networks/

Related Post

2 Replies to “Nhìn đời theo cách của mạng CNN (visualize feature maps- heatmap)”

Leave a Reply

Your email address will not be published. Required fields are marked *