35#ifndef _H_GRB_ALGORITHMS_BICGSTAB
36#define _H_GRB_ALGORITHMS_BICGSTAB
51 namespace algorithms {
156 typename IOType,
typename NonzeroType,
typename InputType,
157 typename ResidualType,
158 class Semiring = Semiring<
159 operators::add< InputType, InputType, InputType >,
160 operators::mul< IOType, NonzeroType, InputType >,
161 identities::zero, identities::one
163 class Minus = operators::subtract< ResidualType >,
164 class Divide = operators::divide< ResidualType >
170 const size_t max_iterations,
173 ResidualType &residual,
181 const Minus &minus = Minus(),
182 const Divide ÷ = Divide()
186 std::is_same< IOType, NonzeroType >::value &&
187 std::is_same< IOType, InputType >::value &&
188 std::is_same< IOType, ResidualType >::value
189 ),
"no_casting descriptor was set but containers with differing domains "
193 std::is_same< NonzeroType, typename Semiring::D1 >::value &&
194 std::is_same< IOType, typename Semiring::D2 >::value &&
195 std::is_same< InputType, typename Semiring::D3 >::value &&
196 std::is_same< InputType, typename Semiring::D4 >::value
197 ),
"no_casting descriptor was set, but semiring has incompatible domains "
198 "with the given containers."
201 std::is_same< InputType, typename Minus::D1 >::value &&
202 std::is_same< InputType, typename Minus::D2 >::value &&
203 std::is_same< InputType, typename Minus::D3 >::value
204 ),
"no_casting descriptor was set, but given minus operator has "
205 "incompatible domains with the given containers."
208 std::is_same< ResidualType, typename Divide::D1 >::value &&
209 std::is_same< ResidualType, typename Divide::D2 >::value &&
210 std::is_same< ResidualType, typename Divide::D3 >::value
211 ),
"no_casting descriptor was set, but given divide operator has "
212 "incompatible domains with the given tolerance type."
214 static_assert( std::is_floating_point< ResidualType >::value,
215 "Require floating-point residual type."
219 std::cout <<
"Entering bicgstab; "
220 <<
"tol = " << tol <<
", "
221 <<
"max_iterations = " << max_iterations <<
"\n";
228 const ResidualType zero = semiring.template getZero< ResidualType >();
229 const ResidualType one = semiring.template getOne< ResidualType >();
232 const size_t n =
nrows( A );
233 if( n !=
ncols( A ) ) {
236 if( n !=
size( x ) ) {
239 if( n !=
size( b ) ) {
242 if( n !=
size( r ) || n !=
size( rhat ) || n !=
size( p ) ||
264 std::cout <<
"\t dynamic run-time error checking passed\n";
268 ResidualType b_norm_squared = zero;
269 RC ret = dot< dense_descr >( b_norm_squared, b, b, semiring );
271 std::cerr <<
"Error: BiCGstab encountered \"" <<
toString(ret)
272 <<
"\" during computation of the norm of b\n";
279 tol *= b_norm_squared;
281 std::cout <<
"Effective squared relative tolerance is " << tol <<
"\n";
285 if(
nnz( x ) != n ) {
286 ret = grb::set< descriptors::invert_mask | descriptors::structural >(
289 assert(
nnz( x ) == n );
294 ret = ret ? ret :
set( t, zero );
295 ret = ret ? ret : mxv< dense_descr >( t, A, x, semiring );
296 assert(
nnz( t ) == n );
297 ret = ret ? ret :
set( r, zero );
298 ret = ret ? ret :
foldl( r, b, semiring.getAdditiveMonoid() );
299 assert(
nnz( r ) == n );
300 ret = ret ? ret : foldl< dense_descr >( r, t, minus );
301 ret = ret ? ret : dot< dense_descr >( residual, r, r, semiring );
305 std::cerr <<
"Error: BiCGstab encountered \"" <<
toString(ret)
306 <<
"\" during prelude\n";
311 if( residual < tol ) {
316 std::cout <<
"\t prelude completed\n";
320 ret = ret ? ret :
set( rhat, r );
321 ret = ret ? ret :
set( p, zero );
322 ret = ret ? ret :
set( v, zero );
323 ResidualType rho, rho_old, alpha, beta, omega, temp;
324 rho_old = alpha = omega = one;
327 for( ; ret ==
SUCCESS && iterations < max_iterations; ++iterations ) {
330 std::cout <<
"\t iteration " << iterations <<
" starts\n";
335 ret = ret ? ret : dot< dense_descr >( rho, rhat, r, semiring );
337 std::cout <<
"\t\t rho = " << rho <<
"\n";
339 if( ret ==
SUCCESS && rho == zero ) {
340 std::cerr <<
"Error: BiCGstab detects r at iteration " << iterations <<
341 " is orthogonal to r-hat\n";
346 ret = ret ? ret :
apply( beta, rho, rho_old, divide );
347 ret = ret ? ret :
apply( temp, alpha, omega, divide );
348 ret = ret ? ret :
foldl( beta, temp, semiring.getMultiplicativeOperator() );
350 std::cout <<
"\t\t beta = " << beta <<
"\n";
355 [&r,&beta,&p,&v,&omega,&semiring,&minus] (
const size_t i) {
357 apply( tmp, omega, v[i], semiring.getMultiplicativeOperator() );
358 foldl( p[ i ], tmp, minus );
359 foldr( beta, p[ i ], semiring.getMultiplicativeOperator() );
360 foldr( r[ i ], p[ i ], semiring.getAdditiveOperator() );
365 ret = ret ? ret :
set( v, zero );
366 ret = ret ? ret : mxv< dense_descr >( v, A, p, semiring );
370 ret = ret ? ret : dot< dense_descr >( alpha, rhat, v, semiring );
371 if( alpha == zero ) {
372 std::cerr <<
"Error: BiCGstab detects rhat is orthogonal to v=Ap "
373 <<
"at iteration " << iterations <<
".\n";
376 ret = ret ? ret :
foldr( rho, alpha, divide );
378 std::cout <<
"\t\t alpha = " << alpha <<
"\n";
387 ResidualType minus_alpha = zero;
388 ret = ret ? ret :
foldl( minus_alpha, alpha, minus );
389 ret = ret ? ret :
set( s, r );
390 ret = ret ? ret : eWiseMul< dense_descr >( s, minus_alpha, v, semiring );
395 ret = ret ? ret : dot< dense_descr >( residual, s, s, semiring );
396 assert( residual > zero );
398 std::cout <<
"\t\t running residual, pre-stabilisation: " << sqrt(residual)
401 if( ret ==
SUCCESS && residual < tol ) {
403 ret = eWiseMul< dense_descr >( x, alpha, p, semiring );
408 ret = ret ? ret :
set( t, zero );
409 ret = ret ? ret : mxv< dense_descr >( t, A, s, semiring );
413 ret = ret ? ret : dot< dense_descr >( temp, t, s, semiring );
415 std::cout <<
"\t\t (t, s) = " << temp <<
"\n";
417 if( ret ==
SUCCESS && temp == zero ) {
418 std::cerr <<
"Error: BiCGstab detects As at iteration " << iterations <<
419 " is orthogonal to s\n";
422 ret = ret ? ret : dot< dense_descr >( omega, t, t, semiring );
424 std::cout <<
"\t\t (t, t) = " << omega <<
"\n";
426 assert( omega > zero );
427 ret = ret ? ret :
foldr( temp, omega, divide );
429 std::cout <<
"\t\t omega = " << omega <<
"\n";
433 ret = ret ? ret : eWiseMul< dense_descr >( x, alpha, p, semiring );
434 ret = ret ? ret : eWiseMul< dense_descr >( x, omega, s, semiring );
438 ResidualType minus_omega = zero;
439 ret = ret ? ret :
foldl( minus_omega, omega, minus );
440 ret = ret ? ret :
set( r, s );
441 ret = ret ? ret : eWiseMul< dense_descr >( r, minus_omega, t, semiring );
446 ret = ret ? ret : dot< dense_descr >( residual, r, r, semiring );
447 assert( residual > zero );
449 std::cout <<
"\t\t running residual, post-stabilisation: "
450 << sqrt(residual) <<
". "
451 <<
"Residual squared: " << residual <<
".\n";
454 if( residual < tol ) {
return SUCCESS; }
463 std::cerr <<
"Warning: call to BiCGstab did not converge within "
464 << max_iterations <<
" iterations. Squared two-norm of the running "
465 <<
"residual is " << residual <<
". "
466 <<
"Target residual squared: " << tol <<
".\n";
470 std::cerr <<
"Error: BiCGstab encountered error \"" <<
toString(ret)
471 <<
"\" while iterating to " << iterations <<
", ";
472 if( iterations == max_iterations ) {
473 std::cerr <<
"which also is the maximum number of iterations.\n";
475 std::cerr <<
"which is below the maximum number of iterations of "
476 << max_iterations <<
"\n";
An ALP/GraphBLAS matrix.
Definition: matrix.hpp:71
A generalised semiring.
Definition: semiring.hpp:186
A GraphBLAS vector.
Definition: vector.hpp:64
The main header to include in order to use the ALP/GraphBLAS API.
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3618
RC foldr(const Vector< InputType, backend, Coords > &x, const Vector< MaskType, backend, Coords > &mask, IOType &y, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Folds a vector into a scalar, right-to-left.
Definition: blas1.hpp:3721
RC eWiseLambda(const Func f, const Vector< DataType, backend, Coords > &x, Args...)
Executes an arbitrary element-wise user-defined function f using any number of vectors of equal lengt...
Definition: blas1.hpp:3524
size_t capacity(const Vector< InputType, backend, Coords > &x) noexcept
Queries the capacity of the given ALP/GraphBLAS container.
Definition: io.hpp:388
RC set(Vector< DataType, backend, Coords > &x, const T val, const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< DataType >::value &&!grb::is_object< T >::value, void >::type *const =nullptr) noexcept
Sets all elements of a vector to the given value.
Definition: io.hpp:857
size_t nrows(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the row size of a given matrix.
Definition: io.hpp:286
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
size_t ncols(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the column size of a given matrix.
Definition: io.hpp:339
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
RC bicgstab(grb::Vector< IOType > &x, const grb::Matrix< NonzeroType > &A, const grb::Vector< InputType > &b, const size_t max_iterations, ResidualType tol, size_t &iterations, ResidualType &residual, Vector< InputType > &r, Vector< InputType > &rhat, Vector< InputType > &p, Vector< InputType > &v, Vector< InputType > &s, Vector< InputType > &t, const Semiring &semiring=Semiring(), const Minus &minus=Minus(), const Divide ÷=Divide())
Solves a linear system with unknown by using the bi-conjugate gradient (bi-CG) stabilised method; i...
Definition: bicgstab.hpp:166
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
static constexpr Descriptor dense
Indicates that all input and output vectors to an ALP/GraphBLAS primitive are structurally dense.
Definition: descriptors.hpp:151
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:450
std::string toString(const RC code)
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
@ ILLEGAL
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
@ MISMATCH
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
@ SUCCESS
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
@ FAILED
Indicates when one of the grb::algorithms has failed to achieve its intended result,...
Definition: rc.hpp:154
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54