Source code for cosmocore.mpi_utils

"""MPI utilities for shared memory and array broadcasting."""

from __future__ import annotations

import numpy as np

from ._mpi import MPI


[docs] class MPISharedMemoryMixin: """Mixin providing shared memory array distribution for MPI classes. Requires ``self.comm`` and ``self.rank``. ``_shared_array`` places one copy of each array per node in shared memory; intra-node ranks get a zero-copy view, and a single buffer-based ``Bcast`` seeds each node's local rank 0. ``_bcast_array`` is the buffer-based broadcast fallback. Call :meth:`close` (or use as a context manager) to release windows and split communicators. """ # ------------------------------------------------------------------ # Shared memory (intra-node, zero-copy) # ------------------------------------------------------------------ def _setup_shared_comm(self): """Create intra- and inter-node communicators, if not already done.""" if hasattr(self, "_shared_comm"): return self._shared_comm = self.comm.Split_type(MPI.COMM_TYPE_SHARED) self._shared_wins: list[MPI.Win] = [] # Inter-node communicator: node-local rank 0 only. Using # self.rank as the key keeps global rank 0 as inter-comm root. node_rank = self._shared_comm.Get_rank() color = 0 if node_rank == 0 else MPI.UNDEFINED self._inter_node_comm = self.comm.Split(color, self.rank) def _shared_array(self, arr: np.ndarray | None = None) -> np.ndarray | None: """Share a numpy array via MPI shared memory (intra-node, zero-copy). Each node's local rank 0 allocates the shared buffer and is seeded with the data (either directly on the global root or via a buffer-based inter-node ``Bcast``). All other ranks attach read-only — no inter-rank copies on a node. Parameters ---------- arr : numpy.ndarray or None Array to share. Only global rank 0 needs to pass the actual array; other ranks may pass ``None`` (shape and dtype are broadcast from rank 0). If rank 0 also passes ``None``, all ranks receive ``None`` (nothing to share). """ self._setup_shared_comm() comm_node = self._shared_comm inter_comm = self._inter_node_comm has_data = self.comm.bcast(arr is not None if self.rank == 0 else None, root=0) if not has_data: return None is_root = self.rank == 0 shape = self.comm.bcast(arr.shape if is_root else None, root=0) dtype = self.comm.bcast(arr.dtype if is_root else None, root=0) nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize alloc = nbytes if comm_node.Get_rank() == 0 else 0 win = MPI.Win.Allocate_shared(alloc, np.dtype(dtype).itemsize, comm=comm_node) self._shared_wins.append(win) buf, _ = win.Shared_query(0) shared = np.ndarray(shape, dtype=dtype, buffer=buf) # Seed each node's shared buffer: global root copies in its # array, then node-roots exchange via buffer-based Bcast. if comm_node.Get_rank() == 0: if self.rank == 0: shared[:] = arr inter_comm.Bcast(shared, root=0) comm_node.Barrier() return shared def _cleanup_shared(self) -> None: """Free all shared-memory windows and split communicators. Idempotent — safe to call multiple times. """ for win in getattr(self, "_shared_wins", []): try: win.Free() except Exception: pass self._shared_wins = [] inter = getattr(self, "_inter_node_comm", None) if inter is not None and inter != MPI.COMM_NULL: try: inter.Free() except Exception: pass if hasattr(self, "_inter_node_comm"): del self._inter_node_comm shared_comm = getattr(self, "_shared_comm", None) if shared_comm is not None: try: shared_comm.Free() except Exception: pass finally: del self._shared_comm
[docs] def close(self) -> None: """Release all MPI shared-memory resources owned by this instance.""" self._cleanup_shared()
def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback) -> bool: self.close() return False # ------------------------------------------------------------------ # Buffer-based broadcast (fallback / non-shared use) # ------------------------------------------------------------------ def _bcast_array(self, arr: np.ndarray | None = None) -> np.ndarray: """Broadcast a numpy array using buffer-based MPI. Uses ``comm.Bcast`` (uppercase) which sends raw memory buffers instead of ``comm.bcast`` (lowercase) which serializes via the standard library. This avoids the ~2 GB message-size limit that affects serialization-based broadcasts in many MPI implementations. Parameters ---------- arr : numpy.ndarray or None Array to broadcast. Only rank 0 needs to pass the actual array; other ranks may pass ``None``. """ is_root = self.rank == 0 and arr is not None shape = self.comm.bcast(arr.shape if is_root else None, root=0) dtype = self.comm.bcast(arr.dtype if is_root else None, root=0) if self.rank != 0: arr = np.empty(shape, dtype=dtype) self.comm.Bcast(arr, root=0) return arr