xpag.buffers.jax_buffer.JaxBuffer#

class JaxBuffer(buffer_size, sampler)#

Bases: Buffer

Methods

init_rp_buffer

Init the replay buffer with a dummy step.

insert

Inserts a transition in the buffer

sample

Returns a batch of transitions

init_rp_buffer(dummy_step, rng)#

Init the replay buffer with a dummy step. (!! do not include the batch dimension !!)

insert(step_batch)#

Inserts a transition in the buffer

sample(batch_size)#

Returns a batch of transitions

Return type:

Dict[str, Array]