My Project
TridiagonalMatrix.hpp
Go to the documentation of this file.
1 // -*- mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*-
2 // vi: set et ts=4 sw=4 sts=4:
3 /*
4  This file is part of the Open Porous Media project (OPM).
5 
6  OPM is free software: you can redistribute it and/or modify
7  it under the terms of the GNU General Public License as published by
8  the Free Software Foundation, either version 2 of the License, or
9  (at your option) any later version.
10 
11  OPM is distributed in the hope that it will be useful,
12  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  GNU General Public License for more details.
15 
16  You should have received a copy of the GNU General Public License
17  along with OPM. If not, see <http://www.gnu.org/licenses/>.
18 
19  Consult the COPYING file in the top-level source directory of this
20  module for the precise wording of the license and the list of
21  copyright holders.
22 */
27 #ifndef OPM_TRIDIAGONAL_MATRIX_HH
28 #define OPM_TRIDIAGONAL_MATRIX_HH
29 
30 #include <iostream>
31 #include <vector>
32 #include <algorithm>
33 #include <cmath>
34 
35 #include <assert.h>
36 
37 namespace Opm {
38 
49 template <class Scalar>
51 {
52  struct TridiagRow_ {
53  TridiagRow_(TridiagonalMatrix& m, size_t rowIdx)
54  : matrix_(m)
55  , rowIdx_(rowIdx)
56  {}
57 
58  Scalar& operator[](size_t colIdx)
59  { return matrix_.at(rowIdx_, colIdx); }
60 
61  Scalar operator[](size_t colIdx) const
62  { return matrix_.at(rowIdx_, colIdx); }
63 
67  TridiagRow_& operator++()
68  { ++ rowIdx_; return *this; }
69 
73  TridiagRow_& operator--()
74  { -- rowIdx_; return *this; }
75 
79  bool operator==(const TridiagRow_& other) const
80  { return other.rowIdx_ == rowIdx_ && &other.matrix_ == &matrix_; }
81 
85  bool operator!=(const TridiagRow_& other) const
86  { return !operator==(other); }
87 
91  TridiagRow_& operator*()
92  { return *this; }
93 
99  size_t index() const
100  { return rowIdx_; }
101 
102  private:
103  TridiagonalMatrix& matrix_;
104  mutable size_t rowIdx_;
105  };
106 
107 public:
108  typedef Scalar FieldType;
109  typedef TridiagRow_ RowType;
110  typedef size_t SizeType;
111  typedef TridiagRow_ iterator;
112  typedef TridiagRow_ const_iterator;
113 
114  explicit TridiagonalMatrix(size_t numRows = 0)
115  {
116  resize(numRows);
117  }
118 
119  TridiagonalMatrix(size_t numRows, Scalar value)
120  {
121  resize(numRows);
122  this->operator=(value);
123  }
124 
129  { *this = source; }
130 
134  size_t size() const
135  { return diag_[0].size(); }
136 
140  size_t rows() const
141  { return size(); }
142 
146  size_t cols() const
147  { return size(); }
148 
152  void resize(size_t n)
153  {
154  if (n == size())
155  return;
156 
157  for (int diagIdx = 0; diagIdx < 3; ++ diagIdx)
158  diag_[diagIdx].resize(n);
159  }
160 
164  Scalar& at(size_t rowIdx, size_t colIdx)
165  {
166  size_t n = size();
167 
168  // special cases
169  if (n > 2) {
170  if (rowIdx == 0 && colIdx == n - 1)
171  return diag_[2][0];
172  if (rowIdx == n - 1 && colIdx == 0)
173  return diag_[0][n - 1];
174  }
175 
176  size_t diagIdx = 1 + colIdx - rowIdx;
177  // make sure that the requested column is in range
178  assert(diagIdx < 3);
179  return diag_[diagIdx][colIdx];
180  }
181 
185  Scalar at(size_t rowIdx, size_t colIdx) const
186  {
187  size_t n = size();
188 
189  // special cases
190  if (rowIdx == 0 && colIdx == n - 1)
191  return diag_[2][0];
192  if (rowIdx == n - 1 && colIdx == 0)
193  return diag_[0][n - 1];
194 
195  size_t diagIdx = 1 + colIdx - rowIdx;
196  // make sure that the requested column is in range
197  assert(diagIdx < 3);
198  return diag_[diagIdx][colIdx];
199  }
200 
205  {
206  for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx)
207  diag_[diagIdx] = source.diag_[diagIdx];
208 
209  return *this;
210  }
211 
216  {
217  for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx)
218  diag_[diagIdx].assign(size(), value);
219 
220  return *this;
221  }
222 
226  iterator begin()
227  { return TridiagRow_(*this, 0); }
228 
232  const_iterator begin() const
233  { return TridiagRow_(const_cast<TridiagonalMatrix&>(*this), 0); }
234 
238  const_iterator end() const
239  { return TridiagRow_(const_cast<TridiagonalMatrix&>(*this), size()); }
240 
244  TridiagRow_ operator[](size_t rowIdx)
245  { return TridiagRow_(*this, rowIdx); }
246 
250  const TridiagRow_ operator[](size_t rowIdx) const
251  { return TridiagRow_(*this, rowIdx); }
252 
257  {
258  unsigned n = size();
259  for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx) {
260  for (unsigned i = 0; i < n; ++i) {
261  diag_[diagIdx][i] *= alpha;
262  }
263  }
264 
265  return *this;
266  }
267 
272  {
273  alpha = 1.0/alpha;
274  unsigned n = size();
275  for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx) {
276  for (unsigned i = 0; i < n; ++i) {
277  diag_[diagIdx][i] *= alpha;
278  }
279  }
280 
281  return *this;
282  }
283 
288  { return axpy(-1.0, other); }
289 
294  { return axpy(1.0, other); }
295 
296 
310  TridiagonalMatrix& axpy(Scalar alpha, const TridiagonalMatrix& other)
311  {
312  assert(size() == other.size());
313 
314  unsigned n = size();
315  for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx)
316  for (unsigned i = 0; i < n; ++ i)
317  diag_[diagIdx][i] += alpha * other[diagIdx][i];
318 
319  return *this;
320  }
321 
334  template<class Vector>
335  void mv(const Vector& source, Vector& dest) const
336  {
337  assert(source.size() == size());
338  assert(dest.size() == size());
339  assert(size() > 1);
340 
341  // deal with rows 1 .. n-2
342  unsigned n = size();
343  for (unsigned i = 1; i < n - 1; ++ i) {
344  dest[i] =
345  diag_[0][i - 1]*source[i-1] +
346  diag_[1][i]*source[i] +
347  diag_[2][i + 1]*source[i + 1];
348  }
349 
350  // rows 0 and n-1
351  dest[0] =
352  diag_[1][0]*source[0] +
353  diag_[2][1]*source[1] +
354  diag_[2][0]*source[n - 1];
355 
356  dest[n - 1] =
357  diag_[0][n-1]*source[0] +
358  diag_[0][n-2]*source[n-2] +
359  diag_[1][n-1]*source[n-1];
360  }
361 
374  template<class Vector>
375  void umv(const Vector& source, Vector& dest) const
376  {
377  assert(source.size() == size());
378  assert(dest.size() == size());
379  assert(size() > 1);
380 
381  // deal with rows 1 .. n-2
382  unsigned n = size();
383  for (unsigned i = 1; i < n - 1; ++ i) {
384  dest[i] +=
385  diag_[0][i - 1]*source[i-1] +
386  diag_[1][i]*source[i] +
387  diag_[2][i + 1]*source[i + 1];
388  }
389 
390  // rows 0 and n-1
391  dest[0] +=
392  diag_[1][0]*source[0] +
393  diag_[2][1]*source[1] +
394  diag_[2][0]*source[n - 1];
395 
396  dest[n - 1] +=
397  diag_[0][n-1]*source[0] +
398  diag_[0][n-2]*source[n-2] +
399  diag_[1][n-1]*source[n-1];
400  }
401 
414  template<class Vector>
415  void mmv(const Vector& source, Vector& dest) const
416  {
417  assert(source.size() == size());
418  assert(dest.size() == size());
419  assert(size() > 1);
420 
421  // deal with rows 1 .. n-2
422  unsigned n = size();
423  for (unsigned i = 1; i < n - 1; ++ i) {
424  dest[i] -=
425  diag_[0][i - 1]*source[i-1] +
426  diag_[1][i]*source[i] +
427  diag_[2][i + 1]*source[i + 1];
428  }
429 
430  // rows 0 and n-1
431  dest[0] -=
432  diag_[1][0]*source[0] +
433  diag_[2][1]*source[1] +
434  diag_[2][0]*source[n - 1];
435 
436  dest[n - 1] -=
437  diag_[0][n-1]*source[0] +
438  diag_[0][n-2]*source[n-2] +
439  diag_[1][n-1]*source[n-1];
440  }
441 
454  template<class Vector>
455  void usmv(Scalar alpha, const Vector& source, Vector& dest) const
456  {
457  assert(source.size() == size());
458  assert(dest.size() == size());
459  assert(size() > 1);
460 
461  // deal with rows 1 .. n-2
462  unsigned n = size();
463  for (unsigned i = 1; i < n - 1; ++ i) {
464  dest[i] +=
465  alpha*(
466  diag_[0][i - 1]*source[i-1] +
467  diag_[1][i]*source[i] +
468  diag_[2][i + 1]*source[i + 1]);
469  }
470 
471  // rows 0 and n-1
472  dest[0] +=
473  alpha*(
474  diag_[1][0]*source[0] +
475  diag_[2][1]*source[1] +
476  diag_[2][0]*source[n - 1]);
477 
478  dest[n - 1] +=
479  alpha*(
480  diag_[0][n-1]*source[0] +
481  diag_[0][n-2]*source[n-2] +
482  diag_[1][n-1]*source[n-1]);
483  }
484 
497  template<class Vector>
498  void mtv(const Vector& source, Vector& dest) const
499  {
500  assert(source.size() == size());
501  assert(dest.size() == size());
502  assert(size() > 1);
503 
504  // deal with rows 1 .. n-2
505  unsigned n = size();
506  for (unsigned i = 1; i < n - 1; ++ i) {
507  dest[i] =
508  diag_[2][i + 1]*source[i-1] +
509  diag_[1][i]*source[i] +
510  diag_[0][i - 1]*source[i + 1];
511  }
512 
513  // rows 0 and n-1
514  dest[0] =
515  diag_[1][0]*source[0] +
516  diag_[0][1]*source[1] +
517  diag_[0][n-1]*source[n - 1];
518 
519  dest[n - 1] =
520  diag_[2][0]*source[0] +
521  diag_[2][n-1]*source[n-2] +
522  diag_[1][n-1]*source[n-1];
523  }
524 
537  template<class Vector>
538  void umtv(const Vector& source, Vector& dest) const
539  {
540  assert(source.size() == size());
541  assert(dest.size() == size());
542  assert(size() > 1);
543 
544  // deal with rows 1 .. n-2
545  unsigned n = size();
546  for (unsigned i = 1; i < n - 1; ++ i) {
547  dest[i] +=
548  diag_[2][i + 1]*source[i-1] +
549  diag_[1][i]*source[i] +
550  diag_[0][i - 1]*source[i + 1];
551  }
552 
553  // rows 0 and n-1
554  dest[0] +=
555  diag_[1][0]*source[0] +
556  diag_[0][1]*source[1] +
557  diag_[0][n-1]*source[n - 1];
558 
559  dest[n - 1] +=
560  diag_[2][0]*source[0] +
561  diag_[2][n-1]*source[n-2] +
562  diag_[1][n-1]*source[n-1];
563  }
564 
577  template<class Vector>
578  void mmtv (const Vector& source, Vector& dest) const
579  {
580  assert(source.size() == size());
581  assert(dest.size() == size());
582  assert(size() > 1);
583 
584  // deal with rows 1 .. n-2
585  unsigned n = size();
586  for (unsigned i = 1; i < n - 1; ++ i) {
587  dest[i] -=
588  diag_[2][i + 1]*source[i-1] +
589  diag_[1][i]*source[i] +
590  diag_[0][i - 1]*source[i + 1];
591  }
592 
593  // rows 0 and n-1
594  dest[0] -=
595  diag_[1][0]*source[0] +
596  diag_[0][1]*source[1] +
597  diag_[0][n-1]*source[n - 1];
598 
599  dest[n - 1] -=
600  diag_[2][0]*source[0] +
601  diag_[2][n-1]*source[n-2] +
602  diag_[1][n-1]*source[n-1];
603  }
604 
617  template<class Vector>
618  void usmtv(Scalar alpha, const Vector& source, Vector& dest) const
619  {
620  assert(source.size() == size());
621  assert(dest.size() == size());
622  assert(size() > 1);
623 
624  // deal with rows 1 .. n-2
625  unsigned n = size();
626  for (unsigned i = 1; i < n - 1; ++ i) {
627  dest[i] +=
628  alpha*(
629  diag_[2][i + 1]*source[i-1] +
630  diag_[1][i]*source[i] +
631  diag_[0][i - 1]*source[i + 1]);
632  }
633 
634  // rows 0 and n-1
635  dest[0] +=
636  alpha*(
637  diag_[1][0]*source[0] +
638  diag_[0][1]*source[1] +
639  diag_[0][n-1]*source[n - 1]);
640 
641  dest[n - 1] +=
642  alpha*(
643  diag_[2][0]*source[0] +
644  diag_[2][n-1]*source[n-2] +
645  diag_[1][n-1]*source[n-1]);
646  }
647 
654  Scalar frobeniusNorm() const
655  { return std::sqrt(frobeniusNormSquared()); }
656 
662  Scalar frobeniusNormSquared() const
663  {
664  Scalar result = 0;
665 
666  unsigned n = size();
667  for (unsigned i = 0; i < n; ++ i)
668  for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx)
669  result += diag_[diagIdx][i];
670 
671  return result;
672  }
673 
679  Scalar infinityNorm() const
680  {
681  Scalar result=0;
682 
683  // deal with rows 1 .. n-2
684  unsigned n = size();
685  for (unsigned i = 1; i < n - 1; ++ i) {
686  result = std::max(result,
687  std::abs(diag_[0][i - 1]) +
688  std::abs(diag_[1][i]) +
689  std::abs(diag_[2][i + 1]));
690  }
691 
692  // rows 0 and n-1
693  result = std::max(result,
694  std::abs(diag_[1][0]) +
695  std::abs(diag_[2][1]) +
696  std::abs(diag_[2][0]));
697 
698 
699  result = std::max(result,
700  std::abs(diag_[0][n-1]) +
701  std::abs(diag_[0][n-2]) +
702  std::abs(diag_[1][n-2]));
703 
704  return result;
705  }
706 
713  template <class XVector, class BVector>
714  void solve(XVector& x, const BVector& b) const
715  {
716  if (size() > 2 && std::abs(diag_[2][0]) < 1e-30)
717  solveWithUpperRight_(x, b);
718  else
719  solveWithoutUpperRight_(x, b);
720  }
721 
725  void print(std::ostream& os = std::cout) const
726  {
727  size_t n = size();
728 
729  // row 0
730  os << at(0, 0) << "\t"
731  << at(0, 1) << "\t";
732 
733  if (n > 3)
734  os << "\t";
735  if (n > 2)
736  os << at(0, n-1);
737  os << "\n";
738 
739  // row 1 .. n - 2
740  for (unsigned rowIdx = 1; rowIdx < n-1; ++rowIdx) {
741  if (rowIdx > 1)
742  os << "\t";
743  if (rowIdx == n - 2)
744  os << "\t";
745 
746  os << at(rowIdx, rowIdx - 1) << "\t"
747  << at(rowIdx, rowIdx) << "\t"
748  << at(rowIdx, rowIdx + 1) << "\n";
749  }
750 
751  // row n - 1
752  if (n > 2)
753  os << at(n-1, 0) << "\t";
754  if (n > 3)
755  os << "\t";
756  if (n > 4)
757  os << "\t";
758  os << at(n-1, n-2) << "\t"
759  << at(n-1, n-1) << "\n";
760  }
761 
762 private:
763  template <class XVector, class BVector>
764  void solveWithUpperRight_(XVector& x, const BVector& b) const
765  {
766  size_t n = size();
767 
768  std::vector<Scalar> lowerDiag(diag_[0]), mainDiag(diag_[1]), upperDiag(diag_[2]), lastColumn(n);
769  std::vector<Scalar> bStar(n);
770  std::copy(b.begin(), b.end(), bStar.begin());
771 
772  lastColumn[0] = upperDiag[0];
773 
774  // forward elimination
775  for (size_t i = 1; i < n; ++i) {
776  Scalar alpha = lowerDiag[i - 1]/mainDiag[i - 1];
777 
778  lowerDiag[i - 1] -= alpha * mainDiag[i - 1];
779  mainDiag[i] -= alpha * upperDiag[i];
780 
781  bStar[i] -= alpha * bStar[i - 1];
782  };
783 
784  // deal with the last row if the entry on the lower left is not zero
785  if (lowerDiag[n - 1] != 0.0 && size() > 2) {
786  Scalar lastRow = lowerDiag[n - 1];
787  for (size_t i = 0; i < n - 1; ++i) {
788  Scalar alpha = lastRow/mainDiag[i];
789  lastRow = - alpha*upperDiag[i + 1];
790  bStar[n - 1] -= alpha * bStar[i];
791  }
792 
793  mainDiag[n-1] += lastRow;
794  }
795 
796  // backward elimination
797  x[n - 1] = bStar[n - 1]/mainDiag[n-1];
798  for (int i = static_cast<int>(n) - 2; i >= 0; --i) {
799  unsigned iu = static_cast<unsigned>(i);
800  x[iu] = (bStar[iu] - x[iu + 1]*upperDiag[iu+1] - x[n-1]*lastColumn[iu])/mainDiag[iu];
801  }
802  }
803 
804  template <class XVector, class BVector>
805  void solveWithoutUpperRight_(XVector& x, const BVector& b) const
806  {
807  size_t n = size();
808 
809  std::vector<Scalar> lowerDiag(diag_[0]), mainDiag(diag_[1]), upperDiag(diag_[2]);
810  std::vector<Scalar> bStar(n);
811  std::copy(b.begin(), b.end(), bStar.begin());
812 
813  // forward elimination
814  for (size_t i = 1; i < n; ++i) {
815  Scalar alpha = lowerDiag[i - 1]/mainDiag[i - 1];
816 
817  lowerDiag[i - 1] -= alpha * mainDiag[i - 1];
818  mainDiag[i] -= alpha * upperDiag[i];
819 
820  bStar[i] -= alpha * bStar[i - 1];
821  };
822 
823  // deal with the last row if the entry on the lower left is not zero
824  if (lowerDiag[n - 1] != 0.0 && size() > 2) {
825  Scalar lastRow = lowerDiag[n - 1];
826  for (size_t i = 0; i < n - 1; ++i) {
827  Scalar alpha = lastRow/mainDiag[i];
828  lastRow = - alpha*upperDiag[i + 1];
829  bStar[n - 1] -= alpha * bStar[i];
830  }
831 
832  mainDiag[n-1] += lastRow;
833  }
834 
835  // backward elimination
836  x[n - 1] = bStar[n - 1]/mainDiag[n-1];
837  for (int i = static_cast<int>(n) - 2; i >= 0; --i) {
838  unsigned iu = static_cast<unsigned>(i);
839  x[iu] = (bStar[iu] - x[iu + 1]*upperDiag[iu+1])/mainDiag[iu];
840  }
841  }
842 
843  mutable std::vector<Scalar> diag_[3];
844 };
845 
846 } // namespace Opm
847 
848 template <class Scalar>
849 std::ostream& operator<<(std::ostream& os, const Opm::TridiagonalMatrix<Scalar>& mat)
850 {
851  mat.print(os);
852  return os;
853 }
854 
855 #endif
Provides a tridiagonal matrix that also supports non-zero entries in the upper right and lower left.
Definition: TridiagonalMatrix.hpp:51
const_iterator end() const
\begin Const iterator for the next-to-last row
Definition: TridiagonalMatrix.hpp:238
Scalar & at(size_t rowIdx, size_t colIdx)
Access an entry.
Definition: TridiagonalMatrix.hpp:164
TridiagonalMatrix(const TridiagonalMatrix &source)
Copy constructor.
Definition: TridiagonalMatrix.hpp:128
TridiagRow_ operator[](size_t rowIdx)
Row access operator.
Definition: TridiagonalMatrix.hpp:244
size_t cols() const
Return the number of columns of the matrix.
Definition: TridiagonalMatrix.hpp:146
TridiagonalMatrix & axpy(Scalar alpha, const TridiagonalMatrix &other)
Multiply and add the matrix entries of another tridiagonal matrix.
Definition: TridiagonalMatrix.hpp:310
size_t rows() const
Return the number of rows of the matrix.
Definition: TridiagonalMatrix.hpp:140
Scalar frobeniusNormSquared() const
Calculate the squared frobenius norm.
Definition: TridiagonalMatrix.hpp:662
void umtv(const Vector &source, Vector &dest) const
Transposed additive matrix-vector product.
Definition: TridiagonalMatrix.hpp:538
const_iterator begin() const
\begin Const iterator for the first row
Definition: TridiagonalMatrix.hpp:232
iterator begin()
\begin Iterator for the first row
Definition: TridiagonalMatrix.hpp:226
void print(std::ostream &os=std::cout) const
Print the matrix to a given output stream.
Definition: TridiagonalMatrix.hpp:725
void usmtv(Scalar alpha, const Vector &source, Vector &dest) const
Transposed scaled additive matrix-vector product.
Definition: TridiagonalMatrix.hpp:618
TridiagonalMatrix & operator=(const TridiagonalMatrix &source)
Assignment operator from another tridiagonal matrix.
Definition: TridiagonalMatrix.hpp:204
void mv(const Vector &source, Vector &dest) const
Matrix-vector product.
Definition: TridiagonalMatrix.hpp:335
Scalar infinityNorm() const
Calculate the infinity norm.
Definition: TridiagonalMatrix.hpp:679
Scalar at(size_t rowIdx, size_t colIdx) const
Access an entry.
Definition: TridiagonalMatrix.hpp:185
const TridiagRow_ operator[](size_t rowIdx) const
Row access operator.
Definition: TridiagonalMatrix.hpp:250
void mmv(const Vector &source, Vector &dest) const
Subtractive matrix-vector product.
Definition: TridiagonalMatrix.hpp:415
void usmv(Scalar alpha, const Vector &source, Vector &dest) const
Scaled additive matrix-vector product.
Definition: TridiagonalMatrix.hpp:455
TridiagonalMatrix & operator=(Scalar value)
Assignment operator from a Scalar.
Definition: TridiagonalMatrix.hpp:215
TridiagonalMatrix & operator-=(const TridiagonalMatrix &other)
Subtraction operator.
Definition: TridiagonalMatrix.hpp:287
void mmtv(const Vector &source, Vector &dest) const
Transposed subtractive matrix-vector product.
Definition: TridiagonalMatrix.hpp:578
size_t size() const
Return the number of rows/columns of the matrix.
Definition: TridiagonalMatrix.hpp:134
TridiagonalMatrix & operator/=(Scalar alpha)
Division by a Scalar.
Definition: TridiagonalMatrix.hpp:271
void solve(XVector &x, const BVector &b) const
Calculate the solution for a linear system of equations.
Definition: TridiagonalMatrix.hpp:714
void resize(size_t n)
Change the number of rows of the matrix.
Definition: TridiagonalMatrix.hpp:152
void mtv(const Vector &source, Vector &dest) const
Transposed matrix-vector product.
Definition: TridiagonalMatrix.hpp:498
Scalar frobeniusNorm() const
Calculate the frobenius norm.
Definition: TridiagonalMatrix.hpp:654
void umv(const Vector &source, Vector &dest) const
Additive matrix-vector product.
Definition: TridiagonalMatrix.hpp:375
TridiagonalMatrix & operator+=(const TridiagonalMatrix &other)
Addition operator.
Definition: TridiagonalMatrix.hpp:293
TridiagonalMatrix & operator*=(Scalar alpha)
Multiplication with a Scalar.
Definition: TridiagonalMatrix.hpp:256