xpag.agents.rljax_agents.network.base.MLP#
- class MLP(output_dim, hidden_units, hidden_activation=<jax._src.custom_derivatives.custom_jvp object>, output_activation=None, hidden_scale=1.0, output_scale=1.0, d2rl=False)#
Bases:
Module
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters:
name – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.
Methods
Returns parameters keyed by name for this module and submodules.
Returns state keyed by name for this module and submodules.
- __call__(x)#
Call self as a function.
- params_dict()#
Returns parameters keyed by name for this module and submodules.
- Return type:
Mapping
[str
,Array
]
- state_dict()#
Returns state keyed by name for this module and submodules.
- Return type:
Mapping
[str
,Array
]