26 #ifndef _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT 27 #define _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT 33 #include <graphblas/utils/iscomplex.hpp> 38 namespace algorithms {
178 bool preconditioned =
true,
180 typename ResidualType,
181 typename NonzeroType,
183 class Ring = Semiring<
187 class Minus = operators::subtract< IOType >,
188 class Divide = operators::divide< IOType >,
189 typename RSI,
typename NZI,
Backend backend
201 const size_t max_iterations,
204 ResidualType &residual,
209 const Ring &ring = Ring(),
210 const Minus &minus = Minus(),
211 const Divide ÷ = Divide()
214 static_assert( std::is_floating_point< ResidualType >::value,
215 "Can only use the CG algorithm with floating-point residual " 218 std::is_same< IOType, ResidualType >::value &&
219 std::is_same< IOType, NonzeroType >::value &&
220 std::is_same< IOType, InputType >::value
221 ),
"One or more of the provided containers have differing element types " 222 "while the no-casting descriptor has been supplied" 225 std::is_same< NonzeroType, typename Ring::D1 >::value &&
226 std::is_same< IOType, typename Ring::D2 >::value &&
227 std::is_same< InputType, typename Ring::D3 >::value &&
228 std::is_same< InputType, typename Ring::D4 >::value
229 ),
"no_casting descriptor was set, but semiring has incompatible domains " 230 "with the given containers." 233 std::is_same< InputType, typename Minus::D1 >::value &&
234 std::is_same< InputType, typename Minus::D2 >::value &&
235 std::is_same< InputType, typename Minus::D3 >::value
236 ),
"no_casting descriptor was set, but given minus operator has " 237 "incompatible domains with the given containers." 240 std::is_same< ResidualType, typename Divide::D1 >::value &&
241 std::is_same< ResidualType, typename Divide::D2 >::value &&
242 std::is_same< ResidualType, typename Divide::D3 >::value
243 ),
"no_casting descriptor was set, but given divide operator has " 244 "incompatible domains with the given tolerance type." 246 static_assert( std::is_floating_point< ResidualType >::value,
247 "Require floating-point residual type." 251 const ResidualType zero_residual = ring.template getZero< ResidualType >();
252 const IOType zero = ring.template getZero< IOType >();
262 std::cerr <<
"Error: initial solution guess and output vector length (" 263 <<
size( x ) <<
") does not match matrix size (" << m <<
").\n";
267 std::cerr <<
"Error: right-hand side size (" <<
grb::size( b ) <<
") does " 268 <<
"not match matrix size (" << m <<
").\n";
272 std::cerr <<
"Error: provided workspace vectors are not of the correct " 276 if( preconditioned &&
grb::size( temp_precond ) != n ) {
277 std::cerr <<
"Error: (left) preconditioner workspace vector does not have " 278 <<
"the correct length.\n";
282 std::cerr <<
"Warning: grb::algorithms::conjugate_gradient requires " 283 <<
"square input matrices, but a non-square input matrix was " 284 <<
"given instead.\n";
302 if( tol <= zero_residual ) {
303 std::cerr <<
"Error: tolerance input to CG must be strictly positive\n";
306 if( max_iterations == 0 ) {
307 std::cerr <<
"Error: at least one CG iteration must be requested\n";
314 residual = std::numeric_limits< double >::infinity();
320 if(
nnz( x ) != n ) {
321 rc = grb::set< descriptors::invert_mask | descriptors::structural >(
326 assert(
nnz( x ) == n );
329 IOType sigma, bnorm, alpha, beta;
333 ret = ret ? ret : grb::mxv< descr_dense >( temp, A, x, ring );
337 ret = ret ? ret :
grb::foldl( r, b, ring.getAdditiveMonoid() );
339 assert(
nnz( r ) == n );
340 assert(
nnz( temp ) == n );
341 ret = ret ? ret : grb::foldl< descr_dense >( r, temp, minus );
345 if( preconditioned ) {
347 ret = ret ? ret : Minv( z, r );
358 ret = ret ? ret : grb::dot< descr_dense >(
361 ring.getAdditiveMonoid(),
369 ret = ret ? ret : grb::dot< descr_dense >(
372 ring.getAdditiveMonoid(),
378 tol *= std::sqrt( grb::utils::is_complex< IOType >::modulus( bnorm ) );
380 std::cerr <<
"Warning: preconditioned CG caught error during prelude (" 388 assert( iter < max_iterations );
392 ret = ret ? ret : grb::set< descr_dense >( temp, 0 );
394 ret = ret ? ret : grb::mxv< descr_dense >( temp, A, u, ring );
399 ret = ret ? ret : grb::dot< descr_dense >(
402 ring.getAdditiveMonoid(),
408 ret = ret ? ret :
grb::apply( alpha, sigma, beta, divide );
412 ret = ret ? ret : grb::eWiseMul< descr_dense >( x, alpha, u, ring );
416 ret = ret ? ret : grb::foldr< descr_dense >( alpha, temp,
417 ring.getMultiplicativeMonoid() );
419 ret = ret ? ret : grb::foldl< descr_dense >( r, temp, minus );
426 ret = ret ? ret : grb::dot< descr_dense >(
429 ring.getAdditiveMonoid(),
433 residual = grb::utils::is_complex< IOType >::modulus( alpha );
437 if( sqrt( residual ) < tol || iter >= max_iterations ) {
break; }
444 if( preconditioned ) {
446 ret = ret ? ret : Minv( z, r ); assert( ret ==
grb::SUCCESS );
447 ret = ret ? ret : grb::dot< descr_dense >(
450 ring.getAdditiveMonoid(),
459 ret = ret ? ret :
grb::apply( alpha, beta, sigma, divide );
463 ret = ret ? ret : grb::foldr< descr_dense >( alpha, u,
464 ring.getMultiplicativeMonoid() );
466 ret = ret ? ret : grb::foldr< descr_dense >( z, u,
467 ring.getAdditiveMonoid() );
478 if( std::sqrt( residual ) >= tol ) {
496 typename ResidualType,
497 typename NonzeroType,
505 typename RSI,
typename NZI,
Backend backend
511 const size_t max_iterations,
514 ResidualType &residual,
518 const Ring &ring = Ring(),
519 const Minus &minus = Minus(),
520 const Divide ÷ = Divide()
523 static_assert( std::is_floating_point< ResidualType >::value,
524 "Can only use the CG algorithm with floating-point residual " 527 std::is_same< IOType, ResidualType >::value &&
528 std::is_same< IOType, NonzeroType >::value &&
529 std::is_same< IOType, InputType >::value
530 ),
"One or more of the provided containers have differing element types " 531 "while the no-casting descriptor has been supplied" 534 std::is_same< NonzeroType, typename Ring::D1 >::value &&
535 std::is_same< IOType, typename Ring::D2 >::value &&
536 std::is_same< InputType, typename Ring::D3 >::value &&
537 std::is_same< InputType, typename Ring::D4 >::value
538 ),
"no_casting descriptor was set, but semiring has incompatible domains " 539 "with the given containers." 542 std::is_same< InputType, typename Minus::D1 >::value &&
543 std::is_same< InputType, typename Minus::D2 >::value &&
544 std::is_same< InputType, typename Minus::D3 >::value
545 ),
"no_casting descriptor was set, but given minus operator has " 546 "incompatible domains with the given containers." 549 std::is_same< ResidualType, typename Divide::D1 >::value &&
550 std::is_same< ResidualType, typename Divide::D2 >::value &&
551 std::is_same< ResidualType, typename Divide::D3 >::value
552 ),
"no_casting descriptor was set, but given divide operator has " 553 "incompatible domains with the given tolerance type." 555 static_assert( std::is_floating_point< ResidualType >::value,
556 "Require floating-point residual type." 565 > dummy_preconditioner =
575 return preconditioned_conjugate_gradient< descr, false >(
577 dummy_preconditioner,
579 iterations, residual,
580 r, u, temp, dummy_buffer,
589 #endif // end _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT RC set(Vector< DataType, backend, Coords > &x, const T val, const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< DataType >::value &&!grb::is_object< T >::value, void >::type *const =nullptr) noexcept
Sets all elements of a vector to the given value.
Definition: io.hpp:858
Standard identity for numerical addition.
Definition: identities.hpp:57
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
An ALP/GraphBLAS matrix.
Definition: matrix.hpp:72
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
A GraphBLAS vector.
Definition: vector.hpp:64
Standard identity for numerical multiplication.
Definition: identities.hpp:79
Numerical substraction of two numbers.
Definition: ops.hpp:301
Conjugate-multiply operator that conjugates the left-hand operand before multiplication.
Definition: ops.hpp:969
grb::RC conjugate_gradient(grb::Vector< IOType, backend > &x, const grb::Matrix< NonzeroType, backend, RSI, RSI, NZI > &A, const grb::Vector< InputType, backend > &b, const size_t max_iterations, ResidualType tol, size_t &iterations, ResidualType &residual, grb::Vector< IOType, backend > &r, grb::Vector< IOType, backend > &u, grb::Vector< IOType, backend > &temp, const Ring &ring=Ring(), const Minus &minus=Minus(), const Divide ÷=Divide())
Solves a linear system with unknown by the Conjugate Gradients (CG) method on general fields.
Definition: conjugate_gradient.hpp:507
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
grb::RC preconditioned_conjugate_gradient(grb::Vector< IOType, backend > &x, const grb::Matrix< NonzeroType, backend, RSI, RSI, NZI > &A, const grb::Vector< InputType, backend > &b, const std::function< grb::RC(grb::Vector< IOType, backend > &, const grb::Vector< IOType, backend > &) > &Minv, const size_t max_iterations, ResidualType tol, size_t &iterations, ResidualType &residual, grb::Vector< IOType, backend > &r, grb::Vector< IOType, backend > &u, grb::Vector< IOType, backend > &temp, grb::Vector< IOType, backend > &temp_precond, const Ring &ring=Ring(), const Minus &minus=Minus(), const Divide ÷=Divide())
Solves a preconditioned linear system with unknown by the Conjugate Gradients (CG) method on genera...
Definition: conjugate_gradient.hpp:191
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
This operator multiplies the two input parameters and writes the result to the output variable.
Definition: ops.hpp:208
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54
size_t nrows(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the row size of a given matrix.
Definition: io.hpp:286
static constexpr Descriptor dense
Indicates that all input and output vectors to an ALP/GraphBLAS primitive are structurally dense.
Definition: descriptors.hpp:151
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
size_t ncols(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the column size of a given matrix.
Definition: io.hpp:339
Conjugate-multiply operator that conjugates the right-hand operand before multiplication.
Definition: ops.hpp:918
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3840
Indicates when one of the grb::algorithms has failed to achieve its intended result,...
Definition: rc.hpp:154
Backend
A collection of all backends.
Definition: backends.hpp:49
This operator takes the sum of the two input parameters and writes it to the output variable.
Definition: ops.hpp:175
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:477
The main header to include in order to use the ALP/GraphBLAS API.
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
Numerical division of two numbers.
Definition: ops.hpp:328
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
size_t capacity(const Vector< InputType, backend, Coords > &x) noexcept
Queries the capacity of the given ALP/GraphBLAS container.
Definition: io.hpp:388
A generalised semiring.
Definition: semiring.hpp:190
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
std::string toString(const RC code)