15#ifndef B2SPLISS_SOLVER_H_
16#define B2SPLISS_SOLVER_H_
21#include "b2ppconfig.h"
22#include "b2sparse_solver.H"
23#include "spliss/spliss.h"
25namespace b2000::b2linalg {
32 std::string precond_type;
38class Spliss_LDLt_sparse_solver :
public LDLt_sparse_solver<T> {
40 Spliss_LDLt_sparse_solver() : updated_value(false) {}
42 ~Spliss_LDLt_sparse_solver() {}
45 size_t s,
size_t nnz,
const size_t* colind,
const size_t* rowind,
const T* value,
46 const int connectivity,
const Dictionary& dictionary) {
48 auto communicator = std::make_shared<spliss::LocalCommunicator>();
53 std::vector<std::pair<spliss::LocalIndexStoreT, spliss::GlobalIndexStoreT>>
55 size_t r_ptr = colind[0];
56 for (
size_t j = 0; j != s; ++j) {
57 for (
size_t r_end = colind[j + 1]; r_ptr != r_end; ++r_ptr) {
58 vector_of_positions.emplace_back(rowind[r_ptr], j);
61 indices.emplace_back(rowind[r_ptr], j);
63 if (rowind[r_ptr] != j) { vector_of_positions.emplace_back(j, rowind[r_ptr]); }
67 family = std::make_shared<const spliss::VectorFamily>(communicator, s, 1);
71 auto connect = std::make_shared<MatrixConnectivityT>(
72 family, family, CreateCSRPatternFromVectorOfPositions(*family, vector_of_positions));
74 auto connectivityDiag = std::make_shared<MatrixConnectivityT>(
75 family, family, spliss::CreateDiagonalSparsityPattern(*family));
78 auto colors = spliss::ComputeColors(*connect);
85 A = std::make_shared<MatrixT>(family, family, colors, connect, connectivityDiag);
87 A->GetAccess().SetZero();
89 conf.residuum = dictionary.get_double(
"SPLISS_RESIDUUM", 1.0E-8);
90 conf.iteration = dictionary.get_int(
"SPLISS_MAX_ITER", 1000000);
91 conf.verbose = dictionary.get_bool(
"SPLISS_VERBOSE",
false);
92 conf.precond_type = dictionary.get_string(
"SPLISS_PRECOND",
"LU_JACOBI_CG");
93 conf.relaxation = dictionary.get_double(
"SPLISS_RELAXATION", 0.8);
95 using VectorT =
typename MatrixT::SourceVectorT;
96 VectorT vec_format(family);
97 auto criterion = std::make_shared<spliss::ConservativeStopCriterion<T>>(
98 conf.residuum, spliss::ResidualCheckMode::Relative, conf.iteration, conf.verbose);
100 auto&& stack = spliss::InitSolverStack(A);
102 if (conf.precond_type ==
"ILU_CG") {
103 solver = stack.template Append<SimpleILUDecomposition>(A->GetOffDiagonalData())
104 .template Append<spliss::CG>(criterion)
106 }
else if (conf.precond_type ==
"LU_CG") {
107 solver = stack.template Append<spliss::LUDecomposition>()
108 .template Append<spliss::CG>(criterion)
110 }
else if (conf.precond_type ==
"CG") {
111 solver = stack.template Append<spliss::CG>(criterion).GetSolver();
112 }
else if (conf.precond_type ==
"JACOBI") {
113 solver = stack.template Append<spliss::LUDecomposition>()
114 .template Append<spliss::Jacobi>(criterion, conf.relaxation)
117 auto criterionJacobi = std::make_shared<spliss::ConfidentStopCriterion<T>>(5u);
118 solver = stack.template Append<spliss::LUDecomposition>()
119 .template Append<spliss::Jacobi>(criterionJacobi, conf.relaxation)
120 .template Append<spliss::CG>(criterion)
124 updated_value =
true;
127 void update_value() { updated_value =
true; }
130 size_t s,
size_t nrhs,
const T* b,
size_t ldb, T* x,
size_t ldx,
131 char left_or_right =
' ') {
133 if (s == 0) {
return; }
135 using VectorT =
typename MatrixT::SourceVectorT;
136 VectorT b_value(family), x_value(family);
139 auto a = A->GetAccess();
144 for (
const auto& i : indices) {
145 a(i.first, i.second)(0, 0) = A_reference[r_ptr];
147 if (i.first != i.second) {
148 a(i.second, i.first)(0, 0) = a(i.first, i.second)(0, 0);
153 updated_value =
false;
156 for (
size_t i = 0; i != nrhs; ++i) {
160 auto b_access = b_value.GetAccess();
161 for (
size_t j = 0; j != s; ++j) { b_access[j][0] = b[j + i * ldb]; }
163 auto status = solver->Apply(b_value, x_value);
165 std::cout <<
"Total number of iterations: " << status->LastIterationNumber()
168#ifdef SPLISS_DEBUG_OUTPUT
169 auto convergedData = status->GetConvergenceData();
171 for (
size_t k = 0; k != convergedData.size(); k++) {
172 std::cout <<
"Obtained solution: " << k <<
", " << convergedData[k] << std::endl;
176 if (status->GetReason() != spliss::SolverStoppingReason::RelativeReductionSuccess) {
178 e <<
"Solver stopped without success!" <<
THROW;
182 auto x_access = x_value.GetAccess();
183 for (
size_t j = 0; j != s; ++j) { x[j + i * ldx] = x_access[j][0]; }
190 std::shared_ptr<const spliss::VectorFamily> family;
192 using MatrixConnectivityT = spliss::MatrixConnectivity<spliss::CSRPattern>;
195 using StorageT = spliss::CompactBlockStorage<T, MatrixConnectivityT, 1, 1>;
198 using MatrixT = spliss::DataBasedMatrix<T, MatrixConnectivityT, 1, 1, T, StorageT>;
199 std::shared_ptr<MatrixT> A;
201 template <
typename ScalarType,
typename StorageType>
202 using SimpleILUDecomposition = spliss::ILUDecomposition<ScalarType, StorageType, StorageType>;
206 const T* A_reference;
208 std::vector<std::pair<size_t, size_t>> indices;
210 std::shared_ptr<spliss::SolverInterface<T>> solver;
214class Spliss_LDLt_extension_sparse_solver :
public LDLt_extension_sparse_solver<T>,
215 public Spliss_LDLt_sparse_solver<T> {
217 Spliss_LDLt_extension_sparse_solver() : Spliss_LDLt_sparse_solver<T>(), div(0) {}
220 size_t size_,
size_t nnz_,
const size_t* colind_,
const size_t* rowind_,
const T* value_,
221 size_t size_ext_,
const int connectivity,
const Dictionary& dictionary) {
224 Spliss_LDLt_sparse_solver<T>::init(
225 size_, nnz_, colind_, rowind_, value_, connectivity, dictionary);
226 m_ab.resize(size_ * 2);
229 void update_value() { Spliss_LDLt_sparse_solver<T>::update_value(); }
232 size_t s,
size_t nrhs,
const T* b,
size_t ldb, T* x,
size_t ldx,
const T* ma_ = 0,
233 const T* mb_ = 0,
const T* mc_ = 0,
char left_or_right =
' ') {
235 std::copy(ma_, ma_ + s - 1, &m_ab[0]);
236 std::copy(mb_, mb_ + s - 1, &m_ab[s - 1]);
237 Spliss_LDLt_sparse_solver<T>::resolve(s - 1, 2, &m_ab[0], s - 1, &m_ab[0], s - 1);
238 div = 1 / (*mc_ - blas::dot(s - 1, ma_, 1, &m_ab[s - 1], 1));
241 for (
size_t i = 0; i != nrhs; ++i) {
242 const T x2 = x[ldx * i + s - 1] =
243 div * (b[ldb * i + s - 1] - blas::dot(s - 1, b + ldb * i, 1, &m_ab[s - 1], 1));
244 Spliss_LDLt_sparse_solver<T>::resolve(s - 1, 1, b + ldb * i, ldb, x + ldx * i, ldx);
245 blas::axpy(s - 1, -x2, &m_ab[0], 1, x + ldx * i, 1);
255class Spliss_LDLt_sparse_solver<csda<double>> :
public LDLt_sparse_solver<csda<double>> {
257 Spliss_LDLt_sparse_solver() {
262 size_t s,
size_t nnz,
const size_t* colind,
const size_t* rowind,
263 const csda<double>* value,
const int connectivity,
const Dictionary& dictionary) {}
265 void update_value() {}
268 size_t s,
size_t nrhs,
const csda<double>* b,
size_t ldb, csda<double>* x,
size_t ldx,
269 char left_or_right =
' ') {
275class Spliss_LDLt_extension_sparse_solver<csda<double>>
276 :
public LDLt_extension_sparse_solver<csda<double>> {
278 Spliss_LDLt_extension_sparse_solver() {
283 size_t s,
size_t nnz,
const size_t* colind,
const size_t* rowind,
284 const csda<double>* value,
size_t,
const int connectivity,
const Dictionary& dictionary) {
287 void update_value() {}
290 size_t s,
size_t nrhs,
const csda<double>* b,
size_t ldb, csda<double>* x,
size_t ldx,
291 const csda<double>* ma = 0,
const csda<double>* mb = 0,
const csda<double>* mc = 0,
292 char left_or_right =
' ') {
#define THROW
Definition b2exception.H:198
Logger & get_logger(const std::string &logger_name="")
Definition b2logging.H:829
GenericException< UnimplementedError_name > UnimplementedError
Definition b2exception.H:314