enn_trainer
Functions
train()
: Train an agent in an entity-gym environment using proximal policy optimization.load_checkpoint()
: Loads a training checkpoint from a given path.load_agent()
: Loads a training checkpoint from a given path and returns the agent.init_train_state()
: Creates the initial state for training, given a config and context.
- enn_trainer.train(state_manager: StateManager[TrainConfig, State], env: Union[Type[Environment], Callable[[EnvConfig, int, int, int], VecEnv]], create_opponent: Optional[Callable[[str, ObsSpace, Mapping[str, Union[CategoricalActionSpace, SelectEntityActionSpace, GlobalCategoricalActionSpace]], device], PPOAgent]] = None, agent: Optional[PPOAgent] = None) float
Train an agent in an entity-gym environment using proximal policy optimization.
- Parameters
state_manager – The hyperstate StateManager encapsulates the configuration and mutable state of the training run.
env – The class of the entity-gym environment to train on.
create_opponent – A function that creates a new opponent agent used for evaluating the agent.
agent – Custom policy network to use.
- enn_trainer.load_checkpoint(path: str) StateManager[TrainConfig, State]
Loads a training checkpoint from a given path.
The returned StateManager has a
config
attribute that contains the configuration of the training run and astate
attribute that contains the state of the training run, including the agent.
- enn_trainer.load_agent(path: str) RogueNetAgent
Loads a training checkpoint from a given path and returns the agent.
Classes
TrainConfig
: Training settings.OptimizerConfig
: Optimizer settings.PPOConfig
: Proximal Policy Optimization settings.RolloutConfig
: Settings for rollout phase of PPO.RogueNetConfig
: RogueNet network parameters.RelposEncodingConfig
: Settings for relative position encoding.EvalConfig
: Evaluation settings.EnvConfig
: Environment settings.State
: Mutable state of training run.RogueNetAgent
: Wraps a rogue_net network, exposing an entity_gym Agent interface.
- class enn_trainer.TrainConfig(env: ~enn_trainer.config.EnvConfig, net: ~rogue_net.rogue_net.RogueNetConfig, optim: ~enn_trainer.config.OptimizerConfig, ppo: ~enn_trainer.config.PPOConfig, rollout: ~enn_trainer.config.RolloutConfig, eval: ~typing.Optional[~enn_trainer.config.EvalConfig] = None, vf_net: ~typing.Optional[~rogue_net.rogue_net.RogueNetConfig] = None, name: str = <factory>, seed: int = 1, total_timesteps: int = 1000000, max_train_time: ~typing.Optional[int] = None, torch_deterministic: bool = True, cuda: bool = True, track: bool = False, wandb_project_name: str = 'enn-ppo', wandb_entity: str = 'entity-neural-network', capture_samples: ~typing.Optional[str] = None, capture_logits: bool = False, capture_samples_subsample: int = 1, trial: ~typing.Optional[int] = None, data_dir: str = '.', cuda_empty_cache: bool = False)
Training settings.
- Parameters
env – Settings for the environment.
net – Hyperparameters for policy network.
optim – Hyperparameters for optimizer.
ppo – Hyperparameters for PPO.
rollout – Hyperparameters for rollout phase.
eval – Optional evaluation settings.
vf_net – Hyperparameters for value function network (if not set, policy and value function share the same network).
name – The name of the experiment.
seed – Seed of the experiment.
total_timesteps – Total number of timesteps to run for.
max_train_time – Train for at most this many seconds.
torch_deterministic – Sets the value of
torch.backends.cudnn.deterministic
.cuda – If
True
, cuda will be enabled by default.track – Track experiment metrics with Weights and Biases.
wandb_project_name – Name of the W&B project to log metrics to.
wandb_entity – The entity (team) of the W&B project to log metrics to.
capture_samples – Write all samples collected from environments during training to this file.
capture_logits – Record full logits of the agent (requires
capture_samples
).capture_samples_subsample – Only persist every nth sample, chosen randomly (requires
capture_samples
).data_dir – Directory to save output from training and logging.
cuda_empty_cache – Empty the torch cuda cache after each optimizer step.
Inheritance
- class enn_trainer.OptimizerConfig(lr: float = 0.001, lr_warmup_steps: Optional[int] = None, bs: int = 1024, weight_decay: float = 0.0, micro_bs: Optional[int] = None, anneal_lr: bool = True, update_epochs: int = 3, max_grad_norm: float = 2.0)
Optimizer settings.
- Parameters
lr – Adam learning rate.
bs – Batch size.
micro_bs – Micro batch size size used for gradient accumulation. Using a lower micro batch size reduces memory usage and performance without affecting training dyanmics.
weight_decay – Adam weight decay.
anneal_lr – Linearly anneal learning rate from initial learning rate to 0.
update_epochs – Number of optimizer passes over each batch of rollout samples.
max_grad_norm – Gradient norm clipping.
Inheritance
- class enn_trainer.PPOConfig(gae: bool = True, gamma: float = 0.99, gae_lambda: float = 0.95, norm_adv: bool = True, clip_coef: float = 0.2, clip_vloss: bool = True, ent_coef: float = 0.1, vf_coef: float = 0.5, target_kl: Optional[float] = None, anneal_entropy: bool = True)
Proximal Policy Optimization settings.
- Parameters
gae – Whether to use generalized advantage estimation for advantage computation.
gamma – Temporal discount factor gamma.
gae_lambda – The lambda for the generalized advantage estimation.
norm_adv – Normalize advantages to 0 mean and 1 std.
clip_coef – The PPO surrogate clipping coefficient.
clip_vloss – Whether to use a clipped loss for the value function.
ent_coef – Coefficient for entropy loss term.
vf_coef – Coefficient for value function loss term.
target_kl – Stop optimization if the KL divergence between the old and new policy exceeds this threshold.
anneal_entropy – Linearly anneal the entropy coefficient from its initial value to 0.
Inheritance
- class enn_trainer.RolloutConfig(steps: int = 16, num_envs: int = 128, processes: int = 4)
Settings for rollout phase of PPO.
- Parameters
steps – The number of steps to run in each environment per policy rollout.
num_envs – The number of parallel game environments.
processes – The number of processes to use to collect env data. The envs are split as equally as possible across the processes.
Inheritance
- class enn_trainer.RogueNetConfig(embd_pdrop: float = 0.0, resid_pdrop: float = 0.0, attn_pdrop: float = 0.0, n_layer: int = 2, n_head: int = 2, d_model: int = 32, pooling: Optional[Literal['mean', 'max', 'meanmax']] = None, relpos_encoding: Optional[RelposEncodingConfig] = None, d_qk: int = 16, translation: Optional[TranslationConfig] = None)
RogueNet network parameters.
- Parameters
embd_pdrop – Dropout probability for embedding layer.
resid_pdrop – Dropout probability for residual branches.
attn_pdrop – Dropout probability for attention.
n_layer – Number of transformer layers.
n_head – Number of attention heads.
d_model – Dimension of embedding.
pooling – Replace attention layer with
"mean"
,"max"
, or"meanmax"
pooling.relpos_encoding – Relative positional encoding settings.
d_qk – Dimension of keys and queries in select-entity action heads.
translation – Settings for transforming all position features to be centered on one entity..
Inheritance
- class enn_trainer.RelposEncodingConfig(extent: ~typing.List[int], position_features: ~typing.List[str], scale: float = 1.0, per_entity_values: bool = False, exclude_entities: ~typing.List[str] = <factory>, value_relpos_projection: bool = False, key_relpos_projection: bool = False, per_entity_projections: bool = False, radial: bool = False, distance: bool = False, rotation_vec_features: ~typing.Optional[~typing.List[str]] = None, rotation_angle_feature: ~typing.Optional[str] = None, interpolate: bool = False, value_gate: ~typing.Literal['linear', 'relu', 'gelu', 'sigmoid', None] = 'relu', enable_negative_distance_weight_bug: bool = False)
Settings for relative position encoding.
- Parameters
extent – Each integer relative position in the interval [-extent, extent] receives a positional embedding, with positions outside the interval snapped to the closest end.
position_features – Names of position features used for relative position encoding.
scale – Relative positions are divided by the scale before being assigned an embedding.
per_entity_values – Whether to use per-entity embeddings for relative positional values.
exclude_entities – List of entity types to exclude from relative position encoding.
key_relpos_projection – Adds a learnable projection from the relative position/distance to the relative positional keys.
value_relpos_projection – Adds a learnable projection from the relative position/distance to the relative positional values.
per_entity_projections – Uses a different learned projection per entity type for the key_relpos_projection and value_relpos_projection.
radial – Buckets all relative positions by their angle. The extent is interpreted as the number of buckets.
distance – Buckets all relative positions by their distance. The extent is interpreted as the number of buckets.
rotation_vec_features – Name of features that give a unit orientation vector for each entity by which to rotate relative positions.
rotation_angle_feature – Name of feature that gives an angle in radians by which to rotate relative positions.
interpolate – Whether to interpolate between the embeddings of neighboring positions.
Inheritance
- class enn_trainer.EvalConfig(steps: int, interval: int, num_envs: Optional[int] = None, processes: Optional[int] = None, env: Optional[EnvConfig] = None, capture_videos: bool = False, capture_samples: str = '', capture_logits: bool = True, capture_samples_subsample: int = 1, run_on_first_step: bool = True, opponent: Optional[str] = None, opponent_only: bool = False)
Evaluation settings.
- Parameters
interval – Number of environment steps between evaluations.
capture_videos – Render videos of the environments during evaluation.
capture_samples – Write samples from evals to this file.
capture_logits – Record full logits of the agent during evaluation (requires
capture_samples
).capture_samples_subsample – Only persist every nth sample, chosen randomly (requires
capture_samples
).run_on_first_step – Whether to run an eval on step 0.
env – Settings for the eval environment. If not set, use same settings as rollouts.
num_envs – The number of parallel game environments to use for evaluation. If not set, use same settings as rollouts.
processes – The number of processes used to run the environment. If not set, use same settings as rollouts.
opponent – Path to opponent policy to evaluate against.
opponent_only – Don’t evaluate the policy, but instead run the opponent against itself.
Inheritance
- class enn_trainer.EnvConfig(kwargs: str = '{}', id: str = 'MoveToOrigin', validate: bool = True)
Environment settings.
- Parameters
kwargs – JSON dictionary with keyword arguments for the environment.
id – The id of the environment.
validate – Perform runtime checks to ensure that the environment correctly implements the interface.
Inheritance
- class enn_trainer.State(step: int, restart: int, next_eval_step: Optional[int], agent: SerializableRogueNet, value_function: Optional[SerializableRogueNet], optimizer: SerializableAdamW, vf_optimizer: Optional[SerializableAdamW], obs_space: ObsSpace, action_space: Dict[str, Union[CategoricalActionSpace, SelectEntityActionSpace, GlobalCategoricalActionSpace]])
Mutable state of training run.
- Parameters
step – The number of elapsed environment steps.
restart – The number of times the training has been restarted from a checkpoint.
agent – The policy network.
value_function – The value function, if separate from the policy network.
optimizer – AdamW optimizer for the policy network.
value_optimizer – AdamW optimizer for the value function, if separate from the policy network.
obs_space – The observation space of the environment.
action_space – The action space of the environment.
Inheritance
- class enn_trainer.RogueNetAgent(agent: RogueNet)
Wraps a rogue_net network, exposing an entity_gym Agent interface.
- Parameters
agent – The underlying RogueNet.
Inheritance