메인 콘텐츠로 건너뛰기
Colab에서 실행해 보기 PyTorch Lightning은 PyTorch 코드를 체계적으로 구성하고, 분산 학습이나 16비트 정밀도와 같은 고급 기능을 쉽게 추가할 수 있도록 하는 경량 래퍼를 제공합니다. W&B는 ML 실험을 기록하기 위한 경량 래퍼를 제공합니다. 하지만 이 둘을 직접 결합할 필요는 없습니다. W&B는 WandbLogger를 통해 PyTorch Lightning 라이브러리에 직접 통합되어 있습니다.

Lightning과 연동하기

from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

wandb_logger = WandbLogger(log_model="all")
trainer = Trainer(logger=wandb_logger)
wandb.log() 사용: WandbLogger는 Trainer의 global_step을 사용해 W&B에 로그를 기록합니다. 코드에서 추가로 wandb.log를 직접 호출하는 경우, wandb.log()에서 step 인자를 사용하지 마십시오.대신, 다른 지표와 마찬가지로 Trainer의 global_step을 함께 기록하십시오:
wandb.log({"accuracy":0.99, "trainer/global_step": step})
대화형 대시보드

회원가입 후 API key 생성하기

API key를 통해 사용 중인 머신을 W&B에 인증할 수 있습니다. 사용자 프로필에서 API key를 생성할 수 있습니다.
보다 간편하게 설정하려면 User Settings로 바로 이동해 API key를 생성하세요. 새로 생성된 API key는 즉시 복사하여 비밀번호 관리자와 같은 안전한 위치에 저장하세요.
  1. 오른쪽 상단의 사용자 프로필 아이콘을 클릭합니다.
  2. User Settings를 선택한 다음 아래로 스크롤하여 API Keys 섹션으로 이동합니다.

wandb 라이브러리를 설치하고 로그인하기

로컬 환경에서 wandb 라이브러리를 설치하고 로그인하려면:
  1. WANDB_API_KEY 환경 변수를 본인의 API key로 설정합니다.
    export WANDB_API_KEY=<your_api_key>
    
  2. wandb 라이브러리를 설치하고 로그인합니다.
    pip install wandb
    
    wandb login
    

PyTorch Lightning의 WandbLogger 사용

PyTorch Lightning에는 메트릭, 모델 가중치, 미디어 등을 로깅하기 위한 여러 종류의 WandbLogger 클래스가 있습니다. Lightning과 연동하려면 WandbLogger 인스턴스를 생성한 후 Lightning의 Trainer 또는 Fabric에 전달합니다.
trainer = Trainer(logger=wandb_logger)

공통 logger 인자

아래는 WandbLogger에서 가장 많이 사용되는 파라미터입니다. 모든 logger 인자에 대한 자세한 내용은 PyTorch Lightning 문서를 확인하세요.
ParameterDescription
project로그를 기록할 wandb Project를 지정합니다
namewandb 실행에 이름을 지정합니다
log_modellog_model="all"이면 모든 모델을, log_model=True이면 학습 종료 시점에 모델을 기록합니다
save_dir데이터가 저장될 경로입니다

하이퍼파라미터 기록하기

class LitModule(LightningModule):
    def __init__(self, *args, **kwarg):
        self.save_hyperparameters()

추가 구성 파라미터 기록

# 파라미터 하나 추가
wandb_logger.experiment.config["key"] = value

# 여러 파라미터 추가
wandb_logger.experiment.config.update({key1: val1, key2: val2})

# wandb 모듈 직접 사용
wandb.config["key"] = value
wandb.config.update()

그래디언트, 파라미터 히스토그램 및 모델 토폴로지 기록

훈련 중 모델의 그래디언트와 파라미터를 모니터링하려면 모델 객체를 wandblogger.watch()에 전달하면 됩니다. 자세한 내용은 PyTorch Lightning WandbLogger 문서를 참고하세요.

메트릭 로깅하기

LightningModule 안에서, 예를 들어 training_step 또는 validation_step 메서드에서 self.log('my_metric_name', metric_vale)를 호출하여 WandbLogger를 사용할 때 메트릭을 W&B에 로깅할 수 있습니다.아래 코드 스니펫은 메트릭과 LightningModule 하이퍼파라미터를 로깅할 수 있도록 LightningModule을 정의하는 방법을 보여줍니다. 이 예제에서는 메트릭을 계산하기 위해 torchmetrics 라이브러리를 사용합니다.
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from lightning.pytorch import LightningModule


class My_LitModule(LightningModule):
    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        """모델 파라미터를 정의하는 메서드"""
        super().__init__()

        # MNIST 이미지는 (1, 28, 28) (channels, width, height)
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        self.loss = CrossEntropyLoss()
        self.lr = lr

        # 하이퍼파라미터를 self.hparams에 저장 (W&B에서 자동 로깅됨)
        self.save_hyperparameters()

    def forward(self, x):
        """추론에 사용되는 메서드 input -> output"""

        # (b, 1, 28, 28) -> (b, 1*28*28)
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        # 3번의 (linear + relu)를 수행
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = self.layer_3(x)
        return x

    def training_step(self, batch, batch_idx):
        """단일 배치에서 loss를 반환해야 함"""
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # loss와 메트릭 로깅
        self.log("train_loss", loss)
        self.log("train_accuracy", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        """메트릭 로깅에 사용"""
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # loss와 메트릭 로깅
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

    def configure_optimizers(self):
        """모델 옵티마이저 정의"""
        return Adam(self.parameters(), lr=self.lr)

    def _get_preds_loss_accuracy(self, batch):
        """train/valid/test 단계가 유사하므로 편의를 위한 함수"""
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y)
        return preds, loss, acc

메트릭의 최소/최대값 로깅하기

wandb의 define_metric 함수를 사용하면 W&B Summary 메트릭에 해당 메트릭의 최소값, 최대값, 평균값 또는 최적값 중 어떤 값을 표시할지 정의할 수 있습니다. define_metric _ 을 사용하지 않으면 마지막으로 로깅된 값이 Summary 메트릭에 표시됩니다. 자세한 내용은 define_metric 레퍼런스 문서가이드를 참고하세요. W&B Summary 메트릭에서 검증 정확도(val_accuracy)의 최대값을 추적하려면, 학습 시작 시점에 단 한 번만 wandb.define_metric을 호출하세요:
class My_LitModule(LightningModule):
    ...

    def validation_step(self, batch, batch_idx):
        if trainer.global_step == 0:
            wandb.define_metric("val_accuracy", summary="max")

        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # 손실과 메트릭 로깅
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

모델 체크포인트 저장

모델 체크포인트를 W&B 아티팩트로 저장하려면, Lightning ModelCheckpoint 콜백을 사용한 다음 WandbLogger에서 log_model 인수를 설정하세요.
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
latestbest 별칭은 W&B 아티팩트에서 모델 체크포인트를 쉽게 가져올 수 있도록 자동으로 설정됩니다.
# 아티팩트 패널에서 참조를 가져올 수 있습니다
# "VERSION"은 버전(예: "v2") 또는 별칭("latest" 또는 "best")일 수 있습니다
checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION"
# 체크포인트를 로컬에 다운로드합니다 (이미 캐시되어 있지 않은 경우)
wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")
# 체크포인트를 로드합니다
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
로그한 모델 체크포인트는 W&B 아티팩트 UI에서 확인할 수 있으며, 전체 모델 계보 정보가 포함됩니다 (UI에서 모델 체크포인트 예시는 여기에서 볼 수 있습니다). 최고의 모델 체크포인트를 북마크하고 팀 전체에서 한 곳에 모아 관리하려면, 이를 W&B Model Registry에 연결하면 됩니다. 여기에서 작업별로 최상의 모델을 구성하고, 모델 수명주기를 관리하며, ML 수명주기 전반에 걸친 손쉬운 추적과 감사를 지원하고, 웹훅 또는 잡을 사용해 하위 단계 작업을 자동화할 수 있습니다.

이미지, 텍스트 등을 로깅하기

WandbLogger에는 미디어를 로깅하기 위한 log_image, log_text, log_table 메서드가 있습니다. wandb.log 또는 trainer.logger.experiment.log를 직접 호출해서 오디오, 분자 구조(Molecules), 포인트 클라우드(Point Clouds), 3D 객체(3D Objects) 등의 다른 미디어 유형도 로깅할 수 있습니다.
# tensors, numpy 배열 또는 PIL 이미지를 사용
wandb_logger.log_image(key="samples", images=[img1, img2])

# 캡션 추가
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# 파일 경로 사용
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

# trainer에서 .log 사용
trainer.logger.experiment.log(
    {"samples": [wandb.Image(img, caption=caption) for (img, caption) in my_images]},
    step=current_trainer_global_step,
)
Lightning의 Callback 시스템을 사용해 WandbLogger를 통해 W&B에 로그를 남기는 시점을 제어할 수 있습니다. 이 예시에서는 검증 이미지와 예측 결과 중 일부 샘플을 로깅합니다.
import torch
import wandb
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger

# or
# from wandb.integration.lightning.fabric import WandbLogger


class LogPredictionSamplesCallback(Callback):
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
        """검증 배치가 끝날 때 호출됩니다."""

        # `outputs`는 `LightningModule.validation_step`에서 반환되며
        # 이 경우 모델의 예측값에 해당합니다

        # 첫 번째 배치에서 샘플 이미지 예측값 20개를 로깅합니다
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [
                f"Ground Truth: {y_i} - Prediction: {y_pred}"
                for y_i, y_pred in zip(y[:n], outputs[:n])
            ]

            # 옵션 1: `WandbLogger.log_image`로 이미지 로깅
            wandb_logger.log_image(key="sample_images", images=images, caption=captions)

            # 옵션 2: 이미지와 예측값을 W&B Table로 로깅
            columns = ["image", "ground truth", "prediction"]
            data = [
                [wandb.Image(x_i), y_i, y_pred] or x_i,
                y_i,
                y_pred in list(zip(x[:n], y[:n], outputs[:n])),
            ]
            wandb_logger.log_table(key="sample_table", columns=columns, data=data)


trainer = pl.Trainer(callbacks=[LogPredictionSamplesCallback()])

Lightning과 W&B로 여러 GPU 사용하기

PyTorch Lightning은 DDP 인터페이스를 통해 멀티 GPU를 지원합니다. 하지만 PyTorch Lightning의 설계 특성상 GPU를 초기화하는 방식에 주의를 기울여야 합니다. Lightning은 학습 루프에서 각 GPU(또는 rank)가 정확히 동일한 방식, 즉 동일한 초기 조건으로 초기화된다고 가정합니다. 그러나 rank 0 프로세스만 wandb.run 객체에 접근할 수 있고, 0이 아닌 rank 프로세스에서는 wandb.run = None입니다. 이로 인해 rank가 0이 아닌 프로세스가 실패할 수 있습니다. 이런 상황에서는 rank 0 프로세스가 이미 크래시된 rank 0이 아닌 프로세스가 조인하기를 계속 기다리게 되어 **데드락(교착 상태)**에 빠질 수 있습니다. 이러한 이유로, 학습 코드를 설정하는 방식에 특히 주의해야 합니다. 권장되는 방법은 코드가 wandb.run 객체에 의존하지 않도록 구성하는 것입니다.
class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("train/loss", loss)
        return {"train_loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("val/loss", loss)
        return {"val_loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def main():
    # 모든 랜덤 시드를 동일한 값으로 설정합니다.
    # 분산 학습 환경에서 중요한 설정입니다.
    # 각 rank는 고유한 초기 가중치를 갖게 됩니다.
    # 초기 가중치가 일치하지 않으면 그래디언트도 일치하지 않아
    # 학습이 수렴하지 않을 수 있습니다.
    pl.seed_everything(1)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

    model = MNISTClassifier()
    wandb_logger = WandbLogger(project="<project_name>")
    callbacks = [
        ModelCheckpoint(
            dirpath="checkpoints",
            every_n_train_steps=100,
        ),
    ]
    trainer = pl.Trainer(
        max_epochs=3, gpus=2, logger=wandb_logger, strategy="ddp", callbacks=callbacks
    )
    trainer.fit(model, train_loader, val_loader)

예시

Colab 노트북이 포함된 비디오 튜토리얼을 보면서 함께 따라 할 수 있습니다.

자주 묻는 질문

W&B는 Lightning과 어떻게 통합되나요?

핵심 통합 방식은 Lightning loggers API를 기반으로 하며, 이를 통해 프레임워크에 구애받지 않는 방식으로 대부분의 로깅 코드를 작성할 수 있습니다. Logger 인스턴스는 Lightning Trainer에 전달되며, 해당 API의 강력한 훅(hook) 및 콜백(callback) 시스템을 기반으로 동작합니다. 이를 통해 연구 코드와 엔지니어링/로깅 코드를 명확하게 분리할 수 있습니다.

별도의 코드를 작성하지 않아도 이 통합은 무엇을 기록하나요?

모델 체크포인트를 W&B에 저장하므로, 이후 실행에서 사용할 수 있도록 확인하거나 다운로드할 수 있습니다. 또한 GPU 사용량과 네트워크 I/O 같은 system metrics, 하드웨어 및 OS 정보와 같은 환경 정보, git 커밋과 diff 패치, 노트북 내용 및 세션 이력을 포함한 code state, 그리고 표준 출력(stdout)에 출력되는 모든 내용을 수집합니다.

학습 설정에서 wandb.run을 사용해야 하는 경우에는 어떻게 하나요?

직접 접근해야 하는 변수의 범위를 더 넓게 지정해야 합니다. 다시 말해, 모든 프로세스에서 초기 조건이 동일하도록 설정해야 합니다.
if os.environ.get("LOCAL_RANK", None) is None:
    os.environ["WANDB_DIR"] = wandb.run.dir
그렇다면 os.environ["WANDB_DIR"]를 사용해서 모델 체크포인트 디렉터리를 설정할 수 있습니다. 이렇게 하면 0이 아닌 rank의 프로세스도 wandb.run.dir에 접근할 수 있습니다.