본문 바로가기
함수 설명/인공지능 (Pytorch)

[PyTorch] 배치(batch) 단위의 행렬 곱셈: torch.bmm() 설명

by First Adventure 2024. 8. 24.
반응형

소개

  torch.bmm은 PyTorch에서 3차원 텐서(batch of matrices) 간의 행렬 곱셈(batch matrix multiplication)을 수행하는 함수입니다. 이 함수는 배치(batch) 단위로 행렬 곱셈을 처리하며, 각 배치에서 두 2차원 텐서(행렬)의 곱셈을 병렬로 수행합니다. torch.bmm은 딥러닝 모델에서 여러 행렬을 동시에 곱해야 하는 경우 유용하며, 특히 RNN, CNN 등의 네트워크에서 자주 사용됩니다.

 

기본 사용법

상세 설명

  • 배치 크기
    • torch.bmm은 3D 텐서를 입력으로 받으며, 첫 번째 차원이 배치 크기를 나타냅니다. 각 배치에서 2차원 텐서(행렬) 간의 곱셈이 병렬로 실행되어 결과가 출력됩니다.
    • 예를 들어, 텐서 크기가 (10, 3, 4)인 경우 10개의 3x4 행렬이 있으며, 이와 곱할 텐서는 (10, 4, 5) 크기를 가져야 하며, 결과는 (10, 3, 5) 크기의 텐서가 됩니다.
  • 행렬 차원
    • 두 텐서 간의 행렬 곱셈을 수행하려면, 첫 번째 텐서의 열 수와 두 번째 텐서의 행 수가 같아야 합니다. 그렇지 않으면 오류가 발생합니다.
    • torch.bmm은 배치 행렬 곱셈에 최적화되어 있으며, RNN이나 CNN과 같은 네트워크에서 다수의 행렬을 병렬로 처리할 때 유용합니다.
    • RNN의 배치 처리나 CNN에서의 필터 연산 등 여러 행렬을 동시에 곱해야 하는 경우 사용됩니다.

예시 설명

  • 첫 번째 예시에서 torch.bmm(tensor1, tensor2)는 두 개의 3D 텐서 간의 행렬 곱셈을 수행하며, 각각 10개의 3x4와 4x5 행렬을 곱하여 10개의 3x5 행렬을 생성합니다.
  • 두 번째 예시에서는 작은 배치 크기(2)의 2x2 행렬을 곱하여 결과를 출력합니다.
import torch

# 두 3D 텐서 생성 (배치 크기 x 행 x 열)
tensor1 = torch.randn(10, 3, 4)  # 10개의 3x4 행렬
tensor2 = torch.randn(10, 4, 5)  # 10개의 4x5 행렬

# 배치 행렬 곱셈
result = torch.bmm(tensor1, tensor2)
print(result.size())
# 출력: torch.Size([10, 3, 5])

# 작은 배치 크기의 행렬 곱셈
tensor1 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])  # 2개의 2x2 행렬
tensor2 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])  # 2개의 2x2 행렬

# 배치 행렬 곱셈
result = torch.bmm(tensor1, tensor2)
print(result)
# 출력: tensor([[[ 7, 10],
#               [15, 22]],

#              [[67, 78],
#               [91, 106]]])

 

라이센스

  PyTorch의 표준 라이브러리와 내장 함수들은 BSD-style license 하에 배포됩니다. 이 라이센스는 자유 소프트웨어 라이센스로, 상업적 사용을 포함한 거의 모든 용도로 사용이 가능합니다. 라이센스와 저작권 정보는 PyTorch의 공식 GitHub 리포지토리에서 확인할 수 있습니다.

 

관련 내용

  

 

반응형