Re-ID 기반 객체 추적(Object Tracking)이란?
객체 추적에서 가장 어려운 문제는 객체가 잠시 가려지거나 화면에서 사라졌다가 다시 등장했을 때도 동일한 ID를 유지하는 것입니다. 이를 위해 대부분의 현대적인 객체 추적 시스템은 Re-ID(Re-Identification) 모델을 사용합니다.
Re-ID는 객체의 외형 정보를 임베딩 벡터로 변환하고, 시간적으로 떨어진 프레임 간에도 같은 객체인지 판단할 수 있도록 도와줍니다. 이번 글에서는 이전 글에서 생성한 MOT crop 데이터셋을 기반으로 Re-ID 임베딩 모델을 구성하고 학습하는 방법을 살펴봅니다.
- 예: 사람이 가려졌다가 다시 등장했을 때 같은 ID로 복원
- 예: 다른 사람과 교차한 이후에도 ID 스위치 방지
목표
- MOT crop + pair 데이터셋을 이용한 Re-ID 학습 파이프라인 구성
- Siamese 구조 기반 임베딩 모델 정의
- 추후 Kalman Filter + Hungarian matching과 자연스럽게 결합 가능하도록 설계
Re-ID 학습을 위한 데이터 구조
이전 글에서 생성한 데이터는 다음과 같은 구조를 가집니다.
mot_crops/
MOT17-02_000001_3.jpg
MOT17-02_000002_3.jpg
...
pairs.txt
pairs.txt는 두 이미지와 라벨로 구성된 pair 데이터입니다.
img_a.jpg,img_b.jpg,1 # 같은 객체
img_c.jpg,img_d.jpg,0 # 다른 객체
이 구조는 Contrastive Loss / Siamese Network 학습에 바로 사용할 수 있습니다.
Re-ID 임베딩 모델 구성
모델 개요
- Backbone: ResNet18
- 입력: 객체 crop 이미지 (256×128)
- 출력: L2 정규화된 임베딩 벡터
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
class ReIDEmbeddingNet(nn.Module):
def __init__(self, embedding_dim=128):
super().__init__()
backbone = resnet18(pretrained=True)
backbone.fc = nn.Identity()
self.backbone = backbone
self.head = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256, embedding_dim)
)
def forward(self, x):
feat = self.backbone(x)
emb = self.head(feat)
emb = F.normalize(emb, p=2, dim=1)
return emb
출력 임베딩은 L2 정규화를 거쳐 거리 기반 비교에 바로 사용할 수 있도록 합니다.
Contrastive Loss 정의
class ContrastiveLoss(nn.Module):
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, e1, e2, label):
dist = torch.norm(e1 - e2, dim=1)
pos = label * dist.pow(2)
neg = (1 - label) * torch.clamp(self.margin - dist, min=0).pow(2)
return (pos + neg).mean()
같은 객체는 거리를 줄이고, 다른 객체는 margin 이상 벌어지도록 학습됩니다.
Re-ID 모델 학습 코드
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import random
from dataclasses import dataclass
from typing import List, Tuple
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# ============================================================
# 0) Utils
# ============================================================
def seed_everything(seed: int = 42):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def ensure_dir(path: str):
os.makedirs(path, exist_ok=True)
# ============================================================
# 1) PIL -> torch tensor (numpy/torchvision 없이)
# ============================================================
def pil_rgb_to_tensor(img: Image.Image) -> torch.Tensor:
"""PIL RGB -> torch.float32 (3,H,W) in [0,1] (numpy 없이)"""
if img.mode != "RGB":
img = img.convert("RGB")
w, h = img.size
raw = img.tobytes() # RGBRGB...
x = torch.ByteTensor(list(raw)).view(h, w, 3).permute(2, 0, 1).contiguous()
return x.float() / 255.0
def resize_pil(img: Image.Image, size_hw: Tuple[int, int]) -> Image.Image:
"""size_hw=(H,W)"""
H, W = size_hw
return img.resize((W, H), resample=Image.BILINEAR)
def random_hflip_pil(img: Image.Image, p=0.5) -> Image.Image:
if random.random() < p:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
@dataclass
class ReIDTransform:
size_hw: Tuple[int, int] = (256, 128) # (H,W) for ReID (commonly 256x128)
hflip_p: float = 0.5
normalize: bool = True
def __call__(self, img: Image.Image) -> torch.Tensor:
img = resize_pil(img, self.size_hw)
img = random_hflip_pil(img, self.hflip_p)
x = pil_rgb_to_tensor(img)
if self.normalize:
# ImageNet mean/std
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
x = (x - mean) / std
return x
# ============================================================
# 2) Dataset: pairs.txt 기반
# ============================================================
class MOTCropPairDataset(Dataset):
"""
pairs.txt format:
img_a.jpg,img_b.jpg,1
img_c.jpg,img_d.jpg,0
"""
def __init__(self, crop_dir: str, pair_file: str, transform=None):
self.crop_dir = crop_dir
self.pair_file = pair_file
self.transform = transform
if not os.path.isdir(crop_dir):
raise FileNotFoundError(f"[ERR] crop_dir not found: {crop_dir}")
if not os.path.isfile(pair_file):
raise FileNotFoundError(f"[ERR] pair_file not found: {pair_file}")
self.samples = self._load_pairs(pair_file)
if len(self.samples) == 0:
raise FileNotFoundError("[ERR] pairs.txt가 비어있거나 유효한 라인이 없습니다.")
def _load_pairs(self, pair_file: str) -> List[Tuple[str, str, int]]:
out = []
with open(pair_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
parts = [p.strip() for p in line.split(",")]
if len(parts) != 3:
continue
a, b, y = parts
try:
y = int(y)
except:
continue
if y not in (0, 1):
continue
pa = os.path.join(self.crop_dir, a)
pb = os.path.join(self.crop_dir, b)
if not (os.path.isfile(pa) and os.path.isfile(pb)):
# 파일 없는 pair는 스킵
continue
out.append((pa, pb, y))
return out
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
pa, pb, y = self.samples[idx]
img1 = Image.open(pa).convert("RGB")
img2 = Image.open(pb).convert("RGB")
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2)
else:
img1 = pil_rgb_to_tensor(img1)
img2 = pil_rgb_to_tensor(img2)
label = torch.tensor(float(y), dtype=torch.float32) # contrastive loss용 (0/1 float)
return img1, img2, label
def split_dataset(ds: Dataset, val_ratio=0.1, seed=42):
n = len(ds)
idxs = list(range(n))
rng = random.Random(seed)
rng.shuffle(idxs)
n_val = int(n * val_ratio)
val_idxs = idxs[:n_val]
trn_idxs = idxs[n_val:]
trn = torch.utils.data.Subset(ds, trn_idxs)
val = torch.utils.data.Subset(ds, val_idxs)
return trn, val
# ============================================================
# 3) Model
# - default: SimpleCNN (torch만)
# - option: resnet18 (torchvision 가능할 때)
# ============================================================
class SimpleCNNBackbone(nn.Module):
"""torch만으로 돌아가는 가벼운 backbone"""
def __init__(self, out_dim=512):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 32, 3, 2, 1), nn.BatchNorm2d(32), nn.ReLU(True), # 1/2
nn.Conv2d(32, 64, 3, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True), # 1/4
nn.Conv2d(64, 128, 3, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),# 1/8
nn.Conv2d(128, 256, 3, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True),# 1/16
nn.AdaptiveAvgPool2d((1, 1)),
)
self.fc = nn.Linear(256, out_dim)
def forward(self, x):
x = self.net(x).flatten(1)
x = self.fc(x)
return x
class ReIDEmbeddingNet(nn.Module):
def __init__(self, embedding_dim=128, backbone="simplecnn", pretrained=False):
super().__init__()
self.backbone_name = backbone
if backbone == "simplecnn":
feat_dim = 512
self.backbone = SimpleCNNBackbone(out_dim=feat_dim)
elif backbone == "resnet18":
# torchvision이 정상일 때만 사용 권장
from torchvision.models import resnet18
if hasattr(torchvision.models, "ResNet18_Weights"):
# 최신 torchvision
weights = "DEFAULT" if pretrained else None
base = resnet18(weights=weights)
else:
base = resnet18(pretrained=pretrained)
base.fc = nn.Identity()
self.backbone = base
feat_dim = 512
else:
raise ValueError("backbone must be one of: simplecnn, resnet18")
self.head = nn.Sequential(
nn.Linear(feat_dim, 256),
nn.ReLU(inplace=True),
nn.Linear(256, embedding_dim),
)
def forward(self, x):
feat = self.backbone(x)
emb = self.head(feat)
emb = F.normalize(emb, p=2, dim=1)
return emb
# ============================================================
# 4) Contrastive Loss
# ============================================================
class ContrastiveLoss(nn.Module):
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, e1, e2, label):
"""
label: float tensor (B,) with 1 for positive(same), 0 for negative(diff)
"""
dist = torch.norm(e1 - e2, dim=1) # (B,)
pos = label * dist.pow(2)
neg = (1 - label) * torch.clamp(self.margin - dist, min=0).pow(2)
return (pos + neg).mean()
@torch.no_grad()
def accuracy_at_threshold(e1, e2, label, thr=0.8):
"""
간단 지표: dist < thr 이면 같은 객체로 예측
"""
dist = torch.norm(e1 - e2, dim=1)
pred_same = (dist < thr).float()
correct = (pred_same == label).float().mean().item()
return correct, dist.mean().item()
# ============================================================
# 5) Train/Eval
# ============================================================
def train_one_epoch(model, loader, optimizer, criterion, device):
model.train()
loss_sum = 0.0
for img1, img2, label in loader:
img1 = img1.to(device, non_blocking=True)
img2 = img2.to(device, non_blocking=True)
label = label.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
e1 = model(img1)
e2 = model(img2)
loss = criterion(e1, e2, label)
loss.backward()
optimizer.step()
loss_sum += loss.item()
return loss_sum / max(1, len(loader))
@torch.no_grad()
def evaluate(model, loader, criterion, device, thr=0.8):
model.eval()
loss_sum = 0.0
acc_sum = 0.0
dist_sum = 0.0
for img1, img2, label in loader:
img1 = img1.to(device, non_blocking=True)
img2 = img2.to(device, non_blocking=True)
label = label.to(device, non_blocking=True)
e1 = model(img1)
e2 = model(img2)
loss = criterion(e1, e2, label)
acc, mean_dist = accuracy_at_threshold(e1, e2, label, thr=thr)
loss_sum += loss.item()
acc_sum += acc
dist_sum += mean_dist
n = max(1, len(loader))
return loss_sum / n, acc_sum / n, dist_sum / n
def save_ckpt(path, model, optimizer, epoch, best_metric):
ensure_dir(os.path.dirname(path))
torch.save({
"epoch": epoch,
"best_metric": best_metric,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, path)
# ============================================================
# 6) Main
# ============================================================
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--crop_dir", type=str, default="mot_crops")
p.add_argument("--pair_file", type=str, default="mot_crops/pairs.txt")
p.add_argument("--epochs", type=int, default=10)
p.add_argument("--batch_size", type=int, default=64)
p.add_argument("--num_workers", type=int, default=4)
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--margin", type=float, default=1.0)
p.add_argument("--emb_dim", type=int, default=128)
p.add_argument("--img_h", type=int, default=256)
p.add_argument("--img_w", type=int, default=128)
p.add_argument("--backbone", type=str, default="simplecnn", choices=["simplecnn", "resnet18"])
p.add_argument("--pretrained", action="store_true", help="(resnet18일 때만) pretrained 사용")
p.add_argument("--val_ratio", type=float, default=0.1)
p.add_argument("--dist_thr", type=float, default=0.8, help="eval용 dist threshold")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--out_dir", type=str, default="runs/reid")
p.add_argument("--cpu", action="store_true")
return p.parse_args()
def main():
args = parse_args()
seed_everything(args.seed)
device = torch.device("cpu" if args.cpu or not torch.cuda.is_available() else "cuda")
print("[Device]", device)
transform = ReIDTransform(size_hw=(args.img_h, args.img_w), hflip_p=0.5, normalize=True)
# full_ds = MOTCropPairDataset(
# crop_dir=args.crop_dir,
# pair_file=args.pair_file,
# transform=transform
# )
full_ds = MOTCropPairDataset(
crop_dir="mot_crops/images",
pair_file="mot_crops/pairs.txt"
)
train_ds, val_ds = split_dataset(full_ds, val_ratio=args.val_ratio, seed=args.seed)
print(f"[Data] total={len(full_ds)} train={len(train_ds)} val={len(val_ds)}")
train_loader = DataLoader(
train_ds, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=(device.type == "cuda")
)
val_loader = DataLoader(
val_ds, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=(device.type == "cuda")
)
model = ReIDEmbeddingNet(
embedding_dim=args.emb_dim,
backbone=args.backbone,
pretrained=args.pretrained
).to(device)
criterion = ContrastiveLoss(margin=args.margin)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
ensure_dir(args.out_dir)
best_acc = -1.0
for epoch in range(1, args.epochs + 1):
tr_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
va_loss, va_acc, va_dist = evaluate(model, val_loader, criterion, device, thr=args.dist_thr)
print(f"[Epoch {epoch:02d}] "
f"train_loss={tr_loss:.4f} | val_loss={va_loss:.4f} | "
f"val_acc@thr={va_acc*100:.2f}% | val_mean_dist={va_dist:.4f}")
# best 저장(여기서는 acc 기준)
if va_acc > best_acc:
best_acc = va_acc
save_ckpt(os.path.join(args.out_dir, "best.pth"), model, optimizer, epoch, best_acc)
# 매 epoch 저장
save_ckpt(os.path.join(args.out_dir, "last.pth"), model, optimizer, epoch, best_acc)
print(f"[Done] best_acc={best_acc*100:.2f}% saved to: {args.out_dir}")
if __name__ == "__main__":
main()
마무리
이번 글에서는 MOT 데이터셋에서 생성한 crop 데이터를 활용하여 Re-ID 임베딩 모델을 구성하고 학습하는 전체 과정을 살펴보았습니다. 이 모델은 이후 Kalman Filter, IOU, Hungarian Matching과 결합되어 실제 객체 추적 시스템의 핵심 구성 요소로 사용됩니다.
다음 글에서는 학습된 Re-ID 임베딩을 이용해 프레임 간 객체를 매칭하고, 간단한 다중 객체 추적 파이프라인을 PyTorch로 구현해보겠습니다.
관련 내용
- [실전 예제/객체 추적/PyTorch] 객체 추적 튜토리얼: MOT 데이터셋으로 PyTorch 데이터셋 만들기
- [PyTorch] 맞춤형 데이터셋 만들기: torch.utils.data.Dataset() 사용 가이드
- [PyTorch] 효율적인 데이터 배치: torch.utils.data.DataLoader() 사용 가이드

'실전 예제, 프로젝트' 카테고리의 다른 글
| [실전 예제/이미지 분류/PyTorch] ResNet 기반 이미지 분류 모델 구성과 학습 (0) | 2026.01.14 |
|---|---|
| [실전 예제/변화 탐지/PyTorch] Siamese 기반 변화 탐지 모델 구성과 학습 (0) | 2026.01.14 |
| [실전 예제/객체 추적/PyTorch] MOT 데이터셋으로 객체 추적 데이터셋 구성하기 (0) | 2025.04.27 |
| [실전 예제/리스트/파이썬] 리스트 요소에 같은 연산을 적용하는 6가지 방법 (0) | 2025.04.22 |
| [실전 예제/변화 탐지/PyTorch] 변화 탐지 튜토리얼: LEVIR 데이터셋으로 PyTorch 데이터셋 만들기 (0) | 2025.04.19 |