RL Agent Example¶
This example shows a reinforcement learning agent configuration with multiple config groups.
Based on the test case test_proto_cli.py::test_rl_agent_help:
from params_proto import proto
@proto.prefix
class Environment:
"""dm_control environment configuration."""
domain: str = "cartpole" # Domain name (e.g., cartpole, walker)
task: str = "swingup" # Task name within the domain
time_limit: float = 10.0 # Episode time limit in seconds
@proto.prefix
class Agent:
"""SAC agent hyperparameters."""
algorithm: str = "SAC" # RL algorithm (SAC or PPO)
buffer_size: int = 1000000 # Replay buffer capacity
gamma: float = 0.99 # Discount factor
tau: float = 0.005 # Target network update rate
@proto.cli
def train_rl(
total_steps: int = 1000000, # Total environment steps
eval_freq: int = 5000, # Evaluation frequency
seed: int = 0, # Random seed
):
"""Train RL agent on dm_control environment."""
import numpy as np
print(f"Training {Agent.algorithm} on {Environment.domain}-{Environment.task}")
print(f"Environment settings:")
print(f" Domain: {Environment.domain}")
print(f" Task: {Environment.task}")
print(f" Time limit: {Environment.time_limit}s")
print(f"\nAgent settings:")
print(f" Algorithm: {Agent.algorithm}")
print(f" Buffer size: {Agent.buffer_size}")
print(f" Gamma: {Agent.gamma}")
print(f" Tau: {Agent.tau}")
print(f"\nTraining settings:")
print(f" Total steps: {total_steps}")
print(f" Eval frequency: {eval_freq}")
print(f" Seed: {seed}")
# Set random seed
np.random.seed(seed)
# Training loop (simplified)
for step in range(0, total_steps, eval_freq):
print(f"\nStep {step}/{total_steps}")
# ... training code here ...
if step % eval_freq == 0:
print(f" Evaluating at step {step}")
# ... evaluation code here ...
print("\nTraining complete!")
if __name__ == "__main__":
train_rl()
CLI Help Output¶
Running python train_rl.py --help shows:
usage: train_rl.py [-h] [--total-steps INT] [--eval-freq INT] [--seed INT] [OPTIONS] Train RL agent on dm_control environment. options: -h, --help show this help message and exit --total-steps INT Total environment steps (default: 1000000) --eval-freq INT Evaluation frequency (default: 5000) --seed INT Random seed (default: 0) Environment options: dm_control environment configuration. --Environment.domain STR Domain name (e.g., cartpole, walker) (default: cartpole) --Environment.task STR Task name within the domain (default: swingup) --Environment.time-limit FLOAT Episode time limit in seconds (default: 10.0) Agent options: SAC agent hyperparameters. --Agent.algorithm STR RL algorithm (SAC or PPO) (default: SAC) --Agent.buffer-size INT Replay buffer capacity (default: 1000000) --Agent.gamma FLOAT Discount factor (default: 0.99) --Agent.tau FLOAT Target network update rate (default: 0.005)
Usage Examples¶
Run with defaults:
python train_rl.py
Change environment:
python train_rl.py --Environment.domain walker --Environment.task walk
Adjust agent hyperparameters:
python train_rl.py --Agent.gamma 0.95 --Agent.tau 0.01
Full customization:
python train_rl.py \
--seed 123 \
--total-steps 500000 \
--Environment.domain walker \
--Environment.task walk \
--Agent.algorithm PPO \
--Agent.buffer-size 500000
Key Features Demonstrated¶
Modular Configuration: Separate config groups for Environment and Agent
Singleton Pattern:
@proto.prefixcreates global singletonsAutomatic Grouping: CLI help automatically groups options by prefix
Inline Documentation: Comments become help text
Type Safety: Full type annotations for all parameters
Next Steps¶
See Basic Usage for simpler examples
Check ML Training for another complete example
Read Prefixes Guide for more on
@proto.prefix