16#ifndef __B2UMFPACK_SOLVER_H__
17#define __B2UMFPACK_SOLVER_H__
19#include "b2blaslapack.H"
20#include "b2ppconfig.h"
21#include "b2sparse_solver.H"
24#ifdef HAVE_UMFPACK_UMFPACK_H
25#include "umfpack/umfpack.h"
26#elif defined HAVE_UFSPARSE_UMFPACK_H
27#include "ufsparse/umfpack.h"
28#elif defined HAVE_SUITESPARSE_UMFPACK_H
29#include "suitesparse/umfpack.h"
30#elif defined HAVE_UMFPACK_H
37namespace b2000 {
namespace b2linalg {
40class UMFPACK_LU_sparse_direct_solver :
public LU_sparse_solver<T> {
42 UMFPACK_LU_sparse_direct_solver()
49 logger(logging::
get_logger(
"linear_algebra.sparse_lu_solver.umfpack")) {
53 logger.LOG(logging::info,
"Using UMFPACK as linear solver! ");
54 Control[UMFPACK_SCALE] = UMFPACK_SCALE_SUM;
55 if (logger.is_enabled_for(logging::debug)) { Control[UMFPACK_PRL] = 2; }
58 ~UMFPACK_LU_sparse_direct_solver() {
71 size_t s_,
size_t nnz_,
const size_t* colind_,
const size_t* rowind_,
const T* value_,
72 int connectivity,
const Dictionary& dictionary) {
74 if (s == 0) {
return; }
89 size_t s_,
size_t nrhs,
const T* b,
size_t ldb, T* x,
size_t ldx,
90 char left_or_right =
' ') {
91 if (s == 0) {
return; }
93 int status = umfpack_zl_symbolic(
94 s_, s_,
reinterpret_cast<long*
>(
const_cast<size_t*
>(colind)),
95 reinterpret_cast<long*
>(
const_cast<size_t*
>(rowind)),
96 reinterpret_cast<double*
>(
const_cast<T*
>(value)), 0, &Symbolic, Control, info);
97 if (status != UMFPACK_OK) {
98 Exception() <<
"umfpack_zl_symbolic return the error code " << status <<
": "
99 << get_error(status) <<
"." <<
THROW;
103 int status = umfpack_zl_numeric(
104 reinterpret_cast<long*
>(
const_cast<size_t*
>(colind)),
105 reinterpret_cast<long*
>(
const_cast<size_t*
>(rowind)),
106 reinterpret_cast<double*
>(
const_cast<T*
>(value)), 0, Symbolic, &Numeric, Control,
108 if (status < UMFPACK_OK) {
109 Exception() <<
"umfpack_zl_numeric return the error code " << status <<
": "
110 << get_error(status) <<
"." <<
THROW;
116 for (
size_t i = 0; i != nrhs; ++i) {
117 int status = umfpack_zl_solve(
118 UMFPACK_A,
reinterpret_cast<long*
>(
const_cast<size_t*
>(colind)),
119 reinterpret_cast<long*
>(
const_cast<size_t*
>(rowind)),
120 reinterpret_cast<double*
>(
const_cast<T*
>(value)), 0,
121 reinterpret_cast<double*
>(x + i * ldx), 0,
122 reinterpret_cast<double*
>(
const_cast<T*
>(b + i * ldb)), 0, Numeric, Control,
124 if (logger.is_enabled_for(logging::info)) { umfpack_zl_report_info(Control, info); }
125 if (status < UMFPACK_OK) {
126 Exception() <<
"umfpack_zl_solve return the error code " << status <<
": "
127 << get_error(status) <<
"." <<
THROW;
131 auto_ptr_array<double> bxcp(
new double[s_ * 2]);
132 for (
size_t i = 0; i != nrhs; ++i) {
134 reinterpret_cast<const double*
>(b + i * ldb),
135 reinterpret_cast<const double*
>(b + i * ldb + s_), bxcp.get());
136 int status = umfpack_zl_solve(
137 UMFPACK_A,
reinterpret_cast<long*
>(
const_cast<size_t*
>(colind)),
138 reinterpret_cast<long*
>(
const_cast<size_t*
>(rowind)),
139 reinterpret_cast<double*
>(
const_cast<T*
>(value)), 0,
140 reinterpret_cast<double*
>(x + i * ldx), 0, bxcp.get(), 0, Numeric, Control,
142 if (logger.is_enabled_for(logging::info)) { umfpack_zl_report_info(Control, info); }
143 if (status < UMFPACK_OK) {
144 Exception() <<
"umfpack_zl_solve return the error code " << status <<
": "
145 << get_error(status) <<
"." <<
THROW;
152 static std::string get_error(
int status) {
154 case UMFPACK_ERROR_out_of_memory:
155 return "out of memory";
156 case UMFPACK_ERROR_invalid_Numeric_object:
157 return "invalid Numeric object";
158 case UMFPACK_ERROR_invalid_Symbolic_object:
159 return "invalid Symbolic object";
160 case UMFPACK_ERROR_argument_missing:
161 return "argument missing";
162 case UMFPACK_ERROR_n_nonpositive:
163 return "n non-positive definite";
164 case UMFPACK_ERROR_invalid_matrix:
165 return "invalid matrix";
166 case UMFPACK_ERROR_different_pattern:
167 return "different pattern";
168 case UMFPACK_ERROR_invalid_system:
169 return "invalid system";
170 case UMFPACK_ERROR_invalid_permutation:
171 return "invalid permutation";
172 case UMFPACK_ERROR_internal_error:
173 return "internal error";
174 case UMFPACK_ERROR_file_IO:
175 return "file IO error";
177 return "unknown error code";
182 double Control[UMFPACK_CONTROL];
183 double info[UMFPACK_INFO];
187 const size_t* colind;
188 const size_t* rowind;
190 logging::Logger& logger;
192 void set_free_linker();
201void UMFPACK_LU_sparse_direct_solver<double>::resolve(
202 size_t s,
size_t nrhs,
const double* b,
size_t lbd,
double* x,
size_t ldx,
206class UMFPACK_LU_extension_sparse_direct_solver :
public LU_extension_sparse_solver<T>,
207 public UMFPACK_LU_sparse_direct_solver<T> {
210 size_t size_,
size_t nnz_,
const size_t* colind_,
const size_t* rowind_,
const T* value_,
211 size_t size_ext,
int connectivity,
const Dictionary& dictionary) {
213 UMFPACK_LU_sparse_direct_solver<T>::init(
214 size_, nnz_, colind_, rowind_, value_, connectivity, dictionary);
215 modified_value =
true;
218 void update_value() {
219 modified_value =
true;
220 UMFPACK_LU_sparse_direct_solver<T>::update_value();
224 size_t s_,
size_t nrhs,
const T* b,
size_t ldb, T* x,
size_t ldx,
const T* ma_ = 0,
225 const T* mb_ = 0,
const T* mc_ = 0,
char left_or_right =
' ') {
226 if (((s_ > 1 && (ma_ == 0 || mb_ == 0)) || mc_ == 0) && modified_value ==
true) {
227 Exception() <<
THROW;
232 if (modified_value && (ma_ != 0 && mb_ != 0 && mc_ != 0)) {
234 std::copy(mb_, mb_ + s_ - 1, mb.begin());
236 d.resize((s_ - 1) * (nrhs + 1));
238 std::vector<T> d_tmp((s_ - 1) * (nrhs + 1));
239 typename std::vector<T>::iterator iter =
240 std::copy(ma_, ma_ + s_ - 1, d_tmp.begin());
241 for (
size_t i = 0; i != nrhs; ++i) {
242 iter = std::copy(b + ldb * i, b + ldb * i + s_ - 1, iter);
244 UMFPACK_LU_sparse_direct_solver<T>::resolve(
245 s_ - 1, nrhs + 1, &d_tmp[0], s_ - 1, &d[0], s_ - 1, left_or_right);
248 div = *mc_ - blas::dot(s_ - 1, &mb[0], 1, &d[0], 1);
252 blas::scal(nrhs, div, x, s_);
255 'T', s_ - 1, nrhs, -div, &d[s_ - 1], s_ - 1, &mb[0], 1, div, x + s_ - 1, s_);
256 blas::ger(s_ - 1, nrhs, -1, &d[0], 1, x + s_ - 1, s_, &d[s_ - 1], s_ - 1);
258 for (
size_t i = 0; i != nrhs; ++i) {
260 d.begin() + (s_ - 1) * (i + 1), d.begin() + (s_ - 1) * (i + 2), x + ldx * i);
263 modified_value =
false;
265 UMFPACK_LU_sparse_direct_solver<T>::resolve(
266 s_ - 1, nrhs, b, ldb, x, ldx, left_or_right);
268 blas::scal(nrhs, div, x, s_);
270 blas::gemv(
'T', s_ - 1, nrhs, -div, x, ldx, &mb[0], 1, div, x + s_ - 1, s_);
271 blas::ger(s_ - 1, nrhs, -1, &d[0], 1, x + s_ - 1, s_, x, ldx);
287class UMFPACK_LDLt_sparse_direct_solver :
public LDLt_sparse_solver<T> {
289 UMFPACK_LDLt_sparse_direct_solver()
297 update_value_flag(true) {}
299 ~UMFPACK_LDLt_sparse_direct_solver() {
306 size_t s_,
size_t nnz,
const size_t* colind_,
const size_t* rowind_,
const T* value_,
307 const int connectivity,
const Dictionary& dictionary) {
315 colind_lu =
new size_t[s + 2];
316 std::fill_n(colind_lu, s + 2, 0);
318 size_t i = colind[0];
319 for (
size_t j = 0; j != s; ++j) {
320 for (
size_t i_end = colind[j + 1]; i != i_end; ++i) {
322 if (rowind[i] != j) { ++colind_lu[rowind[i]]; }
326 for (
size_t j = 0; j != s; ++j) { colind_lu[j + 1] += colind_lu[j]; }
327 rowind_lu =
new size_t[colind_lu[s]];
328 value_lu =
new T[colind_lu[s]];
330 for (
size_t j = 0; j != s; ++j) {
331 for (
size_t i_end = colind[j + 1]; i != i_end; ++i) {
332 rowind_lu[colind_lu[j]] = rowind[i];
333 value_lu[colind_lu[j]++] = value[i];
334 if (rowind[i] != j) {
335 rowind_lu[colind_lu[rowind[i]]] = j;
336 value_lu[colind_lu[rowind[i]]++] = value[i];
342 for (
size_t j = 0; j != s; ++j) {
343 if (colind_lu[j] >= colind_lu[j + 1]) {
344 Exception() <<
"Matrix is singular in structure: dof " << j <<
THROW;
348 solver.init(s, colind_lu[s], colind_lu, rowind_lu, value_lu, connectivity, dictionary);
349 update_value_flag =
false;
352 void update_value() {
353 update_value_flag =
true;
354 solver.update_value();
358 size_t s,
size_t nrhs,
const T* b,
size_t ldb, T* x,
size_t ldx,
359 char left_or_right =
' ') {
360 if (update_value_flag) {
361 const size_t nnz = colind_lu[s];
362 for (
size_t j = s; j != 0; --j) { colind_lu[j] = colind_lu[j - 1]; }
364 size_t i = colind[0];
365 for (
size_t j = 0; j != s; ++j) {
366 for (
size_t i_end = colind[j + 1]; i != i_end; ++i) {
367 value_lu[colind_lu[j]++] = value[i];
368 if (rowind[i] != j) { value_lu[colind_lu[rowind[i]]++] = value[i]; }
372 if (colind_lu[s] != nnz) { Exception() <<
THROW; }
373 update_value_flag =
false;
375 solver.resolve(s, nrhs, b, ldb, x, ldx, left_or_right);
380 const size_t* colind;
381 const size_t* rowind;
386 bool update_value_flag;
387 UMFPACK_LU_sparse_direct_solver<T> solver;
391class UMFPACK_LDLt_extension_sparse_direct_solver :
public LDLt_extension_sparse_solver<T>,
392 public UMFPACK_LDLt_sparse_direct_solver<T> {
394 UMFPACK_LDLt_extension_sparse_direct_solver() : UMFPACK_LDLt_sparse_direct_solver<T>() {}
397 size_t size_,
size_t nnz_,
const size_t* colind_,
const size_t* rowind_,
const T* value_,
398 size_t size_ext_,
const int connectivity,
const Dictionary& dictionary) {
401 UMFPACK_LDLt_sparse_direct_solver<T>::init(
402 size_, nnz_, colind_, rowind_, value_, connectivity, dictionary);
403 m_ab.resize(size_ * 2);
406 void update_value() { UMFPACK_LDLt_sparse_direct_solver<T>::update_value(); }
409 size_t s,
size_t nrhs,
const T* b,
size_t ldb, T* x,
size_t ldx,
const T* ma_ = 0,
410 const T* mb_ = 0,
const T* mc_ = 0,
char left_or_right =
' ') {
413 if (ma_ != 0 && mb_ != 0 && mc_ != 0) {
414 std::copy(ma_, ma_ + s - 1, &m_ab[0]);
415 std::copy(mb_, mb_ + s - 1, &m_ab[s - 1]);
416 UMFPACK_LDLt_sparse_direct_solver<T>::resolve(
417 s - 1, 2, &m_ab[0], s - 1, &m_ab[0], s - 1);
418 div = 1 / (*mc_ - blas::dot(s - 1, ma_, 1, &m_ab[s - 1], 1));
421 for (
size_t i = 0; i != nrhs; ++i) {
422 const T x2 = x[ldx * i + s - 1] =
423 div * (b[ldb * i + s - 1] - blas::dot(s - 1, b + ldb * i, 1, &m_ab[s - 1], 1));
424 UMFPACK_LDLt_sparse_direct_solver<T>::resolve(
425 s - 1, 1, b + ldb * i, ldb, x + ldx * i, ldx);
426 blas::axpy(s - 1, -x2, &m_ab[0], 1, x + ldx * i, 1);
#define THROW
Definition b2exception.H:198
Logger & get_logger(const std::string &logger_name="")
Definition b2logging.H:829
Contains the base classes for implementing Finite Elements.
Definition b2boundary_condition.H:32
GenericException< UnimplementedError_name > UnimplementedError
Definition b2exception.H:314