xpag.agents.rljax_agents.rljax_interface.RljaxSAC#

class RljaxSAC(observation_dim, action_dim, params=None)#

Bases: Agent

Interface to the SAC agent from RLJAX (https://github.com/toshikwa/rljax)

Methods:

  • value() - computes Q-values given a batch of observations and a batch of

    actions.

  • select_action() - selects actions given a batch of observations ; there are

    two modes: one that includes stochasticity for exploration (eval_mode==False), and one that deterministically returns the best possible action (eval_mode==True).

  • train_on_batch() - trains the agent on a batch of transitions (one gradient

    step).

  • save() - saves the agent to the disk.

  • load() - loads a saved agent.

  • write_config() - writes the configuration of the agent (mainly its

    non-default parameters) in a file.

Attributes:

  • _config_string - the configuration of the agent (mainly its non-default

    parameters)

  • sac_params - the SAC parameters in a dict :

    “actor_lr” (default=3e-3): the actor learning rate “critic_lr” (default=3e-3): the critic learning rate “temp_lr” (default=3e-3): the temperature learning rate “discount” (default=0.99): the discount factor “hidden_dims” (default=(256,256)): the hidden layer dimensions for the actor and critic networks “init_temperature” (default=1.): the initial temperature “target_update_period” (default=1): defines how often a soft update of the target critic is performed “tau” (default=5e-2): the soft update coefficient

  • sac - the SAC algorithm as implemented in the RLJAX library

Methods

load

save

select_action

train_on_batch

value

write_config