QBoard » Artificial Intelligence & ML » AI and ML - PyTorch » PyTorch: What is numpy.linalg.multi_dot() equivalent in PyTorch

PyTorch: What is numpy.linalg.multi_dot() equivalent in PyTorch

  • I am trying to perform matrix multiplication of multiple matrices in PyTorch and was wondering what is the equivalent of numpy.linalg.multi_dot() in PyTorch?

    If there isn't one, what is the next best way (in terms of speed and memory) I can do this in PyTorch?

    Code:

    import numpy as np
    import torch
    
    A = np.random.rand(3, 3)
    B = np.random.rand(3, 3)
    C = np.random.rand(3, 3)
    
    results = np.linalg.multi_dot(A, B, C)
    
    A_tsr = torch.tensor(A)
    B_tsr = torch.tensor(B)
    C_tsr = torch.tensor(C)
    
    # What is the PyTorch equivalent of np.linalg.multi_dot()?
      January 7, 2022 12:55 PM IST
    0
  • ~~Looks like one can send tensors into multi_dot~~

    Looks like the numpy implementation casts everything into numpy arrays. If your tensors are on the cpu and detached this should work. Otherwise, the conversion to numpy would fail.

    So in general - likely there isn't an alternative. I think your best shot is to take the multi_dot implementation, e.g. from here for numpy v1.19.0 and adjust it to handle tensors / skip the cast to numpy. Given the similar interface and the code simplicity I think that this should be pretty straightforward.

      January 8, 2022 2:31 PM IST
    0