본문 바로가기
실전 예제, 프로젝트

[실전 예제/인스턴스 분할/PyTorch] 인스턴스 분할 튜토리얼: COCO 데이터셋으로 PyTorch 데이터셋 만들기

by First Adventure 2025. 4. 19.
반응형

인스턴스 분할(Instance Segmentation)이란?

  인스턴스 분할(Instance Segmentation)은 이미지 속의 객체를 픽셀 단위로 구분하되, 객체의 개별 인스턴스마다 다른 마스크를 예측하는 태스크입니다. 즉, 객체 검출(Object Detection)과 의미론적 분할(Semantic Segmentation)을 동시에 수행하는 고급 비전 과제입니다. 이번 시간에는 PyTorch를 이용하여 인스턴스 분할 데이터셋을 만드는 방법에 대해 알아보도록 하겠습니다.

 

인스턴스 분할 vs 객체 검출 vs 의미론적 분할

  • 인스턴스 분할: 픽셀 단위 클래스 + 인스턴스 구분 모두 수행
  • 객체 검출 (Object Detection): 물체 위치를 박스로 찾음
  • 의미론적 분할 (Semantic): 픽셀 단위 클래스 구분 (개별 인스턴스 구분 X)

 

PyTorch용 COCO Instance Segmentation Dataset 만들기

  COCO는 객체 검출/인스턴스 분할에서 가장 많이 쓰이는 표준 데이터셋 중 하나입니다.

COCO(Common Objects in Context) 데이터셋 특징

  • 80개 객체 클래스
  • 바운딩 박스 + 마스크 + 클래스 라벨 포함
  • 마스크는 RLE(Run-Length Encoding) 또는 binary mask 형식

디렉토리 구조 예시

coco/
  images/
    train2017/
    val2017/
  annotations/
    instances_train2017.json
    instances_val2017.json

 

PyTorch 코드 예제

from pycocotools.coco import COCO
from PIL import Image
import torch
import os


class COCOInstanceDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, ann_file, transform=None, filter_crowd=True):
        self.coco = COCO(ann_file)
        self.img_ids = list(self.coco.imgs.keys())
        self.img_dir = img_dir
        self.transform = transform
        self.filter_crowd = filter_crowd

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]

        img_info = self.coco.loadImgs(img_id)[0]
        file_name = img_info["file_name"]
        img_path = os.path.join(self.img_dir, file_name)

        # 이미지 로드
        image = Image.open(img_path).convert("RGB")
        W, H = image.size

        # annotation 로드
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)

        boxes, labels, masks, areas, iscrowd = [], [], [], [], []

        for ann in anns:
            # (선택) crowd는 학습에서 제외하는 편이 일반적으로 안전함
            if self.filter_crowd and int(ann.get("iscrowd", 0)) == 1:
                continue

            # COCO bbox: [x, y, w, h]
            x, y, w, h = ann["bbox"]
            if w <= 1 or h <= 1:
                continue

            # xyxy 변환 + 이미지 경계로 clamp
            x1 = max(0.0, float(x))
            y1 = max(0.0, float(y))
            x2 = min(float(W), x1 + float(w))
            y2 = min(float(H), y1 + float(h))
            if x2 <= x1 or y2 <= y1:
                continue

            boxes.append([x1, y1, x2, y2])

            # 라벨은 우선 COCO category_id 그대로 사용 (재매핑은 학습 글에서 권장)
            labels.append(int(ann["category_id"]))

            # mask: (H, W) uint8 (0/1)
            m = self.coco.annToMask(ann)
            masks.append(torch.as_tensor(m, dtype=torch.uint8))

            # torchvision 권장 필드들
            areas.append(float(ann.get("area", (x2 - x1) * (y2 - y1))))
            iscrowd.append(int(ann.get("iscrowd", 0)))

        # tensor 변환 (빈 케이스도 shape 보장)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        if len(masks) > 0:
            masks = torch.stack(masks, dim=0)  # (N, H, W)
        else:
            masks = torch.zeros((0, H, W), dtype=torch.uint8)

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([img_id], dtype=torch.int64),
            "area": torch.as_tensor(areas, dtype=torch.float32) if len(areas) else torch.zeros((0,), dtype=torch.float32),
            "iscrowd": torch.as_tensor(iscrowd, dtype=torch.int64) if len(iscrowd) else torch.zeros((0,), dtype=torch.int64),
        }

        if self.transform:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.img_ids)

  COCO의 category_id는 연속값이 아닐 수 있어, 실제 학습에서는 background=0을 포함한 연속 라벨로 재매핑하는 것이 안전합니다.

모델 예시: Mask R-CNN (torchvision 제공)

      • 입력: 이미지
      • 출력: `boxes`, `labels`, `scores`, `masks` 포함된 딕셔너리
      • 학습 시 `target`에 `masks`, `labels`, `boxes`가 모두 있어야 함
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn

model = maskrcnn_resnet50_fpn(pretrained=True)
model.eval()

 

마무리

  PyTorch를 이용하여 인스턴스 분할 데이터셋을 어떻게 만드는지 살펴보았습니다. 다음 시간에는 모델 구성 및 학습 방법을 PyTorch로 작성하는 방법을 알아보도록 하겠습니다.

 

관련 내용

반응형