Implement Batch Normalization By Computational Graph

import numpy as np

class BatchNorm():
  def __init__(self, D, eps, momentum):
      self.gamma = np.ones((1, D))
      self.beta = np.zeros((1, D))
      self.eps = eps
      self.running_mean = np.zeros(D)
      self.running_var = np.zeros(D)
      self.momentum = momentum
      self.x_hat = None
      self.sample_mean = np.zeros(D)
      self.sample_var = np.zeros(D)
  
  def forward(self, x, mode):
      n = x.shape[0]
      # reshape the input to shape (N, D)      
      x_flat = x.ravel().reshape(n, -1)
      out = None

      if mode == 'train':
        self.sample_mean = np.mean(x_flat, axis = 0)
        self.sample_var = np.var(x_flat, axis = 0)
        self.x_hat = (x_flat - self.sample_mean) / np.sqrt(self.sample_var + self.eps)
        out = self.gamma * self.x_hat + self.beta
        # We use a running average to get sample mean and variance for testing phase.
        # Momentum is the importance given to the last seen mini-batch, a.k.a “lag”. If the momentum is set to 0, the running mean and 
        # variance come from the last seen mini-batch. However, this may be biased and not the desirable one for testing. 
        # Conversely, if momentum is set to 1, it uses the running mean and variance from the first mini-batch. Essentially, 
        # momentum controls how much each new mini-batch contributes to the running averages.
        self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * self.sample_mean
        self.running_var = self.momentum * self.running_var + (1 - self.momentum) * self.sample_var
      elif mode == 'test':
        x_hat = (x_flat - self.running_mean) / np.sqrt(self.running_var + self.eps)
        out = self.gamma * x_hat + self.beta

      return out
  
  def backward(self, x, dout):
      n = x.shape[0]
      x_flat = x.ravel().reshape(n, -1)
      dgamma = np.sum(dout * self.x_hat, axis = 0)
      dbeta = np.sum(dout, axis = 0)
      dx_hat = dout * self.gamma
      dsigma = -0.5 * np.sum(dx_hat * (x_flat - self.sample_mean), axis=0) * np.power(self.sample_var + self.eps, -1.5)
      dmu = -np.sum(dx_hat / np.sqrt(self.sample_var + self.eps), axis=0) - 2 * dsigma * np.sum(x_flat - self.sample_mean, axis=0) / n
      dx = dx_hat /np.sqrt(self.sample_var + self.eps) + 2.0 * dsigma * (x_flat - self.sample_mean) / n + dmu / n
      return dx, dgamma, dbeta

Last updated

Was this helpful?