26#ifndef _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT 
   27#define _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT 
   33#include <graphblas/utils/iscomplex.hpp> 
   38    namespace algorithms {
 
  149            typename ResidualType,
 
  150            typename NonzeroType,
 
  152            class Ring = Semiring<
 
  156            class Minus = operators::subtract< IOType >,
 
  157            class Divide = operators::divide< IOType >
 
  163            const size_t max_iterations,
 
  166            ResidualType &residual,
 
  170            const Ring &ring = Ring(),
 
  171            const Minus &minus = Minus(),
 
  172            const Divide ÷ = Divide()
 
  175            static_assert( std::is_floating_point< ResidualType >::value,
 
  176                "Can only use the CG algorithm with floating-point residual " 
  179                    std::is_same< IOType, ResidualType >::value &&
 
  180                    std::is_same< IOType, NonzeroType >::value &&
 
  181                    std::is_same< IOType, InputType >::value
 
  182                ), 
"One or more of the provided containers have differing element types " 
  183                "while the no-casting descriptor has been supplied" 
  186                    std::is_same< NonzeroType, typename Ring::D1 >::value &&
 
  187                    std::is_same< IOType, typename Ring::D2 >::value &&
 
  188                    std::is_same< InputType, typename Ring::D3 >::value &&
 
  189                    std::is_same< InputType, typename Ring::D4 >::value
 
  190                ), 
"no_casting descriptor was set, but semiring has incompatible domains " 
  191                "with the given containers." 
  194                    std::is_same< InputType, typename Minus::D1 >::value &&
 
  195                    std::is_same< InputType, typename Minus::D2 >::value &&
 
  196                    std::is_same< InputType, typename Minus::D3 >::value
 
  197                ), 
"no_casting descriptor was set, but given minus operator has " 
  198                "incompatible domains with the given containers." 
  201                    std::is_same< ResidualType, typename Divide::D1 >::value &&
 
  202                    std::is_same< ResidualType, typename Divide::D2 >::value &&
 
  203                    std::is_same< ResidualType, typename Divide::D3 >::value
 
  204                ), 
"no_casting descriptor was set, but given divide operator has " 
  205                "incompatible domains with the given tolerance type." 
  207            static_assert( std::is_floating_point< ResidualType >::value,
 
  208                "Require floating-point residual type." 
  212            const ResidualType zero_residual = ring.template getZero< ResidualType >();
 
  213            const IOType zero = ring.template getZero< IOType >();
 
  219                if( 
size( x ) != n ) {
 
  222                if( 
size( b ) != m ) {
 
  225                if( 
size( r ) != n || 
size( u ) != n || 
size( temp ) != n ) {
 
  226                    std::cerr << 
"Error: provided workspace vectors are not of the correct " 
  231                    std::cerr << 
"Warning: grb::algorithms::conjugate_gradient requires " 
  232                        << 
"square input matrices, but a non-square input matrix was " 
  233                        << 
"given instead.\n";
 
  246                if( tol <= zero_residual ) {
 
  247                    std::cerr << 
"Error: tolerance input to CG must be strictly positive\n";
 
  254            residual = std::numeric_limits< double >::infinity();
 
  257            if( max_iterations == 0 ) {
 
  265                if( 
nnz( x ) != n ) {
 
  266                    rc = set< descriptors::invert_mask | descriptors::structural >(
 
  273                assert( 
nnz( x ) == n );
 
  276            IOType sigma, bnorm, alpha, beta;
 
  283            ret = ret ? ret : grb::mxv< descr_dense >( temp, A, x, ring );
 
  287            ret = ret ? ret : 
grb::set( r, zero );
 
  288            ret = ret ? ret : 
grb::foldl( r, b, ring.getAdditiveMonoid() );
 
  289            assert( 
nnz( r ) == n );
 
  290            assert( 
nnz( temp ) == n );
 
  291            ret = ret ? ret : grb::foldl< descr_dense >( r, temp, minus );
 
  293            assert( 
nnz( r ) == n );
 
  301            if( grb::utils::is_complex< IOType >::value ) {
 
  303                        temp[ i ] = grb::utils::is_complex< IOType >::conjugate( r[ i ] );
 
  306                ret = ret ? ret : grb::dot< descr_dense >( sigma, temp, r, ring );
 
  308                ret = ret ? ret : grb::dot< descr_dense >( sigma, r, r, ring );
 
  315            if( grb::utils::is_complex< IOType >::value ) {
 
  317                        temp[ i ] = grb::utils::is_complex< IOType >::conjugate( b[ i ] );
 
  320                ret = ret ? ret : grb::dot< descr_dense >( bnorm, temp, b, ring );
 
  322                ret = ret ? ret : grb::dot< descr_dense >( bnorm, b, b, ring );
 
  327                tol *= sqrt( grb::utils::is_complex< IOType >::modulus( bnorm ) );
 
  333                assert( iter < max_iterations );
 
  337                ret = ret ? ret : 
grb::set( temp, 0 );
 
  341                ret = ret ? ret : grb::mxv< descr_dense >( temp, A, u, ring );
 
  346                if( grb::utils::is_complex< IOType >::value ) {
 
  348                            u[ i ] = grb::utils::is_complex< IOType >::conjugate( u[ i ] );
 
  352                ret = ret ? ret : grb::dot< descr_dense >( beta, temp, u, ring );
 
  353                if( grb::utils::is_complex< IOType >::value ) {
 
  355                            u[ i ] = grb::utils::is_complex< IOType >::conjugate( u[ i ] );
 
  362                ret = ret ? ret : 
grb::apply( alpha, sigma, beta, divide );
 
  366                ret = ret ? ret : grb::eWiseMul< descr_dense >( x, alpha, u, ring );
 
  371                ret = ret ? ret : 
grb::foldr( alpha, temp, ring.getMultiplicativeMonoid() );
 
  375                ret = ret ? ret : grb::foldl< descr_dense >( r, temp, minus );
 
  380                if( grb::utils::is_complex< IOType >::value ) {
 
  382                            temp[ i ] = grb::utils::is_complex< IOType >::conjugate( r[ i ] );
 
  385                    ret = ret ? ret : grb::dot< descr_dense >( beta, temp, r, ring );
 
  387                    ret = ret ? ret : grb::dot< descr_dense >( beta, r, r, ring );
 
  389                residual = grb::utils::is_complex< IOType >::modulus( beta );
 
  393                    if( sqrt( residual ) < tol || iter >= max_iterations ) {
 
  399                ret = ret ? ret : 
grb::apply( alpha, beta, sigma, divide );
 
  403                ret = ret ? ret : 
grb::set( temp, r );
 
  405                ret = ret ? ret : grb::eWiseMul< descr_dense >( temp, alpha, u, ring );
 
  407                assert( 
nnz( temp ) == 
size( temp ) );
 
  410                std::swap( u, temp );
 
  420                if( sqrt( residual ) >= tol ) {
 
An ALP/GraphBLAS matrix.
Definition: matrix.hpp:71
 
A GraphBLAS vector.
Definition: vector.hpp:64
 
Standard identity for numerical multiplication.
Definition: identities.hpp:79
 
Standard identity for numerical addition.
Definition: identities.hpp:57
 
This operator takes the sum of the two input parameters and writes it to the output variable.
Definition: ops.hpp:177
 
This operator multiplies the two input parameters and writes the result to the output variable.
Definition: ops.hpp:210
 
The main header to include in order to use the ALP/GraphBLAS API.
 
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
 
RC foldr(const Vector< InputType, backend, Coords > &x, const Vector< MaskType, backend, Coords > &mask, IOType &y, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Folds a vector into a scalar, right-to-left.
Definition: blas1.hpp:3943
 
RC eWiseLambda(const Func f, const Vector< DataType, backend, Coords > &x, Args...)
Executes an arbitrary element-wise user-defined function f on any number of vectors of equal length.
Definition: blas1.hpp:3746
 
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3840
 
RC set(Vector< DataType, backend, Coords > &x, const T val, const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< DataType >::value &&!grb::is_object< T >::value, void >::type *const =nullptr) noexcept
Sets all elements of a vector to the given value.
Definition: io.hpp:857
 
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
 
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
 
size_t ncols(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the column size of a given matrix.
Definition: io.hpp:339
 
size_t capacity(const Vector< InputType, backend, Coords > &x) noexcept
Queries the capacity of the given ALP/GraphBLAS container.
Definition: io.hpp:388
 
size_t nrows(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the row size of a given matrix.
Definition: io.hpp:286
 
grb::RC conjugate_gradient(grb::Vector< IOType > &x, const grb::Matrix< NonzeroType > &A, const grb::Vector< InputType > &b, const size_t max_iterations, ResidualType tol, size_t &iterations, ResidualType &residual, grb::Vector< IOType > &r, grb::Vector< IOType > &u, grb::Vector< IOType > &temp, const Ring &ring=Ring(), const Minus &minus=Minus(), const Divide ÷=Divide())
Solves a linear system  with  unknown by the Conjugate Gradients (CG) method on general fields.
Definition: conjugate_gradient.hpp:159
 
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
 
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
 
static constexpr Descriptor dense
Indicates that all input and output vectors to an ALP/GraphBLAS primitive are structurally dense.
Definition: descriptors.hpp:151
 
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:452
 
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
 
@ ILLEGAL
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
 
@ MISMATCH
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
 
@ SUCCESS
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
 
@ FAILED
Indicates when one of the grb::algorithms has failed to achieve its intended result,...
Definition: rc.hpp:154
 
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54