17#ifndef B2PASTIX_SOLVER_H_
18#define B2PASTIX_SOLVER_H_
19#include "b2sparse_solver.H"
26namespace b2000 {
namespace b2linalg {
29class PASTIX_LDLt_seq_sparse_direct_solver :
public LDLt_sparse_solver<T> {
31 PASTIX_LDLt_seq_sparse_direct_solver() {
36 size_t s,
size_t nnz,
const size_t* colind,
const size_t* rowind,
const T* value,
37 const int connectivity,
const Dictionary& dictionary) {}
39 void update_value() {}
42 size_t s,
size_t nrhs,
const T* b,
size_t ldb, T* x,
size_t ldx,
43 char left_or_right =
' ') {
49class PASTIX_LDLt_seq_extension_sparse_direct_solver :
public LDLt_extension_sparse_solver<T> {
51 PASTIX_LDLt_seq_extension_sparse_direct_solver() {
56 size_t s,
size_t nnz,
const size_t* colind,
const size_t* rowind,
const T* value,
57 size_t s_ext,
const int connectivity,
const Dictionary& dictionary) {}
59 void update_value() {}
62 size_t s,
size_t nrhs,
const T* b,
size_t ldb, T* x,
size_t ldx,
const T* ma = 0,
63 const T* mb = 0,
const T* mc = 0,
char left_or_right =
' ') {
107class PASTIX_LDLt_seq_sparse_direct_solver<double> :
public LDLt_sparse_solver<double> {
110 PASTIX_LDLt_seq_sparse_direct_solver() {
112 pastixInitParam(iparm, dparm);
114 iparm[IPARM_FACTORIZATION] = PastixFactLDLT;
117 iparm[IPARM_THREAD_NBR] = 1;
120 iparm[IPARM_SCHEDULER] = PastixSchedSequential;
124 dparm[DPARM_EPSILON_MAGN_CTRL] = -1e-14;
126 pastixInit(&pastix_data, MPI_COMM_WORLD, iparm, dparm);
130 ~PASTIX_LDLt_seq_sparse_direct_solver() { pastixFinalize(&pastix_data); }
133 size_t s,
size_t nnz,
const size_t* colind,
const size_t* rowind,
const double* value,
134 const int connectivity,
const Dictionary& dictionary) {
135 if (s == 0) {
return; }
136 logging::Logger logger =
140 colptr_loc.assign(colind, colind + s + 1);
141 row.assign(rowind, rowind + nnz);
144 spm = std::make_unique<spmatrix_t>();
147 spm->values =
const_cast<double*
>(value);
153 spm->fmttype = SpmCSC;
156 spm->mtxtype = SpmSymmetric;
157 spm->colptr = colptr_loc.data();
158 spm->rowptr = row.data();
162 spm->flttype = SpmDouble;
165 spmUpdateComputedFields(spm.get());
167#ifdef PASTIX_DEBUG_OUTPUT
174 rc = spmCheckAndCorrect(spm.get(), &spm2);
182 FILE* myMatrix = fopen(
"spm.out",
"a");
183 spmSave(spm.get(), myMatrix);
191 pastix_task_analyze(pastix_data, spm.get());
197 pastix_task_numfact(pastix_data, spm.get());
199 updated_value =
false;
202 void update_value() { updated_value =
true; }
205 size_t s,
size_t nrhs,
const double* b,
size_t ldb,
double* x,
size_t ldx,
206 char left_or_right =
' ') {
207 if (s == 0) {
return; }
208 logging::Logger logger =
212 pastix_task_numfact(pastix_data, spm.get());
213 updated_value =
false;
216#ifdef PASTIX_DEBUG_OUTPUT
219 FILE* myMatrix = fopen(
"rhs.out",
"a");
220 spmPrintRHS(spm.get(), nrhs, x, ldx, myMatrix);
226 if (b != x) { std::copy(b, b + s, x); }
227 pastix_task_solve(pastix_data, nrhs, x, ldx);
230 std::vector<double> x_value(ldx * nrhs);
232 for (
size_t i = 0; i != nrhs; ++i) {
233 std::copy(b + i * ldb, b + i * ldb + s, x_value.begin() + i * ldx);
235 pastix_task_solve(pastix_data, nrhs, x_value.data(), ldx);
238 for (
size_t i = 0; i != nrhs; ++i) {
239 std::copy(x_value.begin() + i * ldx, x_value.begin() + i * ldx + s, x + i * ldx);
246 pastix_data_t* pastix_data{
nullptr};
247 bool updated_value{
false};
249 std::vector<int> colptr_loc;
250 std::vector<int> row;
252 std::unique_ptr<spmatrix_t> spm;
253 spm_int_t iparm[IPARM_SIZE];
254 double dparm[DPARM_SIZE];
258class PASTIX_LDLt_seq_extension_sparse_direct_solver<double>
259 :
public LDLt_extension_sparse_solver<double>,
260 public PASTIX_LDLt_seq_sparse_direct_solver<double> {
262 PASTIX_LDLt_seq_extension_sparse_direct_solver()
263 : PASTIX_LDLt_seq_sparse_direct_solver<double>(), div(0) {}
266 size_t size_,
size_t nnz_,
const size_t* colind_,
const size_t* rowind_,
267 const double* value_,
size_t size_ext_,
const int connectivity,
268 const Dictionary& dictionary) {
271 PASTIX_LDLt_seq_sparse_direct_solver<double>::init(
272 size_, nnz_, colind_, rowind_, value_, connectivity, dictionary);
273 m_ab.resize(size_ * 2);
276 void update_value() { PASTIX_LDLt_seq_sparse_direct_solver<double>::update_value(); }
279 size_t s,
size_t nrhs,
const double* b,
size_t ldb,
double* x,
size_t ldx,
280 const double* ma_ = 0,
const double* mb_ = 0,
const double* mc_ = 0,
281 char left_or_right =
' ') {
283 std::copy(ma_, ma_ + s - 1, &m_ab[0]);
284 std::copy(mb_, mb_ + s - 1, &m_ab[s - 1]);
285 PASTIX_LDLt_seq_sparse_direct_solver<double>::resolve(
286 s - 1, 2, &m_ab[0], s - 1, &m_ab[0], s - 1);
287 div = 1 / (*mc_ - blas::dot(s - 1, ma_, 1, &m_ab[s - 1], 1));
290 for (
size_t i = 0; i != nrhs; ++i) {
291 const double x2 = x[ldx * i + s - 1] =
292 div * (b[ldb * i + s - 1] - blas::dot(s - 1, b + ldb * i, 1, &m_ab[s - 1], 1));
293 PASTIX_LDLt_seq_sparse_direct_solver<double>::resolve(
294 s - 1, 1, b + ldb * i, ldb, x + ldx * i, ldx);
295 blas::axpy(s - 1, -x2, &m_ab[0], 1, x + ldx * i, 1);
300 std::vector<double> m_ab;
#define THROW
Definition b2exception.H:198
Logger & get_logger(const std::string &logger_name="")
Definition b2logging.H:829
Contains the base classes for implementing Finite Elements.
Definition b2boundary_condition.H:32
GenericException< UnimplementedError_name > UnimplementedError
Definition b2exception.H:314