Platform
The platform configuration controls how hibayes uses your hardware for model fitting. It manages CPU/GPU device selection, memory allocation, and parallelisation of MCMC chains.
Configuration options
| Option | Type | Default | Description |
|---|---|---|---|
device_type |
string | "cpu" |
Device to use: "cpu" or "gpu" |
num_devices |
int | auto | Number of devices (auto-detects if not set) |
gpu_memory_fraction |
float | 0.9 |
Fraction of GPU memory to use (0.1-1.0) |
chain_method |
string | "parallel" |
How to run chains: "parallel", "sequential", or "vectorized" |
Configuring in YAML
Set platform options in your hibayes.yaml:
For GPU:
Chain methods
The chain_method option controls how MCMC chains are executed:
parallel(default): Runs chains in parallel across available devices. Best for multi-core CPUs or multi-GPU setups.sequential: Runs chains one after another. Use when memory is limited.vectorized: Vectorises chain computation. Can be faster on GPU but uses more memory.
Auto-detection
When num_devices is not specified, hibayes automatically detects available hardware:
- CPU: Uses
os.cpu_count()to detect available cores - GPU: Queries JAX for available GPU devices; falls back to CPU if none found
GPU memory management
For GPU users, hibayes configures JAX memory settings via environment variables:
XLA_PYTHON_CLIENT_PREALLOCATE=false: Disables memory preallocation for better memory managementXLA_PYTHON_CLIENT_MEM_FRACTION: Controls the fraction of GPU memory JAX can use
These are set automatically based on your gpu_memory_fraction setting.
Programmatic setup
The platform is configured automatically when using the CLI. For programmatic use:
- 1
- Call this once before running any models. It sets up JAX/NumPyro device configuration.
Troubleshooting
GPU not detected: Ensure JAX is installed with GPU support (jax[cuda]). hibayes will automatically fall back to CPU if no GPU is found.
Out of memory on GPU: Reduce gpu_memory_fraction or switch to chain_method: sequential.
Slow on CPU: Ensure num_devices matches your available cores, and use chain_method: parallel.