tahoma2d/toonz/sources/tnzext/tlin/tlin_superlu_wrap.cpp

484 lines
13 KiB
C++
Raw Normal View History

2016-03-19 06:57:51 +13:00
extern "C" {
#include "slu_ddefs.h"
#include "slu_util.h"
}
#include <algorithm>
#include "tlin/tlin_cblas_wrap.h"
#include "tlin/tlin_superlu_wrap.h"
//*************************************************************************
// Preliminaries
//*************************************************************************
struct tlin::SuperMatrix final : public ::SuperMatrix {};
struct tlin::superlu_options_t final : public ::superlu_options_t {};
2016-03-19 06:57:51 +13:00
//=======================================================================
2016-06-15 18:43:10 +12:00
namespace {
2016-03-19 06:57:51 +13:00
static const tlin::spmat::HashMap::size_t neg = -1;
bool initialized = false;
static tlin::superlu_options_t defaultOpt;
struct DefaultOptsInitializer {
2016-06-15 18:43:10 +12:00
DefaultOptsInitializer() {
set_default_options(&defaultOpt);
defaultOpt.PrintStat = NO;
}
2016-03-19 06:57:51 +13:00
} _instance;
2016-06-15 18:43:10 +12:00
inline bool rowLess(const tlin::spmat::HashMap::BucketNode *a,
const tlin::spmat::HashMap::BucketNode *b) {
return a->m_key.first < b->m_key.first;
2016-03-19 06:57:51 +13:00
}
}
//*************************************************************************
// SuperLU-specific Functions
//*************************************************************************
2016-06-15 18:43:10 +12:00
void tlin::allocS(SuperMatrix *&A, int rows, int cols, int nnz) {
A = (SuperMatrix *)SUPERLU_MALLOC(sizeof(SuperMatrix));
double *values = doubleMalloc(nnz);
int *rowind = intMalloc(nnz);
int *colptr = intMalloc(cols + 1);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
dCreate_CompCol_Matrix(A, rows, cols, nnz, values, rowind, colptr, SLU_NC,
SLU_D, SLU_GE);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::allocD(SuperMatrix *&A, int rows, int cols) {
A = (SuperMatrix *)SUPERLU_MALLOC(sizeof(SuperMatrix));
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
double *values = doubleMalloc(rows * cols * sizeof(double));
dCreate_Dense_Matrix(A, rows, cols, values, rows, SLU_DN, SLU_D, SLU_GE);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::allocS(SuperMatrix *&A, int rows, int cols, int nnz, int *colptr,
int *rowind, double *values) {
A = (SuperMatrix *)SUPERLU_MALLOC(sizeof(SuperMatrix));
dCreate_CompCol_Matrix(A, rows, cols, nnz, values, rowind, colptr, SLU_NC,
SLU_D, SLU_GE);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::allocD(SuperMatrix *&A, int rows, int cols, int lda,
double *values) {
A = (SuperMatrix *)SUPERLU_MALLOC(sizeof(SuperMatrix));
dCreate_Dense_Matrix(A, rows, cols, values, lda, SLU_DN, SLU_D, SLU_GE);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::freeS(SuperMatrix *m) {
if (!m) return;
Destroy_CompCol_Matrix(m);
SUPERLU_FREE(m);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::freeD(SuperMatrix *m) {
if (!m) return;
Destroy_Dense_Matrix(m);
SUPERLU_FREE(m);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::createS(SuperMatrix &A, int rows, int cols, int nnz) {
double *values = doubleMalloc(nnz);
int *rowind = intMalloc(nnz);
int *colptr = intMalloc(cols + 1);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
dCreate_CompCol_Matrix(&A, rows, cols, nnz, values, rowind, colptr, SLU_NC,
SLU_D, SLU_GE);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::createD(SuperMatrix &A, int rows, int cols) {
double *values = doubleMalloc(rows * cols * sizeof(double));
dCreate_Dense_Matrix(&A, rows, cols, values, rows, SLU_DN, SLU_D, SLU_GE);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::createS(SuperMatrix &A, int rows, int cols, int nnz, int *colptr,
int *rowind, double *values) {
dCreate_CompCol_Matrix(&A, rows, cols, nnz, values, rowind, colptr, SLU_NC,
SLU_D, SLU_GE);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::createD(SuperMatrix &A, int rows, int cols, int lda,
double *values) {
dCreate_Dense_Matrix(&A, rows, cols, values, lda, SLU_DN, SLU_D, SLU_GE);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::destroyS(SuperMatrix &A, bool destroyData) {
if (destroyData)
Destroy_CompCol_Matrix(&A);
else
SUPERLU_FREE(A.Store);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::destroyD(SuperMatrix &A, bool destroyData) {
if (destroyData)
Destroy_Dense_Matrix(&A);
else
SUPERLU_FREE(A.Store);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::freeF(SuperFactors *F) {
if (!F) return;
Destroy_SuperNode_Matrix(F->L);
Destroy_CompCol_Matrix(F->U);
SUPERLU_FREE(F->L);
SUPERLU_FREE(F->U);
SUPERLU_FREE(F->perm_r);
SUPERLU_FREE(F->perm_c);
SUPERLU_FREE(F);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::readDN(SuperMatrix *A, int &lda, double *&values) {
assert(A->Stype == SLU_DN);
DNformat *storage = (DNformat *)A->Store;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
lda = storage->lda;
values = (double *)storage->nzval;
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::readNC(SuperMatrix *A, int &nnz, int *&colptr, int *&rowind,
double *&values) {
assert(A->Stype == SLU_NC); // Only SLU_NC (CCS) format is supported here
NCformat *storage = (NCformat *)A->Store;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
nnz = storage->nnz;
values = (double *)storage->nzval;
rowind = (int *)storage->rowind;
colptr = (int *)storage->colptr;
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::traduceD(const mat &m, SuperMatrix *&A) {
int rows = m.rows(), cols = m.cols();
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!A) allocD(A, rows, cols);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
int lda;
double *Avalues = 0;
readDN(A, lda, Avalues);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
assert(A->nrow == rows && A->ncol == cols && lda == rows);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
memcpy(Avalues, m.values(), rows * cols * sizeof(double));
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::traduceS(spmat &m, SuperMatrix *&A) {
int rows = m.rows(), cols = m.cols(), nnz = (int)m.entries().size();
spmat::HashMap &entries = m.entries();
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Build or extract pointers to out's data
double *values;
int *rowind, *colptr;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!A) allocS(A, rows, cols, nnz);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Retrieve NC arrays from A
int Annz;
readNC(A, Annz, colptr, rowind, values);
assert(A->nrow == rows && A->ncol == cols && Annz == nnz);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Rehash to cols buckets
if (entries.hashFunctor().m_cols != cols) entries.hashFunctor().m_cols = cols;
entries.rehash(cols);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Copy each bucket to the corresponding col
const std::vector<spmat::HashMap::size_t> &buckets = m.entries().buckets();
const tcg::list<spmat::HashMap::BucketNode> &nodes = m.entries().items();
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
std::vector<const spmat::HashMap::BucketNode *> colEntries;
std::vector<const spmat::HashMap::BucketNode *>::size_type j, size;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
double *currVal = values;
int *currRowind = rowind;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
for (int i = 0; i < cols; ++i) {
colptr[i] = (int)(currVal - values);
spmat::HashMap::size_t nodeIdx = buckets[i];
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Retrieve all column entry pointers.
colEntries.clear();
while (nodeIdx != neg) {
const spmat::HashMap::BucketNode &node = nodes[nodeIdx];
colEntries.push_back(&node);
nodeIdx = nodes[nodeIdx].m_next;
}
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Sort them by row
std::sort(colEntries.begin(), colEntries.end(), rowLess);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Finally, write them in the SuperMatrix
size = colEntries.size();
for (j = 0; j < size; ++j) {
const spmat::HashMap::BucketNode &node = *colEntries[j];
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
*currRowind = node.m_key.first;
*currVal = node.m_val;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
++currVal, ++currRowind;
}
}
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
colptr[cols] = nnz;
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::traduceD(const tlin::sparse_matrix<double> &m, SuperMatrix *&A) {
int rows = m.rows(), cols = m.cols();
const spmat::HashMap &entries = m.entries();
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Build or extract pointers to out's data
double *values;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!A) allocD(A, rows, cols);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Retrieve DN arrays from A
int lda;
readDN(A, lda, values);
assert(A->nrow == rows && A->ncol == cols && lda == rows);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
// Copy each value in entries to A
spmat::HashMap::const_iterator it;
for (it = entries.begin(); it != entries.end(); ++it)
values[it->m_key.second * rows + it->m_key.first] = it->m_val;
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::factorize(SuperMatrix *A, SuperFactors *&F, superlu_options_t *opt) {
assert(A->nrow == A->ncol);
int n = A->nrow;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!F) F = (SuperFactors *)SUPERLU_MALLOC(sizeof(SuperFactors));
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!opt) opt = &defaultOpt;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
F->perm_c = intMalloc(n);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
get_perm_c(3, A, F->perm_c);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
SuperMatrix AC;
int *etree = intMalloc(n);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
sp_preorder(opt, A, F->perm_c, etree, &AC);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
F->L = (SuperMatrix *)SUPERLU_MALLOC(sizeof(SuperMatrix));
F->U = (SuperMatrix *)SUPERLU_MALLOC(sizeof(SuperMatrix));
F->perm_r = intMalloc(n);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
SuperLUStat_t stat;
StatInit(&stat);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
int result;
dgstrf(opt, &AC, sp_ienv(1), sp_ienv(2), etree, NULL, 0, F->perm_c, F->perm_r,
F->L, F->U, &stat, &result);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
StatFree(&stat);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
Destroy_CompCol_Permuted(&AC);
SUPERLU_FREE(etree);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (result != 0) freeF(F), F = 0;
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::solve(SuperFactors *F, SuperMatrix *BX, superlu_options_t *opt) {
assert(F);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!opt) opt = &defaultOpt;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
SuperLUStat_t stat;
StatInit(&stat);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
int result;
dgstrs(NOTRANS, F->L, F->U, F->perm_c, F->perm_r, BX, &stat, &result);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
StatFree(&stat);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::solve(SuperFactors *F, SuperMatrix *B, SuperMatrix *&X,
superlu_options_t *opt) {
if (!X) allocD(X, B->nrow, B->ncol);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
double *Bvalues = 0, *Xvalues = 0;
int lda;
readDN(B, lda, Bvalues);
readDN(X, lda, Xvalues);
memcpy(Xvalues, Bvalues, B->nrow * B->ncol * sizeof(double));
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
solve(F, X, opt);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::solve(SuperMatrix *A, SuperMatrix *BX, superlu_options_t *opt) {
assert(A->nrow == A->ncol);
int n = A->nrow;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!opt) opt = &defaultOpt;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
SuperMatrix L, U;
int *perm_c, *perm_r;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
perm_c = intMalloc(n);
perm_r = intMalloc(n);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
SuperLUStat_t stat;
StatInit(&stat);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
int result;
dgssv(opt, A, perm_c, perm_r, &L, &U, BX, &stat, &result);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
Destroy_SuperNode_Matrix(&L);
Destroy_CompCol_Matrix(&U);
SUPERLU_FREE(perm_r);
SUPERLU_FREE(perm_c);
StatFree(&stat);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::solve(SuperMatrix *A, SuperMatrix *B, SuperMatrix *&X,
superlu_options_t *opt) {
if (!X) allocD(X, B->nrow, B->ncol);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
double *Bvalues = 0, *Xvalues = 0;
int lda;
readDN(B, lda, Bvalues);
readDN(X, lda, Xvalues);
memcpy(Xvalues, Bvalues, B->nrow * B->ncol * sizeof(double));
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
solve(A, X, opt);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::solve(SuperFactors *F, double *bx, superlu_options_t *opt) {
SuperMatrix BX;
int rows = F->L->nrow;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
createD(BX, rows, 1, rows, bx);
tlin::solve(F, &BX, opt);
SUPERLU_FREE(BX.Store);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::solve(SuperFactors *F, double *b, double *&x,
superlu_options_t *opt) {
SuperMatrix B, X;
int rows = F->L->nrow;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!x) x = (double *)malloc(rows * sizeof(double));
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
createD(B, rows, 1, rows, b);
createD(X, rows, 1, rows, x);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
SuperMatrix *Xptr = &X; //&X is const
tlin::solve(F, &B, Xptr, opt);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
SUPERLU_FREE(B.Store);
SUPERLU_FREE(X.Store);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::solve(SuperMatrix *A, double *bx, superlu_options_t *opt) {
SuperMatrix BX;
int rows = A->nrow;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
createD(BX, rows, 1, rows, bx);
tlin::solve(A, &BX, opt);
SUPERLU_FREE(BX.Store);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::solve(SuperMatrix *A, double *b, double *&x,
superlu_options_t *opt) {
SuperMatrix B, X;
int rows = A->nrow;
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!x) x = (double *)malloc(rows * sizeof(double));
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
createD(B, rows, 1, rows, b);
createD(X, rows, 1, rows, x);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
SuperMatrix *Xptr = &X; //&X is const
tlin::solve(A, &B, Xptr, opt);
SUPERLU_FREE(B.Store);
SUPERLU_FREE(X.Store);
2016-03-19 06:57:51 +13:00
}
//*************************************************************************
// BLAS-related Functions
//*************************************************************************
2016-06-15 18:43:10 +12:00
void tlin::multiplyS(const SuperMatrix *A, const double *x, double *&y) {
/*
int sp_dgemv (char *, double, SuperMatrix *, double *,
int, double, double *, int);
*/
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
if (!y) {
y = (double *)malloc(A->nrow * sizeof(double));
memset(y, 0, A->nrow * sizeof(double));
}
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
SuperMatrix *_A = const_cast<SuperMatrix *>(A);
double *_x = const_cast<double *>(x);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
sp_dgemv("N", 1.0, _A, _x, 1, 1.0, y, 1);
2016-03-19 06:57:51 +13:00
}
//---------------------------------------------------------------
2016-06-15 18:43:10 +12:00
void tlin::multiplyD(const SuperMatrix *A, const double *x, double *&y) {
int lda;
double *values;
tlin::readDN(const_cast<SuperMatrix *>(A), lda, values);
2016-03-19 06:57:51 +13:00
2016-06-15 18:43:10 +12:00
tlin::multiply(A->nrow, A->ncol, values, x, y);
2016-03-19 06:57:51 +13:00
}