본문 바로가기
실전 예제, 프로젝트

[실전 예제/변화 탐지/PyTorch] Siamese 기반 변화 탐지 모델 구성과 학습

by First Adventure 2026. 1. 14.
반응형

Siamese 기반 변화 탐지(Change Detection)이란?

  변화 탐지(Change Detection)는 두 시점(A/B)의 이미지를 비교해 변화가 발생한 영역을 픽셀 단위로 분할(Segmentation)하는 컴퓨터 비전 태스크입니다. 단순히 “다르다/같다”를 판별하는 것이 아니라, 어디가 얼마나 변했는지를 마스크 형태로 예측하는 것이 핵심입니다.

  Siamese 기반 모델은 두 입력 이미지(A/B)를 같은 가중치(shared weights)를 가진 인코더로 각각 특징을 추출한 뒤, 특징 차이/결합을 통해 변화 영역을 복원합니다. 이번 글에서는 LEVIR 데이터셋을 기준으로 Siamese 구조 모델을 구성하고 학습하는 방법을 PyTorch로 정리합니다.

  • 예: 건물 신축/철거, 도로 확장 등 변화 영역 자동 탐지
  • 예: 재난 전/후 피해 지역 추정(변화 마스크 생성)

 

목표

  1. LEVIR (A/B + Mask) 데이터셋을 그대로 활용해 학습 파이프라인 구성
  2. Siamese Encoder + Decoder 기반 변화 탐지 모델 정의
  3. BCEWithLogits + Dice Loss 조합으로 학습 안정화
  4. IoU/F1 평가 및 간단한 추론/시각화까지 연결

 

학습을 위한 데이터 구조

  이전 글에서 구성한 LEVIR 데이터셋은 기본적으로 아래 형태를 가집니다. (A/B 두 시점 이미지 + 변화 마스크)

LEVIR-CD/
  train/
    A/
      xxx.png
    B/
      xxx.png
    label/
      xxx.png
  val/
    A/
    B/
    label/

  중요한 점은 transform이 이미지(A/B)와 마스크(label)에 동시에 동일하게 적용되어야 한다는 것입니다. 이미지에만 Resize/Flip이 적용되면 학습이 바로 깨집니다.

 

PyTorch Transform (A/B/Mask 동시 적용)

  아래는 가장 기본적인 형태의 동기화 Transform 예시입니다. 랜덤 플립/크롭 등을 넣을 때도 항상 A/B/Mask에 동일한 파라미터로 적용해야 합니다.

import random
import torchvision.transforms.functional as TF

class CDCompose:
    def __init__(self, transforms):
        self.transforms = transforms
    def __call__(self, imgA, imgB, mask):
        for t in self.transforms:
            imgA, imgB, mask = t(imgA, imgB, mask)
        return imgA, imgB, mask

class CDToTensor:
    def __call__(self, imgA, imgB, mask):
        imgA = TF.to_tensor(imgA)  # 0~1
        imgB = TF.to_tensor(imgB)
        # label은 0/255로 들어오는 경우가 많아서 0/1로 정규화
        mask = TF.to_tensor(mask)
        mask = (mask > 0.5).float()
        return imgA, imgB, mask

class CDResize:
    def __init__(self, size):
        self.size = size  # (H,W)
    def __call__(self, imgA, imgB, mask):
        imgA = TF.resize(imgA, self.size)
        imgB = TF.resize(imgB, self.size)
        # mask는 보간으로 값이 흐려지면 안 되므로 NEAREST 사용
        mask = TF.resize(mask, self.size, interpolation=TF.InterpolationMode.NEAREST)
        return imgA, imgB, mask

class CDRandomHFlip:
    def __init__(self, p=0.5):
        self.p = p
    def __call__(self, imgA, imgB, mask):
        if random.random() < self.p:
            imgA = TF.hflip(imgA)
            imgB = TF.hflip(imgB)
            mask = TF.hflip(mask)
        return imgA, imgB, mask

def get_cd_transforms(train=True, size=(512, 512)):
    t = [CDResize(size)]
    if train:
        t.append(CDRandomHFlip(0.5))
    t.append(CDToTensor())
    return CDCompose(t)

 

Siamese 변화 탐지 모델 구성

모델 개요

  • 입력: (imgA, imgB) 두 장의 이미지
  • Encoder: shared weights (Siamese)
  • Feature Fusion: |FA - FB| (절대 차이) 또는 concat([FA, FB, |FA-FB|])
  • Decoder: 변화 마스크(1채널) 복원

  아래는 학습/튜토리얼 용도로 가장 안정적인 가벼운 Siamese U-Net 스타일 구현입니다. (너무 복잡하게 가지 않고, 이후 성능 개선도 가능한 형태로 작성했습니다)

import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.block(x)

class Encoder(nn.Module):
    def __init__(self, in_ch=3, base=32):
        super().__init__()
        self.c1 = nn.Sequential(ConvBNReLU(in_ch, base), ConvBNReLU(base, base))
        self.p1 = nn.MaxPool2d(2)
        self.c2 = nn.Sequential(ConvBNReLU(base, base*2), ConvBNReLU(base*2, base*2))
        self.p2 = nn.MaxPool2d(2)
        self.c3 = nn.Sequential(ConvBNReLU(base*2, base*4), ConvBNReLU(base*4, base*4))
        self.p3 = nn.MaxPool2d(2)
        self.c4 = nn.Sequential(ConvBNReLU(base*4, base*8), ConvBNReLU(base*8, base*8))

    def forward(self, x):
        f1 = self.c1(x)      # (B, base,   H,   W)
        x  = self.p1(f1)
        f2 = self.c2(x)      # (B, base*2, H/2, W/2)
        x  = self.p2(f2)
        f3 = self.c3(x)      # (B, base*4, H/4, W/4)
        x  = self.p3(f3)
        f4 = self.c4(x)      # (B, base*8, H/8, W/8)
        return f1, f2, f3, f4

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            ConvBNReLU(out_ch + skip_ch, out_ch),
            ConvBNReLU(out_ch, out_ch)
        )
    def forward(self, x, skip):
        x = self.up(x)
        # size mismatch 방어(입력 크기가 2의 배수가 아닐 때)
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class SiameseChangeNet(nn.Module):
    def __init__(self, base=32):
        super().__init__()
        self.enc = Encoder(in_ch=3, base=base)  # shared

        # fusion channel: concat([FA, FB, |FA-FB|]) => 3x
        self.fuse4 = ConvBNReLU(base*8*3, base*8)

        self.up3 = UpBlock(base*8, base*4*3, base*4)
        self.up2 = UpBlock(base*4, base*2*3, base*2)
        self.up1 = UpBlock(base*2, base*1*3, base*1)

        self.head = nn.Conv2d(base, 1, kernel_size=1)

    def forward(self, a, b):
        a1, a2, a3, a4 = self.enc(a)
        b1, b2, b3, b4 = self.enc(b)

        d1 = torch.abs(a1 - b1)
        d2 = torch.abs(a2 - b2)
        d3 = torch.abs(a3 - b3)
        d4 = torch.abs(a4 - b4)

        s1 = torch.cat([a1, b1, d1], dim=1)
        s2 = torch.cat([a2, b2, d2], dim=1)
        s3 = torch.cat([a3, b3, d3], dim=1)
        s4 = torch.cat([a4, b4, d4], dim=1)

        x = self.fuse4(s4)
        x = self.up3(x, s3)
        x = self.up2(x, s2)
        x = self.up1(x, s1)

        logits = self.head(x)  # (B,1,H,W)
        return logits

  출력은 sigmoid를 통과하기 전의 logits이며, 학습 시 BCEWithLogitsLoss를 사용하면 수치적으로 안정적입니다.

 

Loss 구성 (BCE + Dice)

  변화 탐지는 보통 변화 픽셀이 매우 적은 클래스 불균형이 발생합니다. BCE만 쓰면 변화 영역이 얇게 사라지거나 배경만 찍는 방향으로 가는 경우가 있어, Dice Loss를 함께 쓰면 안정적인 경우가 많습니다.

class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        probs = probs.view(probs.size(0), -1)
        targets = targets.view(targets.size(0), -1)

        inter = (probs * targets).sum(dim=1)
        union = probs.sum(dim=1) + targets.sum(dim=1)
        dice = (2 * inter + self.eps) / (union + self.eps)
        return 1 - dice.mean()

bce = nn.BCEWithLogitsLoss()
dice = DiceLoss()

def total_loss(logits, mask, w_dice=1.0):
    return bce(logits, mask) + w_dice * dice(logits, mask)

 

학습 코드 (Train / Eval)

  이제 DataLoader에서 (imgA, imgB, mask)를 받아 모델을 학습합니다. 아래 코드는 AMP(자동 혼합정밀)를 포함한 “실전형 최소 학습 루프”입니다.

import os
import random
from glob import glob

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms.functional as TF

# =========================================================
# 1) Transform (A/B/Mask 동기화)
# =========================================================
class CDCompose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, imgA, imgB, mask):
        for t in self.transforms:
            imgA, imgB, mask = t(imgA, imgB, mask)
        return imgA, imgB, mask


class CDToTensor:
    def __call__(self, imgA, imgB, mask):
        imgA = TF.to_tensor(imgA)  # (3,H,W), float 0~1
        imgB = TF.to_tensor(imgB)
        mask = TF.to_tensor(mask)  # (1,H,W), float 0~1
        mask = (mask > 0.5).float()
        return imgA, imgB, mask


class CDResize:
    def __init__(self, size):
        self.size = size  # (H,W)

    def __call__(self, imgA, imgB, mask):
        imgA = TF.resize(imgA, self.size)
        imgB = TF.resize(imgB, self.size)
        mask = TF.resize(mask, self.size, interpolation=TF.InterpolationMode.NEAREST)
        return imgA, imgB, mask


class CDRandomHFlip:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, imgA, imgB, mask):
        if random.random() < self.p:
            imgA = TF.hflip(imgA)
            imgB = TF.hflip(imgB)
            mask = TF.hflip(mask)
        return imgA, imgB, mask


def get_cd_transforms(train=True, size=(512, 512)):
    t = [CDResize(size)]
    if train:
        t.append(CDRandomHFlip(0.5))
    t.append(CDToTensor())
    return CDCompose(t)


# =========================================================
# 2) Dataset (LEVIR-CD)
# =========================================================
class LEVIRDataset(Dataset):
    """
    Return: imgA (3,H,W), imgB (3,H,W), mask (1,H,W) float(0/1)
    """
    def __init__(self, root_dir, split="train", transforms=None):
        self.root_dir = root_dir
        self.split = split
        self.transforms = transforms

        self.dirA = os.path.join(root_dir, split, "A")
        self.dirB = os.path.join(root_dir, split, "B")
        self.dirL = os.path.join(root_dir, split, "label")

        if not (os.path.isdir(self.dirA) and os.path.isdir(self.dirB) and os.path.isdir(self.dirL)):
            raise FileNotFoundError(
                f"LEVIR 폴더 구조를 확인하세요.\n"
                f"- {self.dirA}\n- {self.dirB}\n- {self.dirL}"
            )

        # A 기준 파일명 목록
        self.names = sorted([os.path.basename(p) for p in glob(os.path.join(self.dirA, "*"))])
        if len(self.names) == 0:
            raise FileNotFoundError(f"'{self.dirA}' 안에 이미지가 없습니다.")

        # B/label 존재 체크(빠진 파일 방어)
        filtered = []
        for n in self.names:
            if os.path.exists(os.path.join(self.dirB, n)) and os.path.exists(os.path.join(self.dirL, n)):
                filtered.append(n)
        self.names = filtered

        if len(self.names) == 0:
            raise FileNotFoundError("A/B/label의 파일명이 일치하는 샘플이 없습니다. 파일명을 확인하세요.")

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        pA = os.path.join(self.dirA, name)
        pB = os.path.join(self.dirB, name)
        pL = os.path.join(self.dirL, name)

        imgA = Image.open(pA).convert("RGB")
        imgB = Image.open(pB).convert("RGB")
        mask = Image.open(pL).convert("L")  # 0/255

        if self.transforms:
            imgA, imgB, mask = self.transforms(imgA, imgB, mask)
        else:
            # 최소 기본값
            imgA, imgB, mask = CDToTensor()(imgA, imgB, mask)

        return imgA, imgB, mask


# =========================================================
# 3) Model (Siamese U-Net style)
# =========================================================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class Encoder(nn.Module):
    def __init__(self, in_ch=3, base=32):
        super().__init__()
        self.c1 = nn.Sequential(ConvBNReLU(in_ch, base), ConvBNReLU(base, base))
        self.p1 = nn.MaxPool2d(2)
        self.c2 = nn.Sequential(ConvBNReLU(base, base * 2), ConvBNReLU(base * 2, base * 2))
        self.p2 = nn.MaxPool2d(2)
        self.c3 = nn.Sequential(ConvBNReLU(base * 2, base * 4), ConvBNReLU(base * 4, base * 4))
        self.p3 = nn.MaxPool2d(2)
        self.c4 = nn.Sequential(ConvBNReLU(base * 4, base * 8), ConvBNReLU(base * 8, base * 8))

    def forward(self, x):
        f1 = self.c1(x)  # (B, base, H, W)
        x = self.p1(f1)
        f2 = self.c2(x)  # (B, 2b, H/2, W/2)
        x = self.p2(f2)
        f3 = self.c3(x)  # (B, 4b, H/4, W/4)
        x = self.p3(f3)
        f4 = self.c4(x)  # (B, 8b, H/8, W/8)
        return f1, f2, f3, f4


class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            ConvBNReLU(out_ch + skip_ch, out_ch),
            ConvBNReLU(out_ch, out_ch),
        )

    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)


class SiameseChangeNet(nn.Module):
    def __init__(self, base=32):
        super().__init__()
        self.enc = Encoder(in_ch=3, base=base)  # shared

        self.fuse4 = ConvBNReLU(base * 8 * 3, base * 8)
        self.up3 = UpBlock(base * 8, base * 4 * 3, base * 4)
        self.up2 = UpBlock(base * 4, base * 2 * 3, base * 2)
        self.up1 = UpBlock(base * 2, base * 1 * 3, base * 1)
        self.head = nn.Conv2d(base, 1, kernel_size=1)

    def forward(self, a, b):
        a1, a2, a3, a4 = self.enc(a)
        b1, b2, b3, b4 = self.enc(b)

        d1 = torch.abs(a1 - b1)
        d2 = torch.abs(a2 - b2)
        d3 = torch.abs(a3 - b3)
        d4 = torch.abs(a4 - b4)

        s1 = torch.cat([a1, b1, d1], dim=1)
        s2 = torch.cat([a2, b2, d2], dim=1)
        s3 = torch.cat([a3, b3, d3], dim=1)
        s4 = torch.cat([a4, b4, d4], dim=1)

        x = self.fuse4(s4)
        x = self.up3(x, s3)
        x = self.up2(x, s2)
        x = self.up1(x, s1)

        logits = self.head(x)  # (B,1,H,W)
        return logits


# =========================================================
# 4) Loss (BCE + Dice)
# =========================================================
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        probs = probs.view(probs.size(0), -1)
        targets = targets.view(targets.size(0), -1)

        inter = (probs * targets).sum(dim=1)
        union = probs.sum(dim=1) + targets.sum(dim=1)
        dice = (2 * inter + self.eps) / (union + self.eps)
        return 1 - dice.mean()


bce = nn.BCEWithLogitsLoss()
dice = DiceLoss()


def total_loss(logits, mask, w_dice=1.0):
    return bce(logits, mask) + w_dice * dice(logits, mask)


@torch.no_grad()
def calc_iou(logits, mask, thr=0.5, eps=1e-6):
    pred = (torch.sigmoid(logits) >= thr).float()
    pred = pred.view(pred.size(0), -1)
    mask = mask.view(mask.size(0), -1)

    inter = (pred * mask).sum(dim=1)
    union = (pred + mask - pred * mask).sum(dim=1)
    iou = (inter + eps) / (union + eps)
    return iou.mean().item()


# =========================================================
# 5) Train / Eval
# =========================================================
def train_one_epoch(model, loader, optimizer, scaler, device, w_dice=1.0, amp=True):
    model.train()
    loss_sum, iou_sum = 0.0, 0.0

    for imgA, imgB, mask in loader:
        imgA = imgA.to(device, non_blocking=True)
        imgB = imgB.to(device, non_blocking=True)
        mask = mask.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=amp):
            logits = model(imgA, imgB)
            loss = total_loss(logits, mask, w_dice=w_dice)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loss_sum += loss.item()
        iou_sum += calc_iou(logits.detach(), mask)

    return loss_sum / max(1, len(loader)), iou_sum / max(1, len(loader))


@torch.no_grad()
def evaluate(model, loader, device, w_dice=1.0):
    model.eval()
    loss_sum, iou_sum = 0.0, 0.0

    for imgA, imgB, mask in loader:
        imgA = imgA.to(device, non_blocking=True)
        imgB = imgB.to(device, non_blocking=True)
        mask = mask.to(device, non_blocking=True)

        logits = model(imgA, imgB)
        loss = total_loss(logits, mask, w_dice=w_dice)

        loss_sum += loss.item()
        iou_sum += calc_iou(logits, mask)

    return loss_sum / max(1, len(loader)), iou_sum / max(1, len(loader))


def main():
    # --------------------------
    # Config
    # --------------------------
    data_root = "LEVIR-CD"   # <-- 본인 경로에 맞게 수정
    img_size = (512, 512)
    epochs = 20
    batch_size = 8
    lr = 3e-4
    weight_decay = 1e-4
    w_dice = 1.0
    amp = torch.cuda.is_available()

    save_dir = "checkpoints_cd"
    os.makedirs(save_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --------------------------
    # Dataset / Loader
    # --------------------------
    train_ds = LEVIRDataset(
        root_dir=data_root,
        split="train",
        transforms=get_cd_transforms(train=True, size=img_size),
    )
    val_ds = LEVIRDataset(
        root_dir=data_root,
        split="val",
        transforms=get_cd_transforms(train=False, size=img_size),
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=4, pin_memory=True, drop_last=False
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True, drop_last=False
    )

    # --------------------------
    # Model / Optim
    # --------------------------
    model = SiameseChangeNet(base=32).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=amp)

    best_iou = -1.0

    # --------------------------
    # Train loop
    # --------------------------
    for epoch in range(1, epochs + 1):
        tr_loss, tr_iou = train_one_epoch(model, train_loader, optimizer, scaler, device, w_dice=w_dice, amp=amp)
        va_loss, va_iou = evaluate(model, val_loader, device, w_dice=w_dice)

        print(
            f"[Epoch {epoch:02d}] "
            f"train_loss={tr_loss:.4f} train_iou={tr_iou:.4f} | "
            f"val_loss={va_loss:.4f} val_iou={va_iou:.4f}"
        )

        # best 저장 (val IoU 기준)
        if va_iou > best_iou:
            best_iou = va_iou
            ckpt_path = os.path.join(save_dir, "best.pt")
            torch.save(
                {
                    "epoch": epoch,
                    "best_iou": best_iou,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                },
                ckpt_path,
            )
            print(f"[Saved best] {ckpt_path} (val_iou={best_iou:.4f})")

    print(f"[Done] best_iou={best_iou:.4f}")


if __name__ == "__main__":
    main()

  학습 중에는 loss만 보지 말고 IoU를 같이 확인하는 것이 중요합니다. 변화 탐지는 픽셀 불균형 때문에 loss가 내려가도 결과가 별로인 경우가 종종 있습니다.

 

추론 및 결과 저장(예측 마스크)

  학습이 끝나면 단일 샘플에 대해 예측 마스크를 저장해 결과를 확인할 수 있습니다. 아래는 (A/B/GT/Pred)에서 Pred 마스크를 png로 저장하는 예시입니다.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import argparse
from glob import glob
from typing import Tuple

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================================================
# 1) PIL 기반 transforms (torchvision / numpy 없이)
# ============================================================

class CDCompose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, imgA, imgB, mask):
        for t in self.transforms:
            imgA, imgB, mask = t(imgA, imgB, mask)
        return imgA, imgB, mask


class CDResize:
    def __init__(self, size_hw: Tuple[int, int]):
        self.size_hw = size_hw  # (H, W)

    def __call__(self, imgA, imgB, mask):
        H, W = self.size_hw
        imgA = imgA.resize((W, H), resample=Image.BILINEAR)
        imgB = imgB.resize((W, H), resample=Image.BILINEAR)
        mask = mask.resize((W, H), resample=Image.NEAREST)
        return imgA, imgB, mask


def _pil_rgb_to_tensor_no_numpy(img: Image.Image) -> torch.Tensor:
    """
    PIL RGB -> torch.float32 (3,H,W) in [0,1]  (numpy 없이)
    """
    if img.mode != "RGB":
        img = img.convert("RGB")
    w, h = img.size
    raw = img.tobytes()  # RGBRGB...
    x = torch.ByteTensor(list(raw))  # (H*W*3,)
    x = x.view(h, w, 3).permute(2, 0, 1).contiguous()  # (3,H,W)
    return x.float() / 255.0


def _pil_l_to_mask01_tensor_no_numpy(mask: Image.Image) -> torch.Tensor:
    """
    PIL L(0~255) -> torch.float32 (1,H,W) in {0,1}  (numpy 없이)
    """
    if mask.mode != "L":
        mask = mask.convert("L")
    w, h = mask.size
    raw = mask.tobytes()
    x = torch.ByteTensor(list(raw)).view(h, w)  # (H,W) uint8
    x = (x > 0).float().unsqueeze(0)  # (1,H,W)
    return x


class CDToTensor:
    def __call__(self, imgA, imgB, mask):
        tA = _pil_rgb_to_tensor_no_numpy(imgA)
        tB = _pil_rgb_to_tensor_no_numpy(imgB)
        tM = _pil_l_to_mask01_tensor_no_numpy(mask)
        return tA, tB, tM


def get_cd_transforms(size_hw=(512, 512)):
    return CDCompose([
        CDResize(size_hw),
        CDToTensor(),
    ])


# ============================================================
# 2) Dataset (LEVIR-CD)
# ============================================================

class LEVIRDataset(torch.utils.data.Dataset):
    """
    Return: imgA (3,H,W), imgB (3,H,W), mask (1,H,W) float(0/1), name(str)
    """
    def __init__(self, root_dir: str, split="val", transforms=None):
        self.root_dir = root_dir
        self.split = split
        self.transforms = transforms

        self.dirA = os.path.join(root_dir, split, "A")
        self.dirB = os.path.join(root_dir, split, "B")
        self.dirL = os.path.join(root_dir, split, "label")

        if not (os.path.isdir(self.dirA) and os.path.isdir(self.dirB) and os.path.isdir(self.dirL)):
            raise FileNotFoundError(
                f"[ERR] LEVIR 폴더 구조를 확인하세요:\n"
                f"- {self.dirA}\n- {self.dirB}\n- {self.dirL}"
            )

        names = sorted([os.path.basename(p) for p in glob(os.path.join(self.dirA, "*"))])
        if len(names) == 0:
            raise FileNotFoundError(f"[ERR] '{self.dirA}' 안에 이미지가 없습니다.")

        filtered = []
        for n in names:
            if os.path.exists(os.path.join(self.dirB, n)) and os.path.exists(os.path.join(self.dirL, n)):
                filtered.append(n)
        self.names = filtered

        if len(self.names) == 0:
            raise FileNotFoundError("[ERR] A/B/label의 파일명이 일치하는 샘플이 없습니다. 파일명을 확인하세요.")

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        pA = os.path.join(self.dirA, name)
        pB = os.path.join(self.dirB, name)
        pL = os.path.join(self.dirL, name)

        imgA = Image.open(pA).convert("RGB")
        imgB = Image.open(pB).convert("RGB")
        mask = Image.open(pL).convert("L")

        if self.transforms is not None:
            imgA, imgB, mask = self.transforms(imgA, imgB, mask)
        else:
            imgA, imgB, mask = CDToTensor()(imgA, imgB, mask)

        return imgA, imgB, mask, name


# ============================================================
# 3) Model (SiameseChangeNet)
# ============================================================

class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class Encoder(nn.Module):
    def __init__(self, in_ch=3, base=32):
        super().__init__()
        self.c1 = nn.Sequential(ConvBNReLU(in_ch, base), ConvBNReLU(base, base))
        self.p1 = nn.MaxPool2d(2)
        self.c2 = nn.Sequential(ConvBNReLU(base, base * 2), ConvBNReLU(base * 2, base * 2))
        self.p2 = nn.MaxPool2d(2)
        self.c3 = nn.Sequential(ConvBNReLU(base * 2, base * 4), ConvBNReLU(base * 4, base * 4))
        self.p3 = nn.MaxPool2d(2)
        self.c4 = nn.Sequential(ConvBNReLU(base * 4, base * 8), ConvBNReLU(base * 8, base * 8))

    def forward(self, x):
        f1 = self.c1(x)
        x = self.p1(f1)
        f2 = self.c2(x)
        x = self.p2(f2)
        f3 = self.c3(x)
        x = self.p3(f3)
        f4 = self.c4(x)
        return f1, f2, f3, f4


class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            ConvBNReLU(out_ch + skip_ch, out_ch),
            ConvBNReLU(out_ch, out_ch),
        )

    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)


class SiameseChangeNet(nn.Module):
    def __init__(self, base=32):
        super().__init__()
        self.enc = Encoder(in_ch=3, base=base)

        self.fuse4 = ConvBNReLU(base * 8 * 3, base * 8)

        self.up3 = UpBlock(base * 8, base * 4 * 3, base * 4)
        self.up2 = UpBlock(base * 4, base * 2 * 3, base * 2)
        self.up1 = UpBlock(base * 2, base * 1 * 3, base * 1)

        self.head = nn.Conv2d(base, 1, kernel_size=1)

    def forward(self, a, b):
        a1, a2, a3, a4 = self.enc(a)
        b1, b2, b3, b4 = self.enc(b)

        d1 = torch.abs(a1 - b1)
        d2 = torch.abs(a2 - b2)
        d3 = torch.abs(a3 - b3)
        d4 = torch.abs(a4 - b4)

        s1 = torch.cat([a1, b1, d1], dim=1)
        s2 = torch.cat([a2, b2, d2], dim=1)
        s3 = torch.cat([a3, b3, d3], dim=1)
        s4 = torch.cat([a4, b4, d4], dim=1)

        x = self.fuse4(s4)
        x = self.up3(x, s3)
        x = self.up2(x, s2)
        x = self.up1(x, s1)

        return self.head(x)  # (B,1,H,W) logits


# ============================================================
# 4) Saving utils (PIL, numpy/cv2 없이)
# ============================================================

def tensor_rgb_to_pil(img_t: torch.Tensor) -> Image.Image:
    """
    img_t: (3,H,W) float(0~1) -> PIL RGB
    (numpy 없이 / 단일 저장용이라 느려도 확실)
    """
    x = img_t.detach().cpu()
    x = (x.clamp(0, 1) * 255.0).to(torch.uint8)  # (3,H,W)
    x = x.permute(1, 2, 0).contiguous()          # (H,W,3)

    h, w, _ = x.shape
    raw = bytes(x.reshape(-1).tolist())
    return Image.frombytes("RGB", (w, h), raw)


def mask01_to_pil(mask_t: torch.Tensor) -> Image.Image:
    """
    mask_t: (1,H,W) float{0,1} -> PIL L (0/255)
    """
    m = mask_t.detach().cpu()
    if m.ndim == 3:
        m = m.squeeze(0)
    m = (m.clamp(0, 1) * 255.0).to(torch.uint8)  # (H,W)
    h, w = m.shape
    raw = bytes(m.reshape(-1).tolist())
    return Image.frombytes("L", (w, h), raw)


def overlay_mask_on_rgb(img_rgb_t: torch.Tensor, mask01_t: torch.Tensor, alpha=0.45) -> Image.Image:
    """
    img_rgb_t: (3,H,W) float(0~1)
    mask01_t : (1,H,W) float(0/1)
    -> PIL RGB overlay (mask 영역을 빨간색으로 블렌딩)
    """
    img = (img_rgb_t.detach().cpu().clamp(0, 1) * 255.0).to(torch.uint8)  # (3,H,W)
    mask = mask01_t.detach().cpu()
    if mask.ndim == 3:
        mask = mask.squeeze(0)  # (H,W)
    mask = (mask > 0.5)

    out = img.clone()  # uint8
    if mask.any():
        # 빨강(255,0,0)과 블렌딩: out = (1-a)*img + a*red
        a = float(alpha)
        red = torch.tensor([255, 0, 0], dtype=torch.float32).view(3, 1, 1)
        region = img.float()
        blended = ((1.0 - a) * region + a * red).clamp(0, 255).to(torch.uint8)
        out[:, mask] = blended[:, mask]

    # (3,H,W)->PIL
    out_hw3 = out.permute(1, 2, 0).contiguous()
    h, w, _ = out_hw3.shape
    raw = bytes(out_hw3.reshape(-1).tolist())
    return Image.frombytes("RGB", (w, h), raw)


# ============================================================
# 5) Checkpoint loader (model_state 지원)
# ============================================================

def load_checkpoint(model: nn.Module, ckpt_path: str, device: torch.device):
    ckpt = torch.load(ckpt_path, map_location=device)

    # 1) {model_state: ...} 형태
    if isinstance(ckpt, dict) and "model_state" in ckpt and isinstance(ckpt["model_state"], dict):
        state = ckpt["model_state"]
    # 2) 그냥 state_dict 형태
    elif isinstance(ckpt, dict):
        # 일부는 epoch/best_iou/optimizer_state 등이 섞여있어서 "진짜 weight key"만 남길 수도 있지만
        # 여기서는 엄격하지 않게 로드하되, 실패하면 user가 알려준 구조로 조정 가능.
        state = ckpt
    else:
        raise ValueError("[ERR] ckpt 파일 형식을 해석할 수 없습니다.")

    missing, unexpected = model.load_state_dict(state, strict=False)
    if missing:
        print("[WARN] missing keys:", missing[:10], "..." if len(missing) > 10 else "")
    if unexpected:
        print("[WARN] unexpected keys:", unexpected[:10], "..." if len(unexpected) > 10 else "")


# ============================================================
# 6) Inference
# ============================================================

@torch.no_grad()
def run_single_infer(
    root: str,
    ckpt: str,
    split: str,
    idx: int,
    out_dir: str,
    thr: float,
    size_hw: Tuple[int, int],
    base: int,
    use_cpu: bool,
):
    device = torch.device("cpu" if use_cpu or not torch.cuda.is_available() else "cuda")
    os.makedirs(out_dir, exist_ok=True)

    ds = LEVIRDataset(root_dir=root, split=split, transforms=get_cd_transforms(size_hw))
    if idx < 0 or idx >= len(ds):
        raise IndexError(f"[ERR] idx={idx} 범위 오류. dataset size={len(ds)}")

    model = SiameseChangeNet(base=base).to(device)
    model.eval()
    load_checkpoint(model, ckpt, device)

    imgA, imgB, gt, name = ds[idx]   # (3,H,W), (3,H,W), (1,H,W), str
    imgA_b = imgA.unsqueeze(0).to(device)
    imgB_b = imgB.unsqueeze(0).to(device)

    logits = model(imgA_b, imgB_b)           # (1,1,H,W)
    prob = torch.sigmoid(logits).squeeze(0)  # (1,H,W)
    pred = (prob >= thr).float()             # (1,H,W)

    # Save
    stem = os.path.splitext(name)[0]
    pA = os.path.join(out_dir, f"{stem}_A.png")
    pB = os.path.join(out_dir, f"{stem}_B.png")
    pGT = os.path.join(out_dir, f"{stem}_GT.png")
    pPR = os.path.join(out_dir, f"{stem}_Pred.png")
    pOA = os.path.join(out_dir, f"{stem}_OverlayA.png")
    pOB = os.path.join(out_dir, f"{stem}_OverlayB.png")

    tensor_rgb_to_pil(imgA).save(pA)
    tensor_rgb_to_pil(imgB).save(pB)
    mask01_to_pil(gt).save(pGT)
    mask01_to_pil(pred).save(pPR)

    overlay_mask_on_rgb(imgA, pred, alpha=0.45).save(pOA)
    overlay_mask_on_rgb(imgB, pred, alpha=0.45).save(pOB)

    print("[OK] saved:")
    print(" -", pA)
    print(" -", pB)
    print(" -", pGT)
    print(" -", pPR)
    print(" -", pOA)
    print(" -", pOB)


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--root", required=True, help="LEVIR-CD root (ex: /path/LEVIR-CD)")
    p.add_argument("--ckpt", required=True, help="checkpoint path (.pth)")
    p.add_argument("--split", default="val", choices=["train", "val", "test"])
    p.add_argument("--idx", type=int, default=0)
    p.add_argument("--out_dir", default="out_infer")
    p.add_argument("--thr", type=float, default=0.5)
    p.add_argument("--size", type=int, nargs=2, default=[512, 512], metavar=("H", "W"))
    p.add_argument("--base", type=int, default=32)
    p.add_argument("--cpu", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    run_single_infer(
        root=args.root,
        ckpt=args.ckpt,
        split=args.split,
        idx=args.idx,
        out_dir=args.out_dir,
        thr=args.thr,
        size_hw=(args.size[0], args.size[1]),
        base=args.base,
        use_cpu=args.cpu,
    )


if __name__ == "__main__":
    main()

 

마무리

  이번 글에서는 LEVIR 데이터셋을 기준으로 Siamese 기반 변화 탐지 모델을 구성하고, BCE + Dice Loss 조합으로 학습하는 전체 과정을 살펴보았습니다. 핵심은 두 시점 입력(A/B)을 동일 인코더로 처리하고, 특징 차이를 이용해 변화 마스크를 복원하는 구조를 안정적으로 학습 루프에 연결하는 것입니다.

  다음 글에서는 예측 결과를 좀 더 보기 좋게 시각화(Overlay)하고, IoU/F1을 제대로 측정하는 평가 코드까지 정리해 "학습 → 평가 → 개선" 흐름이 자연스럽게 이어지도록 구성해보겠습니다.

 

관련 내용

 

 

반응형