xpag.buffers.jax_buffer.JaxBuffer#
- class JaxBuffer(buffer_size, sampler)#
Bases:
Buffer
Methods
Init the replay buffer with a dummy step.
Inserts a transition in the buffer
Returns a batch of transitions
- init_rp_buffer(dummy_step, rng)#
Init the replay buffer with a dummy step. (!! do not include the batch dimension !!)
- insert(step_batch)#
Inserts a transition in the buffer
- sample(batch_size)#
Returns a batch of transitions
- Return type:
Dict
[str
,Array
]