torch.mul()
- 두 행렬의 단순 곱 연산 ⇒ element wise 곱 (⊙)
- 같은 위치의 원소끼리 곱함
import torch
a = torch.tensor([[2, 2],
[3, 3]])
b = torch.tensor([[4, 4],
[5, 5]])
torch.mul(a, b)
# 출력 : tensor([[ 8, 8],
# [15, 15]])
torch.matmul()
- 행렬 곱셈
- broadcasting 가능
*broadcasting : 사이즈가 맞지 않아도 자동으로 맞춰 계산해주는 기능
import torch
a = torch.tensor([[2, 2],
[3, 3]])
b = torch.tensor([[4, 4],
[5, 5]])
c = torch.tensor([[[2, 2],
[2, 2]]])
torch.matmul(a, b)
# 출력 : tensor([[18, 18],
# [27, 27]])
torch.matmul(a, c) # broadcasting
# 출력 : tensor([[[ 8, 8],
# [12, 12]]])
torch.mm()
- 행렬 곱셈
- broadcasting 불가
- size를 맞춰줘야 함
- 명시적으로 행렬 곱셈 과정을 파악하기 위해서는 mm을 사용하는 것이 좋음
import torch
a = torch.tensor([[2, 2],
[3, 3]])
b = torch.tensor([[4, 4],
[5, 5]])
c = torch.tensor([[[2, 2],
[2, 2]]])
torch.mm(a, b)
# 출력 : tensor([[18, 18],
# [27, 27]])
torch.mm(a, c)
# RuntimeError 발생