본문 바로가기
함수 설명/인공지능 (Pytorch)

[PyTorch] 다중 클래스 분류에서 필수: torch.nn.CrossEntropyLoss() 사용 가이드

by First Adventure 2024. 8. 17.
반응형

소개

  torch.nn.CrossEntropyLoss는 PyTorch에서 다중 클래스 분류 문제에 널리 사용되는 손실 함수입니다. 이 함수는 소프트맥스 함수와 엔트로피 손실 함수를 결합하여, 모델이 예측한 클래스 확률과 실제 클래스 간의 차이를 측정합니다. 이 손실 함수는 모델이 올바르게 예측할 수 있도록 가중치를 조정하는 데 사용됩니다.

 

기본 사용법

상세 설명

  • CrossEntropyLoss의 작동 방식
    • 이 함수는 모델의 출력(로짓, 즉 선형 변환 후의 값)을 받아 소프트맥스 함수로 클래스 확률을 계산하고, 이를 실제 라벨과 비교하여 손실을 계산합니다.
    • 손실 값은 모델이 잘못 예측할수록 커지며, 올바르게 예측할수록 작아집니다. 모델의 가중치를 업데이트하기 위해 이 손실 값이 역전파 과정에서 사용됩니다.
  • CrossEntropyLoss의 적용 분야
    • 다중 클래스 분류 문제에서 널리 사용됩니다. 예를 들어, 이미지 분류, 텍스트 분류, 다중 카테고리 예측 등에서 활용됩니다.
    • 신경망의 출력이 여러 클래스에 대한 확률을 반환해야 할 때 사용됩니다.

예시 설명

  • weight: 각 클래스에 대해 가중치를 부여하여 특정 클래스가 학습에서 더 중요하게 반영되도록 할 수 있습니다.
  • size_average: 손실을 계산한 후 평균을 구할지 여부를 결정합니다. PyTorch 0.4부터는 reduction 파라미터로 대체되었습니다.
  • ignore_index: 특정 클래스 인덱스를 무시하고 손실 계산에서 제외할 수 있습니다.
  • reduction: 손실을 계산한 후 결과를 어떻게 처리할지를 지정합니다. 가능한 값은 none, mean, sum입니다.
  • criterion = nn.CrossEntropyLoss()는 분류 문제에 적합한 손실 함수를 정의하는 코드로, 모델의 예측값과 실제값 간의 교차 엔트로피 손실을 계산합니다.
  • loss = criterion(predictions, targets)는 예측값 predictions과 실제 클래스 라벨 targets 간의 손실을 계산하여 모델이 얼마나 잘 예측했는지를 평가합니다.
import torch
import torch.nn as nn

# CrossEntropyLoss 함수 정의
criterion = nn.CrossEntropyLoss()
# criterion = nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduction='mean')

# 예측값 (로그잇 확률)과 실제 클래스 라벨 생성
predictions = torch.tensor([[1.5, 0.3, -1.2], [1.2, 0.8, -0.5]])
targets = torch.tensor([0, 2])

# 손실 계산
loss = criterion(predictions, targets)
print(loss.item())
# 출력 예시: 1.4401898384094238

 

라이센스

  PyTorch의 표준 라이브러리와 내장 함수들은 BSD-style license 하에 배포됩니다. 이 라이센스는 자유 소프트웨어 라이센스로, 상업적 사용을 포함한 거의 모든 용도로 사용이 가능합니다. 라이센스와 저작권 정보는 PyTorch의 공식 GitHub 리포지토리에서 확인할 수 있습니다.

 

관련 내용

  [PyTorch] 신경망의 기본 구성 요소: torch.nn.Linear() 사용 가이드

  [PyTorch] CNN 모델의 기초: torch.nn.Conv2d() 사용 가이드

  [PyTorch] CNN에서 풀링 계층 활용하기: torch.nn.MaxPool2d() 사용 가이드

  [PyTorch] 비선형성을 추가하는 핵심: torch.nn.ReLU() 사용 가이드

  [PyTorch] 다중 클래스 분류에서 필수: torch.nn.CrossEntropyLoss() 사용 가이드

  [PyTorch] 회귀 문제에서 필수: torch.nn.MSELoss() 사용 가이드

반응형