Documentation Index
Fetch the complete documentation index at: https://wb-21fd5541-update-training-api-26.mintlify.app/llms.txt
Use this file to discover all available pages before exploring further.
Stable Baselines 3 (SB3)는 PyTorch 기반의 신뢰할 수 있는 강화학습 알고리즘 구현체 모음입니다. W&B의 SB3 인테그레이션은 다음과 같은 기능을 제공합니다:
- 손실(loss) 및 에피소드 리턴(episodic returns)과 같은 메트릭 기록.
- 게임을 플레이하는 에이전트의 비디오 업로드.
- 트레이닝된 모델 저장.
- 모델의 하이퍼파라미터 로그 기록.
- 모델 그레이디언트 히스토그램 로그 기록.
SB3 트레이닝 run 예시를 확인해 보세요.
SB3 Experiments 로그 기록하기
from wandb.integration.sb3 import WandbCallback
model.learn(..., callback=WandbCallback())
WandbCallback 인수
| 인수 | 용도 |
|---|
verbose | sb3 출력의 상세 수준 |
model_save_path | 모델이 저장될 폴더 경로. 기본값은 `None`이며 모델이 로그에 기록되지 않음 |
model_save_freq | 모델 저장 주기 |
gradient_save_freq | 그레이디언트 로그 기록 주기. 기본값은 0이며 그레이디언트가 로그에 기록되지 않음 |
기본 예제
W&B SB3 인테그레이션은 TensorBoard의 로그 출력을 사용하여 메트릭을 로그로 기록합니다.
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb
from wandb.integration.sb3 import WandbCallback
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_name": "CartPole-v1",
}
run = wandb.init(
project="sb3",
config=config,
sync_tensorboard=True, # sb3의 tensorboard 메트릭 자동 업로드
monitor_gym=True, # 게임을 플레이하는 에이전트의 비디오 자동 업로드
save_code=True, # 선택 사항
)
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # 리턴과 같은 통계 기록
return env
env = DummyVecEnv([make_env])
env = VecVideoRecorder(
env,
f"videos/{run.id}",
record_video_trigger=lambda x: x % 2000 == 0,
video_length=200,
)
model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
gradient_save_freq=100,
model_save_path=f"models/{run.id}",
verbose=2,
),
)
run.finish()