34#define NO_CAST_ASSERT( x, y, z ) \
37 "*******************************************************************************************\n" \
38 "* ERROR | " y " " z ".\n" \
39 "*******************************************************************************************\n" \
40 "* Possible fix 1 | Remove no_casting from the template parameters in this call to " y ". \n" \
41 "* Possible fix 2 | For all mismatches in the domains of input and output parameters w.r.t. \n" \
42 "* the semiring domains, as specified in the documentation of the function.\n" \
43 "* supply an input argument of the expected type instead.\n" \
44 "* Possible fix 3 | Provide a compatible semiring where all domains match those of the input\n" \
45 "* parameters, as specified in the documentation of the function.\n" \
46 "*******************************************************************************************\n" );
51 namespace algorithms {
110 typename Ring::D3,
typename Ring::D3,
typename Ring::D4
114 OutputType &similarity,
116 const Ring &ring = Ring(),
const Division &div = Division()
118 static_assert( std::is_floating_point< OutputType >::value,
119 "Cosine similarity requires a floating-point output type." );
123 std::is_same< InputType1, typename Ring::D1 >::value
124 ),
"grb::algorithms::cosine_similarity",
125 "called with a left-hand vector value type that does not match the "
126 "first domain of the given semiring" );
128 std::is_same< InputType2, typename Ring::D2 >::value
129 ),
"grb::algorithms::cosine_similarity",
130 "called with a right-hand vector value type that does not match "
131 "the second domain of the given semiring" );
133 std::is_same< OutputType, typename Ring::D4 >::value
134 ),
"grb::algorithms::cosine_similarity",
135 "called with an output vector value type that does not match the "
136 "fourth domain of the given semiring" );
138 std::is_same< typename Ring::D3, typename Ring::D4 >::value
139 ),
"grb::algorithms::cosine_similarity",
140 "called with a semiring that has unequal additive input domains" );
142 const size_t n =
size( x );
145 if( n !=
size( y ) ) {
150 const bool dense =
nnz( x ) == n &&
nnz( y ) == n;
157 OutputType nominator, denominator;
158 nominator = denominator = ring.template getZero< OutputType >();
159 if( dense && grb::Properties<>::writableCaptured ) {
161 OutputType norm1,
norm2;
162 norm1 =
norm2 = ring.template getZero< OutputType >();
164 [ &x, &y, &nominator, &norm1, &
norm2, &ring ](
const size_t i ) {
165 const auto &mul = ring.getMultiplicativeOperator();
166 const auto &add = ring.getAdditiveOperator();
168 (void)
grb::apply( temp, x[ i ], y[ i ], mul );
170 (void)
grb::apply( temp, x[ i ], x[ i ], mul );
172 (void)
grb::apply( temp, y[ i ], y[ i ], mul );
176 denominator = sqrt( norm1 ) * sqrt(
norm2 );
180 rc = grb::norm2( nominator, x, ring );
182 rc = grb::norm2( denominator, y, ring );
186 ring.getMultiplicativeOperator() );
189 rc =
grb::dot( nominator, x, y, ring );
196 if( denominator == ring.template getZero() ) {
199 if( nominator == ring.template getZero() ) {
202 rc =
grb::apply( similarity, nominator, denominator, div );
A GraphBLAS vector.
Definition: vector.hpp:64
Numerical division of two numbers.
Definition: ops.hpp:328
The main header to include in order to use the ALP/GraphBLAS API.
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3618
RC dot(OutputType &z, const Vector< InputType1, backend, Coords > &x, const Vector< InputType2, backend, Coords > &y, const AddMonoid &addMonoid=AddMonoid(), const AnyOp &anyOp=AnyOp(), const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< OutputType >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&grb::is_monoid< AddMonoid >::value &&grb::is_operator< AnyOp >::value, void >::type *const =nullptr)
Calculates the dot product, , under a given additive monoid and multiplicative operator.
Definition: blas1.hpp:3834
RC eWiseLambda(const Func f, const Vector< DataType, backend, Coords > &x, Args...)
Executes an arbitrary element-wise user-defined function f using any number of vectors of equal lengt...
Definition: blas1.hpp:3524
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
RC norm2(OutputType &x, const Vector< InputType, backend, Coords > &y, const Ring &ring=Ring(), const typename std::enable_if< std::is_floating_point< OutputType >::value, void >::type *const =nullptr)
Provides a generic implementation of the 2-norm computation.
Definition: norm.hpp:76
RC cosine_similarity(OutputType &similarity, const Vector< InputType1 > &x, const Vector< InputType2 > &y, const Ring &ring=Ring(), const Division &div=Division())
Computes the cosine similarity.
Definition: cosine_similarity.hpp:113
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:450
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
@ ILLEGAL
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
@ MISMATCH
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
@ SUCCESS
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54