ALP User Documentation 0.7.alpha
Algebraic Programming User Documentation
Loading...
Searching...
No Matches
cosine_similarity.hpp
Go to the documentation of this file.
1
2/*
3 * Copyright 2021 Huawei Technologies Co., Ltd.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
27#ifndef _H_GRB_COSSIM
28#define _H_GRB_COSSIM
29
30#include <graphblas.hpp>
32
33
34#define NO_CAST_ASSERT( x, y, z ) \
35 static_assert( x, \
36 "\n\n" \
37 "*******************************************************************************************\n" \
38 "* ERROR | " y " " z ".\n" \
39 "*******************************************************************************************\n" \
40 "* Possible fix 1 | Remove no_casting from the template parameters in this call to " y ". \n" \
41 "* Possible fix 2 | For all mismatches in the domains of input and output parameters w.r.t. \n" \
42 "* the semiring domains, as specified in the documentation of the function.\n" \
43 "* supply an input argument of the expected type instead.\n" \
44 "* Possible fix 3 | Provide a compatible semiring where all domains match those of the input\n" \
45 "* parameters, as specified in the documentation of the function.\n" \
46 "*******************************************************************************************\n" );
47
48
49namespace grb {
50
51 namespace algorithms {
52
104 template< Descriptor descr = descriptors::no_operation,
105 typename OutputType,
106 typename InputType1,
107 typename InputType2,
108 class Ring,
109 class Division = grb::operators::divide<
110 typename Ring::D3, typename Ring::D3, typename Ring::D4
111 >
112 >
114 OutputType &similarity,
116 const Ring &ring = Ring(), const Division &div = Division()
117 ) {
118 static_assert( std::is_floating_point< OutputType >::value,
119 "Cosine similarity requires a floating-point output type." );
120
121 // static sanity checks
122 NO_CAST_ASSERT( ( !(descr & descriptors::no_casting) ||
123 std::is_same< InputType1, typename Ring::D1 >::value
124 ), "grb::algorithms::cosine_similarity",
125 "called with a left-hand vector value type that does not match the "
126 "first domain of the given semiring" );
127 NO_CAST_ASSERT( ( !(descr & descriptors::no_casting) ||
128 std::is_same< InputType2, typename Ring::D2 >::value
129 ), "grb::algorithms::cosine_similarity",
130 "called with a right-hand vector value type that does not match "
131 "the second domain of the given semiring" );
132 NO_CAST_ASSERT( ( !(descr & descriptors::no_casting) ||
133 std::is_same< OutputType, typename Ring::D4 >::value
134 ), "grb::algorithms::cosine_similarity",
135 "called with an output vector value type that does not match the "
136 "fourth domain of the given semiring" );
137 NO_CAST_ASSERT( ( !(descr & descriptors::no_casting) ||
138 std::is_same< typename Ring::D3, typename Ring::D4 >::value
139 ), "grb::algorithms::cosine_similarity",
140 "called with a semiring that has unequal additive input domains" );
141
142 const size_t n = size( x );
143
144 // run-time sanity checks
145 if( n != size( y ) ) {
146 return MISMATCH;
147 }
148
149 // check whether inputs are dense
150 const bool dense = nnz( x ) == n && nnz( y ) == n;
151
152 // set return code
153 RC rc = SUCCESS;
154
155 // compute-- choose method depending on we can stream once or need to stream
156 // multiple times
157 OutputType nominator, denominator;
158 nominator = denominator = ring.template getZero< OutputType >();
159 if( dense && grb::Properties<>::writableCaptured ) {
160 // lambda works, so we can stream each vector precisely once:
161 OutputType norm1, norm2;
162 norm1 = norm2 = ring.template getZero< OutputType >();
163 rc = grb::eWiseLambda(
164 [ &x, &y, &nominator, &norm1, &norm2, &ring ]( const size_t i ) {
165 const auto &mul = ring.getMultiplicativeOperator();
166 const auto &add = ring.getAdditiveOperator();
167 OutputType temp;
168 (void)grb::apply( temp, x[ i ], y[ i ], mul );
169 (void)grb::foldl( nominator, temp, add );
170 (void)grb::apply( temp, x[ i ], x[ i ], mul );
171 (void)grb::foldl( norm1, temp, add );
172 (void)grb::apply( temp, y[ i ], y[ i ], mul );
173 (void)grb::foldl( norm2, temp, add );
174 },
175 x, y );
176 denominator = sqrt( norm1 ) * sqrt( norm2 );
177 } else {
178 // cannot stream each vector once, stream each one twice instead using
179 // standard grb functions
180 rc = grb::norm2( nominator, x, ring );
181 if( rc == SUCCESS ) {
182 rc = grb::norm2( denominator, y, ring );
183 }
184 if( rc == SUCCESS ) {
185 rc = grb::foldl( denominator, nominator,
186 ring.getMultiplicativeOperator() );
187 }
188 if( rc == SUCCESS ) {
189 rc = grb::dot( nominator, x, y, ring );
190 }
191 }
192
193 // accumulate
194 if( rc == SUCCESS ) {
195 // catch zeroes
196 if( denominator == ring.template getZero() ) {
197 return ILLEGAL;
198 }
199 if( nominator == ring.template getZero() ) {
200 return ILLEGAL;
201 }
202 rc = grb::apply( similarity, nominator, denominator, div );
203 }
204
205 // done
206 return rc;
207 }
208
209 } // end namespace algorithms
210
211} // end namespace grb
212
213#undef NO_CAST_ASSERT
214
215#endif // end _H_GRB_COSSIM
216
A GraphBLAS vector.
Definition: vector.hpp:64
Numerical division of two numbers.
Definition: ops.hpp:328
The main header to include in order to use the ALP/GraphBLAS API.
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3618
RC dot(OutputType &z, const Vector< InputType1, backend, Coords > &x, const Vector< InputType2, backend, Coords > &y, const AddMonoid &addMonoid=AddMonoid(), const AnyOp &anyOp=AnyOp(), const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< OutputType >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&grb::is_monoid< AddMonoid >::value &&grb::is_operator< AnyOp >::value, void >::type *const =nullptr)
Calculates the dot product, , under a given additive monoid and multiplicative operator.
Definition: blas1.hpp:3834
RC eWiseLambda(const Func f, const Vector< DataType, backend, Coords > &x, Args...)
Executes an arbitrary element-wise user-defined function f using any number of vectors of equal lengt...
Definition: blas1.hpp:3524
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
RC norm2(OutputType &x, const Vector< InputType, backend, Coords > &y, const Ring &ring=Ring(), const typename std::enable_if< std::is_floating_point< OutputType >::value, void >::type *const =nullptr)
Provides a generic implementation of the 2-norm computation.
Definition: norm.hpp:76
RC cosine_similarity(OutputType &similarity, const Vector< InputType1 > &x, const Vector< InputType2 > &y, const Ring &ring=Ring(), const Division &div=Division())
Computes the cosine similarity.
Definition: cosine_similarity.hpp:113
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:450
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
@ ILLEGAL
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
@ MISMATCH
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90
@ SUCCESS
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54
Implements the 2-norm.