Learner API#

Learner specifications and configurations#

FrameworkHyperparameters([eager_tracing, ...])

The framework specific hyper-parameters.

LearnerHyperparameters([learning_rate, ...])

Hyperparameters for a Learner, derived from a subset of AlgorithmConfig values.

TorchLearner configurations#

TorchCompileWhatToCompile(value)

Enumerates schemes of what parts of the TorchLearner can be compiled.

Constructor#

Learner(*[, module_spec, module, ...])

Base class for Learners.

Learner.build()

Builds the Learner.

Learner._check_is_built()

Learner._make_module()

Construct the multi-agent RL module for the learner.

Performing Updates#

Learner.update(batch, *[, minibatch_size, ...])

Do num_iters minibatch updates given the original batch.

Learner._update(batch, **kwargs)

Contains all logic for an in-graph/traceable update step.

Learner.additional_update(*[, ...])

Apply additional non-gradient based updates to this Algorithm.

Learner.additional_update_for_module(*, ...)

Apply additional non-gradient based updates for a single module.

Learner._convert_batch_type(batch)

Converts the elements of a MultiAgentBatch to Tensors on the correct device.

Computing Losses#

Learner.compute_loss(*, fwd_out, batch)

Computes the loss for the module being optimized.

Learner.compute_loss_for_module(*, ...)

Computes the loss for a single module.

Learner._is_module_compatible_with_learner(module)

Check whether the module is compatible with the learner.

Learner._get_tensor_variable(value[, dtype, ...])

Returns a framework-specific tensor variable with the initial given value.

Configuring Optimizers#

Learner.configure_optimizers_for_module(...)

Configures an optimizer for the given module_id.

Learner.configure_optimizers()

Configures, creates, and registers the optimizers for this Learner.

Learner.register_optimizer(*[, module_id, ...])

Registers an optimizer with a ModuleID, name, param list and lr-scheduler.

Learner.get_optimizers_for_module([module_id])

Returns a list of (optimizer_name, optimizer instance)-tuples for module_id.

Learner.get_optimizer([module_id, ...])

Returns the optimizer object, configured under the given module_id and name.

Learner.get_parameters(module)

Returns the list of parameters of a module.

Learner.get_param_ref(param)

Returns a hashable reference to a trainable parameter.

Learner.filter_param_dict_for_optimizer(...)

Reduces the given ParamDict to contain only parameters for given optimizer.

Learner._check_registered_optimizer(...)

Checks that the given optimizer and parameters are valid for the framework.

Learner._set_optimizer_lr(optimizer, lr)

Updates the learning rate of the given local optimizer.

Learner._get_clip_function()

Returns the gradient clipping function to use, given the framework.

Gradient Computation#

Learner.compute_gradients(loss_per_module, ...)

Computes the gradients based on the given losses.

Learner.postprocess_gradients(gradients_dict)

Applies potential postprocessing operations on the gradients.

Learner.postprocess_gradients_for_module(*, ...)

Applies postprocessing operations on the gradients of the given module.

Learner.apply_gradients(gradients_dict)

Applies the gradients to the MultiAgentRLModule parameters.

Saving, Loading, Checkpointing, and Restoring States#

Learner.save_state(path)

Save the state of the learner to path

Learner.load_state(path)

Load the state of the learner from path

Learner._save_optimizers(path)

Save the state of the optimizer to path

Learner._load_optimizers(path)

Load the state of the optimizer from path

Learner.get_state()

Get the state of the learner.

Learner.set_state(state)

Set the state of the learner.

Learner.get_optimizer_state()

Returns the state of all optimizers currently registered in this Learner.

Learner.set_optimizer_state(state)

Sets the state of all optimizers currently registered in this Learner.

Learner._get_metadata()

Adding and Removing Modules#

Learner.add_module(*, module_id, module_spec)

Add a module to the underlying MultiAgentRLModule and the Learner.

Learner.remove_module(module_id)

Remove a module from the Learner.

Managing Results#

Learner.compile_results(*, batch, fwd_out, ...)

Compile results from the update in a numpy-friendly format.

Learner.register_metric(module_id, key, value)

Registers a single key/value metric pair for loss- and gradient stats.

Learner.register_metrics(module_id, metrics_dict)

Registers several key/value metric pairs for loss- and gradient stats.

Learner._check_result(result)

Checks whether the result has the correct format.

LearnerGroup API#

Configuring a LearnerGroup#

LearnerSpec(learner_class, module_spec, ...)

The spec for constructing Learner actors.

LearnerGroup(learner_spec[, max_queue_len])

Coordinator of Learners.