b2api
B2000++ API Reference Manual, VERSION 4.6
 
Loading...
Searching...
No Matches
b2umfpack_solver.H
1//------------------------------------------------------------------------
2// b2umfpack_solver.H --
3//
4//
5// written by Mathias Doreille
6//
7// Copyright (c) 2004-2012 SMR Engineering & Development SA
8// 2502 Bienne, Switzerland
9//
10// All Rights Reserved. Proprietary source code. The contents of
11// this file may not be disclosed to third parties, copied or
12// duplicated in any form, in whole or in part, without the prior
13// written permission of SMR.
14//------------------------------------------------------------------------
15
16#ifndef __B2UMFPACK_SOLVER_H__
17#define __B2UMFPACK_SOLVER_H__
18
19#include "b2blaslapack.H"
20#include "b2ppconfig.h"
21#include "b2sparse_solver.H"
22
23extern "C" {
24#ifdef HAVE_UMFPACK_UMFPACK_H
25#include "umfpack/umfpack.h"
26#elif defined HAVE_UFSPARSE_UMFPACK_H
27#include "ufsparse/umfpack.h"
28#elif defined HAVE_SUITESPARSE_UMFPACK_H
29#include "suitesparse/umfpack.h"
30#elif defined HAVE_UMFPACK_H
31#include "umfpack.h"
32#endif
33}
34
35#ifdef HAVE_UMFPACK
36
37namespace b2000 { namespace b2linalg {
38
39template <typename T>
40class UMFPACK_LU_sparse_direct_solver : public LU_sparse_solver<T> {
41public:
42 UMFPACK_LU_sparse_direct_solver()
43 : Symbolic(0),
44 Numeric(0),
45 s(0),
46 colind(0),
47 rowind(0),
48 value(0),
49 logger(logging::get_logger("linear_algebra.sparse_lu_solver.umfpack")) {
50 // set defaults is used for the Umfpack call, which is different
51 // depending on the data type (double or complex)
52 set_defaults();
53 logger.LOG(logging::info, "Using UMFPACK as linear solver! ");
54 Control[UMFPACK_SCALE] = UMFPACK_SCALE_SUM;
55 if (logger.is_enabled_for(logging::debug)) { Control[UMFPACK_PRL] = 2; }
56 }
57
58 ~UMFPACK_LU_sparse_direct_solver() {
59 /* The set_free_linker function allows to use the
60 * umfpack call to free Symbolic and Numeric, depending on the
61 * data type.*/
62 set_free_linker();
63 }
64
65 /* Initialization function is the same for all data types.
66 The only difference is the umfpack free call for Symbolic and Numeric,
67 which is different for double and complex. The set_free_linker function
68 takes care of this difference. The correct call will be
69 included later as required. */
70 void init(
71 size_t s_, size_t nnz_, const size_t* colind_, const size_t* rowind_, const T* value_,
72 int connectivity, const Dictionary& dictionary) {
73 s = s_;
74 if (s == 0) { return; }
75 colind = colind_;
76 rowind = rowind_;
77 value = value_;
78 set_free_linker();
79 Symbolic = 0;
80 Numeric = 0;
81 }
82
83 // Function to update the values
84 void update_value();
85
86 /*Function to resolve the linear system, valid for complex and
87 csda data type. */
88 void resolve(
89 size_t s_, size_t nrhs, const T* b, size_t ldb, T* x, size_t ldx,
90 char left_or_right = ' ') {
91 if (s == 0) { return; }
92 if (!Symbolic) {
93 int status = umfpack_zl_symbolic(
94 s_, s_, reinterpret_cast<long*>(const_cast<size_t*>(colind)),
95 reinterpret_cast<long*>(const_cast<size_t*>(rowind)),
96 reinterpret_cast<double*>(const_cast<T*>(value)), 0, &Symbolic, Control, info);
97 if (status != UMFPACK_OK) {
98 Exception() << "umfpack_zl_symbolic return the error code " << status << ": "
99 << get_error(status) << "." << THROW;
100 }
101 }
102 if (!Numeric) {
103 int status = umfpack_zl_numeric(
104 reinterpret_cast<long*>(const_cast<size_t*>(colind)),
105 reinterpret_cast<long*>(const_cast<size_t*>(rowind)),
106 reinterpret_cast<double*>(const_cast<T*>(value)), 0, Symbolic, &Numeric, Control,
107 info);
108 if (status < UMFPACK_OK) {
109 Exception() << "umfpack_zl_numeric return the error code " << status << ": "
110 << get_error(status) << "." << THROW;
111 }
112 }
113 if (left_or_right != ' ') { UnimplementedError() << THROW; }
114
115 if (b != x) {
116 for (size_t i = 0; i != nrhs; ++i) {
117 int status = umfpack_zl_solve(
118 UMFPACK_A, reinterpret_cast<long*>(const_cast<size_t*>(colind)),
119 reinterpret_cast<long*>(const_cast<size_t*>(rowind)),
120 reinterpret_cast<double*>(const_cast<T*>(value)), 0,
121 reinterpret_cast<double*>(x + i * ldx), 0,
122 reinterpret_cast<double*>(const_cast<T*>(b + i * ldb)), 0, Numeric, Control,
123 info);
124 if (logger.is_enabled_for(logging::info)) { umfpack_zl_report_info(Control, info); }
125 if (status < UMFPACK_OK) {
126 Exception() << "umfpack_zl_solve return the error code " << status << ": "
127 << get_error(status) << "." << THROW;
128 }
129 }
130 } else {
131 auto_ptr_array<double> bxcp(new double[s_ * 2]);
132 for (size_t i = 0; i != nrhs; ++i) {
133 std::copy(
134 reinterpret_cast<const double*>(b + i * ldb),
135 reinterpret_cast<const double*>(b + i * ldb + s_), bxcp.get());
136 int status = umfpack_zl_solve(
137 UMFPACK_A, reinterpret_cast<long*>(const_cast<size_t*>(colind)),
138 reinterpret_cast<long*>(const_cast<size_t*>(rowind)),
139 reinterpret_cast<double*>(const_cast<T*>(value)), 0,
140 reinterpret_cast<double*>(x + i * ldx), 0, bxcp.get(), 0, Numeric, Control,
141 info);
142 if (logger.is_enabled_for(logging::info)) { umfpack_zl_report_info(Control, info); }
143 if (status < UMFPACK_OK) {
144 Exception() << "umfpack_zl_solve return the error code " << status << ": "
145 << get_error(status) << "." << THROW;
146 }
147 }
148 }
149 }
150
151protected:
152 static std::string get_error(int status) {
153 switch (status) {
154 case UMFPACK_ERROR_out_of_memory:
155 return "out of memory";
156 case UMFPACK_ERROR_invalid_Numeric_object:
157 return "invalid Numeric object";
158 case UMFPACK_ERROR_invalid_Symbolic_object:
159 return "invalid Symbolic object";
160 case UMFPACK_ERROR_argument_missing:
161 return "argument missing";
162 case UMFPACK_ERROR_n_nonpositive:
163 return "n non-positive definite";
164 case UMFPACK_ERROR_invalid_matrix:
165 return "invalid matrix";
166 case UMFPACK_ERROR_different_pattern:
167 return "different pattern";
168 case UMFPACK_ERROR_invalid_system:
169 return "invalid system";
170 case UMFPACK_ERROR_invalid_permutation:
171 return "invalid permutation";
172 case UMFPACK_ERROR_internal_error:
173 return "internal error";
174 case UMFPACK_ERROR_file_IO:
175 return "file IO error";
176 default:
177 return "unknown error code";
178 }
179 }
180
181 void set_defaults();
182 double Control[UMFPACK_CONTROL];
183 double info[UMFPACK_INFO];
184 void* Symbolic;
185 void* Numeric;
186 size_t s;
187 const size_t* colind;
188 const size_t* rowind;
189 const T* value;
190 logging::Logger& logger;
191 // Function to free Symbolic and Numeric
192 void set_free_linker();
193};
194
195/* Resolve for double data_type, which is defined in the C file.
196 By default we define the resolve to be of type complex, as both
197 data types (complex and csda) do the same operations, thus the
198 function body is identical!Only in the case of data type double
199 the operation differs. */
200template <>
201void UMFPACK_LU_sparse_direct_solver<double>::resolve(
202 size_t s, size_t nrhs, const double* b, size_t lbd, double* x, size_t ldx,
203 char left_or_right);
204
205template <typename T>
206class UMFPACK_LU_extension_sparse_direct_solver : public LU_extension_sparse_solver<T>,
207 public UMFPACK_LU_sparse_direct_solver<T> {
208public:
209 void init(
210 size_t size_, size_t nnz_, const size_t* colind_, const size_t* rowind_, const T* value_,
211 size_t size_ext, int connectivity, const Dictionary& dictionary) {
212 if (size_ext != 1) { UnimplementedError() << THROW; }
213 UMFPACK_LU_sparse_direct_solver<T>::init(
214 size_, nnz_, colind_, rowind_, value_, connectivity, dictionary);
215 modified_value = true;
216 }
217
218 void update_value() {
219 modified_value = true;
220 UMFPACK_LU_sparse_direct_solver<T>::update_value();
221 }
222
223 void resolve(
224 size_t s_, size_t nrhs, const T* b, size_t ldb, T* x, size_t ldx, const T* ma_ = 0,
225 const T* mb_ = 0, const T* mc_ = 0, char left_or_right = ' ') {
226 if (((s_ > 1 && (ma_ == 0 || mb_ == 0)) || mc_ == 0) && modified_value == true) {
227 Exception() << THROW;
228 }
229
230 if (left_or_right != ' ') { UnimplementedError() << THROW; }
231
232 if (modified_value && (ma_ != 0 && mb_ != 0 && mc_ != 0)) {
233 mb.resize(s_ - 1);
234 std::copy(mb_, mb_ + s_ - 1, mb.begin());
235
236 d.resize((s_ - 1) * (nrhs + 1));
237 {
238 std::vector<T> d_tmp((s_ - 1) * (nrhs + 1));
239 typename std::vector<T>::iterator iter =
240 std::copy(ma_, ma_ + s_ - 1, d_tmp.begin());
241 for (size_t i = 0; i != nrhs; ++i) {
242 iter = std::copy(b + ldb * i, b + ldb * i + s_ - 1, iter);
243 }
244 UMFPACK_LU_sparse_direct_solver<T>::resolve(
245 s_ - 1, nrhs + 1, &d_tmp[0], s_ - 1, &d[0], s_ - 1, left_or_right);
246 }
247
248 div = *mc_ - blas::dot(s_ - 1, &mb[0], 1, &d[0], 1);
249 if (div == T(0)) { UnimplementedError() << THROW; }
250 div = 1.0 / div;
251 if (s_ == 1) {
252 blas::scal(nrhs, div, x, s_);
253 } else {
254 blas::gemv(
255 'T', s_ - 1, nrhs, -div, &d[s_ - 1], s_ - 1, &mb[0], 1, div, x + s_ - 1, s_);
256 blas::ger(s_ - 1, nrhs, -1, &d[0], 1, x + s_ - 1, s_, &d[s_ - 1], s_ - 1);
257 }
258 for (size_t i = 0; i != nrhs; ++i) {
259 std::copy(
260 d.begin() + (s_ - 1) * (i + 1), d.begin() + (s_ - 1) * (i + 2), x + ldx * i);
261 }
262 d.resize(s_ - 1);
263 modified_value = false;
264 } else {
265 UMFPACK_LU_sparse_direct_solver<T>::resolve(
266 s_ - 1, nrhs, b, ldb, x, ldx, left_or_right);
267 if (s_ == 1) {
268 blas::scal(nrhs, div, x, s_);
269 } else {
270 blas::gemv('T', s_ - 1, nrhs, -div, x, ldx, &mb[0], 1, div, x + s_ - 1, s_);
271 blas::ger(s_ - 1, nrhs, -1, &d[0], 1, x + s_ - 1, s_, x, ldx);
272 }
273 }
274 }
275
276private:
277 bool modified_value;
278 std::vector<T> d;
279 std::vector<T> mb;
280 T div;
281};
282
286template <typename T>
287class UMFPACK_LDLt_sparse_direct_solver : public LDLt_sparse_solver<T> {
288public:
289 UMFPACK_LDLt_sparse_direct_solver()
290 : s(0),
291 colind(0),
292 rowind(0),
293 value(0),
294 colind_lu(0),
295 rowind_lu(0),
296 value_lu(0),
297 update_value_flag(true) {}
298
299 ~UMFPACK_LDLt_sparse_direct_solver() {
300 delete[] colind_lu;
301 delete[] rowind_lu;
302 delete[] value_lu;
303 }
304
305 void init(
306 size_t s_, size_t nnz, const size_t* colind_, const size_t* rowind_, const T* value_,
307 const int connectivity, const Dictionary& dictionary) {
308 s = s_;
309 colind = colind_;
310 rowind = rowind_;
311 value = value_;
312 delete[] colind_lu;
313 delete[] rowind_lu;
314 delete[] value_lu;
315 colind_lu = new size_t[s + 2];
316 std::fill_n(colind_lu, s + 2, 0);
317 colind_lu += 2;
318 size_t i = colind[0];
319 for (size_t j = 0; j != s; ++j) {
320 for (size_t i_end = colind[j + 1]; i != i_end; ++i) {
321 ++colind_lu[j];
322 if (rowind[i] != j) { ++colind_lu[rowind[i]]; }
323 }
324 }
325 --colind_lu;
326 for (size_t j = 0; j != s; ++j) { colind_lu[j + 1] += colind_lu[j]; }
327 rowind_lu = new size_t[colind_lu[s]];
328 value_lu = new T[colind_lu[s]];
329 i = colind[0];
330 for (size_t j = 0; j != s; ++j) {
331 for (size_t i_end = colind[j + 1]; i != i_end; ++i) {
332 rowind_lu[colind_lu[j]] = rowind[i];
333 value_lu[colind_lu[j]++] = value[i];
334 if (rowind[i] != j) {
335 rowind_lu[colind_lu[rowind[i]]] = j;
336 value_lu[colind_lu[rowind[i]]++] = value[i];
337 }
338 }
339 }
340 --colind_lu;
341
342 for (size_t j = 0; j != s; ++j) {
343 if (colind_lu[j] >= colind_lu[j + 1]) {
344 Exception() << "Matrix is singular in structure: dof " << j << THROW;
345 }
346 }
347
348 solver.init(s, colind_lu[s], colind_lu, rowind_lu, value_lu, connectivity, dictionary);
349 update_value_flag = false;
350 }
351
352 void update_value() {
353 update_value_flag = true;
354 solver.update_value();
355 }
356
357 void resolve(
358 size_t s, size_t nrhs, const T* b, size_t ldb, T* x, size_t ldx,
359 char left_or_right = ' ') {
360 if (update_value_flag) {
361 const size_t nnz = colind_lu[s];
362 for (size_t j = s; j != 0; --j) { colind_lu[j] = colind_lu[j - 1]; }
363 ++colind_lu;
364 size_t i = colind[0];
365 for (size_t j = 0; j != s; ++j) {
366 for (size_t i_end = colind[j + 1]; i != i_end; ++i) {
367 value_lu[colind_lu[j]++] = value[i];
368 if (rowind[i] != j) { value_lu[colind_lu[rowind[i]]++] = value[i]; }
369 }
370 }
371 --colind_lu;
372 if (colind_lu[s] != nnz) { Exception() << THROW; }
373 update_value_flag = false;
374 }
375 solver.resolve(s, nrhs, b, ldb, x, ldx, left_or_right);
376 }
377
378private:
379 size_t s;
380 const size_t* colind;
381 const size_t* rowind;
382 const T* value;
383 size_t* colind_lu;
384 size_t* rowind_lu;
385 T* value_lu;
386 bool update_value_flag;
387 UMFPACK_LU_sparse_direct_solver<T> solver;
388};
389
390template <typename T>
391class UMFPACK_LDLt_extension_sparse_direct_solver : public LDLt_extension_sparse_solver<T>,
392 public UMFPACK_LDLt_sparse_direct_solver<T> {
393public:
394 UMFPACK_LDLt_extension_sparse_direct_solver() : UMFPACK_LDLt_sparse_direct_solver<T>() {}
395
396 void init(
397 size_t size_, size_t nnz_, const size_t* colind_, const size_t* rowind_, const T* value_,
398 size_t size_ext_, const int connectivity, const Dictionary& dictionary) {
399 if (size_ext_ != 1) { UnimplementedError() << THROW; }
400
401 UMFPACK_LDLt_sparse_direct_solver<T>::init(
402 size_, nnz_, colind_, rowind_, value_, connectivity, dictionary);
403 m_ab.resize(size_ * 2);
404 }
405
406 void update_value() { UMFPACK_LDLt_sparse_direct_solver<T>::update_value(); }
407
408 void resolve(
409 size_t s, size_t nrhs, const T* b, size_t ldb, T* x, size_t ldx, const T* ma_ = 0,
410 const T* mb_ = 0, const T* mc_ = 0, char left_or_right = ' ') {
411 if (left_or_right != ' ') { UnimplementedError() << THROW; }
412
413 if (ma_ != 0 && mb_ != 0 && mc_ != 0) {
414 std::copy(ma_, ma_ + s - 1, &m_ab[0]);
415 std::copy(mb_, mb_ + s - 1, &m_ab[s - 1]);
416 UMFPACK_LDLt_sparse_direct_solver<T>::resolve(
417 s - 1, 2, &m_ab[0], s - 1, &m_ab[0], s - 1);
418 div = 1 / (*mc_ - blas::dot(s - 1, ma_, 1, &m_ab[s - 1], 1));
419 }
420
421 for (size_t i = 0; i != nrhs; ++i) {
422 const T x2 = x[ldx * i + s - 1] =
423 div * (b[ldb * i + s - 1] - blas::dot(s - 1, b + ldb * i, 1, &m_ab[s - 1], 1));
424 UMFPACK_LDLt_sparse_direct_solver<T>::resolve(
425 s - 1, 1, b + ldb * i, ldb, x + ldx * i, ldx);
426 blas::axpy(s - 1, -x2, &m_ab[0], 1, x + ldx * i, 1);
427 }
428 }
429
430protected:
431 std::vector<T> m_ab;
432 T div;
433};
434
435}} // namespace b2000::b2linalg
436
437#endif
438
439#endif
#define THROW
Definition b2exception.H:198
Logger & get_logger(const std::string &logger_name="")
Definition b2logging.H:829
Contains the base classes for implementing Finite Elements.
Definition b2boundary_condition.H:32
GenericException< UnimplementedError_name > UnimplementedError
Definition b2exception.H:314