Deep Learning

[PyTorch] mul, matmul, mm 차이

y2r1m 2023. 4. 10. 00:08

torch.mul()

  • 두 행렬의 단순 곱 연산 element wise 곱 (⊙)
  • 같은 위치의 원소끼리 곱함
import torch
a = torch.tensor([[2, 2], 
                  [3, 3]])
b = torch.tensor([[4, 4],
                  [5, 5]])
                  
torch.mul(a, b)
# 출력 : tensor([[ 8,  8],
#        	[15, 15]])

 

torch.matmul()

  • 행렬 곱셈
  • broadcasting 가능     

*broadcasting : 사이즈가 맞지 않아도 자동으로 맞춰 계산해주는 기능

import torch
a = torch.tensor([[2, 2], 
                  [3, 3]])
b = torch.tensor([[4, 4],
                  [5, 5]])
c = torch.tensor([[[2, 2],
                  [2, 2]]])
                  
torch.matmul(a, b)
# 출력 : tensor([[18, 18],
#        	[27, 27]])

torch.matmul(a, c) # broadcasting
# 출력 : tensor([[[ 8,  8],
#         	  [12, 12]]])

 

torch.mm()

  • 행렬 곱셈
  • broadcasting 불가
  • size를 맞춰줘야 함
  • 명시적으로 행렬 곱셈 과정을 파악하기 위해서는 mm을 사용하는 것이 좋음
import torch
a = torch.tensor([[2, 2], 
                  [3, 3]])
b = torch.tensor([[4, 4],
                  [5, 5]])
c = torch.tensor([[[2, 2],
                  [2, 2]]])
                  
torch.mm(a, b)
# 출력 : tensor([[18, 18],
#        	[27, 27]])

torch.mm(a, c)
# RuntimeError 발생