enn_trainer

Functions

  • train(): Train an agent in an entity-gym environment using proximal policy optimization.

  • load_checkpoint(): Loads a training checkpoint from a given path.

  • load_agent(): Loads a training checkpoint from a given path and returns the agent.

  • init_train_state(): Creates the initial state for training, given a config and context.

enn_trainer.train(state_manager: StateManager[TrainConfig, State], env: Union[Type[Environment], Callable[[EnvConfig, int, int, int], VecEnv]], create_opponent: Optional[Callable[[str, ObsSpace, Mapping[str, Union[CategoricalActionSpace, SelectEntityActionSpace, GlobalCategoricalActionSpace]], device], PPOAgent]] = None, agent: Optional[PPOAgent] = None) float

Train an agent in an entity-gym environment using proximal policy optimization.

Parameters
  • state_manager – The hyperstate StateManager encapsulates the configuration and mutable state of the training run.

  • env – The class of the entity-gym environment to train on.

  • create_opponent – A function that creates a new opponent agent used for evaluating the agent.

  • agent – Custom policy network to use.

enn_trainer.load_checkpoint(path: str) StateManager[TrainConfig, State]

Loads a training checkpoint from a given path.

The returned StateManager has a config attribute that contains the configuration of the training run and a state attribute that contains the state of the training run, including the agent.

enn_trainer.load_agent(path: str) RogueNetAgent

Loads a training checkpoint from a given path and returns the agent.

enn_trainer.init_train_state(cfg: TrainConfig, ctx: Dict[str, Any]) State

Creates the initial state for training, given a config and context.

Classes

class enn_trainer.TrainConfig(env: ~enn_trainer.config.EnvConfig, net: ~rogue_net.rogue_net.RogueNetConfig, optim: ~enn_trainer.config.OptimizerConfig, ppo: ~enn_trainer.config.PPOConfig, rollout: ~enn_trainer.config.RolloutConfig, eval: ~typing.Optional[~enn_trainer.config.EvalConfig] = None, vf_net: ~typing.Optional[~rogue_net.rogue_net.RogueNetConfig] = None, name: str = <factory>, seed: int = 1, total_timesteps: int = 1000000, max_train_time: ~typing.Optional[int] = None, torch_deterministic: bool = True, cuda: bool = True, track: bool = False, wandb_project_name: str = 'enn-ppo', wandb_entity: str = 'entity-neural-network', capture_samples: ~typing.Optional[str] = None, capture_logits: bool = False, capture_samples_subsample: int = 1, trial: ~typing.Optional[int] = None, data_dir: str = '.', cuda_empty_cache: bool = False)

Training settings.

Parameters
  • env – Settings for the environment.

  • net – Hyperparameters for policy network.

  • optim – Hyperparameters for optimizer.

  • ppo – Hyperparameters for PPO.

  • rollout – Hyperparameters for rollout phase.

  • eval – Optional evaluation settings.

  • vf_net – Hyperparameters for value function network (if not set, policy and value function share the same network).

  • name – The name of the experiment.

  • seed – Seed of the experiment.

  • total_timesteps – Total number of timesteps to run for.

  • max_train_time – Train for at most this many seconds.

  • torch_deterministic – Sets the value of torch.backends.cudnn.deterministic.

  • cuda – If True, cuda will be enabled by default.

  • track – Track experiment metrics with Weights and Biases.

  • wandb_project_name – Name of the W&B project to log metrics to.

  • wandb_entity – The entity (team) of the W&B project to log metrics to.

  • capture_samples – Write all samples collected from environments during training to this file.

  • capture_logits – Record full logits of the agent (requires capture_samples).

  • capture_samples_subsample – Only persist every nth sample, chosen randomly (requires capture_samples).

  • data_dir – Directory to save output from training and logging.

  • cuda_empty_cache – Empty the torch cuda cache after each optimizer step.

Inheritance

Inheritance diagram of TrainConfig
classmethod upgrade_rules() Dict[int, List[RewriteRule]]

Returns a list of rewrite rules that can be applied to the given version to make it compatible with the next version.

class enn_trainer.OptimizerConfig(lr: float = 0.001, lr_warmup_steps: Optional[int] = None, bs: int = 1024, weight_decay: float = 0.0, micro_bs: Optional[int] = None, anneal_lr: bool = True, update_epochs: int = 3, max_grad_norm: float = 2.0)

Optimizer settings.

Parameters
  • lr – Adam learning rate.

  • bs – Batch size.

  • micro_bs – Micro batch size size used for gradient accumulation. Using a lower micro batch size reduces memory usage and performance without affecting training dyanmics.

  • weight_decay – Adam weight decay.

  • anneal_lr – Linearly anneal learning rate from initial learning rate to 0.

  • update_epochs – Number of optimizer passes over each batch of rollout samples.

  • max_grad_norm – Gradient norm clipping.

Inheritance

Inheritance diagram of OptimizerConfig
class enn_trainer.PPOConfig(gae: bool = True, gamma: float = 0.99, gae_lambda: float = 0.95, norm_adv: bool = True, clip_coef: float = 0.2, clip_vloss: bool = True, ent_coef: float = 0.1, vf_coef: float = 0.5, target_kl: Optional[float] = None, anneal_entropy: bool = True)

Proximal Policy Optimization settings.

Parameters
  • gae – Whether to use generalized advantage estimation for advantage computation.

  • gamma – Temporal discount factor gamma.

  • gae_lambda – The lambda for the generalized advantage estimation.

  • norm_adv – Normalize advantages to 0 mean and 1 std.

  • clip_coef – The PPO surrogate clipping coefficient.

  • clip_vloss – Whether to use a clipped loss for the value function.

  • ent_coef – Coefficient for entropy loss term.

  • vf_coef – Coefficient for value function loss term.

  • target_kl – Stop optimization if the KL divergence between the old and new policy exceeds this threshold.

  • anneal_entropy – Linearly anneal the entropy coefficient from its initial value to 0.

Inheritance

Inheritance diagram of PPOConfig
class enn_trainer.RolloutConfig(steps: int = 16, num_envs: int = 128, processes: int = 4)

Settings for rollout phase of PPO.

Parameters
  • steps – The number of steps to run in each environment per policy rollout.

  • num_envs – The number of parallel game environments.

  • processes – The number of processes to use to collect env data. The envs are split as equally as possible across the processes.

Inheritance

Inheritance diagram of RolloutConfig
class enn_trainer.RogueNetConfig(embd_pdrop: float = 0.0, resid_pdrop: float = 0.0, attn_pdrop: float = 0.0, n_layer: int = 2, n_head: int = 2, d_model: int = 32, pooling: Optional[Literal['mean', 'max', 'meanmax']] = None, relpos_encoding: Optional[RelposEncodingConfig] = None, d_qk: int = 16, translation: Optional[TranslationConfig] = None)

RogueNet network parameters.

Parameters
  • embd_pdrop – Dropout probability for embedding layer.

  • resid_pdrop – Dropout probability for residual branches.

  • attn_pdrop – Dropout probability for attention.

  • n_layer – Number of transformer layers.

  • n_head – Number of attention heads.

  • d_model – Dimension of embedding.

  • pooling – Replace attention layer with "mean", "max", or "meanmax" pooling.

  • relpos_encoding – Relative positional encoding settings.

  • d_qk – Dimension of keys and queries in select-entity action heads.

  • translation – Settings for transforming all position features to be centered on one entity..

Inheritance

Inheritance diagram of RogueNetConfig
class enn_trainer.RelposEncodingConfig(extent: ~typing.List[int], position_features: ~typing.List[str], scale: float = 1.0, per_entity_values: bool = False, exclude_entities: ~typing.List[str] = <factory>, value_relpos_projection: bool = False, key_relpos_projection: bool = False, per_entity_projections: bool = False, radial: bool = False, distance: bool = False, rotation_vec_features: ~typing.Optional[~typing.List[str]] = None, rotation_angle_feature: ~typing.Optional[str] = None, interpolate: bool = False, value_gate: ~typing.Literal['linear', 'relu', 'gelu', 'sigmoid', None] = 'relu', enable_negative_distance_weight_bug: bool = False)

Settings for relative position encoding.

Parameters
  • extent – Each integer relative position in the interval [-extent, extent] receives a positional embedding, with positions outside the interval snapped to the closest end.

  • position_features – Names of position features used for relative position encoding.

  • scale – Relative positions are divided by the scale before being assigned an embedding.

  • per_entity_values – Whether to use per-entity embeddings for relative positional values.

  • exclude_entities – List of entity types to exclude from relative position encoding.

  • key_relpos_projection – Adds a learnable projection from the relative position/distance to the relative positional keys.

  • value_relpos_projection – Adds a learnable projection from the relative position/distance to the relative positional values.

  • per_entity_projections – Uses a different learned projection per entity type for the key_relpos_projection and value_relpos_projection.

  • radial – Buckets all relative positions by their angle. The extent is interpreted as the number of buckets.

  • distance – Buckets all relative positions by their distance. The extent is interpreted as the number of buckets.

  • rotation_vec_features – Name of features that give a unit orientation vector for each entity by which to rotate relative positions.

  • rotation_angle_feature – Name of feature that gives an angle in radians by which to rotate relative positions.

  • interpolate – Whether to interpolate between the embeddings of neighboring positions.

Inheritance

Inheritance diagram of RelposEncodingConfig
class enn_trainer.EvalConfig(steps: int, interval: int, num_envs: Optional[int] = None, processes: Optional[int] = None, env: Optional[EnvConfig] = None, capture_videos: bool = False, capture_samples: str = '', capture_logits: bool = True, capture_samples_subsample: int = 1, run_on_first_step: bool = True, opponent: Optional[str] = None, opponent_only: bool = False)

Evaluation settings.

Parameters
  • interval – Number of environment steps between evaluations.

  • capture_videos – Render videos of the environments during evaluation.

  • capture_samples – Write samples from evals to this file.

  • capture_logits – Record full logits of the agent during evaluation (requires capture_samples).

  • capture_samples_subsample – Only persist every nth sample, chosen randomly (requires capture_samples).

  • run_on_first_step – Whether to run an eval on step 0.

  • env – Settings for the eval environment. If not set, use same settings as rollouts.

  • num_envs – The number of parallel game environments to use for evaluation. If not set, use same settings as rollouts.

  • processes – The number of processes used to run the environment. If not set, use same settings as rollouts.

  • opponent – Path to opponent policy to evaluate against.

  • opponent_only – Don’t evaluate the policy, but instead run the opponent against itself.

Inheritance

Inheritance diagram of EvalConfig
class enn_trainer.EnvConfig(kwargs: str = '{}', id: str = 'MoveToOrigin', validate: bool = True)

Environment settings.

Parameters
  • kwargs – JSON dictionary with keyword arguments for the environment.

  • id – The id of the environment.

  • validate – Perform runtime checks to ensure that the environment correctly implements the interface.

Inheritance

Inheritance diagram of EnvConfig
class enn_trainer.State(step: int, restart: int, next_eval_step: Optional[int], agent: SerializableRogueNet, value_function: Optional[SerializableRogueNet], optimizer: SerializableAdamW, vf_optimizer: Optional[SerializableAdamW], obs_space: ObsSpace, action_space: Dict[str, Union[CategoricalActionSpace, SelectEntityActionSpace, GlobalCategoricalActionSpace]])

Mutable state of training run.

Parameters
  • step – The number of elapsed environment steps.

  • restart – The number of times the training has been restarted from a checkpoint.

  • agent – The policy network.

  • value_function – The value function, if separate from the policy network.

  • optimizer – AdamW optimizer for the policy network.

  • value_optimizer – AdamW optimizer for the value function, if separate from the policy network.

  • obs_space – The observation space of the environment.

  • action_space – The action space of the environment.

Inheritance

Inheritance diagram of State
class enn_trainer.RogueNetAgent(agent: RogueNet)

Wraps a rogue_net network, exposing an entity_gym Agent interface.

Parameters

agent – The underlying RogueNet.

Inheritance

Inheritance diagram of RogueNetAgent