QBoard » Artificial Intelligence & ML » AI and ML - PyTorch » How to initialize weights in PyTorch?

How to initialize weights in PyTorch?

  • How to initialize the weights and biases (for example, with He or Xavier initialization) in a network in PyTorch?
      August 24, 2020 4:34 PM IST
    0
  • To initialize the weights of a single layer, use a function from torch.nn.init. For instance:
    conv1 = torch.nn.Conv2d(...)
    torch.nn.init.xavier_uniform(conv1.weight)​
    Alternatively, you can modify the parameters by writing to conv1.weight.data (which is a torch.Tensor). Example
    conv1.weight.data.fill_(0.01)

    The same applies for biases:

    conv1.bias.data.fill_(0.01)

    nn.Sequential or custom nn.Module:

    Pass an initialization function to torch.nn.Module.apply. It will initialize the weights in the entire nn.Module recursively.

    apply(fn): Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also torch-nn-init).

    Example:
    def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform(m.weight)
            m.bias.data.fill_(0.01)
    
    net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
    net.apply(init_weights)
      August 24, 2020 4:40 PM IST
    0
  • To initialize layers you typically don't need to do anything.

    PyTorch will do it for you. If you think about, this has lot of sense. Why should we initialize layers, when PyTorch can do that following the latest trends.

    Check for instance the Linear layer.

    In the __init__  method it will call Kaiming He init function.

    def reset_parameters(self):
            init.kaiming_uniform_(self.weight, a=math.sqrt(3))
            if self.bias is not None:
                fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
                bound = 1 / math.sqrt(fan_in)
                init.uniform_(self.bias, -bound, bound)
     

    The similar is for other layers types. For conv2d for instance check here.

    To note : The gain of proper initialization is the faster training speed. If your problem deserves special initialization you can do it afterwords.

    This post was edited by Jainew Nanda at September 17, 2020 12:15 PM IST
      September 17, 2020 12:14 PM IST
    0
  • import torch.nn as nn        
    
        # a simple network
        rand_net = nn.Sequential(nn.Linear(in_features, h_size),
                                 nn.BatchNorm1d(h_size),
                                 nn.ReLU(),
                                 nn.Linear(h_size, h_size),
                                 nn.BatchNorm1d(h_size),
                                 nn.ReLU(),
                                 nn.Linear(h_size, 1),
                                 nn.ReLU())
    
        # initialization function, first checks the module type,
        # then applies the desired changes to the weights
        def init_normal(m):
            if type(m) == nn.Linear:
                nn.init.uniform_(m.weight)
    
        # use the modules apply function to recursively apply the initialization
        rand_net.apply(init_normal)
      September 17, 2020 12:43 PM IST
    0
  • To initialise weights with a normal distribution use:

    torch.nn.init.normal_(tensor, mean=0, std=1)
    Or to use a constant distribution write:
    
    torch.nn.init.constant_(tensor, value)
    Or to use an uniform distribution:
    
    torch.nn.init.uniform_(tensor, a=0, b=1) # a: lower_bound, b: upper_bound

    You can check other methods to initialise tensors here

      September 17, 2020 12:45 PM IST
    0
  • If you want some extra flexibility, you can also set the weights manually.

    Say you have input of all ones:

     
    import torch
    import torch.nn as nn
    
    input = torch.ones((8, 8))
    print(input)
    tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.]])

    And you want to make a dense layer with no bias (so we can visualize):

    d = nn.Linear(8, 8, bias=False)

    Set all the weights to 0.5 (or anything else):

    d.weight.data = torch.full((8, 8), 0.5)
    print(d.weight.data)

    The weights:

    Out[14]: 
    tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000]])

    All your weights are now 0.5. Pass the data through:

    d(input)
    Out[13]: 
    tensor([[4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.]], grad_fn=<MmBackward>)

    Remember that each neuron receives 8 inputs, all of which have weight 0.5 and value of 1 (and no bias), so it sums up to 4 for each.

      September 17, 2020 12:50 PM IST
    0
  • If you see a deprecation warning 

    def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)
    
    net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
    net.apply(init_weights)
    
     
      December 22, 2020 3:02 PM IST
    0