xpag.agents.rljax_agents.rljax_interface.RljaxSAC#
- class RljaxSAC(observation_dim, action_dim, params=None)#
Bases:
Agent
Interface to the SAC agent from RLJAX (https://github.com/toshikwa/rljax)
Methods:
value()
- computes Q-values given a batch of observations and a batch ofactions.
select_action()
- selects actions given a batch of observations ; there aretwo modes: one that includes stochasticity for exploration (eval_mode==False), and one that deterministically returns the best possible action (eval_mode==True).
train_on_batch()
- trains the agent on a batch of transitions (one gradientstep).
save()
- saves the agent to the disk.load()
- loads a saved agent.write_config()
- writes the configuration of the agent (mainly itsnon-default parameters) in a file.
Attributes:
_config_string
- the configuration of the agent (mainly its non-defaultparameters)
sac_params
- the SAC parameters in a dict :“actor_lr” (default=3e-3): the actor learning rate “critic_lr” (default=3e-3): the critic learning rate “temp_lr” (default=3e-3): the temperature learning rate “discount” (default=0.99): the discount factor “hidden_dims” (default=(256,256)): the hidden layer dimensions for the actor and critic networks “init_temperature” (default=1.): the initial temperature “target_update_period” (default=1): defines how often a soft update of the target critic is performed “tau” (default=5e-2): the soft update coefficient
sac
- the SAC algorithm as implemented in the RLJAX library
Methods
load
save
select_action
train_on_batch
value
write_config