반응형
소개
torch.optim.lr_scheduler.StepLR()은 PyTorch에서 학습률(learning rate)을 단계적으로 감소시키는 스케줄러입니다. 주어진 스텝 수마다 학습률을 특정 비율로 줄이는 방식으로, 모델이 학습 중에 수렴하는 속도를 제어하고, 최적화 과정에서 과도한 학습률로 인한 문제를 줄이기 위해 사용됩니다.
기본 사용법
상세 설명
- StepLR의 작동 방식
- 지정된 스텝 수(step_size)마다 학습률을 gamma 비율로 감소시킵니다.
- 학습이 진행됨에 따라 학습률을 단계적으로 줄여 모델이 더 안정적으로 수렴하도록 돕습니다.
- 학습 초기에는 큰 학습률을 유지해 빠르게 최적점을 찾아가고, 후반부에는 작은 학습률로 세밀하게 조정할 수 있습니다.
- StepLR의 적용 분야
- 큰 학습률로 시작해 점진적으로 학습률을 줄이며 학습시키는 전략에 유용합니다.
- 복잡한 학습 문제에서 초기에는 빠른 수렴을, 후반부에는 미세한 조정을 필요로 하는 상황에 적합합니다.
예시 설명
- step_size: 학습률을 조정할 주기를 지정합니다. 예를 들어, step_size=30으로 설정하면 매 30 에포크마다 학습률이 감소합니다.
- gamma: 학습률을 감소시키는 비율을 지정합니다. 예를 들어, gamma=0.1이면 학습률이 이전 학습률의 10%로 줄어듭니다.
- scheduler = StepLR(optimizer, step_size=30, gamma=0.1)은 학습률을 매 30 에포크마다 10%로 감소시킵니다.
- scheduler.step()을 호출하면 현재 학습률이 감소하게 됩니다. 이 함수는 일반적으로 학습 루프 내에서 에포크마다 호출됩니다.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
# 간단한 모델 정의
model = nn.Linear(10, 1)
# 옵티마이저 설정
optimizer = optim.SGD(model.parameters(), lr=0.1)
# StepLR 스케줄러 설정
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
# 학습 루프
for epoch in range(100):
optimizer.zero_grad() # 그래디언트 초기화
outputs = model(torch.randn(10)) # 모델의 출력 계산
loss = nn.MSELoss()(outputs, torch.randn(1)) # 손실 계산
loss.backward() # 역전파로 그래디언트 계산
optimizer.step() # 옵티마이저로 가중치 업데이트
scheduler.step() # 학습률 조정
print(f"Epoch {epoch+1}, LR: {scheduler.get_last_lr()[0]}")
라이센스
PyTorch의 표준 라이브러리와 내장 함수들은 BSD-style license 하에 배포됩니다. 이 라이센스는 자유 소프트웨어 라이센스로, 상업적 사용을 포함한 거의 모든 용도로 사용이 가능합니다. 라이센스와 저작권 정보는 PyTorch의 공식 GitHub 리포지토리에서 확인할 수 있습니다.
관련 내용
[PyTorch] 딥러닝 최적화 알고리즘: torch.optim.Adam() 사용 가이드
[PyTorch] 딥러닝 최적화 알고리즘: torch.optim.SGD() 사용 가이드
[PyTorch] 모델 학습을 최적화하는 스케줄러: torch.optim.lr_scheduler.StepLR() 활용법
반응형
'함수 설명 > 인공지능 (Pytorch)' 카테고리의 다른 글
[PyTorch] 다중 클래스 분류에서 필수: torch.nn.CrossEntropyLoss() 사용 가이드 (0) | 2024.08.17 |
---|---|
[PyTorch] 회귀 문제에서 필수: torch.nn.MSELoss() 사용 가이드 (0) | 2024.08.17 |
[PyTorch] 딥러닝 최적화 알고리즘: torch.optim.SGD() 사용 가이드 (0) | 2024.08.17 |
[PyTorch] 딥러닝 최적화 알고리즘: torch.optim.Adam() 사용 가이드 (0) | 2024.08.17 |
[PyTorch] 맞춤형 데이터셋 만들기: torch.utils.data.Dataset() 사용 가이드 (0) | 2024.08.16 |