메인 콘텐츠로 건너뛰기
Colab에서 실행해 보기 이 튜토리얼에서는 MONAI를 사용해 다중 레이블 3D 뇌종양 분할 작업을 위한 학습 워크플로를 구성하고, W&B의 실험 추적 및 데이터 시각화 기능을 사용하는 방법을 설명합니다. 이 튜토리얼에서는 다음과 같은 내용을 다룹니다:
  1. W&B 실행을 초기화하고, 재현성을 위해 해당 실행과 연관된 모든 설정(config)을 동기화합니다.
  2. MONAI transform API:
    1. 딕셔너리 형식 데이터용 MONAI transform.
    2. MONAI transforms API에 따라 새 transform을 정의하는 방법.
    3. 데이터 증강을 위해 강도를 무작위로 조정하는 방법.
  3. 데이터 로딩 및 시각화:
    1. 메타데이터가 포함된 Nifti 이미지를 로드하고, 이미지 목록을 로드한 뒤 스택하는 방법.
    2. 학습 및 검증을 가속화하기 위해 입출력(I/O)과 transform을 캐시하는 방법.
    3. wandb.Table과 W&B의 대화형 세그멘테이션 오버레이를 사용하여 데이터를 시각화합니다.
  4. 3D SegResNet 모델 학습
    1. MONAI의 networks, losses, metrics API를 사용하는 방법.
    2. PyTorch 학습 루프를 사용하여 3D SegResNet 모델을 학습하는 방법.
    3. W&B를 사용하여 학습 실험을 추적합니다.
    4. 모델 체크포인트를 W&B에서 모델 아티팩트로 로깅하고 버전 관리합니다.
  5. wandb.Table과 W&B의 대화형 세그멘테이션 오버레이를 사용하여 검증 데이터셋에 대한 예측을 시각화하고 비교합니다.

설정 및 설치

먼저 MONAI와 W&B의 최신 버전을 설치하세요.
!python -c "import monai" || pip install -q -U "monai[nibabel, tqdm]"
!python -c "import wandb" || pip install -q -U wandb
import os

import numpy as np
from tqdm.auto import tqdm
import wandb

from monai.apps import DecathlonDataset
from monai.data import DataLoader, decollate_batch
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism

import torch
그런 다음 Colab 인스턴스를 인증하여 W&B를 사용할 수 있도록 합니다.
wandb.login()

W&B 실행 초기화

실험 추적을 시작하려면 새 W&B 실행을 초기화하세요. 적절한 구성(config) 시스템을 사용하는 것은 재현 가능한 머신 러닝을 위한 권장 모범 사례입니다. W&B를 사용하면 각 실험의 하이퍼파라미터를 추적할 수 있습니다.
with wandb.init(project="monai-brain-tumor-segmentation") as run:

    config = run.config
    config.seed = 0
    config.roi_size = [224, 224, 144]
    config.batch_size = 1
    config.num_workers = 4
    config.max_train_images_visualized = 20
    config.max_val_images_visualized = 20
    config.dice_loss_smoothen_numerator = 0
    config.dice_loss_smoothen_denominator = 1e-5
    config.dice_loss_squared_prediction = True
    config.dice_loss_target_onehot = False
    config.dice_loss_apply_sigmoid = True
    config.initial_learning_rate = 1e-4
    config.weight_decay = 1e-5
    config.max_train_epochs = 50
    config.validation_intervals = 1
    config.dataset_dir = "./dataset/"
    config.checkpoint_dir = "./checkpoints"
    config.inference_roi_size = (128, 128, 64)
    config.max_prediction_images_visualized = 20
결정론적 학습을 활성화하거나 비활성화하려면 모듈의 난수 시드도 설정해야 합니다.
set_determinism(seed=config.seed)

# 디렉토리 생성
os.makedirs(config.dataset_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)

데이터 로딩 및 변환

여기서는 monai.transforms API를 사용해 다중 클래스 레이블을 원-핫 형식의 멀티 라벨 세그멘테이션 태스크에 맞는 레이블로 변환하는 커스텀 변환을 생성합니다.
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    brats 클래스를 기반으로 레이블을 다중 채널로 변환합니다:
    레이블 1은 종양 주변 부종(peritumoral edema)
    레이블 2는 GD 조영 종양(GD-enhancing tumor)
    레이블 3은 괴사 및 비조영 종양 핵심(necrotic and non-enhancing tumor core)
    가능한 클래스는 TC (종양 핵심, Tumor core), WT (전체 종양, Whole tumor)
    및 ET (조영 종양, Enhancing tumor)입니다.

    Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # 레이블 2와 레이블 3을 병합하여 TC 구성
            result.append(torch.logical_or(d[key] == 2, d[key] == 3))
            # 레이블 1, 2, 3을 병합하여 WT 구성
            result.append(
                torch.logical_or(
                    torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
                )
            )
            # 레이블 2는 ET
            result.append(d[key] == 2)
            d[key] = torch.stack(result, axis=0).float()
        return d
다음으로, 학습용 및 검증용 데이터셋에 대한 변환을 각각 설정합니다.
train_transform = Compose(
    [
        # 4개의 Nifti 이미지를 로드하고 함께 스택
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(
            keys=["image", "label"], roi_size=config.roi_size, random_size=False
        ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

데이터셋

이 실험에 사용된 데이터셋은 http://medicaldecathlon.com/ 에서 제공합니다. 이 데이터셋은 다중 모달 및 다중 사이트 MRI 데이터(FLAIR, T1w, T1gd, T2w)를 사용해 교종(Gliomas), 괴사성/활성 종양, 부종(oedema)을 세분화(segmentation)합니다. 데이터셋은 총 750개의 4D 볼륨(학습 484개 + 테스트 266개)으로 구성됩니다. DecathlonDataset을 사용해 데이터셋을 자동으로 다운로드하고 압축을 해제하십시오. 이 클래스는 MONAI의 CacheDataset을 상속하며, 이를 통해 학습 시 cache_num=N으로 설정해 N개의 항목을 캐싱할 수 있고, 검증 시에는 기본 인수를 사용해 메모리 크기에 따라 모든 항목을 캐싱할 수 있습니다.
train_dataset = DecathlonDataset(
    root_dir=config.dataset_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)
val_dataset = DecathlonDataset(
    root_dir=config.dataset_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)
참고: train_datasettrain_transform을 적용하는 대신, 학습 및 검증 데이터셋 모두에 val_transform을 적용하세요. 이는 학습 전에 데이터셋의 두 분할(학습/검증) 모두에서 샘플을 시각화하기 때문입니다.

데이터셋 시각화

W&B는 이미지, 비디오, 오디오 등을 지원합니다. 결과를 탐색하고 실행, 모델, 데이터셋 간을 시각적으로 비교할 수 있도록 리치 미디어를 로깅할 수 있습니다. segmentation mask overlay system을 사용해 데이터 규모를 시각화하십시오. Tables에 세그멘테이션 마스크를 로깅하려면, 테이블의 각 행마다 wandb.Image 객체를 제공해야 합니다. 아래 의사코드 예제를 참고하십시오:
table = wandb.Table(columns=["ID", "Image"])

for id, img, label in zip(ids, images, labels):
    mask_img = wandb.Image(
        img,
        masks={
            "prediction": {"mask_data": label, "class_labels": class_labels}
            # ...
        },
    )

    table.add_data(id, img)

run.log({"Table": table})
이제 샘플 이미지, 레이블, wandb.Table 객체와 관련 메타데이터를 입력으로 받아 W&B 대시보드에 로깅될 테이블의 행을 채우는 간단한 유틸리티 함수를 작성하세요.
def log_data_samples_into_tables(
    sample_image: np.array,
    sample_label: np.array,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    num_channels, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            ground_truth_wandb_images = []
            for channel_idx in range(num_channels):
                ground_truth_wandb_images.append(
                    masks = {
                        "ground-truth/Tumor-Core": {
                            "mask_data": sample_label[0, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Tumor Core"},
                        },
                        "ground-truth/Whole-Tumor": {
                            "mask_data": sample_label[1, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Whole Tumor"},
                        },
                        "ground-truth/Enhancing-Tumor": {
                            "mask_data": sample_label[2, :, :, slice_idx] * 3,
                            "class_labels": {0: "background", 3: "Enhancing Tumor"},
                        },
                    }
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks=masks,
                    )
                )
            table.add_data(split, data_idx, slice_idx, *ground_truth_wandb_images)
            progress_bar.update(1)
    return table
다음으로, 데이터 시각화를 테이블에 채워 넣을 수 있도록 wandb.Table 객체와 그 객체를 구성하는 열들을 정의합니다.
table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Image-Channel-0",
        "Image-Channel-1",
        "Image-Channel-2",
        "Image-Channel-3",
    ]
)
그런 다음 train_datasetval_dataset을 각각 순회하여 데이터 샘플에 대한 시각화를 생성하고, 대시보드에 로깅할 테이블의 각 행을 채웁니다.
# train_dataset에 대한 시각화 생성
max_samples = (
    min(config.max_train_images_visualized, len(train_dataset))
    if config.max_train_images_visualized > 0
    else len(train_dataset)
)
progress_bar = tqdm(
    enumerate(train_dataset[:max_samples]),
    total=max_samples,
    desc="훈련 데이터셋 시각화 생성 중:",
)
for data_idx, sample in progress_bar:
    sample_image = sample["image"].detach().cpu().numpy()
    sample_label = sample["label"].detach().cpu().numpy()
    table = log_data_samples_into_tables(
        sample_image,
        sample_label,
        split="train",
        data_idx=data_idx,
        table=table,
    )

# val_dataset에 대한 시각화 생성
max_samples = (
    min(config.max_val_images_visualized, len(val_dataset))
    if config.max_val_images_visualized > 0
    else len(val_dataset)
)
progress_bar = tqdm(
    enumerate(val_dataset[:max_samples]),
    total=max_samples,
    desc="검증 데이터셋 시각화 생성 중:",
)
for data_idx, sample in progress_bar:
    sample_image = sample["image"].detach().cpu().numpy()
    sample_label = sample["label"].detach().cpu().numpy()
    table = log_data_samples_into_tables(
        sample_image,
        sample_label,
        split="val",
        data_idx=data_idx,
        table=table,
    )

# 대시보드에 테이블 로깅
run.log({"Tumor-Segmentation-Data": table})
데이터는 W&B 대시보드에서 대화형 표 형식으로 표시됩니다. 각 행에서 데이터 볼륨의 특정 슬라이스에 대해, 각 채널이 해당 세그멘테이션 마스크와 함께 오버레이된 모습을 확인할 수 있습니다. Weave 쿼리를 작성하여 테이블의 데이터를 필터링하고 특정 행에 집중할 수 있습니다.
로그된 테이블 데이터
이미지를 열고, 대화형 오버레이를 사용해 각 세그멘테이션 마스크와 어떻게 상호작용할 수 있는지 확인해 보세요.
세그멘테이션 맵
참고: 이 데이터셋의 레이블은 클래스 간에 서로 겹치지 않는 마스크로 구성됩니다. 오버레이는 레이블을 오버레이 내에서 개별 마스크로 로깅합니다.

데이터 불러오기

데이터셋으로부터 데이터를 로드할 PyTorch DataLoader를 생성합니다. DataLoader를 생성하기 전에, 학습에 사용할 데이터에 전처리와 변환을 적용할 수 있도록 train_datasettransformtrain_transform으로 설정합니다.
# 학습 데이터셋에 train_transforms 적용
train_dataset.transform = train_transform

# train_loader 생성
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
)

# val_loader 생성
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
)

모델, 손실 함수, 그리고 옵티마이저 생성

이 튜토리얼에서는 3D MRI brain tumor segmentation using auto-encoder regularization 논문을 기반으로 SegResNet 모델을 생성합니다. SegResNet 모델은 monai.networks API의 일부로 PyTorch 모듈 형태로 구현되어 있으며, 옵티마이저와 학습률 스케줄러도 함께 제공합니다.
device = torch.device("cuda:0")

# 모델 생성
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)

# 옵티마이저 생성
optimizer = torch.optim.Adam(
    model.parameters(),
    config.initial_learning_rate,
    weight_decay=config.weight_decay,
)

# 학습률 스케줄러 생성
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config.max_train_epochs
)
monai.losses API를 사용해 손실 함수를 멀티레이블 DiceLoss로 정의하고, monai.metrics API를 사용해 이에 대응하는 Dice 지표를 정의합니다.
loss_function = DiceLoss(
    smooth_nr=config.dice_loss_smoothen_numerator,
    smooth_dr=config.dice_loss_smoothen_denominator,
    squared_pred=config.dice_loss_squared_prediction,
    to_onehot_y=config.dice_loss_target_onehot,
    sigmoid=config.dice_loss_apply_sigmoid,
)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# 자동 혼합 정밀도를 사용하여 학습 가속화
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
혼합 정밀도 추론을 위한 간단한 유틸리티를 정의합니다. 이는 학습 과정의 검증 단계에서, 그리고 학습이 끝난 후 모델을 실행할 때 유용합니다.
def inference(model, input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    with torch.cuda.amp.autocast():
        return _compute(input)

학습 및 검증

학습을 시작하기 전에 이후 학습 및 검증 과정을 run.log()로 추적할 때 사용할 메트릭 속성을 정의합니다.
run.define_metric("epoch/epoch_step")
run.define_metric("epoch/*", step_metric="epoch/epoch_step")
run.define_metric("batch/batch_step")
run.define_metric("batch/*", step_metric="batch/batch_step")
run.define_metric("validation/validation_step")
run.define_metric("validation/*", step_metric="validation/validation_step")

batch_step = 0
validation_step = 0
metric_values = []
metric_values_tumor_core = []
metric_values_whole_tumor = []
metric_values_enhanced_tumor = []

표준 PyTorch 학습 루프 실행

with wandb.init(
    project="monai-brain-tumor-segmentation",
    config=config,
    job_type="train",
    reinit=True,
) as run:

    # W&B 아티팩트 객체 정의
    artifact = wandb.Artifact(
        name=f"{run.id}-checkpoint", type="model"
    )

    epoch_progress_bar = tqdm(range(config.max_train_epochs), desc="Training:")

    for epoch in epoch_progress_bar:
        model.train()
        epoch_loss = 0

        total_batch_steps = len(train_dataset) // train_loader.batch_size
        batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)
        
        # 학습 단계
        for batch_data in batch_progress_bar:
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            epoch_loss += loss.item()
            batch_progress_bar.set_description(f"train_loss: {loss.item():.4f}:")
            ## 배치별 학습 손실을 W&B에 기록
            run.log({"batch/batch_step": batch_step, "batch/train_loss": loss.item()})
            batch_step += 1

        lr_scheduler.step()
        epoch_loss /= total_batch_steps
        ## 배치별 학습 손실 및 학습률을 W&B에 기록
        run.log(
            {
                "epoch/epoch_step": epoch,
                "epoch/mean_train_loss": epoch_loss,
                "epoch/learning_rate": lr_scheduler.get_last_lr()[0],
            }
        )
        epoch_progress_bar.set_description(f"Training: train_loss: {epoch_loss:.4f}:")

        # 검증 및 모델 체크포인트 저장 단계
        if (epoch + 1) % config.validation_intervals == 0:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    val_outputs = inference(model, val_inputs)
                    val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                    dice_metric(y_pred=val_outputs, y=val_labels)
                    dice_metric_batch(y_pred=val_outputs, y=val_labels)

                metric_values.append(dice_metric.aggregate().item())
                metric_batch = dice_metric_batch.aggregate()
                metric_values_tumor_core.append(metric_batch[0].item())
                metric_values_whole_tumor.append(metric_batch[1].item())
                metric_values_enhanced_tumor.append(metric_batch[2].item())
                dice_metric.reset()
                dice_metric_batch.reset()

                checkpoint_path = os.path.join(config.checkpoint_dir, "model.pth")
                torch.save(model.state_dict(), checkpoint_path)
                
                # W&B 아티팩트를 사용하여 모델 체크포인트 기록 및 버전 관리.
                artifact.add_file(local_path=checkpoint_path)
                run.log_artifact(artifact, aliases=[f"epoch_{epoch}"])

                # 검증 지표를 W&B 대시보드에 기록.
                run.log(
                    {
                        "validation/validation_step": validation_step,
                        "validation/mean_dice": metric_values[-1],
                        "validation/mean_dice_tumor_core": metric_values_tumor_core[-1],
                        "validation/mean_dice_whole_tumor": metric_values_whole_tumor[-1],
                        "validation/mean_dice_enhanced_tumor": metric_values_enhanced_tumor[-1],
                    }
                )
                validation_step += 1


    # 이 아티팩트의 기록이 완료될 때까지 대기
    artifact.wait()
wandb.log으로 코드를 계측하면 학습 및 검증 과정과 관련된 모든 지표를 추적할 수 있을 뿐만 아니라, W&B 대시보드에서 시스템 지표(여기서는 CPU와 GPU)도 모두 기록할 수 있습니다.
학습 및 검증 추적
학습 중에 로깅된 모델 체크포인트 아티팩트의 다양한 버전에 접근하려면 W&B 실행 대시보드의 Artifacts 탭으로 이동하십시오.
모델 체크포인트 로깅

추론

아티팩트 인터페이스를 사용해 어떤 버전의 아티팩트가 최적의 모델 체크포인트인지 선택할 수 있습니다(이 예에서는 에포크별 평균 학습 손실). 또한 아티팩트의 전체 lineage를 탐색하고, 필요한 버전을 선택해 사용할 수 있습니다.
Model artifact tracking
에포크별 평균 학습 손실이 가장 좋은 모델 아티팩트 버전을 가져와 체크포인트 state dictionary를 모델에 로드합니다.
run = wandb.init(
    project="monai-brain-tumor-segmentation",
    job_type="inference",
    reinit=True,
)
model_artifact = run.use_artifact(
    "geekyrakshit/monai-brain-tumor-segmentation/d5ex6n4a-checkpoint:v49",
    type="model",
)
model_artifact_dir = model_artifact.download()
model.load_state_dict(torch.load(os.path.join(model_artifact_dir, "model.pth")))
model.eval()

예측 시각화 및 정답 레이블과의 비교

사전 학습된 모델의 예측을 시각화하고, 대화형 세그멘테이션 마스크 오버레이를 사용해 해당 정답 세그멘테이션 마스크와 비교하는 또 다른 유틸리티 함수를 만듭니다.
def log_predictions_into_tables(
    sample_image: np.array,
    sample_label: np.array,
    predicted_label: np.array,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    num_channels, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            wandb_images = []
            for channel_idx in range(num_channels):
                wandb_images += [
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Tumor-Core": {
                                "mask_data": sample_label[0, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Tumor Core"},
                            },
                            "prediction/Tumor-Core": {
                                "mask_data": predicted_label[0, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Tumor Core"},
                            },
                        },
                    ),
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Whole-Tumor": {
                                "mask_data": sample_label[1, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Whole Tumor"},
                            },
                            "prediction/Whole-Tumor": {
                                "mask_data": predicted_label[1, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Whole Tumor"},
                            },
                        },
                    ),
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Enhancing-Tumor": {
                                "mask_data": sample_label[2, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Enhancing Tumor"},
                            },
                            "prediction/Enhancing-Tumor": {
                                "mask_data": predicted_label[2, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Enhancing Tumor"},
                            },
                        },
                    ),
                ]
            table.add_data(split, data_idx, slice_idx, *wandb_images)
            progress_bar.update(1)
    return table
예측 결과를 예측 테이블에 로깅합니다.
run = wandb.init(
    project="monai-brain-tumor-segmentation",
    job_type="inference",
    reinit=True,
)
# 예측 테이블 생성
prediction_table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Image-Channel-0/Tumor-Core",
        "Image-Channel-1/Tumor-Core",
        "Image-Channel-2/Tumor-Core",
        "Image-Channel-3/Tumor-Core",
        "Image-Channel-0/Whole-Tumor",
        "Image-Channel-1/Whole-Tumor",
        "Image-Channel-2/Whole-Tumor",
        "Image-Channel-3/Whole-Tumor",
        "Image-Channel-0/Enhancing-Tumor",
        "Image-Channel-1/Enhancing-Tumor",
        "Image-Channel-2/Enhancing-Tumor",
        "Image-Channel-3/Enhancing-Tumor",
    ]
)

# 추론 및 시각화 수행
with torch.no_grad():
    config.max_prediction_images_visualized
    max_samples = (
        min(config.max_prediction_images_visualized, len(val_dataset))
        if config.max_prediction_images_visualized > 0
        else len(val_dataset)
    )
    progress_bar = tqdm(
        enumerate(val_dataset[:max_samples]),
        total=max_samples,
        desc="예측 생성 중:",
    )
    for data_idx, sample in progress_bar:
        val_input = sample["image"].unsqueeze(0).to(device)
        val_output = inference(model, val_input)
        val_output = post_trans(val_output[0])
        prediction_table = log_predictions_into_tables(
            sample_image=sample["image"].cpu().numpy(),
            sample_label=sample["label"].cpu().numpy(),
            predicted_label=val_output.cpu().numpy(),
            data_idx=data_idx,
            split="validation",
            table=prediction_table,
        )

    run.log({"Predictions/Tumor-Segmentation-Data": prediction_table})


# 실험 종료
run.finish()
대화형 분할 마스크 오버레이를 사용해 각 클래스별 예측 분할 마스크와 정답 레이블을 분석하고 비교하세요.
Predictions and ground-truth

감사의 말 및 추가 리소스