Tensorflowでbatch normlization(バッチ正規化)を使用する際の注意点(常に同じものしか出力しなくなったら)

バッチ正規化とは

書く層ごとに分散が1平均が0になるようにするらしい

これを使うことでL2正規化やdropoutの必要性を減らせたり過学習を防げたりといろいろなメリットがあるらしい

こいつを使わないと人生損してるぜHAHAHA的な話もあるらしい

 

 

注意点(落とし穴)

タイトルにも書いた通りちゃんと理屈を知らないと常に同じものしか出力しなくなる

私は理屈知らないのに実装して出力全部同じになった

Tensorflowで実装する際は学習中と学習後で別の処理をしなくてはいけないらしい

 

 

具体的に言うと学習中は与えられるミニバッチの平均や分散を計算する必要があるが学習後は入力に対してそのような処理を行う必要はないらしい

 

じゃあどうすればいいの

Implementing Batch Normalization in Tensorflow - R2RT

この記事の

def batch_norm_wrapper(inputs, is_training, decay = 0.999):

    scale = tf.Variable(tf.ones([inputs.get_shape()[-1]]))
    beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]]))
    pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
    pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)

    if is_training:
        batch_mean, batch_var = tf.nn.moments(inputs,[0])
        train_mean = tf.assign(pop_mean,
                               pop_mean * decay + batch_mean * (1 - decay))
        train_var = tf.assign(pop_var,
                              pop_var * decay + batch_var * (1 - decay))
        with tf.control_dependencies([train_mean, train_var]):
            return tf.nn.batch_normalization(inputs,
                batch_mean, batch_var, beta, scale, epsilon)
    else:
        return tf.nn.batch_normalization(inputs,
            pop_mean, pop_var, beta, scale, epsilon)

のようにバッチ正規化を実装するとよい