00001 #include "Triplet.hpp"
00002 #include "CRS.hpp"
00003 #include "CS_ELEMENT.hpp"
00004 #include "CS_ARRAY.hpp"
00005
00006 #include <cstdlib>
00007 #include <vector>
00008
00009
00010
00011 #ifndef _H_CS_CRS
00012 #define _H_CS_CRS
00013
00014 #ifdef _DEBUG
00015 #include<iostream>
00016 #endif
00017
00021 template< typename T >
00022 class CS_CRS {
00023
00024 private:
00025
00027 typedef unsigned long int ULI;
00028
00029 protected:
00030
00032 CS_ELEMENT< ULI >* nnz;
00033
00035 CS_ELEMENT< ULI >* nor;
00036
00038 CS_ELEMENT< ULI >* noc;
00039
00041 CS_ARRAY< ULI >* row_start;
00042
00044 CS_ARRAY< T >* nzs;
00045
00047 CS_ARRAY< ULI >* col_ind;
00048
00050 CS_ELEMENT< T >* zero_element;
00051
00057 static int compareTriplets( const void * left, const void * right ) {
00058 const Triplet< T > one = **( (Triplet< T > **)left );
00059 const Triplet< T > two = **( (Triplet< T > **)right );
00060 if ( one.j() < two.j() )
00061 return -1;
00062 if ( one.j() > two.j() )
00063 return 1;
00064 return 0;
00065 }
00066
00075 bool find( ULI col_index, ULI search_start, ULI search_end, ULI &ret ) {
00076 for( ULI i=search_start; i<search_end; i++ )
00077 if( (*col_ind)[ i ] == col_index ) {
00078 ret = i;
00079 return true;
00080 }
00081 return false;
00082 }
00083
00084 public:
00085
00095 CS_CRS( const ULI number_of_nonzeros, const ULI number_of_rows, const ULI number_of_cols, const T zero ):
00096 nnz( number_of_nonzeros ), nor( number_of_rows ) {
00097 nnz = new CS_ELEMENT< ULI >( number_of_nonzeros );
00098 nor = new CS_ELEMENT< ULI >( number_of_rows );
00099 noc = new CS_ELEMENT< ULI >( number_of_cols );
00100 zero_element = new CS_ELEMENT< T >( zero );
00101 row_start = new CS_ARRAY< ULI >( nor + 1, "Row start vector (CS_CRS.h)" );
00102 nzs = new CS_ARRAY< T>( nnz, "Non-zero storage vector (CS_CRS.h" );
00103 col_ind = new CS_ARRAY< ULI >( nnz, "Columm index vector (CS_CRS.h)" );
00104 }
00105
00107 CS_CRS( CRS< T >& toCopy ) {
00108 zero_element = new CS_ELEMENT< T >( toCopy.zero_element );
00109 nnz = new CS_ELEMENT< ULI >( toCopy.nnz );
00110 nor = new CS_ELEMENT< ULI >( toCopy.nor );
00111 noc = new CS_ELEMENT< ULI >( toCopy.noc );
00112 row_start = new CS_ARRAY< ULI >( nor + 1, "Row start vector (CS_CRS.h)" );
00113 nzs = new CS_ARRAY< T>( nnz, "Non-zero storage vector (CS_CRS.h" );
00114 col_ind = new CS_ARRAY< ULI >( nnz, "Columm index vector (CS_CRS.h)" );
00115 for( ULI i=0; i<(*nnz).const_access(); i++ ) {
00116 nzs[ i ] = toCopy.nzs->const_access( i );
00117 col_ind[ i ] = toCopy.col_ind->const_access( i );
00118 }
00119 for( ULI i=0; i<(*nor).const_access(); i++ )
00120 row_start[ i ] = toCopy.row_start->const_access( i );
00121 }
00122
00126 CS_CRS( CS_CRS< T >& toCopy ) {
00127 zero_element = toCopy.zero_element;
00128 nnz = toCopy.nnz;
00129 nor = toCopy.nor;
00130 noc = toCopy.noc;
00131 row_start = new CS_ARRAY< ULI >( nor + 1, "Row start vector (CS_CRS.h)" );
00132 nzs = new CS_ARRAY< T>( nnz, "Non-zero storage vector (CS_CRS.h" );
00133 col_ind = new CS_ARRAY< ULI >( nnz, "Columm index vector (CS_CRS.h)" );
00134 for( ULI i=0; i<(*nnz).const_access(); i++ ) {
00135 nzs[ i ] = toCopy.nzs->const_access( i );
00136 col_ind[ i ] = toCopy.col_ind->const_access( i );
00137 }
00138 for( ULI i=0; i<(*nor).const_access(); i++ )
00139 row_start[ i ] = toCopy.row_start->const_access( i );
00140 }
00141
00152 CS_CRS( std::vector< Triplet< T > > input, unsigned long int currow, unsigned long int curcol, T& zero ) {
00153 zero_element = new CS_ELEMENT< T >( zero );
00154 nnz = new CS_ELEMENT< ULI >( input.size() );
00155 nor = new CS_ELEMENT< ULI >( static_cast< ULI >( currow ) );
00156 noc = new CS_ELEMENT< ULI >( static_cast< ULI >( curcol ) );
00157
00158
00159 std::vector< std::vector< Triplet< T >* > > ds( nor->const_access() );
00160
00161
00162 typename std::vector< Triplet< T > >::iterator in_it;
00163 in_it = input.begin();
00164 for( ; in_it != input.end(); in_it++ ) {
00165 Triplet< T >* cur = &(*in_it);
00166 const ULI currow = cur->i();
00167 const T value = cur->value;
00168 if( value == zero_element->const_access() ) { nnz->access() -= 1; continue; }
00169 ds.at( currow ).push_back( cur );
00170 }
00171
00172
00173 row_start = new CS_ARRAY< ULI >( nor->const_access() + 1, "Row start vector (CS_CRS.h)" );
00174 nzs = new CS_ARRAY< T>( nnz->const_access(), "Non-zero storage vector (CS_CRS.h" );
00175 col_ind = new CS_ARRAY< ULI >( nnz->const_access(), "Columm index vector (CS_CRS.h)" );
00176
00177 ULI index = 0;
00178 for( ULI currow = 0; currow < nor->const_access(); currow++ ) {
00179 row_start->access( currow ) = index;
00180 if( ds.at( currow ).size() == 0 ) continue;
00181 qsort( &( ds.at( currow )[ 0 ] ), ds.at( currow ).size(), sizeof( Triplet< T >* ), &compareTriplets );
00182 typename std::vector< Triplet< T >* >::iterator row_it = ds.at( currow ).begin();
00183 for( ; row_it!=ds.at( currow ).end(); row_it++ ) {
00184 const Triplet< T > cur = *(*row_it);
00185 nzs->access( index ) = cur.value;
00186 col_ind->access( index ) = cur.j();
00187 index++;
00188 }
00189 }
00190
00191 row_start->access( nor->const_access() ) = nnz->const_access();
00192 }
00193
00202 CS_CRS( std::vector< Triplet< T > > input, T& zero ) {
00203 zero_element = new CS_ELEMENT< T >( zero );
00204
00205 nnz = new CS_ELEMENT< ULI >( input.size() );
00206
00207
00208 nor = new CS_ELEMENT< ULI >( static_cast< ULI >( 0 ) );
00209 noc = new CS_ELEMENT< ULI >( static_cast< ULI >( 0 ) );
00210 typename std::vector< Triplet< T > >::iterator in_it;
00211 in_it = input.begin();
00212 for( ; in_it!=input.end(); in_it++ ) {
00213 const ULI currow = ( *in_it ).i();
00214 const ULI curcol = ( *in_it ).j();
00215 if ( currow > nor->const_access() ) nor->access() = currow;
00216 if ( curcol > noc->const_access() ) noc->access() = curcol;
00217 }
00218 #ifdef _DEBUG
00219 std::cout << "Max row index found: " << nor->const_access() << std::endl;
00220 #endif
00221 nor->access() += 1;
00222 noc->access() += 1;
00223
00224
00225 std::vector< std::vector< Triplet< T >* > > ds( nor->const_access() );
00226
00227
00228 in_it = input.begin();
00229 for( ; in_it != input.end(); in_it++ ) {
00230 Triplet< T >* cur = &(*in_it);
00231 const ULI currow = cur->i();
00232 const T value = cur->value;
00233 if( value == zero_element->const_access() ) { nnz->access() -= 1; continue; }
00234 ds.at( currow ).push_back( cur );
00235 }
00236
00237
00238 row_start = new CS_ARRAY< ULI >( nor->const_access() + 1, "Row start vector (CS_CRS.h)" );
00239 nzs = new CS_ARRAY< T>( nnz->const_access(), "Non-zero storage vector (CS_CRS.h" );
00240 col_ind = new CS_ARRAY< ULI >( nnz->const_access(), "Columm index vector (CS_CRS.h)" );
00241
00242
00243 ULI index = 0;
00244 for( ULI currow = 0; currow < nor->const_access(); currow++ ) {
00245 #ifdef _DEBUG
00246 std::cout << "row_start[ " << currow << " ]=" << index << std::endl;
00247 #endif
00248 row_start->access( currow ) = index;
00249 if( ds.at( currow ).size() == 0 ) continue;
00250 typename std::vector< Triplet< T >* >::iterator row_it = ds.at( currow ).begin();
00251 for( ; row_it!=ds.at( currow ).end(); row_it++ ) {
00252 const Triplet< T > cur = *(*row_it);
00253 nzs->access( index ) = cur.value;
00254 col_ind->access( index ) = cur.j();
00255 index++;
00256 }
00257 }
00258 row_start->access( nor->const_access() ) = nnz->const_access();
00259 }
00260
00267 T& random_access( ULI i, ULI j ) {
00268 const ULI found_index;
00269 if ( find( j, row_start[ i ], row_start[ i+1 ], found_index ) )
00270 return (*nzs)[ found_index ];
00271 else
00272 return zero_element;
00273 }
00274
00284 T* zax( T* x_orig ) {
00285
00286 CS_ARRAY< T >* x = new CS_ARRAY< T >( noc->unrecorded_access(), "x-vector in CS_CRS::zax" );
00287 x->unrecordedCopyFrom( x_orig );
00288
00289
00290
00291
00292 #ifdef _DEBUG
00293 std::cout << "Debug: nor=" << nor.unrecorded_access() << std::endl;
00294 #endif
00295 CS_ARRAY< T >* ret = new CS_ARRAY< T >( nor->unrecorded_access(), "z-vector in CS_CRS::zax" );
00296 for( ULI i=0; i < nor->unrecorded_access(); ++i )
00297 ret->access( i ) = zero_element->unrecorded_access();
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316 for( ULI i=0; i < nor->unrecorded_access(); i++ ) {
00317 for( ULI j=row_start->const_access( i ); j<row_start->const_access( i+1 ); j++ )
00318 ret->access( i ) += nzs->const_access( j ) * x->const_access( col_ind->const_access( j ) );
00319 }
00320 T* real_ret = new T[ nor->unrecorded_access() ];
00321 ret->unrecordedCopyTo( real_ret );
00322 delete ret;
00323 delete x;
00324 return real_ret;
00325 }
00326
00328 ~CS_CRS() {
00329 delete row_start;
00330 delete zero_element;
00331 delete nzs;
00332 delete col_ind;
00333 delete nnz;
00334 delete nor;
00335 delete noc;
00336 }
00337
00338 };
00339
00340 #endif
00341