Source code for cudass.backends.cudss_backend

"""cuDSS backend: COO->CSR, ANALYSIS/FACTORIZATION/SOLVE, multiple RHS.

Prefers our cudss_bindings; falls back to nvmath's cudss only if our bindings
are not built/importable.
"""

from typing import Any, Optional, Tuple

import torch

from cudass.backends.base import BackendBase
from cudass.cuda.cuda_types import CUDA_R_32F, CUDA_R_32I, CUDA_R_64F
from cudass.types import MatrixType


# Pre-load libcudss.so.0 from nvidia-cudss-cu* so the Cython extension finds it
# without requiring users to set LD_LIBRARY_PATH manually.
def _preload_cudss_lib():
    import ctypes
    import os
    import site

    for cu in ("cu12", "cu13"):
        for base in site.getsitepackages():
            lib = os.path.join(base, "nvidia", cu, "lib", "libcudss.so.0")
            if os.path.isfile(lib):
                try:
                    ctypes.CDLL(lib, mode=ctypes.RTLD_GLOBAL)
                    return
                except OSError:
                    pass

try:
    _preload_cudss_lib()
except Exception:
    pass

# Prefer our own cudss_bindings; fall back to nvmath only if ours are not available
_cudss = None
_cudss_import_error = None

try:
    import cudass.cuda.bindings.cudss_bindings as _cudss
except ImportError as e:
    _cudss_import_error = str(e)
    _nvmath_cudss = None
    try:
        from nvmath import cudss as _nvmath_cudss
    except ImportError:
        try:
            from nvmath.bindings import cudss as _nvmath_cudss
        except ImportError:
            pass

    if _nvmath_cudss is not None:
        # Adapter: nvmath -> our bindings API (create_handle, matrix_create_csr row_end_ptr=, etc.)
        class _CudssApi:
            pass

        _api = _CudssApi()
        _api.create_handle = _nvmath_cudss.create
        _api.destroy_handle = _nvmath_cudss.destroy
        _api.config_create = _nvmath_cudss.config_create
        _api.config_destroy = _nvmath_cudss.config_destroy
        _api.data_create = _nvmath_cudss.data_create
        _api.data_destroy = _nvmath_cudss.data_destroy
        _api.set_stream = _nvmath_cudss.set_stream
        _api.matrix_destroy = _nvmath_cudss.matrix_destroy
        _api.matrix_set_values = _nvmath_cudss.matrix_set_values
        _api.execute = _nvmath_cudss.execute
        _api.CUDA_R_32F = CUDA_R_32F
        _api.CUDA_R_64F = CUDA_R_64F
        _api.CUDA_R_32I = CUDA_R_32I
        _api.CUDSS_MTYPE_GENERAL = int(_nvmath_cudss.MatrixType.GENERAL)
        _api.CUDSS_MTYPE_SYMMETRIC = int(_nvmath_cudss.MatrixType.SYMMETRIC)
        _api.CUDSS_MTYPE_SPD = int(_nvmath_cudss.MatrixType.SPD)
        _api.CUDSS_MVIEW_FULL = int(_nvmath_cudss.MatrixViewType.FULL)
        _api.CUDSS_BASE_ZERO = int(_nvmath_cudss.IndexBase.ZERO)
        _api.CUDSS_LAYOUT_COL_MAJOR = int(_nvmath_cudss.Layout.COL_MAJOR)
        _api.CUDSS_PHASE_ANALYSIS = int(_nvmath_cudss.Phase.ANALYSIS)
        _api.CUDSS_PHASE_FACTORIZATION = int(_nvmath_cudss.Phase.FACTORIZATION)
        _api.CUDSS_PHASE_SOLVE = int(_nvmath_cudss.Phase.SOLVE)

        def _nvmath_matrix_create_csr(
            m, n, nnz, row_start_ptr, col_indices_ptr, values_ptr,
            index_type, value_type, mtype, mview, index_base, row_end_ptr=0,
        ):
            return _nvmath_cudss.matrix_create_csr(
                m, n, nnz, row_start_ptr, 0, col_indices_ptr, values_ptr,
                index_type, value_type, mtype, mview, index_base,
            )

        _api.matrix_create_csr = _nvmath_matrix_create_csr

        def _nvmath_matrix_create_dn(nrows, ncols, ld, values_ptr, vt, layout):
            return _nvmath_cudss.matrix_create_dn(nrows, ncols, ld, values_ptr, vt, layout)

        _api.matrix_create_dn = _nvmath_matrix_create_dn
        _cudss = _api
        _cudss_import_error = None


def _coo_to_csr(
    index: torch.Tensor, value: torch.Tensor, m: int, n: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """COO (index [2,nnz], value [nnz]) -> CSR rowStart (m+1), colIndices (nnz), values (nnz).

    Merges duplicate (i,j) by summing. cuDSS requires no duplicate column indices in a row.

    Args:
        index: COO indices [2, nnz].
        value: COO values [nnz].
        m: Number of rows.
        n: Number of columns.

    Returns:
        Tuple of (row_start [m+1], col_indices [nnz], values [nnz]).
    """
    nnz = index.shape[1]
    if nnz == 0:
        row_start = torch.zeros(m + 1, dtype=torch.int32, device=index.device)
        col_indices = torch.empty(0, dtype=torch.int32, device=index.device)
        values = torch.empty(0, dtype=value.dtype, device=value.device)
        return row_start, col_indices, values
    perm = torch.argsort(index[0] * (max(n, m) + 1) + index[1])
    idx_sorted = index[:, perm]
    val_sorted = value[perm].contiguous()
    rows, cols = idx_sorted[0], idx_sorted[1]
    # Coalesce duplicate (row,col): sum values. cuDSS forbids duplicates in a row.
    keys = rows * (max(n, m) + 1) + cols
    if keys.numel() > 1:
        run_end = (keys[1:] != keys[:-1]).nonzero(as_tuple=True)[0] + 1
        run_start = torch.cat([torch.tensor([0], device=keys.device, dtype=run_end.dtype), run_end])
        tail = torch.tensor([keys.numel()], device=keys.device, dtype=run_end.dtype)
        run_end = torch.cat([run_end, tail])
        rows = rows[run_start]
        cols = cols[run_start]
        pairs = zip(run_start.tolist(), run_end.tolist())
        values = torch.stack([val_sorted[s:e].sum() for s, e in pairs])
    else:
        values = val_sorted
    counts = torch.bincount(rows, minlength=m).to(torch.int32)
    z = torch.zeros(1, dtype=torch.int32, device=index.device)
    # cumsum of int32 promotes to int64; cuDSS with CUDA_R_32I expects int32
    row_start = torch.cat([z, torch.cumsum(counts, 0)], dim=0).to(torch.int32)
    col_indices = cols.to(torch.int32).contiguous()
    return row_start, col_indices, values


def _matrix_type_to_cudss_mtype(matrix_type: MatrixType) -> int:
    if matrix_type == MatrixType.GENERAL or matrix_type == MatrixType.GENERAL_RECTANGULAR:
        return _cudss.CUDSS_MTYPE_GENERAL
    if matrix_type == MatrixType.SYMMETRIC:
        return _cudss.CUDSS_MTYPE_SYMMETRIC
    if matrix_type == MatrixType.SPD:
        return _cudss.CUDSS_MTYPE_SPD
    raise ValueError(f"cuDSS does not support {matrix_type}")


def _dtype_to_value_type(dtype: torch.dtype) -> int:
    if dtype == torch.float32:
        return _cudss.CUDA_R_32F
    if dtype == torch.float64:
        return _cudss.CUDA_R_64F
    raise ValueError(f"dtype must be float32 or float64, got {dtype}")


[docs] class CUDSSBackend(BackendBase): """cuDSS backend for GENERAL, SYMMETRIC, SPD, and tentative GENERAL_RECTANGULAR.""" @property def backend_name(self) -> str: return "cudss" def __init__( self, matrix_type: MatrixType, device: torch.device, dtype: torch.dtype, use_cache: bool = True, cache: Optional[Any] = None, ): if _cudss is None: raise RuntimeError(f"cudss_bindings not available: {_cudss_import_error}") self._matrix_type = matrix_type self._device = device self._dtype = dtype self._use_cache = use_cache self._cache = cache self._handle: Optional[int] = None self._config: Optional[int] = None self._data: Optional[int] = None self._mat_a: Optional[int] = None self._row_start: Optional[torch.Tensor] = None self._col_indices: Optional[torch.Tensor] = None self._values: Optional[torch.Tensor] = None # keep CSR values alive; cuDSS does not copy self._m = 0 self._n = 0 self._nnz = 0 @property def device(self) -> torch.device: return self._device @property def dtype(self) -> torch.dtype: return self._dtype def _ensure_handle_config_data(self) -> None: if self._handle is not None: return self._handle = _cudss.create_handle() self._config = _cudss.config_create() self._data = _cudss.data_create(self._handle) stream = torch.cuda.current_stream(self._device).cuda_stream _cudss.set_stream(self._handle, stream) def _destroy_mat_a(self) -> None: if self._mat_a is not None: _cudss.matrix_destroy(self._mat_a) self._mat_a = None self._row_start = None self._col_indices = None self._values = None def _cache_key(self, index: torch.Tensor, m: int, n: int) -> str: h = hash((m, n, index.shape[1], index.data_ptr())) return f"cudss_{m}_{n}_{h}"
[docs] def update_matrix( self, A_sparse: Tuple[torch.Tensor, torch.Tensor, int, int], structure_changed: bool = False, ) -> None: index, value, m, n = A_sparse vt = _dtype_to_value_type(value.dtype) mtype = _matrix_type_to_cudss_mtype(self._matrix_type) row_start, col_indices, values = _coo_to_csr(index, value, m, n) nnz = values.shape[0] # after coalesce, may differ from index.shape[1] def _run_phase(phase: int) -> None: buf_b = torch.zeros(n, 1, dtype=values.dtype, device=self._device) buf_x = torch.zeros(n, 1, dtype=values.dtype, device=self._device) layout = _cudss.CUDSS_LAYOUT_COL_MAJOR mb = _cudss.matrix_create_dn(n, 1, n, buf_b.data_ptr(), vt, layout) mx = _cudss.matrix_create_dn(n, 1, n, buf_x.data_ptr(), vt, layout) try: _cudss.execute( self._handle, phase, self._config, self._data, self._mat_a, mx, mb, ) finally: _cudss.matrix_destroy(mb) _cudss.matrix_destroy(mx) cache = self._cache ck = self._cache_key(index, m, n) if cache else None if self._use_cache and cache and ck: cached = cache.get(ck, self._device) if cached is not None and not structure_changed: (h, cfg, d, ma, rs, ci) = cached self._handle, self._config, self._data = h, cfg, d self._mat_a = ma self._row_start, self._col_indices = rs, ci self._values = values _cudss.matrix_set_values(self._mat_a, values.data_ptr()) _run_phase(_cudss.CUDSS_PHASE_FACTORIZATION) self._m, self._n, self._nnz = m, n, nnz return self._ensure_handle_config_data() struct_or_size = ( structure_changed or self._mat_a is None or self._m != m or self._n != n or self._nnz != nnz ) if struct_or_size: self._destroy_mat_a() if self._data is not None: _cudss.data_destroy(self._handle, self._data) self._data = _cudss.data_create(self._handle) # cuDSS CSR: rowStart (m+1). rowEnd: pass 0 (NULL); cuDSS uses rowStart[i+1] as end. self._mat_a = _cudss.matrix_create_csr( m, n, nnz, row_start.data_ptr(), col_indices.data_ptr(), values.data_ptr(), _cudss.CUDA_R_32I, vt, mtype, _cudss.CUDSS_MVIEW_FULL, _cudss.CUDSS_BASE_ZERO, row_end_ptr=0, ) self._row_start = row_start self._col_indices = col_indices self._values = values _run_phase(_cudss.CUDSS_PHASE_ANALYSIS) _run_phase(_cudss.CUDSS_PHASE_FACTORIZATION) else: self._values = values _cudss.matrix_set_values(self._mat_a, values.data_ptr()) _run_phase(_cudss.CUDSS_PHASE_FACTORIZATION) self._m, self._n, self._nnz = m, n, nnz if self._use_cache and cache and ck: ent = ( self._handle, self._config, self._data, self._mat_a, self._row_start, self._col_indices, ) cache.put(ck, ent, self._device)
[docs] def solve(self, b: torch.Tensor) -> torch.Tensor: n = self._n if b.dim() == 1: nrhs = 1 b_ = b.unsqueeze(1).contiguous() else: nrhs = b.shape[1] b_ = b.contiguous() # cuDSS expects column-major: column j stored at [j*n : (j+1)*n]. # b_ is [n, nrhs] row-major; b_.T.contiguous() gives [nrhs, n] whose first n # elements are b_[:,0], i.e. column-major layout for n x nrhs. b_t = b_.T.contiguous() x_buf = torch.empty(nrhs, n, device=b.device, dtype=b.dtype) vt = _dtype_to_value_type(b.dtype) _cudss.set_stream( self._handle, int(torch.cuda.current_stream(self._device).cuda_stream), ) mat_b = _cudss.matrix_create_dn( n, nrhs, n, b_t.data_ptr(), vt, _cudss.CUDSS_LAYOUT_COL_MAJOR ) mat_x = _cudss.matrix_create_dn( n, nrhs, n, x_buf.data_ptr(), vt, _cudss.CUDSS_LAYOUT_COL_MAJOR ) try: _cudss.execute( self._handle, _cudss.CUDSS_PHASE_SOLVE, self._config, self._data, self._mat_a, mat_x, mat_b, ) finally: _cudss.matrix_destroy(mat_b) _cudss.matrix_destroy(mat_x) torch.cuda.current_stream(self._device).synchronize() x = x_buf.T return x.squeeze(1) if nrhs == 1 else x