b2api
B2000++ API Reference Manual, VERSION 4.6
 
Loading...
Searching...
No Matches
b2pastix_solver.H
1//------------------------------------------------------------------------
2// b2pastix_solver.H --
3//
4// written by Mathias Doreille
5//
6// (c) 2011 SMR Engineering & Development SA
7// 2502 Bienne, Switzerland
8//
9// All Rights Reserved. Proprietary source code. The contents of
10// this file may not be disclosed to third parties, copied or
11// duplicated in any form, in whole or in part, without the prior
12// written permission of SMR.
13//------------------------------------------------------------------------
14
15// TODO Pastix is only implemented for sequential runs!
16// MPI is not supported!
17#ifndef B2PASTIX_SOLVER_H_
18#define B2PASTIX_SOLVER_H_
19#include "b2sparse_solver.H"
20
21extern "C" {
22#include <pastix.h>
23#include <spm.h>
24}
25
26namespace b2000 { namespace b2linalg {
27
28template <typename T>
29class PASTIX_LDLt_seq_sparse_direct_solver : public LDLt_sparse_solver<T> {
30public:
31 PASTIX_LDLt_seq_sparse_direct_solver() {
32 UnimplementedError() << "The Pastix library is not enabled." << THROW;
33 }
34
35 void init(
36 size_t s, size_t nnz, const size_t* colind, const size_t* rowind, const T* value,
37 const int connectivity, const Dictionary& dictionary) {}
38
39 void update_value() {}
40
41 void resolve(
42 size_t s, size_t nrhs, const T* b, size_t ldb, T* x, size_t ldx,
43 char left_or_right = ' ') {
45 }
46};
47
48template <typename T>
49class PASTIX_LDLt_seq_extension_sparse_direct_solver : public LDLt_extension_sparse_solver<T> {
50public:
51 PASTIX_LDLt_seq_extension_sparse_direct_solver() {
52 UnimplementedError() << "The Pastix library is not enabled." << THROW;
53 }
54
55 void init(
56 size_t s, size_t nnz, const size_t* colind, const size_t* rowind, const T* value,
57 size_t s_ext, const int connectivity, const Dictionary& dictionary) {}
58
59 void update_value() {}
60
61 void resolve(
62 size_t s, size_t nrhs, const T* b, size_t ldb, T* x, size_t ldx, const T* ma = 0,
63 const T* mb = 0, const T* mc = 0, char left_or_right = ' ') {
65 }
66};
67
68// TODO This might be later needed, if we want to provide something to pastix from command line
69// Get iparm and dparm
70
71/*
72 inline void decode_pastix_param(const Dictionary& d, int* iparam, double* dparam) {
73 std::string icntl = d.get_string("PASTIX_IPARAM", "");
74 if (icntl != "") {
75 std::istringstream iss(icntl);
76 for (;;) {
77 int k, v;
78 char c;
79 if (iss >> k && iss >> c && iss >> v) {
80 if (c != ':' || (k < 0 && k > 65))
81 Exception() << "Cannot parse the PASTIX_IPARAM analysis directive " << icntl <<
82THROW; iparam[k] = v; } else Exception() << "Cannot parse the PASTIX_IPARAM analysis directive " <<
83icntl << THROW; iss >> c; if (!iss) break; if (c != ',') Exception() << "Cannot parse the
84PASTIX_IPARAM analysis directive " << icntl << THROW;
85 }
86 }
87 //Get the DPARAM string from the directory d
88 std::string cntl = d.get_string("PASTIX_DPARAM", "");
89 if (cntl != "") {
90 std::istringstream dss(cntl);
91 for (;;) {
92 int k;
93 double v;
94 char c;
95 if (dss >> k && dss >> c && dss >> v) {
96 if (c != ':' || (k < 0 && k > 24))
97 Exception() << "Cannot parse the PASTIX_DPARAM analysis directive " << cntl <<
98THROW; dparam[k] = v; } else Exception() << "Cannot parse the PASTIX_DPARAM analysis directive " <<
99cntl << THROW; dss >> c; if (!dss) break; if (c != ',') Exception() << "Cannot parse the
100PASTIX_DPARAM analysis directive " << cntl << THROW;
101 }
102 }
103}
104*/
105
106template <>
107class PASTIX_LDLt_seq_sparse_direct_solver<double> : public LDLt_sparse_solver<double> {
108public:
109 // constructor
110 PASTIX_LDLt_seq_sparse_direct_solver() {
111 // initialize iparm and dparm
112 pastixInitParam(iparm, dparm);
113 // We have LDL^t => symmetric matrix type
114 iparm[IPARM_FACTORIZATION] = PastixFactLDLT;
115 // Number of threads to use for computation
116 // TODO Adjust later for shared memory
117 iparm[IPARM_THREAD_NBR] = 1;
118 // Sequential run
119 // TODO Adjust later for parallel execution
120 iparm[IPARM_SCHEDULER] = PastixSchedSequential;
121 // Factor to prevent pivoting, criteria relative to
122 // vector norm, related to the convergence of the
123 // equation system!
124 dparm[DPARM_EPSILON_MAGN_CTRL] = -1e-14;
125 // Initialize pastix
126 pastixInit(&pastix_data, MPI_COMM_WORLD, iparm, dparm);
127 }
128
129 // destructor
130 ~PASTIX_LDLt_seq_sparse_direct_solver() { pastixFinalize(&pastix_data); }
131
132 void init(
133 size_t s, size_t nnz, const size_t* colind, const size_t* rowind, const double* value,
134 const int connectivity, const Dictionary& dictionary) {
135 if (s == 0) { return; }
136 logging::Logger logger =
137 logging::get_logger("linear_algebra.sparse_symmetric_solver.pastix");
138
139 // Assign colind and rowind to colptr_loc and row
140 colptr_loc.assign(colind, colind + s + 1);
141 row.assign(rowind, rowind + nnz);
142
143 // Initialize and allocate space for the sparse matrix
144 spm = std::make_unique<spmatrix_t>();
145 spmInit(spm.get());
146
147 spm->values = const_cast<double*>(value);
148 // Number of vertices in the local graph
149 spm->n = s;
150 // Number of non-zero entries
151 spm->nnz = nnz;
152 // Compressed sparse column format
153 spm->fmttype = SpmCSC;
154 // Matrix type, symmetric but not necessarily
155 // positive-definite
156 spm->mtxtype = SpmSymmetric;
157 spm->colptr = colptr_loc.data();
158 spm->rowptr = row.data();
159 // Basval is 1 for Fortran notation and 0 for C
160 spm->baseval = 0;
161 // Our datatype is of double, see definition of template
162 spm->flttype = SpmDouble;
163
164 // Set NON computed values (e.g. gnnz or gN)
165 spmUpdateComputedFields(spm.get());
166
167#ifdef PASTIX_DEBUG_OUTPUT
168 /* Use check and correct only if the matrix is not
169 * symmetric. It is a costly call and should be avoided!
170 * Useful for debugging! */
171 int rc = 0;
172 spmatrix_t spm2;
173 // Correct spm to correspond to PaStiX standards
174 rc = spmCheckAndCorrect(spm.get(), &spm2);
175 if (rc != 0) {
176 spmExit(spm.get());
177 *spm = spm2;
178 rc = 0;
179 }
180
181 // Dump the matrix to a file, used for debugging!
182 FILE* myMatrix = fopen("spm.out", "a");
183 spmSave(spm.get(), myMatrix);
184 fclose(myMatrix);
185
186#endif // PASTIX_DEBUG_OUTPUT
187
188 /*Perform all the preprocessing steps:
189 * ordering, symbolic factorization,
190 * reordering, proportionnal mapping */
191 pastix_task_analyze(pastix_data, spm.get());
192
193 /*Perform all the numerical factorization
194 * steps: fill the internal block CSC and
195 * the solver matrix structures, then apply
196 * the factorization step. */
197 pastix_task_numfact(pastix_data, spm.get());
198
199 updated_value = false;
200 }
201
202 void update_value() { updated_value = true; }
203
204 void resolve(
205 size_t s, size_t nrhs, const double* b, size_t ldb, double* x, size_t ldx,
206 char left_or_right = ' ') {
207 if (s == 0) { return; }
208 logging::Logger logger =
209 logging::get_logger("linear_algebra.sparse_symmetric_solver.pastix");
210
211 if (updated_value) {
212 pastix_task_numfact(pastix_data, spm.get());
213 updated_value = false;
214 }
215
216#ifdef PASTIX_DEBUG_OUTPUT
217 // Dump the rhs to a file, meant for debugging!
218
219 FILE* myMatrix = fopen("rhs.out", "a");
220 spmPrintRHS(spm.get(), nrhs, x, ldx, myMatrix);
221 fclose(myMatrix);
222
223#endif // PASTIX_DEBUG_OUTPUT
224
225 if (nrhs == 1) {
226 if (b != x) { std::copy(b, b + s, x); }
227 pastix_task_solve(pastix_data, nrhs, x, ldx);
228
229 } else {
230 std::vector<double> x_value(ldx * nrhs);
231
232 for (size_t i = 0; i != nrhs; ++i) {
233 std::copy(b + i * ldb, b + i * ldb + s, x_value.begin() + i * ldx);
234 }
235 pastix_task_solve(pastix_data, nrhs, x_value.data(), ldx);
236
237 // Copy the data back to x
238 for (size_t i = 0; i != nrhs; ++i) {
239 std::copy(x_value.begin() + i * ldx, x_value.begin() + i * ldx + s, x + i * ldx);
240 }
241 }
242 }
243
244protected:
245 // Pointer to a storage structure
246 pastix_data_t* pastix_data{nullptr};
247 bool updated_value{false};
248 // Used to copy rowind and colind
249 std::vector<int> colptr_loc;
250 std::vector<int> row;
251 // Sparse matrix storage
252 std::unique_ptr<spmatrix_t> spm;
253 spm_int_t iparm[IPARM_SIZE];
254 double dparm[DPARM_SIZE];
255};
256
257template <>
258class PASTIX_LDLt_seq_extension_sparse_direct_solver<double>
259 : public LDLt_extension_sparse_solver<double>,
260 public PASTIX_LDLt_seq_sparse_direct_solver<double> {
261public:
262 PASTIX_LDLt_seq_extension_sparse_direct_solver()
263 : PASTIX_LDLt_seq_sparse_direct_solver<double>(), div(0) {}
264
265 void init(
266 size_t size_, size_t nnz_, const size_t* colind_, const size_t* rowind_,
267 const double* value_, size_t size_ext_, const int connectivity,
268 const Dictionary& dictionary) {
269 if (size_ext_ != 1) { UnimplementedError() << THROW; }
270
271 PASTIX_LDLt_seq_sparse_direct_solver<double>::init(
272 size_, nnz_, colind_, rowind_, value_, connectivity, dictionary);
273 m_ab.resize(size_ * 2);
274 }
275
276 void update_value() { PASTIX_LDLt_seq_sparse_direct_solver<double>::update_value(); }
277
278 void resolve(
279 size_t s, size_t nrhs, const double* b, size_t ldb, double* x, size_t ldx,
280 const double* ma_ = 0, const double* mb_ = 0, const double* mc_ = 0,
281 char left_or_right = ' ') {
282 if (mc_ != 0) {
283 std::copy(ma_, ma_ + s - 1, &m_ab[0]);
284 std::copy(mb_, mb_ + s - 1, &m_ab[s - 1]);
285 PASTIX_LDLt_seq_sparse_direct_solver<double>::resolve(
286 s - 1, 2, &m_ab[0], s - 1, &m_ab[0], s - 1);
287 div = 1 / (*mc_ - blas::dot(s - 1, ma_, 1, &m_ab[s - 1], 1));
288 }
289
290 for (size_t i = 0; i != nrhs; ++i) {
291 const double x2 = x[ldx * i + s - 1] =
292 div * (b[ldb * i + s - 1] - blas::dot(s - 1, b + ldb * i, 1, &m_ab[s - 1], 1));
293 PASTIX_LDLt_seq_sparse_direct_solver<double>::resolve(
294 s - 1, 1, b + ldb * i, ldb, x + ldx * i, ldx);
295 blas::axpy(s - 1, -x2, &m_ab[0], 1, x + ldx * i, 1);
296 }
297 }
298
299protected:
300 std::vector<double> m_ab;
301 double div;
302};
303
304}} // namespace b2000::b2linalg
305
306#endif /* B2PASTIX_SOLVER_H_ */
#define THROW
Definition b2exception.H:198
Logger & get_logger(const std::string &logger_name="")
Definition b2logging.H:829
Contains the base classes for implementing Finite Elements.
Definition b2boundary_condition.H:32
GenericException< UnimplementedError_name > UnimplementedError
Definition b2exception.H:314