반응형
소개
torch.nn.CrossEntropyLoss는 PyTorch에서 다중 클래스 분류 문제에 널리 사용되는 손실 함수입니다. 이 함수는 소프트맥스 함수와 엔트로피 손실 함수를 결합하여, 모델이 예측한 클래스 확률과 실제 클래스 간의 차이를 측정합니다. 이 손실 함수는 모델이 올바르게 예측할 수 있도록 가중치를 조정하는 데 사용됩니다.
기본 사용법
상세 설명
- CrossEntropyLoss의 작동 방식
- 이 함수는 모델의 출력(로짓, 즉 선형 변환 후의 값)을 받아 소프트맥스 함수로 클래스 확률을 계산하고, 이를 실제 라벨과 비교하여 손실을 계산합니다.
- 손실 값은 모델이 잘못 예측할수록 커지며, 올바르게 예측할수록 작아집니다. 모델의 가중치를 업데이트하기 위해 이 손실 값이 역전파 과정에서 사용됩니다.
- CrossEntropyLoss의 적용 분야
- 다중 클래스 분류 문제에서 널리 사용됩니다. 예를 들어, 이미지 분류, 텍스트 분류, 다중 카테고리 예측 등에서 활용됩니다.
- 신경망의 출력이 여러 클래스에 대한 확률을 반환해야 할 때 사용됩니다.
예시 설명
- weight: 각 클래스에 대해 가중치를 부여하여 특정 클래스가 학습에서 더 중요하게 반영되도록 할 수 있습니다.
- size_average: 손실을 계산한 후 평균을 구할지 여부를 결정합니다. PyTorch 0.4부터는 reduction 파라미터로 대체되었습니다.
- ignore_index: 특정 클래스 인덱스를 무시하고 손실 계산에서 제외할 수 있습니다.
- reduction: 손실을 계산한 후 결과를 어떻게 처리할지를 지정합니다. 가능한 값은 none, mean, sum입니다.
- criterion = nn.CrossEntropyLoss()는 분류 문제에 적합한 손실 함수를 정의하는 코드로, 모델의 예측값과 실제값 간의 교차 엔트로피 손실을 계산합니다.
- loss = criterion(predictions, targets)는 예측값 predictions과 실제 클래스 라벨 targets 간의 손실을 계산하여 모델이 얼마나 잘 예측했는지를 평가합니다.
import torch
import torch.nn as nn
# CrossEntropyLoss 함수 정의
criterion = nn.CrossEntropyLoss()
# criterion = nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduction='mean')
# 예측값 (로그잇 확률)과 실제 클래스 라벨 생성
predictions = torch.tensor([[1.5, 0.3, -1.2], [1.2, 0.8, -0.5]])
targets = torch.tensor([0, 2])
# 손실 계산
loss = criterion(predictions, targets)
print(loss.item())
# 출력 예시: 1.4401898384094238
라이센스
PyTorch의 표준 라이브러리와 내장 함수들은 BSD-style license 하에 배포됩니다. 이 라이센스는 자유 소프트웨어 라이센스로, 상업적 사용을 포함한 거의 모든 용도로 사용이 가능합니다. 라이센스와 저작권 정보는 PyTorch의 공식 GitHub 리포지토리에서 확인할 수 있습니다.
관련 내용
[PyTorch] 신경망의 기본 구성 요소: torch.nn.Linear() 사용 가이드
[PyTorch] CNN 모델의 기초: torch.nn.Conv2d() 사용 가이드
[PyTorch] CNN에서 풀링 계층 활용하기: torch.nn.MaxPool2d() 사용 가이드
[PyTorch] 비선형성을 추가하는 핵심: torch.nn.ReLU() 사용 가이드
[PyTorch] 다중 클래스 분류에서 필수: torch.nn.CrossEntropyLoss() 사용 가이드
반응형
'함수 설명 > 인공지능 (Pytorch)' 카테고리의 다른 글
[PyTorch] CNN에서 풀링 계층 활용하기: torch.nn.MaxPool2d() 사용 가이드 (0) | 2024.08.17 |
---|---|
[PyTorch] 비선형성을 추가하는 핵심: torch.nn.ReLU() 사용 가이드 (0) | 2024.08.17 |
[PyTorch] 회귀 문제에서 필수: torch.nn.MSELoss() 사용 가이드 (0) | 2024.08.17 |
[PyTorch] 모델 학습을 최적화하는 스케줄러: torch.optim.lr_scheduler.StepLR() 활용법 (0) | 2024.08.17 |
[PyTorch] 딥러닝 최적화 알고리즘: torch.optim.SGD() 사용 가이드 (0) | 2024.08.17 |