VB.net 2010 视频教程 VB.net 2010 视频教程 python基础视频教程
SQL Server 2008 视频教程 c#入门经典教程 Visual Basic从门到精通视频教程
当前位置:
首页 > 编程开发 > 批处理教程 >
  • 『TensorFlow』批处理类

 
回到顶部

 基础知识

下面有莫凡的对于批处理的解释:

1
2
3
4
5
6
7
8
9
10
11
12
13
fc_mean,fc_var = tf.nn.moments(
    Wx_plus_b,
    axes=[0],
    # 想要 normalize 的维度, [0] 代表 batch 维度
    # 如果是图像数据, 可以传入 [0, 1, 2], 相当于求[batch, height, width] 的均值/方差, 注意不要加入 channel 维度
)
scale = tf.Variable(tf.ones([out_size]))
shift = tf.Variable(tf.zeros([out_size]))
epsilon = 0.001
Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b,fc_mean,fc_var,shift,scale,epsilon)
# 上面那一步, 在做如下事情:
# Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)
# Wx_plus_b = Wx_plus_b * scale + shift

 

回到顶部

tf.contrib.layers.batch_norm:封装好的批处理类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class batch_norm():
    '''batch normalization层'''
 
    def __init__(self, epsilon=1e-5,
                 momentum=0.9, name='batch_norm'):
        '''
        初始化
        :param epsilon:    防零极小值
        :param momentum:   滑动平均参数
        :param name:       节点名称
        '''
        with tf.variable_scope(name):
            self.epsilon = epsilon
            self.momentum = momentum
            self.name = name
 
    def __call__(self, x, train=True):
        # 一个封装了的会在内部调用batch_normalization进行正则化的高级接口
        return tf.contrib.layers.batch_norm(x,
                                            decay=self.momentum,        # 滑动平均参数
                                            updates_collections=None,
                                            epsilon=self.epsilon,
                                            scale=True,
                                            is_training=train,          # 影响滑动平均
                                            scope=self.name)

1.

Note: when training, the moving_mean and moving_variance need to be updated.
    By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
    need to be added as a dependency to the `train_op`. For example:
    
    ```python
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss)
    ```
    
    One can set updates_collections=None to force the updates in place, but that
    can have a speed penalty, especially in distributed settings.

 

2.

is_training: Whether or not the layer is in training mode. In training mode
        it would accumulate the statistics of the moments into `moving_mean` and
        `moving_variance` using an exponential moving average with the given
        `decay`. When it is not in training mode then it would use the values of
        the `moving_mean` and the `moving_variance`.
 

 

回到顶部

tf.nn.batch_normalization:原始接口封装使用

实际上tf.contrib.layers.batch_norm对于tf.nn.moments和tf.nn.batch_normalization进行了一次封装,这个类又进行了一次封装(主要是制订了一部分默认参数),实际操作时可以仅仅使用tf.contrib.layers.batch_norm函数,它已经足够方便了。

添加了滑动平均处理之后,也就是不使用封装,直接使用tf.nn.moments和tf.nn.batch_normalization实现的batch_norm函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
    with tf.variable_scope(scope):
        # beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
        # gamma = tf.get_variable(name='gamma', shape=[n_out],
        #                         initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
        batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
        ema = tf.train.ExponentialMovingAverage(decay=decay)
 
        def mean_var_with_update():
            ema_apply_op = ema.apply([batch_mean,batch_var])
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(batch_mean),tf.identity(batch_var)
                # identity之后会把Variable转换为Tensor并入图中,
                # 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制
 
        mean,var = tf.cond(phase_train,
                           mean_var_with_update,
                           lambda: (ema.average(batch_mean),ema.average(batch_var)))
       normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
    return normed

 

另一种将滑动平均展开了的方式,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def batch_norm(x, size, training, decay=0.999):
    beta = tf.Variable(tf.zeros([size]), name='beta')
    scale = tf.Variable(tf.ones([size]), name='scale')
    pop_mean = tf.Variable(tf.zeros([size]))
    pop_var = tf.Variable(tf.ones([size]))
    epsilon = 1e-3
 
    batch_mean, batch_var = tf.nn.moments(x, [012])
    train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
    train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
 
    def batch_statistics():
        with tf.control_dependencies([train_mean, train_var]):
            return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')
 
    def population_statistics():
        return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm')
 
return tf.cond(training, batch_statistics, population_statistics)

 注, tf.cond:流程控制,参数一True,则执行参数二的函数,否则执行参数三函数。

 

出处:https://www.cnblogs.com/hellcat/p/7380022.html

 



相关教程