본문 바로가기
실전 예제, 프로젝트

[실전 예제/객체 추적/PyTorch] Re-ID 기반 객체 추적 모델 구성과 학습

by First Adventure 2026. 1. 14.
반응형

Re-ID 기반 객체 추적(Object Tracking)이란?

  객체 추적에서 가장 어려운 문제는 객체가 잠시 가려지거나 화면에서 사라졌다가 다시 등장했을 때도 동일한 ID를 유지하는 것입니다. 이를 위해 대부분의 현대적인 객체 추적 시스템은 Re-ID(Re-Identification) 모델을 사용합니다.

  Re-ID는 객체의 외형 정보를 임베딩 벡터로 변환하고, 시간적으로 떨어진 프레임 간에도 같은 객체인지 판단할 수 있도록 도와줍니다. 이번 글에서는 이전 글에서 생성한 MOT crop 데이터셋을 기반으로 Re-ID 임베딩 모델을 구성하고 학습하는 방법을 살펴봅니다.

  • 예: 사람이 가려졌다가 다시 등장했을 때 같은 ID로 복원
  • 예: 다른 사람과 교차한 이후에도 ID 스위치 방지

 

목표

  1. MOT crop + pair 데이터셋을 이용한 Re-ID 학습 파이프라인 구성
  2. Siamese 구조 기반 임베딩 모델 정의
  3. 추후 Kalman Filter + Hungarian matching과 자연스럽게 결합 가능하도록 설계

 

Re-ID 학습을 위한 데이터 구조

이전 글에서 생성한 데이터는 다음과 같은 구조를 가집니다.

mot_crops/
  MOT17-02_000001_3.jpg
  MOT17-02_000002_3.jpg
  ...
  pairs.txt

pairs.txt는 두 이미지와 라벨로 구성된 pair 데이터입니다.

img_a.jpg,img_b.jpg,1   # 같은 객체
img_c.jpg,img_d.jpg,0   # 다른 객체

이 구조는 Contrastive Loss / Siamese Network 학습에 바로 사용할 수 있습니다.

 

Re-ID 임베딩 모델 구성

모델 개요

  • Backbone: ResNet18
  • 입력: 객체 crop 이미지 (256×128)
  • 출력: L2 정규화된 임베딩 벡터
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18


class ReIDEmbeddingNet(nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()
        backbone = resnet18(pretrained=True)
        backbone.fc = nn.Identity()
        self.backbone = backbone

        self.head = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, embedding_dim)
        )

    def forward(self, x):
        feat = self.backbone(x)
        emb = self.head(feat)
        emb = F.normalize(emb, p=2, dim=1)
        return emb

  출력 임베딩은 L2 정규화를 거쳐 거리 기반 비교에 바로 사용할 수 있도록 합니다.

 

Contrastive Loss 정의

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, e1, e2, label):
        dist = torch.norm(e1 - e2, dim=1)
        pos = label * dist.pow(2)
        neg = (1 - label) * torch.clamp(self.margin - dist, min=0).pow(2)
        return (pos + neg).mean()

  같은 객체는 거리를 줄이고, 다른 객체는 margin 이상 벌어지도록 학습됩니다.

 

Re-ID 모델 학습 코드

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import argparse
import random
from dataclasses import dataclass
from typing import List, Tuple

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ============================================================
# 0) Utils
# ============================================================

def seed_everything(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)


# ============================================================
# 1) PIL -> torch tensor (numpy/torchvision 없이)
# ============================================================

def pil_rgb_to_tensor(img: Image.Image) -> torch.Tensor:
    """PIL RGB -> torch.float32 (3,H,W) in [0,1] (numpy 없이)"""
    if img.mode != "RGB":
        img = img.convert("RGB")
    w, h = img.size
    raw = img.tobytes()  # RGBRGB...
    x = torch.ByteTensor(list(raw)).view(h, w, 3).permute(2, 0, 1).contiguous()
    return x.float() / 255.0


def resize_pil(img: Image.Image, size_hw: Tuple[int, int]) -> Image.Image:
    """size_hw=(H,W)"""
    H, W = size_hw
    return img.resize((W, H), resample=Image.BILINEAR)


def random_hflip_pil(img: Image.Image, p=0.5) -> Image.Image:
    if random.random() < p:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img


@dataclass
class ReIDTransform:
    size_hw: Tuple[int, int] = (256, 128)  # (H,W) for ReID (commonly 256x128)
    hflip_p: float = 0.5
    normalize: bool = True

    def __call__(self, img: Image.Image) -> torch.Tensor:
        img = resize_pil(img, self.size_hw)
        img = random_hflip_pil(img, self.hflip_p)
        x = pil_rgb_to_tensor(img)

        if self.normalize:
            # ImageNet mean/std
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std  = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            x = (x - mean) / std

        return x


# ============================================================
# 2) Dataset: pairs.txt 기반
# ============================================================

class MOTCropPairDataset(Dataset):
    """
    pairs.txt format:
      img_a.jpg,img_b.jpg,1
      img_c.jpg,img_d.jpg,0
    """
    def __init__(self, crop_dir: str, pair_file: str, transform=None):
        self.crop_dir = crop_dir
        self.pair_file = pair_file
        self.transform = transform

        if not os.path.isdir(crop_dir):
            raise FileNotFoundError(f"[ERR] crop_dir not found: {crop_dir}")
        if not os.path.isfile(pair_file):
            raise FileNotFoundError(f"[ERR] pair_file not found: {pair_file}")

        self.samples = self._load_pairs(pair_file)
        if len(self.samples) == 0:
            raise FileNotFoundError("[ERR] pairs.txt가 비어있거나 유효한 라인이 없습니다.")

    def _load_pairs(self, pair_file: str) -> List[Tuple[str, str, int]]:
        out = []
        with open(pair_file, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith("#"):
                    continue
                parts = [p.strip() for p in line.split(",")]
                if len(parts) != 3:
                    continue
                a, b, y = parts
                try:
                    y = int(y)
                except:
                    continue
                if y not in (0, 1):
                    continue

                pa = os.path.join(self.crop_dir, a)
                pb = os.path.join(self.crop_dir, b)
                if not (os.path.isfile(pa) and os.path.isfile(pb)):
                    # 파일 없는 pair는 스킵
                    continue

                out.append((pa, pb, y))
        return out

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        pa, pb, y = self.samples[idx]
        img1 = Image.open(pa).convert("RGB")
        img2 = Image.open(pb).convert("RGB")

        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        else:
            img1 = pil_rgb_to_tensor(img1)
            img2 = pil_rgb_to_tensor(img2)

        label = torch.tensor(float(y), dtype=torch.float32)  # contrastive loss용 (0/1 float)
        return img1, img2, label


def split_dataset(ds: Dataset, val_ratio=0.1, seed=42):
    n = len(ds)
    idxs = list(range(n))
    rng = random.Random(seed)
    rng.shuffle(idxs)
    n_val = int(n * val_ratio)
    val_idxs = idxs[:n_val]
    trn_idxs = idxs[n_val:]

    trn = torch.utils.data.Subset(ds, trn_idxs)
    val = torch.utils.data.Subset(ds, val_idxs)
    return trn, val


# ============================================================
# 3) Model
#   - default: SimpleCNN (torch만)
#   - option: resnet18 (torchvision 가능할 때)
# ============================================================

class SimpleCNNBackbone(nn.Module):
    """torch만으로 돌아가는 가벼운 backbone"""
    def __init__(self, out_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1), nn.BatchNorm2d(32), nn.ReLU(True),  # 1/2
            nn.Conv2d(32, 64, 3, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True), # 1/4
            nn.Conv2d(64, 128, 3, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),# 1/8
            nn.Conv2d(128, 256, 3, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True),# 1/16
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Linear(256, out_dim)

    def forward(self, x):
        x = self.net(x).flatten(1)
        x = self.fc(x)
        return x


class ReIDEmbeddingNet(nn.Module):
    def __init__(self, embedding_dim=128, backbone="simplecnn", pretrained=False):
        super().__init__()
        self.backbone_name = backbone

        if backbone == "simplecnn":
            feat_dim = 512
            self.backbone = SimpleCNNBackbone(out_dim=feat_dim)

        elif backbone == "resnet18":
            # torchvision이 정상일 때만 사용 권장
            from torchvision.models import resnet18
            if hasattr(torchvision.models, "ResNet18_Weights"):
                # 최신 torchvision
                weights = "DEFAULT" if pretrained else None
                base = resnet18(weights=weights)
            else:
                base = resnet18(pretrained=pretrained)

            base.fc = nn.Identity()
            self.backbone = base
            feat_dim = 512

        else:
            raise ValueError("backbone must be one of: simplecnn, resnet18")

        self.head = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, embedding_dim),
        )

    def forward(self, x):
        feat = self.backbone(x)
        emb = self.head(feat)
        emb = F.normalize(emb, p=2, dim=1)
        return emb


# ============================================================
# 4) Contrastive Loss
# ============================================================

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, e1, e2, label):
        """
        label: float tensor (B,) with 1 for positive(same), 0 for negative(diff)
        """
        dist = torch.norm(e1 - e2, dim=1)  # (B,)
        pos = label * dist.pow(2)
        neg = (1 - label) * torch.clamp(self.margin - dist, min=0).pow(2)
        return (pos + neg).mean()


@torch.no_grad()
def accuracy_at_threshold(e1, e2, label, thr=0.8):
    """
    간단 지표: dist < thr 이면 같은 객체로 예측
    """
    dist = torch.norm(e1 - e2, dim=1)
    pred_same = (dist < thr).float()
    correct = (pred_same == label).float().mean().item()
    return correct, dist.mean().item()


# ============================================================
# 5) Train/Eval
# ============================================================

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    loss_sum = 0.0

    for img1, img2, label in loader:
        img1 = img1.to(device, non_blocking=True)
        img2 = img2.to(device, non_blocking=True)
        label = label.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        e1 = model(img1)
        e2 = model(img2)
        loss = criterion(e1, e2, label)
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()

    return loss_sum / max(1, len(loader))


@torch.no_grad()
def evaluate(model, loader, criterion, device, thr=0.8):
    model.eval()
    loss_sum = 0.0
    acc_sum = 0.0
    dist_sum = 0.0

    for img1, img2, label in loader:
        img1 = img1.to(device, non_blocking=True)
        img2 = img2.to(device, non_blocking=True)
        label = label.to(device, non_blocking=True)

        e1 = model(img1)
        e2 = model(img2)
        loss = criterion(e1, e2, label)
        acc, mean_dist = accuracy_at_threshold(e1, e2, label, thr=thr)

        loss_sum += loss.item()
        acc_sum += acc
        dist_sum += mean_dist

    n = max(1, len(loader))
    return loss_sum / n, acc_sum / n, dist_sum / n


def save_ckpt(path, model, optimizer, epoch, best_metric):
    ensure_dir(os.path.dirname(path))
    torch.save({
        "epoch": epoch,
        "best_metric": best_metric,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
    }, path)


# ============================================================
# 6) Main
# ============================================================

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--crop_dir", type=str, default="mot_crops")
    p.add_argument("--pair_file", type=str, default="mot_crops/pairs.txt")

    p.add_argument("--epochs", type=int, default=10)
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--margin", type=float, default=1.0)

    p.add_argument("--emb_dim", type=int, default=128)
    p.add_argument("--img_h", type=int, default=256)
    p.add_argument("--img_w", type=int, default=128)

    p.add_argument("--backbone", type=str, default="simplecnn", choices=["simplecnn", "resnet18"])
    p.add_argument("--pretrained", action="store_true", help="(resnet18일 때만) pretrained 사용")
    p.add_argument("--val_ratio", type=float, default=0.1)

    p.add_argument("--dist_thr", type=float, default=0.8, help="eval용 dist threshold")
    p.add_argument("--seed", type=int, default=42)

    p.add_argument("--out_dir", type=str, default="runs/reid")
    p.add_argument("--cpu", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    seed_everything(args.seed)

    device = torch.device("cpu" if args.cpu or not torch.cuda.is_available() else "cuda")
    print("[Device]", device)

    transform = ReIDTransform(size_hw=(args.img_h, args.img_w), hflip_p=0.5, normalize=True)

    # full_ds = MOTCropPairDataset(
    #     crop_dir=args.crop_dir,
    #     pair_file=args.pair_file,
    #     transform=transform
    # )
    
    full_ds = MOTCropPairDataset(
    crop_dir="mot_crops/images",
    pair_file="mot_crops/pairs.txt"
    )
    
    train_ds, val_ds = split_dataset(full_ds, val_ratio=args.val_ratio, seed=args.seed)
    print(f"[Data] total={len(full_ds)} train={len(train_ds)} val={len(val_ds)}")

    train_loader = DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=(device.type == "cuda")
    )
    val_loader = DataLoader(
        val_ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=(device.type == "cuda")
    )

    model = ReIDEmbeddingNet(
        embedding_dim=args.emb_dim,
        backbone=args.backbone,
        pretrained=args.pretrained
    ).to(device)

    criterion = ContrastiveLoss(margin=args.margin)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)

    ensure_dir(args.out_dir)
    best_acc = -1.0

    for epoch in range(1, args.epochs + 1):
        tr_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
        va_loss, va_acc, va_dist = evaluate(model, val_loader, criterion, device, thr=args.dist_thr)

        print(f"[Epoch {epoch:02d}] "
              f"train_loss={tr_loss:.4f} | val_loss={va_loss:.4f} | "
              f"val_acc@thr={va_acc*100:.2f}% | val_mean_dist={va_dist:.4f}")

        # best 저장(여기서는 acc 기준)
        if va_acc > best_acc:
            best_acc = va_acc
            save_ckpt(os.path.join(args.out_dir, "best.pth"), model, optimizer, epoch, best_acc)

        # 매 epoch 저장
        save_ckpt(os.path.join(args.out_dir, "last.pth"), model, optimizer, epoch, best_acc)

    print(f"[Done] best_acc={best_acc*100:.2f}% saved to: {args.out_dir}")


if __name__ == "__main__":
    main()

 

 

마무리

  이번 글에서는 MOT 데이터셋에서 생성한 crop 데이터를 활용하여 Re-ID 임베딩 모델을 구성하고 학습하는 전체 과정을 살펴보았습니다. 이 모델은 이후 Kalman Filter, IOU, Hungarian Matching과 결합되어 실제 객체 추적 시스템의 핵심 구성 요소로 사용됩니다.

  다음 글에서는 학습된 Re-ID 임베딩을 이용해 프레임 간 객체를 매칭하고, 간단한 다중 객체 추적 파이프라인을 PyTorch로 구현해보겠습니다.

 

관련 내용

 

반응형