ALP User Documentation 0.7.0
Algebraic Programming User Documentation
bicgstab.hpp
Go to the documentation of this file.
1
2/*
3 * Copyright 2021 Huawei Technologies Co., Ltd.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
35#ifndef _H_GRB_ALGORITHMS_BICGSTAB
36#define _H_GRB_ALGORITHMS_BICGSTAB
37
38#include <graphblas.hpp>
39
40#include <iostream>
41#include <type_traits>
42
43#ifdef _DEBUG
44 #include <cmath> // for sqrt, making the silent assumption that ResidualType
45 // is a supported type for it
46#endif
47
48
49namespace grb {
50
51 namespace algorithms {
52
154 template<
156 typename IOType, typename NonzeroType, typename InputType,
157 typename ResidualType,
158 class Semiring = Semiring<
159 operators::add< InputType, InputType, InputType >,
160 operators::mul< IOType, NonzeroType, InputType >,
161 identities::zero, identities::one
162 >,
163 class Minus = operators::subtract< ResidualType >,
164 class Divide = operators::divide< ResidualType >
165 >
170 const size_t max_iterations,
171 ResidualType tol,
172 size_t &iterations,
173 ResidualType &residual,
180 const Semiring &semiring = Semiring(),
181 const Minus &minus = Minus(),
182 const Divide &divide = Divide()
183 ) {
184 // static checks
185 static_assert( !( descr & descriptors::no_casting ) || (
186 std::is_same< IOType, NonzeroType >::value &&
187 std::is_same< IOType, InputType >::value &&
188 std::is_same< IOType, ResidualType >::value
189 ), "no_casting descriptor was set but containers with differing domains "
190 "were given."
191 );
192 static_assert( !( descr & descriptors::no_casting ) || (
193 std::is_same< NonzeroType, typename Semiring::D1 >::value &&
194 std::is_same< IOType, typename Semiring::D2 >::value &&
195 std::is_same< InputType, typename Semiring::D3 >::value &&
196 std::is_same< InputType, typename Semiring::D4 >::value
197 ), "no_casting descriptor was set, but semiring has incompatible domains "
198 "with the given containers."
199 );
200 static_assert( !( descr & descriptors::no_casting ) || (
201 std::is_same< InputType, typename Minus::D1 >::value &&
202 std::is_same< InputType, typename Minus::D2 >::value &&
203 std::is_same< InputType, typename Minus::D3 >::value
204 ), "no_casting descriptor was set, but given minus operator has "
205 "incompatible domains with the given containers."
206 );
207 static_assert( !( descr & descriptors::no_casting ) || (
208 std::is_same< ResidualType, typename Divide::D1 >::value &&
209 std::is_same< ResidualType, typename Divide::D2 >::value &&
210 std::is_same< ResidualType, typename Divide::D3 >::value
211 ), "no_casting descriptor was set, but given divide operator has "
212 "incompatible domains with the given tolerance type."
213 );
214 static_assert( std::is_floating_point< ResidualType >::value,
215 "Require floating-point residual type."
216 );
217
218#ifdef _DEBUG
219 std::cout << "Entering bicgstab; "
220 << "tol = " << tol << ", "
221 << "max_iterations = " << max_iterations << "\n";
222#endif
223
224 // descriptor for indiciating dense computations
225 constexpr Descriptor dense_descr = descr | descriptors::dense;
226
227 // get an alias to zero and one in case 1 and 0 can't cast properly
228 const ResidualType zero = semiring.template getZero< ResidualType >();
229 const ResidualType one = semiring.template getOne< ResidualType >();
230
231 // dynamic checks, sizes:
232 const size_t n = nrows( A );
233 if( n != ncols( A ) ) {
234 return MISMATCH;
235 }
236 if( n != size( x ) ) {
237 return MISMATCH;
238 }
239 if( n != size( b ) ) {
240 return MISMATCH;
241 }
242 if( n != size( r ) || n != size( rhat ) || n != size( p ) ||
243 n != size( p ) || n != size( s ) || n != size( t )
244 ) {
245 return MISMATCH;
246 }
247
248 // dynamic checks, capacity:
249 if( n != capacity( x ) ) {
250 return ILLEGAL;
251 }
252 if( n != capacity( r ) || n != capacity( rhat ) || n != capacity( p ) ||
253 n != capacity( p ) || n != capacity( s ) || n != capacity( t )
254 ) {
255 return ILLEGAL;
256 }
257
258 // dynamic checks, others:
259 if( tol <= zero ) {
260 return ILLEGAL;
261 }
262
263#ifdef _DEBUG
264 std::cout << "\t dynamic run-time error checking passed\n";
265#endif
266
267 // prelude
268 ResidualType b_norm_squared = zero;
269 RC ret = dot< dense_descr >( b_norm_squared, b, b, semiring );
270 if( ret ) {
271 std::cerr << "Error: BiCGstab encountered \"" << toString(ret)
272 << "\" during computation of the norm of b\n";
273 return ret;
274 }
275
276 // make it so that we do not need to take square roots when detecting
277 // convergence
278 tol *= tol;
279 tol *= b_norm_squared;
280#ifdef _DEBUG
281 std::cout << "Effective squared relative tolerance is " << tol << "\n";
282#endif
283
284 // ensure that x is structurally dense
285 if( nnz( x ) != n ) {
286 ret = grb::set< descriptors::invert_mask | descriptors::structural >(
287 x, x, zero
288 );
289 assert( nnz( x ) == n );
290 }
291
292 // compute residual (squared), taking into account that b may be sparse
293 residual = zero;
294 ret = ret ? ret : set( t, zero ); // t = Ax
295 ret = ret ? ret : mxv< dense_descr >( t, A, x, semiring );
296 assert( nnz( t ) == n );
297 ret = ret ? ret : set( r, zero ); // r = b - Ax
298 ret = ret ? ret : foldl( r, b, semiring.getAdditiveMonoid() );
299 assert( nnz( r ) == n );
300 ret = ret ? ret : foldl< dense_descr >( r, t, minus );
301 ret = ret ? ret : dot< dense_descr >( residual, r, r, semiring ); // residual
302
303 // check for prelude error
304 if( ret ) {
305 std::cerr << "Error: BiCGstab encountered \"" << toString(ret)
306 << "\" during prelude\n";
307 return ret;
308 }
309
310 // check if the guess was good enough
311 if( residual < tol ) {
312 return SUCCESS;
313 }
314
315#ifdef _DEBUG
316 std::cout << "\t prelude completed\n";
317#endif
318
319 // start iterations
320 ret = ret ? ret : set( rhat, r );
321 ret = ret ? ret : set( p, zero );
322 ret = ret ? ret : set( v, zero );
323 ResidualType rho, rho_old, alpha, beta, omega, temp;
324 rho_old = alpha = omega = one;
325 iterations = 0;
326
327 for( ; ret == SUCCESS && iterations < max_iterations; ++iterations ) {
328
329#ifdef _DEBUG
330 std::cout << "\t iteration " << iterations << " starts\n";
331#endif
332
333 // rho = ( rhat, r )
334 rho = zero;
335 ret = ret ? ret : dot< dense_descr >( rho, rhat, r, semiring );
336#ifdef _DEBUG
337 std::cout << "\t\t rho = " << rho << "\n";
338#endif
339 if( ret == SUCCESS && rho == zero ) {
340 std::cerr << "Error: BiCGstab detects r at iteration " << iterations <<
341 " is orthogonal to r-hat\n";
342 return FAILED;
343 }
344
345 // beta = (rho / rho_old) * (alpha / omega)
346 ret = ret ? ret : apply( beta, rho, rho_old, divide );
347 ret = ret ? ret : apply( temp, alpha, omega, divide );
348 ret = ret ? ret : foldl( beta, temp, semiring.getMultiplicativeOperator() );
349#ifdef _DEBUG
350 std::cout << "\t\t beta = " << beta << "\n";
351#endif
352
353 // p = r + beta ( p - omega * v )
354 ret = ret ? ret : eWiseLambda(
355 [&r,beta,&p,&v,omega,&semiring,&minus] (const size_t i) {
356 InputType tmp;
357 apply( tmp, omega, v[i], semiring.getMultiplicativeOperator() );
358 foldl( p[ i ], tmp, minus );
359 foldr( beta, p[ i ], semiring.getMultiplicativeOperator() );
360 foldr( r[ i ], p[ i ], semiring.getAdditiveOperator() );
361 }, v, p, r
362 );
363
364 // v = Ap
365 ret = ret ? ret : set( v, zero );
366 ret = ret ? ret : mxv< dense_descr >( v, A, p, semiring );
367
368 // alpha = rho / (rhat, v)
369 alpha = zero;
370 ret = ret ? ret : dot< dense_descr >( alpha, rhat, v, semiring );
371 if( alpha == zero ) {
372 std::cerr << "Error: BiCGstab detects rhat is orthogonal to v=Ap "
373 << "at iteration " << iterations << ".\n";
374 return FAILED;
375 }
376 ret = ret ? ret : foldr( rho, alpha, divide );
377#ifdef _DEBUG
378 std::cout << "\t\t alpha = " << alpha << "\n";
379#endif
380
381 // x += alpha * p is post-poned to either the pre-stabilisation exit, or
382 // after the stabilisation step
383 //ret = ret ? ret : eWiseMul( x, alpha, p, semiring );
384
385 // s = r - alpha * v
386 {
387 ResidualType minus_alpha = zero;
388 ret = ret ? ret : foldl( minus_alpha, alpha, minus );
389 ret = ret ? ret : set( s, r );
390 ret = ret ? ret : eWiseMul< dense_descr >( s, minus_alpha, v, semiring );
391 }
392
393 // check residual
394 residual = zero;
395 ret = ret ? ret : dot< dense_descr >( residual, s, s, semiring );
396 assert( residual > zero );
397#ifdef _DEBUG
398 std::cout << "\t\t running residual, pre-stabilisation: " << sqrt(residual)
399 << "\n";
400#endif
401 if( ret == SUCCESS && residual < tol ) {
402 // update result (x += alpha * p) and exit
403 ret = eWiseMul< dense_descr >( x, alpha, p, semiring );
404 return ret;
405 }
406
407 // t = As
408 ret = ret ? ret : set( t, zero );
409 ret = ret ? ret : mxv< dense_descr >( t, A, s, semiring );
410
411 // omega = (t, s) / (t, t);
412 omega = temp = zero;
413 ret = ret ? ret : dot< dense_descr >( temp, t, s, semiring );
414#ifdef _DEBUG
415 std::cout << "\t\t (t, s) = " << temp << "\n";
416#endif
417 if( ret == SUCCESS && temp == zero ) {
418 std::cerr << "Error: BiCGstab detects As at iteration " << iterations <<
419 " is orthogonal to s\n";
420 return FAILED;
421 }
422 ret = ret ? ret : dot< dense_descr >( omega, t, t, semiring );
423#ifdef _DEBUG
424 std::cout << "\t\t (t, t) = " << omega << "\n";
425#endif
426 assert( omega > zero );
427 ret = ret ? ret : foldr( temp, omega, divide );
428#ifdef _DEBUG
429 std::cout << "\t\t omega = " << omega << "\n";
430#endif
431
432 // x += alpha * p + omega * s
433 ret = ret ? ret : eWiseMul< dense_descr >( x, alpha, p, semiring );
434 ret = ret ? ret : eWiseMul< dense_descr >( x, omega, s, semiring );
435
436 // r = s - omega * t
437 {
438 ResidualType minus_omega = zero;
439 ret = ret ? ret : foldl( minus_omega, omega, minus );
440 ret = ret ? ret : set( r, s );
441 ret = ret ? ret : eWiseMul< dense_descr >( r, minus_omega, t, semiring );
442 }
443
444 // check residual
445 residual = zero;
446 ret = ret ? ret : dot< dense_descr >( residual, r, r, semiring );
447 assert( residual > zero );
448#ifdef _DEBUG
449 std::cout << "\t\t running residual, post-stabilisation: "
450 << sqrt(residual) << ". "
451 << "Residual squared: " << residual << ".\n";
452#endif
453 if( ret == SUCCESS ) {
454 if( residual < tol ) { return SUCCESS; }
455
456 // go to next iteration
457 rho_old = rho;
458 }
459 }
460
461 if( ret == SUCCESS ) {
462 // if we are here, then we did not detect convergence
463 std::cerr << "Warning: call to BiCGstab did not converge within "
464 << max_iterations << " iterations. Squared two-norm of the running "
465 << "residual is " << residual << ". "
466 << "Target residual squared: " << tol << ".\n";
467 return FAILED;
468 } else {
469 // if we are here, we exited due to an ALP error code
470 std::cerr << "Error: BiCGstab encountered error \"" << toString(ret)
471 << "\" while iterating to " << iterations << ", ";
472 if( iterations == max_iterations ) {
473 std::cerr << "which also is the maximum number of iterations.\n";
474 } else {
475 std::cerr << "which is below the maximum number of iterations of "
476 << max_iterations << "\n";
477 }
478 return ret;
479 }
480 }
481
482 }
483}
484
485#endif // end _H_GRB_ALGORITHMS_BICGSTAB
486
An ALP/GraphBLAS matrix.
Definition: matrix.hpp:71
A generalised semiring.
Definition: semiring.hpp:186
A GraphBLAS vector.
Definition: vector.hpp:64
The main header to include in order to use the ALP/GraphBLAS API.
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
RC foldr(const Vector< InputType, backend, Coords > &x, const Vector< MaskType, backend, Coords > &mask, IOType &y, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Folds a vector into a scalar, right-to-left.
Definition: blas1.hpp:3943
RC eWiseLambda(const Func f, const Vector< DataType, backend, Coords > &x, Args...)
Executes an arbitrary element-wise user-defined function f on any number of vectors of equal length.
Definition: blas1.hpp:3746
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3840
RC set(Vector< DataType, backend, Coords > &x, const T val, const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< DataType >::value &&!grb::is_object< T >::value, void >::type *const =nullptr) noexcept
Sets all elements of a vector to the given value.
Definition: io.hpp:857
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
size_t ncols(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the column size of a given matrix.
Definition: io.hpp:339
size_t capacity(const Vector< InputType, backend, Coords > &x) noexcept
Queries the capacity of the given ALP/GraphBLAS container.
Definition: io.hpp:388
size_t nrows(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the row size of a given matrix.
Definition: io.hpp:286
RC bicgstab(grb::Vector< IOType > &x, const grb::Matrix< NonzeroType > &A, const grb::Vector< InputType > &b, const size_t max_iterations, ResidualType tol, size_t &iterations, ResidualType &residual, Vector< InputType > &r, Vector< InputType > &rhat, Vector< InputType > &p, Vector< InputType > &v, Vector< InputType > &s, Vector< InputType > &t, const Semiring &semiring=Semiring(), const Minus &minus=Minus(), const Divide &divide=Divide())
Solves a linear system with unknown by using the bi-conjugate gradient (bi-CG) stabilised method; i...
Definition: bicgstab.hpp:166
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
static constexpr Descriptor dense
Indicates that all input and output vectors to an ALP/GraphBLAS primitive are structurally dense.
Definition: descriptors.hpp:151
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:452
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
@ ILLEGAL
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
@ MISMATCH
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
@ SUCCESS
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
@ FAILED
Indicates when one of the grb::algorithms has failed to achieve its intended result,...
Definition: rc.hpp:154
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54
std::string toString(const RC code)