xpag.agents.rljax_agents.network.base.MLP#

class MLP(output_dim, hidden_units, hidden_activation=<jax._src.custom_derivatives.custom_jvp object>, output_activation=None, hidden_scale=1.0, output_scale=1.0, d2rl=False)#

Bases: Module

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters:

name – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

Methods

params_dict

Returns parameters keyed by name for this module and submodules.

state_dict

Returns state keyed by name for this module and submodules.

__call__(x)#

Call self as a function.

params_dict()#

Returns parameters keyed by name for this module and submodules.

Return type:

Mapping[str, Array]

state_dict()#

Returns state keyed by name for this module and submodules.

Return type:

Mapping[str, Array]