ALP User Documentation  0.8.preview
Algebraic Programming User Documentation
bicgstab.hpp
Go to the documentation of this file.
1 
2 /*
3  * Copyright 2021 Huawei Technologies Co., Ltd.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
35 #ifndef _H_GRB_ALGORITHMS_BICGSTAB
36 #define _H_GRB_ALGORITHMS_BICGSTAB
37 
38 #include <graphblas.hpp>
39 
40 #include <iostream>
41 #include <type_traits>
42 
43 #ifdef _DEBUG
44  #include <cmath> // for sqrt, making the silent assumption that ResidualType
45  // is a supported type for it
46 #endif
47 
48 
49 namespace grb {
50 
51  namespace algorithms {
52 
154  template<
156  typename IOType, typename NonzeroType, typename InputType,
157  typename ResidualType,
158  class Semiring = Semiring<
159  operators::add< InputType, InputType, InputType >,
160  operators::mul< IOType, NonzeroType, InputType >,
161  identities::zero, identities::one
162  >,
163  class Minus = operators::subtract< ResidualType >,
164  class Divide = operators::divide< ResidualType >
165  >
169  const grb::Vector< InputType > &b,
170  const size_t max_iterations,
171  ResidualType tol,
172  size_t &iterations,
173  ResidualType &residual,
175  Vector< InputType > &rhat,
180  const Semiring &semiring = Semiring(),
181  const Minus &minus = Minus(),
182  const Divide &divide = Divide()
183  ) {
184  // static checks
185  static_assert( !( descr & descriptors::no_casting ) || (
186  std::is_same< IOType, NonzeroType >::value &&
187  std::is_same< IOType, InputType >::value &&
188  std::is_same< IOType, ResidualType >::value
189  ), "no_casting descriptor was set but containers with differing domains "
190  "were given."
191  );
192  static_assert( !( descr & descriptors::no_casting ) || (
193  std::is_same< NonzeroType, typename Semiring::D1 >::value &&
194  std::is_same< IOType, typename Semiring::D2 >::value &&
195  std::is_same< InputType, typename Semiring::D3 >::value &&
196  std::is_same< InputType, typename Semiring::D4 >::value
197  ), "no_casting descriptor was set, but semiring has incompatible domains "
198  "with the given containers."
199  );
200  static_assert( !( descr & descriptors::no_casting ) || (
201  std::is_same< InputType, typename Minus::D1 >::value &&
202  std::is_same< InputType, typename Minus::D2 >::value &&
203  std::is_same< InputType, typename Minus::D3 >::value
204  ), "no_casting descriptor was set, but given minus operator has "
205  "incompatible domains with the given containers."
206  );
207  static_assert( !( descr & descriptors::no_casting ) || (
208  std::is_same< ResidualType, typename Divide::D1 >::value &&
209  std::is_same< ResidualType, typename Divide::D2 >::value &&
210  std::is_same< ResidualType, typename Divide::D3 >::value
211  ), "no_casting descriptor was set, but given divide operator has "
212  "incompatible domains with the given tolerance type."
213  );
214  static_assert( std::is_floating_point< ResidualType >::value,
215  "Require floating-point residual type."
216  );
217 
218 #ifdef _DEBUG
219  std::cout << "Entering bicgstab; "
220  << "tol = " << tol << ", "
221  << "max_iterations = " << max_iterations << "\n";
222 #endif
223 
224  // descriptor for indiciating dense computations
225  constexpr Descriptor dense_descr = descr | descriptors::dense;
226 
227  // get an alias to zero and one in case 1 and 0 can't cast properly
228  const ResidualType zero = semiring.template getZero< ResidualType >();
229  const ResidualType one = semiring.template getOne< ResidualType >();
230 
231  // dynamic checks, sizes:
232  const size_t n = nrows( A );
233  if( n != ncols( A ) ) {
234  return MISMATCH;
235  }
236  if( n != size( x ) ) {
237  return MISMATCH;
238  }
239  if( n != size( b ) ) {
240  return MISMATCH;
241  }
242  if( n != size( r ) || n != size( rhat ) || n != size( p ) ||
243  n != size( p ) || n != size( s ) || n != size( t )
244  ) {
245  return MISMATCH;
246  }
247 
248  // dynamic checks, capacity:
249  if( n != capacity( x ) ) {
250  return ILLEGAL;
251  }
252  if( n != capacity( r ) || n != capacity( rhat ) || n != capacity( p ) ||
253  n != capacity( p ) || n != capacity( s ) || n != capacity( t )
254  ) {
255  return ILLEGAL;
256  }
257 
258  // dynamic checks, others:
259  if( tol <= zero ) {
260  return ILLEGAL;
261  }
262 
263 #ifdef _DEBUG
264  std::cout << "\t dynamic run-time error checking passed\n";
265 #endif
266 
267  // prelude
268  ResidualType b_norm_squared = zero;
269  RC ret = dot< dense_descr >( b_norm_squared, b, b, semiring );
270  if( ret ) {
271  std::cerr << "Error: BiCGstab encountered \"" << toString(ret)
272  << "\" during computation of the norm of b\n";
273  return ret;
274  }
275 
276  // make it so that we do not need to take square roots when detecting
277  // convergence
278  tol *= tol;
279  tol *= b_norm_squared;
280 #ifdef _DEBUG
281  std::cout << "Effective squared relative tolerance is " << tol << "\n";
282 #endif
283 
284  // ensure that x is structurally dense
285  if( nnz( x ) != n ) {
286  ret = grb::set< descriptors::invert_mask | descriptors::structural >(
287  x, x, zero
288  );
289  assert( nnz( x ) == n );
290  }
291 
292  // compute residual (squared), taking into account that b may be sparse
293  residual = zero;
294  ret = ret ? ret : set( t, zero ); // t = Ax
295  ret = ret ? ret : mxv< dense_descr >( t, A, x, semiring );
296  assert( nnz( t ) == n );
297  ret = ret ? ret : set( r, zero ); // r = b - Ax
298  ret = ret ? ret : foldl( r, b, semiring.getAdditiveMonoid() );
299  assert( nnz( r ) == n );
300  ret = ret ? ret : foldl< dense_descr >( r, t, minus );
301  ret = ret ? ret : dot< dense_descr >( residual, r, r, semiring ); // residual
302 
303  // check for prelude error
304  if( ret ) {
305  std::cerr << "Error: BiCGstab encountered \"" << toString(ret)
306  << "\" during prelude\n";
307  return ret;
308  }
309 
310  // check if the guess was good enough
311  if( residual < tol ) {
312  return SUCCESS;
313  }
314 
315 #ifdef _DEBUG
316  std::cout << "\t prelude completed\n";
317 #endif
318 
319  // start iterations
320  ret = ret ? ret : set( rhat, r );
321  ret = ret ? ret : set( p, zero );
322  ret = ret ? ret : set( v, zero );
323  ResidualType rho, rho_old, alpha, beta, omega, temp;
324  rho_old = alpha = omega = one;
325  iterations = 0;
326 
327  for( ; ret == SUCCESS && iterations < max_iterations; ++iterations ) {
328 
329 #ifdef _DEBUG
330  std::cout << "\t iteration " << iterations << " starts\n";
331 #endif
332 
333  // rho = ( rhat, r )
334  rho = zero;
335  ret = ret ? ret : dot< dense_descr >( rho, rhat, r, semiring );
336 #ifdef _DEBUG
337  std::cout << "\t\t rho = " << rho << "\n";
338 #endif
339  if( ret == SUCCESS && rho == zero ) {
340  std::cerr << "Error: BiCGstab detects r at iteration " << iterations <<
341  " is orthogonal to r-hat\n";
342  return FAILED;
343  }
344 
345  // beta = (rho / rho_old) * (alpha / omega)
346  ret = ret ? ret : apply( beta, rho, rho_old, divide );
347  ret = ret ? ret : apply( temp, alpha, omega, divide );
348  ret = ret ? ret : foldl( beta, temp, semiring.getMultiplicativeOperator() );
349 #ifdef _DEBUG
350  std::cout << "\t\t beta = " << beta << "\n";
351 #endif
352 
353  // p = r + beta ( p - omega * v )
354  ret = ret ? ret : eWiseLambda(
355  [&r,beta,&p,&v,omega,&semiring,&minus] (const size_t i) {
356  InputType tmp;
357  apply( tmp, omega, v[i], semiring.getMultiplicativeOperator() );
358  foldl( p[ i ], tmp, minus );
359  foldr( beta, p[ i ], semiring.getMultiplicativeOperator() );
360  foldr( r[ i ], p[ i ], semiring.getAdditiveOperator() );
361  }, v, p, r
362  );
363 
364  // v = Ap
365  ret = ret ? ret : set( v, zero );
366  ret = ret ? ret : mxv< dense_descr >( v, A, p, semiring );
367 
368  // alpha = rho / (rhat, v)
369  alpha = zero;
370  ret = ret ? ret : dot< dense_descr >( alpha, rhat, v, semiring );
371  if( alpha == zero ) {
372  std::cerr << "Error: BiCGstab detects rhat is orthogonal to v=Ap "
373  << "at iteration " << iterations << ".\n";
374  return FAILED;
375  }
376  ret = ret ? ret : foldr( rho, alpha, divide );
377 #ifdef _DEBUG
378  std::cout << "\t\t alpha = " << alpha << "\n";
379 #endif
380 
381  // x += alpha * p is post-poned to either the pre-stabilisation exit, or
382  // after the stabilisation step
383  //ret = ret ? ret : eWiseMul( x, alpha, p, semiring );
384 
385  // s = r - alpha * v
386  {
387  ResidualType minus_alpha = zero;
388  ret = ret ? ret : foldl( minus_alpha, alpha, minus );
389  ret = ret ? ret : set( s, r );
390  ret = ret ? ret : eWiseMul< dense_descr >( s, minus_alpha, v, semiring );
391  }
392 
393  // check residual
394  residual = zero;
395  ret = ret ? ret : dot< dense_descr >( residual, s, s, semiring );
396  assert( residual > zero );
397 #ifdef _DEBUG
398  std::cout << "\t\t running residual, pre-stabilisation: " << sqrt(residual)
399  << "\n";
400 #endif
401  if( ret == SUCCESS && residual < tol ) {
402  // update result (x += alpha * p) and exit
403  ret = eWiseMul< dense_descr >( x, alpha, p, semiring );
404  return ret;
405  }
406 
407  // t = As
408  ret = ret ? ret : set( t, zero );
409  ret = ret ? ret : mxv< dense_descr >( t, A, s, semiring );
410 
411  // omega = (t, s) / (t, t);
412  omega = temp = zero;
413  ret = ret ? ret : dot< dense_descr >( temp, t, s, semiring );
414 #ifdef _DEBUG
415  std::cout << "\t\t (t, s) = " << temp << "\n";
416 #endif
417  if( ret == SUCCESS && temp == zero ) {
418  std::cerr << "Error: BiCGstab detects As at iteration " << iterations <<
419  " is orthogonal to s\n";
420  return FAILED;
421  }
422  ret = ret ? ret : dot< dense_descr >( omega, t, t, semiring );
423 #ifdef _DEBUG
424  std::cout << "\t\t (t, t) = " << omega << "\n";
425 #endif
426  assert( omega > zero );
427  ret = ret ? ret : foldr( temp, omega, divide );
428 #ifdef _DEBUG
429  std::cout << "\t\t omega = " << omega << "\n";
430 #endif
431 
432  // x += alpha * p + omega * s
433  ret = ret ? ret : eWiseMul< dense_descr >( x, alpha, p, semiring );
434  ret = ret ? ret : eWiseMul< dense_descr >( x, omega, s, semiring );
435 
436  // r = s - omega * t
437  {
438  ResidualType minus_omega = zero;
439  ret = ret ? ret : foldl( minus_omega, omega, minus );
440  ret = ret ? ret : set( r, s );
441  ret = ret ? ret : eWiseMul< dense_descr >( r, minus_omega, t, semiring );
442  }
443 
444  // check residual
445  residual = zero;
446  ret = ret ? ret : dot< dense_descr >( residual, r, r, semiring );
447  assert( residual > zero );
448 #ifdef _DEBUG
449  std::cout << "\t\t running residual, post-stabilisation: "
450  << sqrt(residual) << ". "
451  << "Residual squared: " << residual << ".\n";
452 #endif
453  if( ret == SUCCESS ) {
454  if( residual < tol ) { return SUCCESS; }
455 
456  // go to next iteration
457  rho_old = rho;
458  }
459  }
460 
461  if( ret == SUCCESS ) {
462  // if we are here, then we did not detect convergence
463  std::cerr << "Warning: call to BiCGstab did not converge within "
464  << max_iterations << " iterations. Squared two-norm of the running "
465  << "residual is " << residual << ". "
466  << "Target residual squared: " << tol << ".\n";
467  return FAILED;
468  } else {
469  // if we are here, we exited due to an ALP error code
470  std::cerr << "Error: BiCGstab encountered error \"" << toString(ret)
471  << "\" while iterating to " << iterations << ", ";
472  if( iterations == max_iterations ) {
473  std::cerr << "which also is the maximum number of iterations.\n";
474  } else {
475  std::cerr << "which is below the maximum number of iterations of "
476  << max_iterations << "\n";
477  }
478  return ret;
479  }
480  }
481 
482  }
483 }
484 
485 #endif // end _H_GRB_ALGORITHMS_BICGSTAB
486 
RC set(Vector< DataType, backend, Coords > &x, const T val, const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< DataType >::value &&!grb::is_object< T >::value, void >::type *const =nullptr) noexcept
Sets all elements of a vector to the given value.
Definition: io.hpp:858
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
An ALP/GraphBLAS matrix.
Definition: matrix.hpp:72
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
A GraphBLAS vector.
Definition: vector.hpp:64
RC bicgstab(grb::Vector< IOType > &x, const grb::Matrix< NonzeroType > &A, const grb::Vector< InputType > &b, const size_t max_iterations, ResidualType tol, size_t &iterations, ResidualType &residual, Vector< InputType > &r, Vector< InputType > &rhat, Vector< InputType > &p, Vector< InputType > &v, Vector< InputType > &s, Vector< InputType > &t, const Semiring &semiring=Semiring(), const Minus &minus=Minus(), const Divide &divide=Divide())
Solves a linear system with unknown by using the bi-conjugate gradient (bi-CG) stabilised method; i...
Definition: bicgstab.hpp:166
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54
size_t nrows(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the row size of a given matrix.
Definition: io.hpp:286
static constexpr Descriptor dense
Indicates that all input and output vectors to an ALP/GraphBLAS primitive are structurally dense.
Definition: descriptors.hpp:151
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
size_t ncols(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the column size of a given matrix.
Definition: io.hpp:339
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3840
Indicates when one of the grb::algorithms has failed to achieve its intended result,...
Definition: rc.hpp:154
RC foldr(const Vector< InputType, backend, Coords > &x, const Vector< MaskType, backend, Coords > &mask, IOType &y, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Folds a vector into a scalar, right-to-left.
Definition: blas1.hpp:3943
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:477
The main header to include in order to use the ALP/GraphBLAS API.
RC eWiseLambda(const Func f, const Vector< DataType, backend, Coords > &x, Args...)
Executes an arbitrary element-wise user-defined function f on any number of vectors of equal length.
Definition: blas1.hpp:3746
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
size_t capacity(const Vector< InputType, backend, Coords > &x) noexcept
Queries the capacity of the given ALP/GraphBLAS container.
Definition: io.hpp:388
A generalised semiring.
Definition: semiring.hpp:190
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
std::string toString(const RC code)