MNIST分类

卷积神经网络

Posted by Gavin on August 11, 2019

青衫烟雨客

似是故人来

前言

介绍

之前我们尝试过传统方式对手写数据集MNIST进行分类,这次使用卷积神经网络再次尝试一下。

卷积层利用卷积核对图像进行各种操作,在CNN中来说,卷积层大多是进行特征生成。但这样往往会造成内存吃不消,因此需要引入池化层,提取重要特征,缩小样本。

实验

数据准备

#prepare data
mnist = fetch_mldata('MNIST original')
data, target = mnist['data'], mnist['target']
one_hot_encoder = OneHotEncoder()
target = one_hot_encoder.fit_transform([[x] for x in target]).todense()
X_train, X_test, y_train, y_test = data[:60000], data[60000:], target[:60000], target[60000:]

使用的sklearn中的MNIST数据,由于tensorflow的缘故,将其中标签列进行独热编码。

建立计算图

  1. 准备输入

     X = tf.placeholder(tf.float32,[None,784])
     y = tf.placeholder(tf.float32,[None,10])
     X_image = tf.reshape(X,[-1,28,28,1])
    

    图片尺寸:28*28

  2. 第一个卷积层

     conv1 = tf.layers.conv2d(X_image,filters=32, kernel_size=3, strides=1, padding="SAME",activation=tf.nn.relu, name="conv1")
    

    图片尺寸:28*28,32个特征生成

  3. 第二个卷积层

     conv2 = tf.layers.conv2d(conv1,filters=64, kernel_size=3, strides=2, padding="SAME",activation=tf.nn.relu, name="conv2")
    

    图片尺寸:14*14,64个特征生成

  4. 池化层

     pool3 = tf.nn.max_pool(conv2,ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1],padding="VALID")
     pool3_flat = tf.reshape(pool3, shape=[-1, 64 * 7 * 7])
    

    图片尺寸:7*7,64个特征
    并将其合并为一列,为连接层准备

  5. 连接层与输出层

     fc1 = tf.layers.dense(pool3_flat, 64, activation=tf.nn.relu, name="fc1")
     logits = tf.layers.dense(fc1, 10, name="output")
     Y_proba = tf.nn.softmax(logits, name="Y_proba")
    
  6. 损失函数与训练

     xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=Y_proba, labels=y)
     loss = tf.reduce_mean(xentropy)
     optimizer = tf.train.AdamOptimizer()
     training_op = optimizer.minimize(loss)
    
  7. 记录

     now = datetime.utcnow().strftime('%Y%m%d%H%M%S')
     root_logdir = 'tf_logs'
     logdir = '{}/run-{}/'.format(root_logdir, now)
     loss_change = tf.summary.scalar('loss',loss)
     file_writer = tf.summary.FileWriter(logdir,tf.get_default_graph())
    
  8. 评估

     y_pred = tf.arg_max(Y_proba,1)
     correct = tf.equal(tf.arg_max(y,1),y_pred)
     accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    
  9. 开始训练

     init = tf.global_variables_initializer()
     saver = tf.train.Saver()
    	
     #prepare train
     def shuffle_batch(X, y, batch_size):
         rnd_idx = np.random.permutation(len(X))
         n_batches = len(X) // batch_size
         for batch_idx in np.array_split(rnd_idx, n_batches):
             X_batch, y_batch = X[batch_idx], y[batch_idx]
             yield X_batch, y_batch
    	
     #train
     n_epochs = 10
     batch_size = 100
     with tf.Session() as sess:
         init.run()
         for epoch in range(n_epochs):
             for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):
                 sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
             loss_batch = loss.eval(feed_dict={X: X_batch, y: y_batch})
             acc_batch = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
             acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
             print(epoch, "Loss:", loss_batch, "Last batch accuracy:", acc_batch, "Test accuracy:", acc_test)
             save_path = saver.save(sess, "./my_mnist_model")
             save_path = saver.save(sess, "./my_mnist_model")
         	file_writer.add_summary(loss_batch,epoch)
    

查看训练情况

可知训练陷入震荡状态

计算图如上所示
初步怀疑出现震荡是训练度不够,因此尝试再训练一个轮回,看看结果。

有所提升
损失函数也出现了下降
又训练了一轮,发现始终在震荡
死马当活马医一回,试着将batch_size设置为1看看
放弃,并没有好转,识别正确率维持在67%左右。

错误分析

绘制混淆矩阵

绘制错误率矩阵
可知8很容易和其他数字混淆


后记

sklearn实现

class CNN(BaseEstimator, ClassifierMixin):
    """docstring for CNN"""
    def __init__(self):
        self._session = None    
    def _build_graph(self):
        X = tf.placeholder(tf.float32,[None,784])
        y = tf.placeholder(tf.float32,[None,10])
        X_image = tf.reshape(X,[-1,28,28,1])
        conv1 = tf.layers.conv2d(X_image,filters=32, kernel_size=3, strides=1, padding="SAME",activation=tf.nn.relu, name="conv1")
        conv2 = tf.layers.conv2d(conv1,filters=64, kernel_size=3, strides=2, padding="SAME",activation=tf.nn.relu, name="conv2")
        pool3 = tf.nn.max_pool(conv2,ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1],padding="VALID")
        pool3_flat = tf.reshape(pool3, shape=[-1, 64 * 7 * 7])       
        fc1 = tf.layers.dense(pool3_flat, 64, activation=tf.nn.relu, name="fc1")
        logits = tf.layers.dense(fc1, 10, name="output")
        Y_proba = tf.nn.softmax(logits, name="Y_proba")
        xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=Y_proba, labels=y)
        loss = tf.reduce_mean(xentropy)
        optimizer = tf.train.AdamOptimizer()
        training_op = optimizer.minimize(loss)
        y_pred = tf.arg_max(Y_proba,1)
        correct = tf.equal(tf.arg_max(y,1),y_pred)
        accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
        self._X, self._y = X, y
        self._Y_proba, self._loss = Y_proba, loss
        self._training_op, self._accuracy = training_op, accuracy
        self._init, self._saver = init, saver
        self._y_pred = y_pred
    def close_session(self):
        if self._session:
            self._session.close()
    def fit(self, X, y, n_epochs=100, batch_size=50):
        self.close_session()
        self._graph = tf.Graph()
        with self._graph.as_default():
            self._build_graph()
        self._session = tf.Session(graph=self._graph)
        with self._session.as_default() as sess:
            self._init.run()
            for epoch in range(n_epochs):
                for X_batch, y_batch in shuffle_batch(X, y, batch_size):
                    sess.run(self._training_op, feed_dict={self._X: X_batch, self._y: y_batch})
                acc_batch = self._accuracy.eval(feed_dict={self._X: X_batch, self._y: y_batch})
                acc_test = self._accuracy.eval(feed_dict={self._X: X_test, self._y: y_test})
                print(epoch,"Last batch accuracy:", acc_batch, "Test accuracy:", acc_test)
                if acc_test > 0.98:
                    print("Stop!!")
                    break
    def restore_model(self):
        self.close_session()
        self._graph = tf.Graph()
        with self._graph.as_default():
            self._build_graph()
        self._session = tf.Session(graph=self._graph)
        with self._session.as_default() as sess:
            self._saver.restore(sess,'my_mnist_model')
    def predict(self, X):
        if not self._session:
            raise NotFittedError("This %s instance is not fitted yet" % self.__class__.__name__)
        with self._session.as_default() as sess:
            return self._y_pred.eval(feed_dict={self._X: X})

匆忙,没有将使用到的参数变量化,以供之后使用