Platform

Specify hardware and configuration for fitting.

The platform configuration controls how hibayes uses your hardware for model fitting. It manages CPU/GPU device selection, memory allocation, and parallelisation of MCMC chains.

Configuration options

Option Type Default Description
device_type string "cpu" Device to use: "cpu" or "gpu"
num_devices int auto Number of devices (auto-detects if not set)
gpu_memory_fraction float 0.9 Fraction of GPU memory to use (0.1-1.0)
chain_method string "parallel" How to run chains: "parallel", "sequential", or "vectorized"

Configuring in YAML

Set platform options in your hibayes.yaml:

platform:
  device_type: cpu
  num_devices: 4
  chain_method: parallel

For GPU:

platform:
  device_type: gpu
  gpu_memory_fraction: 0.8
  chain_method: parallel

Chain methods

The chain_method option controls how MCMC chains are executed:

  • parallel (default): Runs chains in parallel across available devices. Best for multi-core CPUs or multi-GPU setups.
  • sequential: Runs chains one after another. Use when memory is limited.
  • vectorized: Vectorises chain computation. Can be faster on GPU but uses more memory.

Auto-detection

When num_devices is not specified, hibayes automatically detects available hardware:

  • CPU: Uses os.cpu_count() to detect available cores
  • GPU: Queries JAX for available GPU devices; falls back to CPU if none found
from hibayes.platform import PlatformConfig

# Auto-detect everything
config = PlatformConfig()
print(f"Using {config.num_devices} {config.device_type} device(s)")

# Explicitly configure
config = PlatformConfig(
    device_type="gpu",
    num_devices=2,
    gpu_memory_fraction=0.7,
    chain_method="parallel",
)

GPU memory management

For GPU users, hibayes configures JAX memory settings via environment variables:

  • XLA_PYTHON_CLIENT_PREALLOCATE=false: Disables memory preallocation for better memory management
  • XLA_PYTHON_CLIENT_MEM_FRACTION: Controls the fraction of GPU memory JAX can use

These are set automatically based on your gpu_memory_fraction setting.

Programmatic setup

The platform is configured automatically when using the CLI. For programmatic use:

from hibayes.platform import PlatformConfig, configure_computation_platform
from hibayes.ui import ModellingDisplay

config = PlatformConfig(device_type="cpu", num_devices=8)
display = ModellingDisplay()

configure_computation_platform(config, display)
1
Call this once before running any models. It sets up JAX/NumPyro device configuration.

Troubleshooting

GPU not detected: Ensure JAX is installed with GPU support (jax[cuda]). hibayes will automatically fall back to CPU if no GPU is found.

Out of memory on GPU: Reduce gpu_memory_fraction or switch to chain_method: sequential.

Slow on CPU: Ensure num_devices matches your available cores, and use chain_method: parallel.