ALP User Documentation  0.8.preview
Algebraic Programming User Documentation
cosine_similarity.hpp
Go to the documentation of this file.
1 
2 /*
3  * Copyright 2021 Huawei Technologies Co., Ltd.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
27 #ifndef _H_GRB_COSSIM
28 #define _H_GRB_COSSIM
29 
30 #include <graphblas.hpp>
32 
33 
34 #define NO_CAST_ASSERT( x, y, z ) \
35  static_assert( x, \
36  "\n\n" \
37  "*******************************************************************************************\n" \
38  "* ERROR | " y " " z ".\n" \
39  "*******************************************************************************************\n" \
40  "* Possible fix 1 | Remove no_casting from the template parameters in this call to " y ". \n" \
41  "* Possible fix 2 | For all mismatches in the domains of input and output parameters w.r.t. \n" \
42  "* the semiring domains, as specified in the documentation of the function.\n" \
43  "* supply an input argument of the expected type instead.\n" \
44  "* Possible fix 3 | Provide a compatible semiring where all domains match those of the input\n" \
45  "* parameters, as specified in the documentation of the function.\n" \
46  "*******************************************************************************************\n" );
47 
48 
49 namespace grb {
50 
51  namespace algorithms {
52 
106  template<
108  typename OutputType,
109  typename InputType1,
110  typename InputType2,
111  class Ring,
112  class Division = grb::operators::divide<
113  typename Ring::D3, typename Ring::D3, typename Ring::D4
114  >
115  >
117  OutputType &similarity,
118  const Vector< InputType1 > &x, const Vector< InputType2 > &y,
119  const Ring &ring = Ring(), const Division &div = Division()
120  ) {
121  static_assert( std::is_floating_point< OutputType >::value,
122  "Cosine similarity requires a floating-point output type." );
123 
124  // static sanity checks
125  NO_CAST_ASSERT( ( !(descr & descriptors::no_casting) ||
126  std::is_same< InputType1, typename Ring::D1 >::value
127  ), "grb::algorithms::cosine_similarity",
128  "called with a left-hand vector value type that does not match the "
129  "first domain of the given semiring" );
130  NO_CAST_ASSERT( ( !(descr & descriptors::no_casting) ||
131  std::is_same< InputType2, typename Ring::D2 >::value
132  ), "grb::algorithms::cosine_similarity",
133  "called with a right-hand vector value type that does not match "
134  "the second domain of the given semiring" );
135  NO_CAST_ASSERT( ( !(descr & descriptors::no_casting) ||
136  std::is_same< OutputType, typename Ring::D4 >::value
137  ), "grb::algorithms::cosine_similarity",
138  "called with an output vector value type that does not match the "
139  "fourth domain of the given semiring" );
140  NO_CAST_ASSERT( ( !(descr & descriptors::no_casting) ||
141  std::is_same< typename Ring::D3, typename Ring::D4 >::value
142  ), "grb::algorithms::cosine_similarity",
143  "called with a semiring that has unequal additive input domains" );
144 
145  const size_t n = size( x );
146 
147  // run-time sanity checks
148  if( n != size( y ) ) {
149  return MISMATCH;
150  }
151 
152  // check whether inputs are dense
153  const bool dense = nnz( x ) == n && nnz( y ) == n;
154 
155  // set return code
156  RC rc = SUCCESS;
157 
158  // compute-- choose method depending on we can stream once or need to stream
159  // multiple times
160  OutputType nominator, denominator;
161  nominator = denominator = ring.template getZero< OutputType >();
163  // lambda works, so we can stream each vector precisely once:
164  OutputType norm1, norm2;
165  norm1 = norm2 = ring.template getZero< OutputType >();
166  rc = grb::eWiseLambda(
167  [ &x, &y, &nominator, &norm1, &norm2, &ring ]( const size_t i ) {
168  const auto &mul = ring.getMultiplicativeOperator();
169  const auto &add = ring.getAdditiveOperator();
170  OutputType temp;
171  (void) grb::apply( temp, x[ i ], y[ i ], mul );
172  (void) grb::foldl( nominator, temp, add );
173  (void) grb::apply( temp, x[ i ], x[ i ], mul );
174  (void) grb::foldl( norm1, temp, add );
175  (void) grb::apply( temp, y[ i ], y[ i ], mul );
176  (void) grb::foldl( norm2, temp, add );
177  }, x, y
178  );
179  denominator = sqrt( norm1 ) * sqrt( norm2 );
180  } else {
181  // cannot stream each vector once, stream each one twice instead using
182  // standard grb functions
183  rc = grb::norm2( nominator, x, ring );
184  if( rc == SUCCESS ) {
185  rc = grb::norm2( denominator, y, ring );
186  }
187  if( rc == SUCCESS ) {
188  rc = grb::foldl( denominator, nominator,
189  ring.getMultiplicativeOperator() );
190  }
191  if( rc == SUCCESS ) {
192  rc = grb::dot( nominator, x, y, ring );
193  }
194  }
195 
196  // accumulate
197  if( rc == SUCCESS ) {
198  // catch zeroes
199  if( denominator == ring.template getZero() ) {
200  return ILLEGAL;
201  }
202  if( nominator == ring.template getZero() ) {
203  return ILLEGAL;
204  }
205  rc = grb::apply( similarity, nominator, denominator, div );
206  }
207 
208  // done
209  return rc;
210  }
211 
212  } // end namespace algorithms
213 
214 } // end namespace grb
215 
216 #undef NO_CAST_ASSERT
217 
218 #endif // end _H_GRB_COSSIM
219 
A call to a primitive has determined that one of its arguments was illegal as per the specification o...
Definition: rc.hpp:143
RC
Return codes of ALP primitives.
Definition: rc.hpp:47
A GraphBLAS vector.
Definition: vector.hpp:64
static constexpr Descriptor no_casting
Disallows the standard casting of input parameters to a compatible domain in case they did not match ...
Definition: descriptors.hpp:196
size_t nnz(const Vector< DataType, backend, Coords > &x) noexcept
Request the number of nonzeroes in a given vector.
Definition: io.hpp:479
static constexpr Descriptor no_operation
Indicates no additional pre- or post-processing on any of the GraphBLAS function arguments.
Definition: descriptors.hpp:63
unsigned int Descriptor
Descriptors indicate pre- or post-processing for some or all of the arguments to an ALP/GraphBLAS cal...
Definition: descriptors.hpp:54
static constexpr Descriptor dense
Indicates that all input and output vectors to an ALP/GraphBLAS primitive are structurally dense.
Definition: descriptors.hpp:151
static enum RC apply(OutputType &out, const InputType1 &x, const InputType2 &y, const OP &op=OP(), const typename std::enable_if< grb::is_operator< OP >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&!grb::is_object< OutputType >::value, void >::type *=nullptr)
Out-of-place application of the operator OP on two data elements.
Definition: blas0.hpp:179
RC cosine_similarity(OutputType &similarity, const Vector< InputType1 > &x, const Vector< InputType2 > &y, const Ring &ring=Ring(), const Division &div=Division())
Computes the cosine similarity.
Definition: cosine_similarity.hpp:116
Implements the 2-norm.
RC foldl(IOType &x, const Vector< InputType, backend, Coords > &y, const Vector< MaskType, backend, Coords > &mask, const Monoid &monoid=Monoid(), const typename std::enable_if< !grb::is_object< IOType >::value &&!grb::is_object< InputType >::value &&!grb::is_object< MaskType >::value &&grb::is_monoid< Monoid >::value, void >::type *const =nullptr)
Reduces, or folds, a vector into a scalar.
Definition: blas1.hpp:3840
RC norm2(OutputType &x, const Vector< InputType, backend, Coords > &y, const Ring &ring=Ring(), const std::function< OutputType(OutputType) > sqrtX=std_sqrt< OutputType, OutputType >, const typename std::enable_if< std::is_floating_point< OutputType >::value, void >::type *=nullptr)
Provides a generic implementation of the 2-norm computation.
Definition: norm.hpp:93
Collection of various properties on the given ALP/GraphBLAS backend.
Definition: properties.hpp:52
The ALP/GraphBLAS namespace.
Definition: graphblas.hpp:477
The main header to include in order to use the ALP/GraphBLAS API.
RC dot(OutputType &z, const Vector< InputType1, backend, Coords > &x, const Vector< InputType2, backend, Coords > &y, const AddMonoid &addMonoid=AddMonoid(), const AnyOp &anyOp=AnyOp(), const Phase &phase=EXECUTE, const typename std::enable_if< !grb::is_object< OutputType >::value &&!grb::is_object< InputType1 >::value &&!grb::is_object< InputType2 >::value &&grb::is_monoid< AddMonoid >::value &&grb::is_operator< AnyOp >::value, void >::type *const =nullptr)
Calculates the dot product, , under a given additive monoid and multiplicative operator.
Definition: blas1.hpp:4056
RC eWiseLambda(const Func f, const Vector< DataType, backend, Coords > &x, Args...)
Executes an arbitrary element-wise user-defined function f on any number of vectors of equal length.
Definition: blas1.hpp:3746
size_t size(const Vector< DataType, backend, Coords > &x) noexcept
Request the size of a given vector.
Definition: io.hpp:235
Numerical division of two numbers.
Definition: ops.hpp:328
Indicates the primitive has executed successfully.
Definition: rc.hpp:54
One or more of the ALP/GraphBLAS objects passed to the primitive that returned this error have mismat...
Definition: rc.hpp:90