ALP User Documentation 0.7.alpha
Algebraic Programming User Documentation
Loading...
Searching...
No Matches
conjugate_gradient.hpp
Go to the documentation of this file.
1
2/*
3 * Copyright 2021 Huawei Technologies Co., Ltd.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
26#ifndef _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT
27#define _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT
28
29#include <cstdio>
30#include <complex>
31
32#include <graphblas.hpp>
33#include <graphblas/utils/iscomplex.hpp>
34
35
36namespace grb {
37
38 namespace algorithms {
39
147 template< Descriptor descr = descriptors::no_operation,
148 typename IOType,
149 typename ResidualType,
150 typename NonzeroType,
151 typename InputType,
152 class Ring = Semiring<
155 >,
156 class Minus = operators::subtract< IOType >,
157 class Divide = operators::divide< IOType >
158 >
163 const size_t max_iterations,
164 ResidualType tol,
165 size_t &iterations,
166 ResidualType &residual,
170 const Ring &ring = Ring(),
171 const Minus &minus = Minus(),
172 const Divide &divide = Divide()
173 ) {
174 // static checks
175 static_assert( std::is_floating_point< ResidualType >::value,
176 "Can only use the CG algorithm with floating-point residual "
177 "types." ); // unless some different norm were used: issue #89
178 static_assert( !( descr & descriptors::no_casting ) || (
179 std::is_same< IOType, ResidualType >::value &&
180 std::is_same< IOType, NonzeroType >::value &&
181 std::is_same< IOType, InputType >::value
182 ), "One or more of the provided containers have differing element types "
183 "while the no-casting descriptor has been supplied"
184 );
185 static_assert( !( descr & descriptors::no_casting ) || (
186 std::is_same< NonzeroType, typename Ring::D1 >::value &&
187 std::is_same< IOType, typename Ring::D2 >::value &&
188 std::is_same< InputType, typename Ring::D3 >::value &&
189 std::is_same< InputType, typename Ring::D4 >::value
190 ), "no_casting descriptor was set, but semiring has incompatible domains "
191 "with the given containers."
192 );
193 static_assert( !( descr & descriptors::no_casting ) || (
194 std::is_same< InputType, typename Minus::D1 >::value &&
195 std::is_same< InputType, typename Minus::D2 >::value &&
196 std::is_same< InputType, typename Minus::D3 >::value
197 ), "no_casting descriptor was set, but given minus operator has "
198 "incompatible domains with the given containers."
199 );
200 static_assert( !( descr & descriptors::no_casting ) || (
201 std::is_same< ResidualType, typename Divide::D1 >::value &&
202 std::is_same< ResidualType, typename Divide::D2 >::value &&
203 std::is_same< ResidualType, typename Divide::D3 >::value
204 ), "no_casting descriptor was set, but given divide operator has "
205 "incompatible domains with the given tolerance type."
206 );
207 static_assert( std::is_floating_point< ResidualType >::value,
208 "Require floating-point residual type."
209 );
210
211 constexpr const Descriptor descr_dense = descr | descriptors::dense;
212 const ResidualType zero_residual = ring.template getZero< ResidualType >();
213 const IOType zero = ring.template getZero< IOType >();
214 const size_t n = grb::ncols( A );
215
216 // dynamic checks
217 {
218 const size_t m = grb::nrows( A );
219 if( size( x ) != n ) {
220 return MISMATCH;
221 }
222 if( size( b ) != m ) {
223 return MISMATCH;
224 }
225 if( size( r ) != n || size( u ) != n || size( temp ) != n ) {
226 std::cerr << "Error: provided workspace vectors are not of the correct "
227 << "length.\n";
228 return MISMATCH;
229 }
230 if( m != n ) {
231 std::cerr << "Warning: grb::algorithms::conjugate_gradient requires "
232 << "square input matrices, but a non-square input matrix was "
233 << "given instead.\n";
234 return ILLEGAL;
235 }
236
237 // capacities
238 if( capacity( x ) != n ) {
239 return ILLEGAL;
240 }
241 if( capacity( r ) != n || capacity( u ) != n || capacity( temp ) != n ) {
242 return ILLEGAL;
243 }
244
245 // others
246 if( tol <= zero_residual ) {
247 std::cerr << "Error: tolerance input to CG must be strictly positive\n";
248 return ILLEGAL;
249 }
250 }
251
252 // set pure output fields to neutral defaults
253 iterations = 0;
254 residual = std::numeric_limits< double >::infinity();
255
256 // trivial shortcuts
257 if( max_iterations == 0 ) {
258 return FAILED;
259 }
260
261 // make x and b structurally dense (if not already) so that the remainder
262 // algorithm can safely use the dense descriptor for faster operations
263 {
264 RC rc = SUCCESS;
265 if( nnz( x ) != n ) {
266 rc = set< descriptors::invert_mask | descriptors::structural >(
267 x, x, zero
268 );
269 }
270 if( rc != SUCCESS ) {
271 return rc;
272 }
273 assert( nnz( x ) == n );
274 }
275
276 IOType sigma, bnorm, alpha, beta;
277
278 // temp = 0
279 grb::RC ret = grb::set( temp, 0 );
280 assert( ret == SUCCESS );
281
282 // temp = A * x
283 ret = ret ? ret : grb::mxv< descr_dense >( temp, A, x, ring );
284 assert( ret == SUCCESS );
285
286 // r = b - temp;
287 ret = ret ? ret : grb::set( r, zero );
288 ret = ret ? ret : grb::foldl( r, b, ring.getAdditiveMonoid() );
289 assert( nnz( r ) == n );
290 assert( nnz( temp ) == n );
291 ret = ret ? ret : grb::foldl< descr_dense >( r, temp, minus );
292 assert( ret == SUCCESS );
293 assert( nnz( r ) == n );
294
295 // u = r;
296 ret = ret ? ret : grb::set( u, r );
297 assert( ret == SUCCESS );
298
299 // sigma = r' * r;
300 sigma = zero;
301 if( grb::utils::is_complex< IOType >::value ) {
302 ret = ret ? ret : grb::eWiseLambda( [&temp,&r]( const size_t i ) {
303 temp[ i ] = grb::utils::is_complex< IOType >::conjugate( r[ i ] );
304 }, temp
305 );
306 ret = ret ? ret : grb::dot< descr_dense >( sigma, temp, r, ring );
307 } else {
308 ret = ret ? ret : grb::dot< descr_dense >( sigma, r, r, ring );
309 }
310
311 assert( ret == SUCCESS );
312
313 // bnorm = b' * b;
314 bnorm = zero;
315 if( grb::utils::is_complex< IOType >::value ) {
316 ret = ret ? ret : grb::eWiseLambda( [&temp,&b]( const size_t i ) {
317 temp[ i ] = grb::utils::is_complex< IOType >::conjugate( b[ i ] );
318 }, temp
319 );
320 ret = ret ? ret : grb::dot< descr_dense >( bnorm, temp, b, ring );
321 } else {
322 ret = ret ? ret : grb::dot< descr_dense >( bnorm, b, b, ring );
323 }
324 assert( ret == SUCCESS );
325
326 if( ret == SUCCESS ) {
327 tol *= sqrt( grb::utils::is_complex< IOType >::modulus( bnorm ) );
328 }
329
330 size_t iter = 0;
331
332 do {
333 assert( iter < max_iterations );
334 (void) ++iter;
335
336 // temp = 0
337 ret = ret ? ret : grb::set( temp, 0 );
338 assert( ret == SUCCESS );
339
340 // temp = A * u;
341 ret = ret ? ret : grb::mxv< descr_dense >( temp, A, u, ring );
342 assert( ret == SUCCESS );
343
344 // beta = u' * temp
345 beta = zero;
346 if( grb::utils::is_complex< IOType >::value ) {
347 ret = ret ? ret : grb::eWiseLambda( [&u]( const size_t i ) {
348 u[ i ] = grb::utils::is_complex< IOType >::conjugate( u[ i ] );
349 }, u
350 );
351 }
352 ret = ret ? ret : grb::dot< descr_dense >( beta, temp, u, ring );
353 if( grb::utils::is_complex< IOType >::value ) {
354 ret = ret ? ret : grb::eWiseLambda( [&u]( const size_t i ) {
355 u[ i ] = grb::utils::is_complex< IOType >::conjugate( u[ i ] );
356 }, u
357 );
358 }
359 assert( ret == SUCCESS );
360
361 // alpha = sigma / beta;
362 ret = ret ? ret : grb::apply( alpha, sigma, beta, divide );
363 assert( ret == SUCCESS );
364
365 // x = x + alpha * u;
366 ret = ret ? ret : grb::eWiseMul< descr_dense >( x, alpha, u, ring );
367 assert( ret == SUCCESS );
368
369 // temp = alpha .* temp
370 // Warning: operator-based foldr requires temp be dense
371 ret = ret ? ret : grb::foldr( alpha, temp, ring.getMultiplicativeMonoid() );
372 assert( ret == SUCCESS );
373
374 // r = r - temp;
375 ret = ret ? ret : grb::foldl< descr_dense >( r, temp, minus );
376 assert( ret == SUCCESS );
377
378 // beta = r' * r;
379 beta = zero;
380 if( grb::utils::is_complex< IOType >::value ) {
381 ret = ret ? ret : grb::eWiseLambda( [&temp,&r]( const size_t i ) {
382 temp[ i ] = grb::utils::is_complex< IOType >::conjugate( r[ i ] );
383 }, temp
384 );
385 ret = ret ? ret : grb::dot< descr_dense >( beta, temp, r, ring );
386 } else {
387 ret = ret ? ret : grb::dot< descr_dense >( beta, r, r, ring );
388 }
389 residual = grb::utils::is_complex< IOType >::modulus( beta );
390 assert( ret == SUCCESS );
391
392 if( ret == SUCCESS ) {
393 if( sqrt( residual ) < tol || iter >= max_iterations ) {
394 break;
395 }
396 }
397
398 // alpha = beta / sigma;
399 ret = ret ? ret : grb::apply( alpha, beta, sigma, divide );
400 assert( ret == SUCCESS );
401
402 // temp = r + alpha * u;
403 ret = ret ? ret : grb::set( temp, r );
404 assert( ret == SUCCESS );
405 ret = ret ? ret : grb::eWiseMul< descr_dense >( temp, alpha, u, ring );
406 assert( ret == SUCCESS );
407 assert( nnz( temp ) == size( temp ) );
408
409 // u = temp
410 std::swap( u, temp );
411
412 sigma = beta;
413 } while( ret == SUCCESS );
414
415 // output that is independent of error code
416 iterations = iter;
417
418 // return correct error code
419 if( ret == SUCCESS ) {
420 if( sqrt( residual ) >= tol ) {
421 // did not converge within iterations
422 return FAILED;
423 }
424 }
425 return ret;
426 }
427
428 } // namespace algorithms
429
430} // end namespace grb
431
432#endif // end _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT
433
An ALP/GraphBLAS matrix.
Definition: matrix.hpp:71
A GraphBLAS vector.
Definition: vector.hpp:64
Standard identity for numerical multiplication.
Definition: identities.hpp:79
Standard identity for numerical addition.
Definition: identities.hpp:57
This operator takes the sum of the two input parameters and writes it to the output variable.
Definition: ops.hpp:174
This operator multiplies the two input parameters and writes the result to the output variable.
Definition: ops.hpp:208
The main header to include in order to use the ALP/GraphBLAS API.
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3618
RC foldr(const Vector< InputType, backend, Coords > &x, const Vector< MaskType, backend, Coords > &mask, IOType &y, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Folds a vector into a scalar, right-to-left.
Definition: blas1.hpp:3721
RC eWiseLambda(const Func f, const Vector< DataType, backend, Coords > &x, Args...)
Executes an arbitrary element-wise user-defined function f using any number of vectors of equal lengt...
Definition: blas1.hpp:3524
size_t capacity(const Vector< InputType, backend, Coords > &x) noexcept
Queries the capacity of the given ALP/GraphBLAS container.
Definition: io.hpp:388
RC set(Vector< DataType, backend, Coords > &x, const T val, const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< DataType >::value &&!grb::is_object< T >::value, void >::type *const =nullptr) noexcept
Sets all elements of a vector to the given value.
Definition: io.hpp:857
size_t nrows(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the row size of a given matrix.
Definition: io.hpp:286
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
size_t ncols(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the column size of a given matrix.
Definition: io.hpp:339
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
grb::RC conjugate_gradient(grb::Vector< IOType > &x, const grb::Matrix< NonzeroType > &A, const grb::Vector< InputType > &b, const size_t max_iterations, ResidualType tol, size_t &iterations, ResidualType &residual, grb::Vector< IOType > &r, grb::Vector< IOType > &u, grb::Vector< IOType > &temp, const Ring &ring=Ring(), const Minus &minus=Minus(), const Divide &divide=Divide())
Solves a linear system with unknown by the Conjugate Gradients (CG) method on general fields.
Definition: conjugate_gradient.hpp:159
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
static constexpr Descriptor dense
Indicates that all input and output vectors to an ALP/GraphBLAS primitive are structurally dense.
Definition: descriptors.hpp:151
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:450
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
@ ILLEGAL
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
@ MISMATCH
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
@ SUCCESS
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
@ FAILED
Indicates when one of the grb::algorithms has failed to achieve its intended result,...
Definition: rc.hpp:154
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54