26#ifndef _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT
27#define _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT
33#include <graphblas/utils/iscomplex.hpp>
38 namespace algorithms {
149 typename ResidualType,
150 typename NonzeroType,
152 class Ring = Semiring<
156 class Minus = operators::subtract< IOType >,
157 class Divide = operators::divide< IOType >
163 const size_t max_iterations,
166 ResidualType &residual,
170 const Ring &ring = Ring(),
171 const Minus &minus = Minus(),
172 const Divide ÷ = Divide()
175 static_assert( std::is_floating_point< ResidualType >::value,
176 "Can only use the CG algorithm with floating-point residual "
179 std::is_same< IOType, ResidualType >::value &&
180 std::is_same< IOType, NonzeroType >::value &&
181 std::is_same< IOType, InputType >::value
182 ),
"One or more of the provided containers have differing element types "
183 "while the no-casting descriptor has been supplied"
186 std::is_same< NonzeroType, typename Ring::D1 >::value &&
187 std::is_same< IOType, typename Ring::D2 >::value &&
188 std::is_same< InputType, typename Ring::D3 >::value &&
189 std::is_same< InputType, typename Ring::D4 >::value
190 ),
"no_casting descriptor was set, but semiring has incompatible domains "
191 "with the given containers."
194 std::is_same< InputType, typename Minus::D1 >::value &&
195 std::is_same< InputType, typename Minus::D2 >::value &&
196 std::is_same< InputType, typename Minus::D3 >::value
197 ),
"no_casting descriptor was set, but given minus operator has "
198 "incompatible domains with the given containers."
201 std::is_same< ResidualType, typename Divide::D1 >::value &&
202 std::is_same< ResidualType, typename Divide::D2 >::value &&
203 std::is_same< ResidualType, typename Divide::D3 >::value
204 ),
"no_casting descriptor was set, but given divide operator has "
205 "incompatible domains with the given tolerance type."
207 static_assert( std::is_floating_point< ResidualType >::value,
208 "Require floating-point residual type."
212 const ResidualType zero_residual = ring.template getZero< ResidualType >();
213 const IOType zero = ring.template getZero< IOType >();
219 if(
size( x ) != n ) {
222 if(
size( b ) != m ) {
225 if(
size( r ) != n ||
size( u ) != n ||
size( temp ) != n ) {
226 std::cerr <<
"Error: provided workspace vectors are not of the correct "
231 std::cerr <<
"Warning: grb::algorithms::conjugate_gradient requires "
232 <<
"square input matrices, but a non-square input matrix was "
233 <<
"given instead.\n";
246 if( tol <= zero_residual ) {
247 std::cerr <<
"Error: tolerance input to CG must be strictly positive\n";
254 residual = std::numeric_limits< double >::infinity();
257 if( max_iterations == 0 ) {
265 if(
nnz( x ) != n ) {
266 rc = set< descriptors::invert_mask | descriptors::structural >(
273 assert(
nnz( x ) == n );
276 IOType sigma, bnorm, alpha, beta;
283 ret = ret ? ret : grb::mxv< descr_dense >( temp, A, x, ring );
287 ret = ret ? ret :
grb::set( r, zero );
288 ret = ret ? ret :
grb::foldl( r, b, ring.getAdditiveMonoid() );
289 assert(
nnz( r ) == n );
290 assert(
nnz( temp ) == n );
291 ret = ret ? ret : grb::foldl< descr_dense >( r, temp, minus );
293 assert(
nnz( r ) == n );
301 if( grb::utils::is_complex< IOType >::value ) {
303 temp[ i ] = grb::utils::is_complex< IOType >::conjugate( r[ i ] );
306 ret = ret ? ret : grb::dot< descr_dense >( sigma, temp, r, ring );
308 ret = ret ? ret : grb::dot< descr_dense >( sigma, r, r, ring );
315 if( grb::utils::is_complex< IOType >::value ) {
317 temp[ i ] = grb::utils::is_complex< IOType >::conjugate( b[ i ] );
320 ret = ret ? ret : grb::dot< descr_dense >( bnorm, temp, b, ring );
322 ret = ret ? ret : grb::dot< descr_dense >( bnorm, b, b, ring );
327 tol *= sqrt( grb::utils::is_complex< IOType >::modulus( bnorm ) );
333 assert( iter < max_iterations );
337 ret = ret ? ret :
grb::set( temp, 0 );
341 ret = ret ? ret : grb::mxv< descr_dense >( temp, A, u, ring );
346 if( grb::utils::is_complex< IOType >::value ) {
348 u[ i ] = grb::utils::is_complex< IOType >::conjugate( u[ i ] );
352 ret = ret ? ret : grb::dot< descr_dense >( beta, temp, u, ring );
353 if( grb::utils::is_complex< IOType >::value ) {
355 u[ i ] = grb::utils::is_complex< IOType >::conjugate( u[ i ] );
362 ret = ret ? ret :
grb::apply( alpha, sigma, beta, divide );
366 ret = ret ? ret : grb::eWiseMul< descr_dense >( x, alpha, u, ring );
371 ret = ret ? ret :
grb::foldr( alpha, temp, ring.getMultiplicativeMonoid() );
375 ret = ret ? ret : grb::foldl< descr_dense >( r, temp, minus );
380 if( grb::utils::is_complex< IOType >::value ) {
382 temp[ i ] = grb::utils::is_complex< IOType >::conjugate( r[ i ] );
385 ret = ret ? ret : grb::dot< descr_dense >( beta, temp, r, ring );
387 ret = ret ? ret : grb::dot< descr_dense >( beta, r, r, ring );
389 residual = grb::utils::is_complex< IOType >::modulus( beta );
393 if( sqrt( residual ) < tol || iter >= max_iterations ) {
399 ret = ret ? ret :
grb::apply( alpha, beta, sigma, divide );
403 ret = ret ? ret :
grb::set( temp, r );
405 ret = ret ? ret : grb::eWiseMul< descr_dense >( temp, alpha, u, ring );
407 assert(
nnz( temp ) ==
size( temp ) );
410 std::swap( u, temp );
420 if( sqrt( residual ) >= tol ) {
An ALP/GraphBLAS matrix.
Definition: matrix.hpp:71
A GraphBLAS vector.
Definition: vector.hpp:64
Standard identity for numerical multiplication.
Definition: identities.hpp:79
Standard identity for numerical addition.
Definition: identities.hpp:57
This operator takes the sum of the two input parameters and writes it to the output variable.
Definition: ops.hpp:174
This operator multiplies the two input parameters and writes the result to the output variable.
Definition: ops.hpp:208
The main header to include in order to use the ALP/GraphBLAS API.
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3618
RC foldr(const Vector< InputType, backend, Coords > &x, const Vector< MaskType, backend, Coords > &mask, IOType &y, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Folds a vector into a scalar, right-to-left.
Definition: blas1.hpp:3721
RC eWiseLambda(const Func f, const Vector< DataType, backend, Coords > &x, Args...)
Executes an arbitrary element-wise user-defined function f using any number of vectors of equal lengt...
Definition: blas1.hpp:3524
size_t capacity(const Vector< InputType, backend, Coords > &x) noexcept
Queries the capacity of the given ALP/GraphBLAS container.
Definition: io.hpp:388
RC set(Vector< DataType, backend, Coords > &x, const T val, const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< DataType >::value &&!grb::is_object< T >::value, void >::type *const =nullptr) noexcept
Sets all elements of a vector to the given value.
Definition: io.hpp:857
size_t nrows(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the row size of a given matrix.
Definition: io.hpp:286
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
size_t ncols(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the column size of a given matrix.
Definition: io.hpp:339
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
grb::RC conjugate_gradient(grb::Vector< IOType > &x, const grb::Matrix< NonzeroType > &A, const grb::Vector< InputType > &b, const size_t max_iterations, ResidualType tol, size_t &iterations, ResidualType &residual, grb::Vector< IOType > &r, grb::Vector< IOType > &u, grb::Vector< IOType > &temp, const Ring &ring=Ring(), const Minus &minus=Minus(), const Divide ÷=Divide())
Solves a linear system with unknown by the Conjugate Gradients (CG) method on general fields.
Definition: conjugate_gradient.hpp:159
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
static constexpr Descriptor dense
Indicates that all input and output vectors to an ALP/GraphBLAS primitive are structurally dense.
Definition: descriptors.hpp:151
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:450
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
@ ILLEGAL
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
@ MISMATCH
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
@ SUCCESS
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
@ FAILED
Indicates when one of the grb::algorithms has failed to achieve its intended result,...
Definition: rc.hpp:154
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54