26#ifndef _H_GRB_ALGORITHMS_SPARSE_NN_SINGLE_INFERENCE
27#define _H_GRB_ALGORITHMS_SPARSE_NN_SINGLE_INFERENCE
35 namespace algorithms {
49 bool thresholded,
typename ThresholdType,
50 typename IOType,
typename WeightType,
typename BiasType,
51 class ReluMonoid,
class Ring,
class MinMonoid
53 grb::RC sparse_nn_single_inference(
57 const std::vector< BiasType > &biases,
58 const ThresholdType threshold,
60 const ReluMonoid &relu,
66 std::is_same< IOType, WeightType >::value &&
67 std::is_same< IOType, BiasType >::value
68 ),
"Input containers have different domains even though the no_casting"
69 "descriptor was given"
72 const size_t num_layers = layers.size();
77 if( num_layers == 0 ) {
80 if( biases.size() != num_layers ) {
89 for(
size_t i = 1; i < num_layers; ++i ) {
94 for(
size_t i = 0; i < num_layers; ++i ) {
133 ret = ret ? ret :
grb::vxm( out, in, ( layers[ 0 ] ), ring );
136 ret = ret ? ret : grb::foldl< descriptors::dense >(
137 out, biases[ 1 ], ring.getAdditiveMonoid()
140 for(
size_t i = 1; ret ==
SUCCESS && i < num_layers - 1; ++i ) {
142 ret = ret ? ret : grb::foldl< descriptors::dense >( out, 0, relu );
146 ret = ret ? ret : grb::foldl< descriptors::dense >( out, threshold, min );
151 std::swap( out, temp );
157 ret = ret ? ret : grb::vxm< descriptors::dense >(
158 out, temp, ( layers[ i ] ), ring
161 ret = ret ? ret : grb::foldl< descriptors::dense >(
162 out, biases[ i + 1 ], ring.getAdditiveMonoid()
166 ret = ret ? ret : grb::foldl< descriptors::dense >( out, 0, relu );
170 ret = ret ? ret : grb::foldl< descriptors::dense >( out, threshold, min );
261 class ReluMonoid = Monoid<
265 class Ring = Semiring<
274 const std::vector< BiasType > &biases,
276 const ReluMonoid &relu = ReluMonoid(),
277 const Ring &ring = Ring()
281 std::is_same< IOType, WeightType >::value &&
282 std::is_same< IOType, BiasType >::value
283 ),
"Input containers have different domains even though the no_casting "
284 "descriptor was given" );
287 > dummyTresholdMonoid;
288 return internal::sparse_nn_single_inference<
294 relu, dummyTresholdMonoid, ring
387 typename ThresholdType = IOType,
391 class ReluMonoid =
Monoid<
404 const std::vector< BiasType > &biases,
405 const ThresholdType threshold,
407 const ReluMonoid &relu = ReluMonoid(),
408 const MinMonoid &min = MinMonoid(),
409 const Ring &ring = Ring()
413 std::is_same< IOType, WeightType >::value &&
414 std::is_same< IOType, BiasType >::value
415 ),
"Input containers have different domains even though the no_casting "
416 "descriptor was given" );
417 return internal::sparse_nn_single_inference<
An ALP/GraphBLAS matrix.
Definition: matrix.hpp:71
A generalised monoid.
Definition: monoid.hpp:54
A generalised semiring.
Definition: semiring.hpp:186
A GraphBLAS vector.
Definition: vector.hpp:64
Standard identity for the minimum operator.
Definition: identities.hpp:101
Standard identity for the maximum operator.
Definition: identities.hpp:124
Standard identity for numerical multiplication.
Definition: identities.hpp:79
Standard identity for numerical addition.
Definition: identities.hpp:57
This operator takes the sum of the two input parameters and writes it to the output variable.
Definition: ops.hpp:174
This operator takes the minimum of the two input parameters and writes the result to the output varia...
Definition: ops.hpp:274
This operator multiplies the two input parameters and writes the result to the output variable.
Definition: ops.hpp:208
This operation is equivalent to grb::operators::min.
Definition: ops.hpp:515
The main header to include in order to use the ALP/GraphBLAS API.
RC vxm(Vector< IOType, backend, Coords > &u, const Vector< InputType3, backend, Coords > &u_mask, const Vector< InputType1, backend, Coords > &v, const Vector< InputType4, backend, Coords > &v_mask, const Matrix< InputType2, backend, RIT, CIT, NIT > &A, const Semiring &semiring=Semiring(), const Phase &phase=EXECUTE, typename std::enable_if< grb::is_semiring< Semiring >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< InputType3 >::value &&!grb::is_object< InputType4 >::value &&!grb::is_object< IOType >::value, void >::type *=nullptr)
Left-handed in-place doubly-masked sparse matrix times vector multiplication, .
Definition: blas2.hpp:307
size_t capacity(const Vector< InputType, backend, Coords > &x) noexcept
Queries the capacity of the given ALP/GraphBLAS container.
Definition: io.hpp:388
RC set(Vector< DataType, backend, Coords > &x, const T val, const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< DataType >::value &&!grb::is_object< T >::value, void >::type *const =nullptr) noexcept
Sets all elements of a vector to the given value.
Definition: io.hpp:857
size_t nrows(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the row size of a given matrix.
Definition: io.hpp:286
size_t ncols(const Matrix< InputType, backend, RIT, CIT, NIT > &A) noexcept
Requests the column size of a given matrix.
Definition: io.hpp:339
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
grb::RC sparse_nn_single_inference(grb::Vector< IOType > &out, const grb::Vector< IOType > &in, const std::vector< grb::Matrix< WeightType > > &layers, const std::vector< BiasType > &biases, grb::Vector< IOType > &temp, const ReluMonoid &relu=ReluMonoid(), const Ring &ring=Ring())
Performs an inference step of a single data element through a Sparse Neural Network defined by num_la...
Definition: sparse_nn_single_inference.hpp:270
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:450
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
@ ILLEGAL
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
@ MISMATCH
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
@ SUCCESS
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54