본문 바로가기
실전 예제, 프로젝트

[실전 예제/인스턴스 분할/PyTorch] Mask R-CNN 모델 구성과 COCO 학습

by First Adventure 2026. 1. 24.
반응형

인스턴스 분할(Instance Segmentation)이란?

  인스턴스 분할(Instance Segmentation)은 객체 검출(Object Detection)과 의미 분할(Semantic Segmentation)을 결합한 컴퓨터 비전 태스크입니다. 이미지 안의 각 객체를 구분하면서, 객체마다 픽셀 단위의 마스크를 예측하는 것이 핵심입니다.

  Mask R-CNN은 Faster R-CNN 구조를 기반으로, ROI 단위에서 Bounding Box + Class + Mask를 동시에 예측하는 대표적인 인스턴스 분할 모델입니다. 이번 글에서는 이전 글에서 구성한 COCO 인스턴스 분할 데이터셋을 바탕으로, PyTorch에서 Mask R-CNN 모델을 구성하고 실제 학습까지 연결하는 과정을 정리합니다.

  • 예: 사람·차량·사물 영역을 개별 객체 단위로 분할
  • 예: 객체 개수 계산, 정밀 영역 분석

 

목표

  1. COCO 인스턴스 분할 데이터셋을 입력으로 하는 학습 파이프라인 구성
  2. torchvision Mask R-CNN 모델 구조 이해
  3. Box / Mask loss를 포함한 학습 코드 작성

 

학습을 위한 데이터 구조

  이전 글에서 구성한 COCO 인스턴스 분할 데이터셋은 COCO annotation 포맷(JSON)을 그대로 사용하며, 이미지와 마스크 정보를 함께 포함합니다. Mask R-CNN 학습을 위해서는 COCO의  instances_*.json에 포함된 segmentation 정보를 마스크로 디코딩해야 하며, 이를 위해 pycocotools가 필요합니다.

coco/
  train2017/
  val2017/
  annotations/
    instances_train2017.json
    instances_val2017.json

  Mask R-CNN 학습에서는 각 샘플마다 다음 정보가 필요합니다.

  • boxes: (N,4) 형태의 bounding box
  • labels: 각 객체의 class id
  • masks: (N,H,W) 형태의 binary mask

 

Mask R-CNN 모델 구성

모델 개요

  • Backbone: ResNet50 + FPN
  • Detector: Faster R-CNN 기반
  • Segmentation Head: ROIAlign + FCN Mask Head
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# -------------------------
# Model
# -------------------------
def build_model(num_classes: int):
    try:
        model = maskrcnn_resnet50_fpn(weights="DEFAULT")
    except TypeError:
        model = maskrcnn_resnet50_fpn(pretrained=True)

    # box predictor
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # mask predictor
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden, num_classes)
    return model

  COCO 데이터셋을 그대로 사용할 경우, 클래스 수는 background 포함으로 지정해야 합니다.

 

학습 코드

실행 예시

python train_maskrcnn_coco_official.py --coco_root ./coco --epochs 10 --batch 1 --num_workers 4
# train_maskrcnn_coco_official.py
import os
import json
import argparse
from typing import Any, Dict, List, Tuple

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

from torchvision.transforms import functional as F
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# COCO segmentation decode (RLE / polygon)
# NOTE: pycocotools가 필요합니다: pip install pycocotools
from pycocotools import mask as coco_mask

# 실행 예시
# python train_maskrcnn_coco_official.py --coco_root ./coco --epochs 10 --batch 1 --num_workers 4

# -------------------------
# Utils
# -------------------------
def collate_fn(batch):
    # detection/segmentation: variable number of instances per image
    return tuple(zip(*batch))


def set_seed(seed: int):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# -------------------------
# Category mapping
# -------------------------
def build_category_mapping(instances_ann_file: str) -> Tuple[Dict[int, int], int]:
    """
    COCO category_id는 연속이 아니므로 contiguous label(1..K)로 재매핑.
    background는 0.
    """
    with open(instances_ann_file, "r", encoding="utf-8") as f:
        coco = json.load(f)

    cat_ids = sorted([int(c["id"]) for c in coco.get("categories", [])])
    cat_map = {cid: i + 1 for i, cid in enumerate(cat_ids)}  # 1..K
    num_classes = len(cat_ids) + 1  # + background(0)
    return cat_map, num_classes


# -------------------------
# COCO Instance Segmentation Dataset
# -------------------------
class CocoInstanceSegDataset(Dataset):
    """
    Returns:
      image: FloatTensor [3,H,W] in [0,1]
      target:
        boxes   FloatTensor [N,4]  xyxy
        labels  Int64Tensor [N]
        masks   UInt8/Bool Tensor [N,H,W] (0/1)
        image_id Int64Tensor [1]
        area    FloatTensor [N]
        iscrowd Int64Tensor [N]
    """
    def __init__(self, images_dir: str, ann_file: str, cat_id_to_contiguous: Dict[int, int]):
        self.images_dir = images_dir
        self.ann_file = ann_file
        self.cat_map = cat_id_to_contiguous

        if not os.path.isdir(self.images_dir):
            raise FileNotFoundError(f"[ERR] images_dir not found: {self.images_dir}")
        if not os.path.isfile(self.ann_file):
            raise FileNotFoundError(f"[ERR] ann_file not found: {self.ann_file}")

        with open(self.ann_file, "r", encoding="utf-8") as f:
            coco = json.load(f)

        self.id_to_img = {img["id"]: img for img in coco.get("images", [])}
        self.img_ids = sorted(self.id_to_img.keys())

        self.anns_by_img: Dict[int, List[dict]] = {}
        for ann in coco.get("annotations", []):
            img_id = ann["image_id"]
            self.anns_by_img.setdefault(img_id, []).append(ann)

    def __len__(self) -> int:
        return len(self.img_ids)

    def _ann_to_mask(self, ann: dict, height: int, width: int) -> torch.Tensor:
        """
        ann['segmentation'] can be:
          - polygon list(s)
          - RLE dict (counts: bytes/str)
          - uncompressed RLE dict (counts: list)  <-- 이 케이스가 현재 에러 원인
        Return: uint8 mask (H,W) with values {0,1}
        """
        seg = ann.get("segmentation", None)
        if seg is None or seg == []:
            return torch.zeros((height, width), dtype=torch.uint8)
    
        # 1) RLE(dict) 케이스
        if isinstance(seg, dict):
            rle = seg
    
            # (A) counts가 list인 비압축 RLE -> frPyObjects로 변환 필요
            if isinstance(rle.get("counts"), list):
                rle = coco_mask.frPyObjects(rle, height, width)
                # frPyObjects가 list를 줄 때가 있어 방어
                if isinstance(rle, list):
                    rle = rle[0]
    
            # (B) counts가 str인 경우 -> bytes로 변환(버전/환경에 따라 필요)
            elif isinstance(rle.get("counts"), str):
                rle = dict(rle)
                rle["counts"] = rle["counts"].encode("utf-8")
    
        # 2) polygon(list) 케이스
        else:
            rles = coco_mask.frPyObjects(seg, height, width)
            rle = coco_mask.merge(rles)
    
        m = coco_mask.decode(rle)  # (H,W) or (H,W,1)
        if m.ndim == 3:
            m = m[:, :, 0]
        m = (m > 0).astype("uint8")
        return torch.from_numpy(m)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]:
        img_id = self.img_ids[idx]
        info = self.id_to_img[img_id]
        file_name = info["file_name"]
        width = int(info.get("width", 0))
        height = int(info.get("height", 0))

        img_path = os.path.join(self.images_dir, file_name)
        img = Image.open(img_path).convert("RGB")
        image = F.to_tensor(img)  # float32, [0,1]

        # If width/height missing in json, infer from loaded image
        if width <= 0 or height <= 0:
            height, width = image.shape[1], image.shape[2]

        anns = self.anns_by_img.get(img_id, [])

        boxes: List[List[float]] = []
        labels: List[int] = []
        masks: List[torch.Tensor] = []
        areas: List[float] = []
        iscrowd: List[int] = []

        for ann in anns:
            # bbox: [x,y,w,h] -> xyxy
            x, y, w, h = ann["bbox"]
            if w <= 1 or h <= 1:
                continue

            cat_id = int(ann["category_id"])
            if cat_id not in self.cat_map:
                continue

            x1, y1 = float(x), float(y)
            x2, y2 = float(x + w), float(y + h)

            # clamp to image size
            x1 = max(0.0, min(x1, width - 1.0))
            y1 = max(0.0, min(y1, height - 1.0))
            x2 = max(0.0, min(x2, width * 1.0))
            y2 = max(0.0, min(y2, height * 1.0))
            if x2 <= x1 or y2 <= y1:
                continue

            m = self._ann_to_mask(ann, height, width)
            # 일부 annotation은 segmentation이 비거나 이상할 수 있음 → bbox 영역이랑 불일치 시라도 유지
            # mask가 전부 0이면 skip 하고 싶다면 아래 조건을 켜도 됨.
            # if int(m.sum()) == 0: continue

            boxes.append([x1, y1, x2, y2])
            labels.append(self.cat_map[cat_id])
            masks.append(m)
            areas.append(float(ann.get("area", w * h)))
            iscrowd.append(int(ann.get("iscrowd", 0)))

        if len(boxes) == 0:
            boxes_t = torch.zeros((0, 4), dtype=torch.float32)
            labels_t = torch.zeros((0,), dtype=torch.int64)
            masks_t = torch.zeros((0, height, width), dtype=torch.uint8)
            areas_t = torch.zeros((0,), dtype=torch.float32)
            iscrowd_t = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes_t = torch.tensor(boxes, dtype=torch.float32)
            labels_t = torch.tensor(labels, dtype=torch.int64)
            masks_t = torch.stack(masks, dim=0).to(torch.uint8)  # [N,H,W]
            areas_t = torch.tensor(areas, dtype=torch.float32)
            iscrowd_t = torch.tensor(iscrowd, dtype=torch.int64)

        target = {
            "boxes": boxes_t,
            "labels": labels_t,
            "masks": masks_t,  # uint8/0-1
            "image_id": torch.tensor([img_id], dtype=torch.int64),
            "area": areas_t,
            "iscrowd": iscrowd_t,
        }
        return image, target


# -------------------------
# Model
# -------------------------
def build_model(num_classes: int):
    try:
        model = maskrcnn_resnet50_fpn(weights="DEFAULT")
    except TypeError:
        model = maskrcnn_resnet50_fpn(pretrained=True)

    # box predictor
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # mask predictor
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden, num_classes)
    return model


# -------------------------
# Train / quick "val loss"
# -------------------------
def train_one_epoch(model, loader, optimizer, device, epoch: int, print_every: int = 50):
    model.train()
    loss_sum = 0.0

    for it, (images, targets) in enumerate(loader, start=1):
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(v for v in loss_dict.values())

        optimizer.zero_grad(set_to_none=True)
        losses.backward()
        optimizer.step()

        loss_sum += float(losses.item())

        if it % print_every == 0:
            avg = loss_sum / it
            detail = {k: float(v.item()) for k, v in loss_dict.items()}
            print(f"[Epoch {epoch}] iter={it}/{len(loader)} loss={avg:.4f} detail={detail}")

    return loss_sum / max(1, len(loader))


@torch.no_grad()
def compute_loss_on_loader(model, loader, device):
    # eval()에서는 loss가 안 나오는 케이스가 많아서 train()+no_grad로 간단 계산
    model.train()
    loss_sum = 0.0
    n = 0
    for images, targets in loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        loss_sum += float(sum(v for v in loss_dict.values()).item())
        n += 1
    return loss_sum / max(1, n)


def save_ckpt(model, optimizer, epoch: int, out_path: str):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    torch.save({"epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict()}, out_path)


# -------------------------
# Main
# -------------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--coco_root", default="coco", help="folder that contains train2017/val2017/annotations")
    p.add_argument("--epochs", type=int, default=10)
    p.add_argument("--batch", type=int, default=1)  # maskrcnn은 무거워서 기본 1
    p.add_argument("--lr", type=float, default=0.005)
    p.add_argument("--momentum", type=float, default=0.9)
    p.add_argument("--weight_decay", type=float, default=0.0005)
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--device", default="cuda")
    p.add_argument("--out_dir", default="runs/maskrcnn_coco_official")
    p.add_argument("--print_every", type=int, default=50)
    return p.parse_args()


def main():
    args = parse_args()
    set_seed(args.seed)

    device = torch.device(args.device if torch.cuda.is_available() and args.device.startswith("cuda") else "cpu")

    train_images = os.path.join(args.coco_root, "train2017")
    val_images = os.path.join(args.coco_root, "val2017")
    ann_dir = os.path.join(args.coco_root, "annotations")
    train_ann = os.path.join(ann_dir, "instances_train2017.json")
    val_ann = os.path.join(ann_dir, "instances_val2017.json")

    cat_map, num_classes = build_category_mapping(train_ann)

    print(f"[INFO] COCO root     : {args.coco_root}")
    print(f"[INFO] num_classes  : {num_classes} (incl background)")
    print(f"[INFO] train_images : {train_images}")
    print(f"[INFO] train_ann    : {train_ann}")
    print(f"[INFO] val_images   : {val_images}")
    print(f"[INFO] val_ann      : {val_ann}")

    train_ds = CocoInstanceSegDataset(train_images, train_ann, cat_map)
    val_ds = CocoInstanceSegDataset(val_images, val_ann, cat_map)

    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    model = build_model(num_classes)
    model.to(device)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    print(f"[INFO] device={device} train={len(train_ds)} val={len(val_ds)}")

    for epoch in range(1, args.epochs + 1):
        train_loss = train_one_epoch(model, train_loader, optimizer, device, epoch, print_every=args.print_every)
        val_loss = compute_loss_on_loader(model, val_loader, device)
        lr_scheduler.step()

        lr_now = optimizer.param_groups[0]["lr"]
        print(f"[EPOCH {epoch}] train_loss={train_loss:.4f} val_loss={val_loss:.4f} lr={lr_now:.6f}")

        ckpt_path = os.path.join(args.out_dir, f"epoch_{epoch:03d}.pth")
        save_ckpt(model, optimizer, epoch, ckpt_path)

    print(f"[DONE] checkpoints saved to: {args.out_dir}")


if __name__ == "__main__":
    main()

  Mask R-CNN은 box loss 외에도 mask loss가 추가되기 때문에, 객체 검출보다 메모리 사용량이 큽니다. batch size를 무리하게 키우지 않는 것이 중요합니다.

 

마무리

  이번 글에서는 COCO 데이터셋을 기반으로 Mask R-CNN 모델을 구성하고 PyTorch에서 실제 학습까지 연결하는 과정을 살펴보았습니다. 인스턴스 분할은 객체 검출보다 정보량이 훨씬 풍부한 태스크이지만, 그만큼 데이터 구조와 메모리 관리가 중요합니다.

  다음 글에서는 학습된 Mask R-CNN 모델을 이용해 실제 이미지에 대한 추론 결과를 시각화하고, mask 품질을 확인하는 방법을 정리해보겠습니다.

 

관련 내용

 

 

반응형