메인 콘텐츠로 건너뛰기
Colab에서 실행해 보기 torchtune은 대규모 언어 모델(LLM)의 작성, 미세 튜닝 및 실험 과정을 간소화하기 위해 설계된 PyTorch 기반 라이브러리입니다. 또한 torchtune에는 W&B로 로깅을 위한 기능이 내장되어 있어 학습 과정의 추적 및 시각화를 한층 더 강화할 수 있습니다.
TorchTune 학습 대시보드
torchtune을 사용한 Mistral 7B 미세 튜닝에 대한 W&B 블로그 게시물을 확인하세요.

손쉽게 W&B 로깅 사용하기

실행 시 명령줄 인수를 재정의하세요:
tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
  metric_logger._component_=torchtune.utils.metric_logging.WandBLogger \
  metric_logger.project="llama3_lora" \
  log_every_n_steps=5

W&B 메트릭 로거 사용하기

레시피의 config 파일에서 metric_logger 섹션을 수정해 W&B 로깅을 활성화합니다. _component_ 값을 torchtune.utils.metric_logging.WandBLogger 클래스로 변경하세요. 또한 로깅 동작을 세부 설정하기 위해 project 이름과 log_every_n_steps 값을 지정할 수 있습니다. wandb.init() 함수에 인수를 전달하는 것과 동일한 방식으로 다른 kwargs들도 넘길 수 있습니다. 예를 들어 팀 단위로 작업하는 경우, 팀 이름을 지정하기 위해 entity 인수를 WandBLogger 클래스에 전달할 수 있습니다.
# inside llama3/8B_lora_single_device.yaml
metric_logger:
  _component_: torchtune.utils.metric_logging.WandBLogger
  project: llama3_lora
  entity: my_project
  job_type: lora_finetune_single_device
  group: my_awesome_experiments
log_every_n_steps: 5

무엇이 기록되나요?

기록된 지표는 W&B 대시보드에서 확인할 수 있습니다. 기본적으로 W&B는 구성(config) 파일과 실행 시 덮어쓴 값(launch overrides)에 포함된 모든 하이퍼파라미터를 기록합니다. W&B는 최종 구성(resolved config)을 Overview 탭에 표시합니다. 또한 이 구성을 YAML 형식으로 Files 탭에 저장합니다.
TorchTune configuration

기록된 지표

각 recipe는 자체적인 학습 루프를 사용합니다. 각 recipe별로 기록되는 지표를 확인하세요. 기본적으로 다음 지표들이 포함됩니다:
MetricDescription
loss모델의 loss
lrlearning rate
tokens_per_second모델의 초당 토큰 수
grad_norm모델의 gradient norm
global_step학습 루프에서의 현재 step에 해당합니다. gradient accumulation을 고려하며, 기본적으로 optimizer step이 수행될 때마다 gradient가 누적되고, gradient_accumulation_steps에 한 번씩 모델이 업데이트됩니다.
global_step은 학습 step의 개수와 동일하지 않습니다. 이는 학습 루프에서의 현재 step에 해당합니다. gradient accumulation을 고려하며, 기본적으로 optimizer step이 수행될 때마다 global_step이 1씩 증가합니다. 예를 들어, dataloader에 batch가 10개 있고, gradient accumulation steps가 2이며 3 epoch 동안 실행하는 경우, optimizer는 15번 step을 수행하게 되며, 이때 global_step은 1에서 15까지의 값을 갖습니다.
torchtune의 간결한 설계 덕분에 커스텀 지표를 쉽게 추가하거나 기존 지표를 수정할 수 있습니다. 해당 recipe 파일만 수정하면 됩니다. 예를 들어, 전체 epoch 수에서의 진행 비율로 current_epoch를 계산해 다음과 같이 기록할 수 있습니다:
# 레시피 파일의 `train.py` 함수 내부
self._metric_logger.log_dict(
    {"current_epoch": self.epochs * self.global_step / self._steps_per_epoch},
    step=self.global_step,
)
이 라이브러리는 빠르게 발전하고 있으며, 현재 제공되는 메트릭은 변경될 수 있습니다. 사용자 정의 메트릭을 추가하려면 레시피를 수정하고 해당 self._metric_logger.* 함수를 호출해야 합니다.

체크포인트 저장 및 로드

torchtune 라이브러리는 다양한 체크포인트 형식을 지원합니다. 사용하는 모델의 출처에 따라 적절한 checkpointer 클래스로 전환해야 합니다. 모델 체크포인트를 W&B 아티팩트에 저장하려는 경우, 가장 간단한 방법은 해당 레시피 내의 save_checkpoint 함수를 오버라이드하는 것입니다. 다음은 save_checkpoint 함수를 오버라이드하여 모델 체크포인트를 W&B 아티팩트에 저장하는 예시입니다.
def save_checkpoint(self, epoch: int) -> None:
    ...
    ## W&B에 체크포인트를 저장합니다
    ## Checkpointer 클래스에 따라 파일 이름이 달라집니다
    ## full_finetune의 경우에 대한 예시입니다
    checkpoint_file = Path.joinpath(
        self._checkpointer._output_dir, f"torchtune_model_{epoch}"
    ).with_suffix(".pt")
    wandb_artifact = wandb.Artifact(
        name=f"torchtune_model_{epoch}",
        type="model",
        # 모델 체크포인트 설명
        description="Model checkpoint",
        # dict 형태로 원하는 메타데이터를 추가할 수 있습니다
        metadata={
            utils.SEED_KEY: self.seed,
            utils.EPOCHS_KEY: self.epochs_run,
            utils.TOTAL_EPOCHS_KEY: self.total_epochs,
            utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
        },
    )
    wandb_artifact.add_file(checkpoint_file)
    wandb.log_artifact(wandb_artifact)