본문 바로가기
Developer/Machine Learning

[Tensorflow] MNIST 데이터셋 CNN 기본 예제

by Doony 2020. 1. 12.

지난 포스팅에서 CIFAR-10 데이터셋을 다루는 법에 대해 알아보았습니다. 자세한 포스팅은 여기에서 확인하실 수 있습니다.
이번에는 그보다 더 기본학습예제인 MNIST 데이터셋에 대해 알아보고, CNN 예제 코드를 살펴보겠습니다.


MNIST DATASET?

머신러닝을 공부할 때 아주 유용한 데이터셋입니다. 손글씨로 이루어진 숫자(0~9) 흑백 이미지들이며, 28x28 픽셀 크기로 그 크기는 작습니다. CIFAR 데이터보다 훨씬 가볍기 때문에 이런저런 알고리즘을 테스트하기 편합니다. 특히 집에서 일반 컴퓨터로 작업할 때는 CIFAR보단 MNIST가 가볍고 좋은 것 같네요.


CNN 딥러닝 모델

코드는 아주 단순합니다. 기본적인 CNN 레이어를 두층 만들어놨는데요. 정확도는 거의 1에 수렴할만큼 탁월한 성능을 보여주고 있습니다. 코드는 다음과 같습니다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import tensorflow as tf
 
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)
tf.set_random_seed(777)
 
learning_rate = 0.001
training_epochs = 10
batch_size = 100
 
= tf.placeholder(tf.float32, [None, 784])
X_img = tf.reshape(X, [-128281])
= tf.placeholder(tf.float32, [None, 10])
 
 
W1 = tf.Variable(tf.random_normal([33132], stddev = 0.01))
L1 = tf.nn.conv2d(X_img, W1, strides = [1111], padding = 'SAME')
L1 = tf.nn.relu(L1)
L1 = tf.nn.max_pool(L1, ksize = [1221], strides = [1221], padding = 'SAME')
 
 
W2 = tf.Variable(tf.random_normal([333264], stddev = 0.01))
L2 = tf.nn.conv2d(L1, W2, strides = [1111], padding = 'SAME')
L2 = tf.nn.relu(L2)
L2 = tf.nn.max_pool(L2, ksize = [1221], strides = [1221], padding = 'SAME')
L2_flat = tf.reshape(L2, [-17*7*64])
 
W3 = tf.Variable(tf.random_normal([7*7*6410], stddev = 0.01))
= tf.Variable(tf.random_normal([10]))
logits = tf.matmul(L2_flat, W3) + b
 
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = logits, labels = Y))
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(cost)
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 
# 텐서보드 활용하기 위한 스칼라 정의
cost_train_sum = tf.summary.scalar("cost_train", cost)
accuracy_train_sum = tf.summary.scalar("accuracy_train", accuracy)
 
 
sess = tf.Session()
sess.run(tf.global_variables_initializer())
 
 
# 모델 저장을 위한 부분
import os
save_file = './model_mnist.ckpt'
SAVER_DIR = "modelMNIST"
saver = tf.train.Saver()
checkpoint_path = os.path.join(SAVER_DIR, "modelMNIST")
ckpt = tf.train.get_checkpoint_state(SAVER_DIR)
 
# 텐서보드 
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('./mygraph/train', sess.graph)
writer_test = tf.summary.FileWriter('./mygraph/test', sess.graph)
 
 
print('Learning Started')
for epoch in range(training_epochs):
    avg_cost = 0
    avg_cost_test = 0
    total_batch = int(mnist.train.num_examples / batch_size)
    
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(bach_size)
        feed_dict = {X: batch_xs, Y: batch_ys}
        summary, c, a, _ = sess.run([merged, cost, accuracy, optimizer], feed_dict = feed_dict)
        avg_cost += c / total_batch
        
        batch_test_xs, batch_test_ys = mnist.test.next_batch(batch_size)
        feed_dict_test = {X: batch_test_xs, Y: batch_test_ys}
        summary_test, c_test, a_test = sess.run([merged, cost, accuracy], feed_dict = feed_dict_test)
        avg_cost_test += c_test / total_batch
        
    print('Epoch: ''%04d' % (epoch+1), 'Train cost = ''{:.9f}'.format(avg_cost))
    print('Epoch: ''%04d' % (epoch+1), 'Train acc = ''{:.9f}'.format(a))
    print('Epoch: ''%04d' % (epoch+1), 'Test cost = ''{:.9f}'.format(avg_cost_test))
    print('Epoch: ''%04d' % (epoch+1), 'Test acc = ''{:.9f}'.format(a_test))
    print('\n')
    
    # Tensorboard에서 epoch별 스칼라값 확인하기 위함
    writer.add_summary(summary, global_step = epoch)
    writer_test.add_summary(summary_test, global_step = epoch)
    
    # epoch별 모델 체크포인트 저장
    saver.save(sess, checkpoint_path, global_step = epoch)
 
print('Learning finished')
writer.close()
saver.save(sess, save_file)
 
cs

딥러닝으로 꼭 해야하나?

MNIST 데이터셋은 특징이 있습니다. CIFAR 데이터셋 대비 훨씬 규칙성이 있다는 점입니다. 아무래도 숫자다보니, 동일한 레이블을 가지는 이미지들이 비교적 비슷한? 경향이 있습니다.
때문에 꼭 딥러닝이 아닌, 다른 머신러닝 알고리즘으로도 좋은 성능을 가지는 모델을 만들 수 있습니다.
자세한 사항은 여기에 정리되어 있습니다. 다음에 다른 알고리즘으로도 학습하여 비교해보는 것도 좋은 공부가 될 것 같네요.

댓글