JDFTx  1.2.1
matrix3.h
1 /*-------------------------------------------------------------------
2 Copyright 2011 Ravishankar Sundararaman
3 
4 This file is part of JDFTx.
5 
6 JDFTx is free software: you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, either version 3 of the License, or
9 (at your option) any later version.
10 
11 JDFTx is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15 
16 You should have received a copy of the GNU General Public License
17 along with JDFTx. If not, see <http://www.gnu.org/licenses/>.
18 -------------------------------------------------------------------*/
19 
20 #ifndef JDFTX_CORE_MATRIX3_H
21 #define JDFTX_CORE_MATRIX3_H
22 
23 #include <core/vector3.h>
24 
25 template<typename scalar=double> class matrix3
26 {
27  scalar m[3][3];
28 
29 public:
30  //accessors:
31  __hostanddev__ scalar& operator()(int i, int j) { return m[i][j]; }
32  __hostanddev__ const scalar& operator()(int i, int j) const { return m[i][j]; }
33  __hostanddev__ vector3<scalar> row(int i) const { return vector3<scalar>(m[i][0], m[i][1], m[i][2]); }
34  __hostanddev__ vector3<scalar> column(int i) const { return vector3<scalar>(m[0][i], m[1][i], m[2][i]); }
35 
36  __hostanddev__ void set_row(int i, const vector3<scalar>& v) { for(int j=0; j<3; j++) m[i][j] = v[j]; }
37  __hostanddev__ void set_rows(const vector3<scalar>& v0, const vector3<scalar>& v1, const vector3<scalar>& v2)
38  { for(int j=0; j<3; j++) { m[0][j] = v0[j]; m[1][j] = v1[j]; m[2][j] = v2[j]; }
39  }
40  __hostanddev__ void set_col(int j, const vector3<scalar>& v) { for(int i=0; i<3; i++) m[i][j] = v[i]; }
41  __hostanddev__ void set_cols(const vector3<scalar>& v0, const vector3<scalar>& v1, const vector3<scalar>& v2)
42  { for(int i=0; i<3; i++) { m[i][0] = v0[i]; m[i][1] = v1[i]; m[i][2] = v2[i]; }
43  }
44 
45  //constructors:
46  explicit __hostanddev__ matrix3(scalar d0=0, scalar d1=0, scalar d2=0)
47  { m[0][0] = d0; m[1][1] = d1, m[2][2] = d2;
48  m[0][1] = m[0][2] = m[1][0] = m[1][2] = m[2][0] = m[2][1] = 0.0;
49  }
50  __hostanddev__ matrix3(
51  scalar m00, scalar m01, scalar m02,
52  scalar m10, scalar m11, scalar m12,
53  scalar m20, scalar m21, scalar m22 )
54  { m[0][0] = m00; m[0][1] = m01; m[0][2] = m02;
55  m[1][0] = m10; m[1][1] = m11; m[1][2] = m12;
56  m[2][0] = m20; m[2][1] = m21; m[2][2] = m22;
57  }
58  template<typename scalar2> explicit __hostanddev__ matrix3(const matrix3<scalar2>& n)
59  { for(int i=0; i<3; i++)
60  for(int j=0; j<3; j++)
61  m[i][j] = scalar(n(i,j));
62  }
63 
64  //arithmetic operators
65  __hostanddev__ matrix3<scalar> operator-() const
66  { matrix3<scalar> n;
67  for(int i=0; i<3; i++)
68  for(int j=0; j<3; j++)
69  n(i,j) = -m[i][j];
70  return n;
71  }
72  __hostanddev__ matrix3<scalar> operator+(const matrix3<scalar> &n) const
73  { matrix3<scalar> ret;
74  for(int i=0; i<3; i++)
75  for(int j=0; j<3; j++)
76  ret(i,j) = m[i][j] + n(i,j);
77  return ret;
78  }
79  __hostanddev__ matrix3<scalar>& operator+=(const matrix3<scalar> &n)
80  { for(int i=0; i<3; i++)
81  for(int j=0; j<3; j++)
82  m[i][j] += n(i,j);
83  return *this;
84  }
85  __hostanddev__ matrix3<scalar> operator-(const matrix3<scalar> &n) const
86  { matrix3<scalar> ret;
87  for(int i=0; i<3; i++)
88  for(int j=0; j<3; j++)
89  ret(i,j) = m[i][j] - n(i,j);
90  return ret;
91  }
92  __hostanddev__ matrix3<scalar>& operator-=(const matrix3<scalar> &n)
93  { for(int i=0; i<3; i++)
94  for(int j=0; j<3; j++)
95  m[i][j] -= n(i,j);
96  return *this;
97  }
98  __hostanddev__ matrix3<scalar>& operator*=(scalar s)
99  { for(int i=0; i<3; i++)
100  for(int j=0; j<3; j++)
101  m[i][j] *= s;
102  return *this;
103  }
104  __hostanddev__ matrix3<scalar> operator*(scalar s) const
105  { matrix3<scalar> ret;
106  for(int i=0; i<3; i++)
107  for(int j=0; j<3; j++)
108  ret(i,j) = m[i][j] * s;
109  return ret;
110  }
111 
113  #define METRIC_LENGTH_SQUARED \
114  return v[0]*v[0]*m[0][0] + v[1]*v[1]*m[1][1] + v[2]*v[2]*m[2][2] \
115  + 2*(v[0]*v[1]*m[0][1] + v[0]*v[2]*m[0][2] + v[1]*v[2]*m[1][2]);
116  __hostanddev__ double metric_length_squared(const vector3<double> &v) const { METRIC_LENGTH_SQUARED }
117  __hostanddev__ scalar metric_length_squared(const vector3<int> &v) const { METRIC_LENGTH_SQUARED }
118  #undef METRIC_LENGTH_SQUARED
119 
120  __hostanddev__ matrix3<scalar> operator/(scalar s) const { return (*this) * (1.0/s); }
121  __hostanddev__ matrix3<scalar>& operator/(scalar s) { return (*this) *= (1.0/s); }
122 
124  __hostanddev__ matrix3<scalar> operator~() const
125  { matrix3<scalar> ret;
126  for(int i=0; i<3; i++)
127  for(int j=0; j<3; j++)
128  ret(i,j) = m[j][i];
129  return ret;
130  }
131 
132  void print(FILE* fp, const char *format, bool brackets=true) const
133  { for(int i=0; i<3; i++)
134  { if(brackets) fprintf(fp, "[ ");
135  for(int j=0; j<3; j++) fprintf(fp, format, m[i][j]);
136  if(brackets) fprintf(fp, " ]\n"); else fprintf(fp, "\n");
137  }
138  }
139 
140  //Comparison operators
141  __hostanddev__ bool operator==(const matrix3<scalar>& n) const
142  { for(int i=0; i<3; i++)
143  for(int j=0; j<3; j++)
144  if(m[i][j] != n(i,j))
145  return false;
146  return true;
147  }
148  __hostanddev__ bool operator!=(const matrix3<scalar>& n) const
149  { return ! (*this == n);
150  }
151 };
152 
153 //Multiplies:
154 template<typename scalar> __hostanddev__ matrix3<scalar> operator*(scalar s, const matrix3<scalar> &m) { return m*s; }
155 
156 template<typename scalar> __hostanddev__ matrix3<scalar> outer(const vector3<scalar> &a, const vector3<scalar> &b)
157 { matrix3<scalar> m;
158  for(int i=0; i<3; i++)
159  for(int j=0; j<3; j++)
160  m(i,j) = a[i] * b[j];
161  return m;
162 }
163 
164 #define MUL_MAT_VEC(retType) \
165  vector3<retType> ret; \
166  for(int i=0; i<3; i++) \
167  for(int j=0; j<3; j++) \
168  ret[i] += m(i,j) * v[j]; \
169  return ret;
170 template<typename scalar> __hostanddev__ vector3<scalar> operator*(const matrix3<scalar>& m, const vector3<scalar> &v)
171 { MUL_MAT_VEC(scalar)
172 }
173 template<typename scalar> __hostanddev__ vector3<scalar> operator*(const matrix3<scalar>& m, const vector3<int> &v)
174 { MUL_MAT_VEC(scalar)
175 }
176 template<typename scalar> __hostanddev__ vector3<scalar> operator*(const matrix3<int>& m, const vector3<scalar> &v)
177 { MUL_MAT_VEC(scalar)
178 }
179 __hostanddev__ vector3<int> operator*(const matrix3<int>& m, const vector3<int> &v)
180 { MUL_MAT_VEC(int)
181 }
182 #undef MUL_MAT_VEC
183 
184 #define MUL_VEC_MAT(retType) \
185  vector3<retType> ret; \
186  for(int i=0; i<3; i++) \
187  for(int j=0; j<3; j++) \
188  ret[j] += v[i] * m(i,j); \
189  return ret;
190 template<typename scalar> __hostanddev__ vector3<scalar> operator*(const vector3<scalar> &v, const matrix3<scalar>& m)
191 { MUL_VEC_MAT(scalar)
192 }
193 template<typename scalar> __hostanddev__ vector3<scalar> operator*(const vector3<int> &v, const matrix3<scalar>& m)
194 { MUL_VEC_MAT(scalar)
195 }
196 template<typename scalar> __hostanddev__ vector3<scalar> operator*(const vector3<scalar> &v, const matrix3<int>& m)
197 { MUL_VEC_MAT(scalar)
198 }
199 __hostanddev__ vector3<int> operator*(const vector3<int> &v, const matrix3<int>& m)
200 { MUL_VEC_MAT(int)
201 }
202 #undef MUL_VEC_MAT
203 
204 #define MUL_MAT_MAT(retType) \
205  matrix3<retType> ret; \
206  for(int i=0; i<3; i++) \
207  for(int j=0; j<3; j++) \
208  for(int k=0; k<3; k++) \
209  ret(i,j) += m(i,k) * n(k,j); \
210  return ret;
211 template<typename scalar> __hostanddev__ matrix3<scalar> operator*(const matrix3<scalar> &m, const matrix3<scalar>& n)
212 { MUL_MAT_MAT(scalar)
213 }
214 template<typename scalar> __hostanddev__ matrix3<scalar> operator*(const matrix3<scalar> &m, const matrix3<int>& n)
215 { MUL_MAT_MAT(scalar)
216 }
217 template<typename scalar> __hostanddev__ matrix3<scalar> operator*(const matrix3<int> &m, const matrix3<scalar>& n)
218 { MUL_MAT_MAT(scalar)
219 }
220 __hostanddev__ matrix3<int> operator*(const matrix3<int> &m, const matrix3<int>& n)
221 { MUL_MAT_MAT(int)
222 }
223 #undef MUL_MAT_MAT
224 template<typename scalar> __hostanddev__ matrix3<scalar>& operator*=(matrix3<scalar> &m, const matrix3<scalar>& n)
225 { return (m = m * n); }
226 
227 
228 template<typename scalar> __hostanddev__ matrix3<scalar> Diag(vector3<scalar> v)
229 { return matrix3<scalar>(v[0],v[1],v[2]); }
230 
231 template<typename scalar> __hostanddev__ scalar trace(const matrix3<scalar> &m) { return m(0,0)+m(1,1)+m(2,2); }
232 template<typename scalar> __hostanddev__ scalar det(const matrix3<scalar> &m) { return box(m.row(0),m.row(1),m.row(2)); }
233 template<typename scalar> __hostanddev__ matrix3<scalar> adjugate(const matrix3<scalar> &m)
234 { matrix3<scalar> adj;
235  adj.set_cols(cross(m.row(1),m.row(2)), cross(m.row(2),m.row(0)), cross(m.row(0),m.row(1)));
236  return adj;
237 }
238 __hostanddev__ matrix3<> inv(const matrix3<> &m)
239 { return (1./det(m)) * adjugate(m);
240 }
241 __hostanddev__ double nrm2(const matrix3<>& m) { return sqrt(trace((~m)*m)); }
242 
244 __hostanddev__ matrix3<> rotation(double theta, int axis)
245 { double s, c; sincos(theta, &s, &c);
246  switch(axis)
247  { case 0:
248  return matrix3<>(
249  1, 0, 0,
250  0, c, s,
251  0, -s, c );
252  case 1:
253  return matrix3<>(
254  c, 0, -s,
255  0, 1, 0,
256  s, 0, c );
257  default:
258  return matrix3<>(
259  c, s, 0,
260  -s, c, 0,
261  0, 0, 1 );
262  }
263 }
264 
265 #endif //JDFTX_CORE_MATRIX3_H
ScalarField sqrt(const ScalarField &)
Elementwise square root (preserve input)
Definition: matrix3.h:25
__hostanddev__ vector3< scalar > cross(const vector3< scalar > &a, const vector3< scalar > &b)
cross product
Definition: vector3.h:115
__hostanddev__ scalar box(const vector3< scalar > &a, const vector3< scalar > &b, const vector3< scalar > &c)
box product / triple product
Definition: vector3.h:123
ScalarField inv(const ScalarField &)
Elementwise reciprocal (preserve input)
Generic 3-vector.
Definition: vector3.h:33
__hostanddev__ matrix3< scalar > operator~() const
transpose
Definition: matrix3.h:124
double nrm2(const Tptr &X)
Definition: Operators.h:199