메인 콘텐츠로 건너뛰기
Colab에서 실행해 보기 이 문서에서는 PyTorch Lightning을 사용해 이미지 분류 파이프라인을 구축합니다. 코드의 가독성과 재현성을 높이기 위해 이 스타일 가이드를 따르겠습니다. 이와 관련된 좋은 설명은 여기에서 확인할 수 있습니다.

PyTorch Lightning과 W&B 설정

이 튜토리얼에서는 PyTorch Lightning과 W&B를 사용합니다.
pip install lightning -q
pip install wandb -qU
import lightning.pytorch as pl

# 머신러닝 추적을 위한 최고의 도구
from lightning.pytorch.loggers import WandbLogger

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms
from torchvision.datasets import CIFAR10

import wandb
이제 wandb 계정에 로그인하세요.
wandb.login()

DataModule - 우리가 원하던 데이터 파이프라인

DataModule은 데이터 관련 훅(hook)을 LightningModule에서 분리해, 특정 데이터셋에 구애받지 않는 모델을 개발할 수 있게 해 주는 방법입니다. 데이터 파이프라인을 하나의 공유 가능하고 재사용 가능한 클래스로 구성합니다. DataModule은 PyTorch에서 데이터 처리에 관련된 다섯 가지 단계를 캡슐화합니다:
  • 다운로드 / 토큰화 / 처리.
  • 정제하고 (필요하다면) 디스크에 저장.
  • Dataset 안에 로드.
  • 변환 적용 (회전, 토큰화 등…).
  • DataLoader 안에 래핑.
DataModule에 대해 더 알아보려면 여기를 참고하세요. 이제 Cifar-10 데이터셋을 위한 DataModule을 만들어 봅시다.
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        self.num_classes = 10
    
    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        # 데이터로더에서 사용할 train/val 데이터셋 할당
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # 데이터로더에서 사용할 test 데이터셋 할당
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

Callbacks

콜백은 여러 프로젝트에서 재사용할 수 있는 독립적인 프로그램입니다. PyTorch Lightning에는 자주 사용되는 내장 콜백이 몇 가지 포함되어 있습니다. PyTorch Lightning의 콜백에 대해 더 알아보려면 여기를 참고하세요.

내장 콜백

이 튜토리얼에서는 Early StoppingModel Checkpoint 내장 콜백을 사용합니다. 이 콜백들은 Trainer에 전달할 수 있습니다.

사용자 정의 콜백

사용자 정의 Keras 콜백에 익숙하다면, PyTorch 파이프라인에서도 동일한 작업을 수행할 수 있다는 점은 정말 금상첨화입니다. 이미지 분류를 수행하고 있으므로, 일부 이미지 샘플에 대한 모델의 예측을 시각화할 수 있으면 매우 유용합니다. 이를 콜백으로 구현해 두면 모델을 초기 단계에서 디버깅하는 데 도움이 됩니다.
class ImagePredictionLogger(pl.callbacks.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
    
    def on_validation_epoch_end(self, trainer, pl_module):
        # 텐서를 CPU로 이동
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # 모델 예측 가져오기
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # 이미지를 wandb Image로 로깅
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            })
        

LightningModule - 시스템 정의하기

LightningModule은 모델이 아니라 시스템을 정의합니다. 여기서 시스템이란 모든 연구 코드를 하나의 클래스로 묶어, 독립적으로 동작할 수 있게 하는 것을 의미합니다. LightningModule은 PyTorch 코드를 다음 5개 섹션으로 구성합니다:
  • 계산(__init__)
  • 학습 루프(training_step)
  • 검증 루프(validation_step)
  • 테스트 루프(test_step)
  • 옵티마이저(configure_optimizers)
이렇게 하면 특정 데이터셋에 의존하지 않는 모델을 만들어 쉽게 공유할 수 있습니다. 이제 CIFAR-10 분류를 위한 시스템을 만들어 보겠습니다.
class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        # 하이퍼파라미터 로깅
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1)

        self.pool1 = torch.nn.MaxPool2d(2)
        self.pool2 = torch.nn.MaxPool2d(2)
        
        n_sizes = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(n_sizes, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    # conv 블록에서 Linear 레이어로 들어가는 출력 텐서의 크기를 반환합니다.
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    # conv 블록에서 특징 텐서를 반환합니다
    def _forward_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool2(F.relu(self.conv4(x)))
        return x
    
    # 추론 시 사용됩니다
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = F.log_softmax(self.fc3(x), dim=1)
       
       return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # 학습 메트릭
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # 검증 메트릭
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # 검증 메트릭
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

학습 및 평가

DataModule로 데이터 파이프라인을, LightningModule로 모델 아키텍처와 학습 루프를 구성했다면, 나머지는 PyTorch Lightning Trainer가 모두 자동으로 처리합니다. Trainer는 다음을 자동으로 처리합니다:
  • 에포크 및 배치 반복
  • optimizer.step(), backward, zero_grad() 호출
  • .eval() 호출, 그래디언트 활성화/비활성화
  • 가중치 저장 및 로드
  • W&B 로깅
  • 멀티 GPU 학습 지원
  • TPU 지원
  • 16비트 학습 지원
dm = CIFAR10DataModule(batch_size=32)
# x_dataloader에 접근하려면 prepare_data와 setup을 호출해야 합니다.
dm.prepare_data()
dm.setup()

# 이미지 예측을 로깅하기 위해 커스텀 ImagePredictionLogger 콜백에 필요한 샘플입니다.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
model = LitModel((3, 32, 32), dm.num_classes)

# wandb 로거 초기화
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# 콜백 초기화
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()

# 트레이너 초기화
trainer = pl.Trainer(max_epochs=2,
                     logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples),
                                checkpoint_callback],
                     )

# 모델 학습 
trainer.fit(model, dm)

# 홀드아웃 테스트 세트에서 모델 평가 ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())

# wandb 실행 종료
run.finish()

마무리 생각

저는 TensorFlow/Keras 생태계에서 시작해서, PyTorch가 우아한 프레임워크임에도 불구하고 다소 부담스럽게 느껴졌습니다. 어디까지나 제 개인적인 경험입니다. PyTorch Lightning을 살펴보면서, 제가 PyTorch를 멀리하게 만들었던 거의 모든 이유가 해결된다는 걸 깨달았습니다. 제 기대를 간단히 정리하면 다음과 같습니다:
  • 예전: 전통적인 PyTorch 모델 정의는 여기저기 흩어져 있었습니다. 모델은 어떤 model.py 스크립트에 있고, 학습 루프는 train.py 파일에 있는 식이었죠. 파이프라인을 이해하려면 이리저리 왔다 갔다 하며 코드를 봐야 했습니다.
  • 지금: LightningModuletraining_step, validation_step 등을 모델 정의와 함께 묶어 주는 시스템 역할을 합니다. 이제는 모듈화되어 있고 공유하기도 좋습니다.
  • 예전: TensorFlow/Keras의 가장 큰 장점 중 하나는 입력 데이터 파이프라인입니다. 데이터셋 카탈로그도 풍부하고 계속 성장하고 있습니다. PyTorch의 데이터 파이프라인은 가장 큰 고통 포인트였습니다. 일반적인 PyTorch 코드에서는 데이터 다운로드/정리/전처리가 여러 파일에 흩어져 있는 경우가 많았습니다.
  • 지금: DataModule은 데이터 파이프라인을 하나의 공유 가능하고 재사용 가능한 클래스로 정리해 줍니다. train_dataloader, val_dataloader(들), test_dataloader(들)과 그에 맞는 transform, 데이터 처리/다운로드 단계를 모아 놓은 집합이라고 보면 됩니다.
  • 예전: Keras에서는 model.fit으로 모델을 학습시키고, model.predict로 추론을 수행할 수 있습니다. model.evaluate는 테스트 데이터에 대해 익숙한 단순 평가를 제공했죠. PyTorch에서는 그렇지 않습니다. 보통은 train.pytest.py가 따로 존재합니다.
  • 지금: LightningModule이 있으면 Trainer가 모든 것을 자동화해 줍니다. 모델을 학습하고 평가하기 위해서는 trainer.fittrainer.test만 호출하면 됩니다.
  • 예전: TensorFlow는 TPU를 사랑하고, PyTorch는…
  • 지금: PyTorch Lightning을 사용하면 동일한 모델을 여러 GPU는 물론 TPU에서도 아주 쉽게 학습시킬 수 있습니다.
  • 예전: 저는 Callback의 큰 팬이고, 커스텀 Callback을 직접 작성하는 것을 선호합니다. Early Stopping처럼 사소해 보이는 것도 전통적인 PyTorch에서는 논쟁거리였습니다.
  • 지금: PyTorch Lightning에서는 Early Stopping과 Model Checkpointing을 사용하는 것이 정말 간단합니다. 여기에 커스텀 Callback도 직접 작성할 수 있습니다.

🎨 결론 및 참고 자료

이 리포트가 도움이 되었기를 바랍니다. 제공된 코드를 직접 만져 보면서, 원하는 데이터셋으로 이미지 분류기를 학습시켜 보시기를 권장합니다. PyTorch Lightning에 대해 더 배우는 데 도움이 되는 자료는 다음과 같습니다:
  • Step-by-step walk-through: 공식 튜토리얼 중 하나입니다. 문서가 매우 잘 정리되어 있어 학습 자료로 적극 추천합니다.
  • Use Pytorch Lightning with W&B: PyTorch Lightning에서 W&B를 사용하는 방법을 익힐 수 있는 간단한 Colab 예제입니다.