A PyTorch library for differentiable sparse linear system solvers with automatic differentiation support.
- Automatic differentiation support for sparse linear system solvers
- Compatible with multiple solver backends (SciPy, PyPardiso, custom solvers)
- GPU support for PyTorch sparse tensors
- Seamless integration with PyTorch's autograd system
- Python 3.10 or above
- Conda environment (recommended)
- Note: PyPardiso requires Intel MKL and works best on Intel systems. For broader compatibility across different architectures, use SciPy solvers.
conda create -n torch_sparse_solve_autograd python=3.10
conda activate torch_sparse_solve_autogradYou can install the package directly from GitHub without cloning:
pip install git+https://github.com/AkshayK325/torch_sparse_solve_autograd.gitgit clone https://github.com/AkshayK325/torch_sparse_solve_autograd.git
cd torch_sparse_solve_autogradpip install -r requirements.txtThe requirements include:
numpy— Numerical computingscipy— Scientific computing and sparse matrix operations (cross-platform)torch— PyTorch deep learning framework with GPU supporttorchvision— PyTorch vision utilitiespypardiso— Parallel sparse direct solver (Intel systems only, requires Intel MKL)
Solver Compatibility:
- SciPy solvers (
sp.sparse.linalg.spsolve,sp.sparse.linalg.cg): Works on all platforms (Intel, AMD, ARM) - PyPardiso (
pypardiso.spsolve): Requires Intel MKL, optimized for Intel processors - Custom solvers: You can provide your own sparse solver function
import torch
from torch_sparse_solve_autograd import DifferentiableSparseSolve
import scipy as sp
torch.manual_seed(42)
n = 4
r = torch.ones(n, dtype=torch.float64, requires_grad=True)
mask = torch.tril(torch.ones(n, n, dtype=torch.float64))
A = (r[:, None] * mask * torch.randn(n, n, dtype=torch.float64)).to_sparse()
b = torch.randn(n, dtype=torch.float64)
print("Shape of A:", A.shape)
print("Shape of b:", b.shape)
# Solve Ax = b with automatic differentiation support
x = DifferentiableSparseSolve.apply(A, b, sp.sparse.linalg.spsolve)
# Compare to torch.solve:
A_dense = A.to_dense()
print("Difference between DifferentiableSparseSolve and torch.linalg.solve:")
print(x - torch.linalg.solve(A_dense, b))
# Backward pass support
loss = x.sum()
loss.backward()
print("Loss value:", loss.item())
print("Gradient of loss with respect to r:", r.grad)For a more comprehensive example including optimization problems with sparse linear systems, see exampleCode.py. This example demonstrates:
- Creating large sparse systems (1D Laplacian)
- Using different solvers (PyPardiso, SciPy solvers, custom solvers)
- Shape optimization with automatic differentiation
- Gradient verification with finite differences
This project is licensed under the GNU General Public License v3.0 (GPL-3.0). See the LICENSE file for details.