메인 콘텐츠로 건너뛰기
Stable Baselines 3 (SB3)는 PyTorch로 구현된 강화 학습 알고리즘들의 신뢰할 수 있는 구현 모음입니다. W&B의 SB3 통합 기능은 다음을 제공합니다.
  • loss와 episodic return 같은 지표를 기록합니다.
  • 에이전트가 게임을 플레이하는 비디오를 업로드합니다.
  • 학습된 모델을 저장합니다.
  • 모델의 하이퍼파라미터를 기록합니다.
  • 모델의 그래디언트 히스토그램을 기록합니다.
예시 SB3 학습 실행을 확인하세요.

SB3 실험 기록하기

from wandb.integration.sb3 import WandbCallback

model.learn(..., callback=WandbCallback())
W&B를 사용한 Stable Baselines 3 학습

WandbCallback 인자

ArgumentUsage
verbosesb3 출력의 상세 수준
model_save_path모델이 저장될 폴더 경로입니다. 기본값은 None이므로 모델은 로깅되지 않습니다
model_save_freq모델을 저장하는 주기입니다
gradient_save_freq그래디언트를 로깅하는 주기입니다. 기본값은 0이므로 그래디언트는 로깅되지 않습니다

기본 예시

W&B SB3 통합 기능은 TensorBoard의 로그 출력을 사용해 메트릭을 기록합니다.
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb
from wandb.integration.sb3 import WandbCallback


config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 25000,
    "env_name": "CartPole-v1",
}
run = wandb.init(
    project="sb3",
    config=config,
    sync_tensorboard=True,  # sb3의 tensorboard 메트릭 자동 업로드
    monitor_gym=True,  # 에이전트가 게임을 플레이하는 영상 자동 업로드
    save_code=True,  # 선택 사항
)


def make_env():
    env = gym.make(config["env_name"])
    env = Monitor(env)  # 리턴 등의 통계 기록
    return env


env = DummyVecEnv([make_env])
env = VecVideoRecorder(
    env,
    f"videos/{run.id}",
    record_video_trigger=lambda x: x % 2000 == 0,
    video_length=200,
)
model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        gradient_save_freq=100,
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)
run.finish()