Main Page | Class List | File List | Class Members

CS_CRS.hpp

00001 #include "Triplet.hpp"
00002 #include "CRS.hpp"
00003 #include "CS_ELEMENT.hpp"
00004 #include "CS_ARRAY.hpp"
00005 
00006 #include <cstdlib>
00007 #include <vector>
00008 
00009 //#define _DEBUG
00010 
00011 #ifndef _H_CS_CRS
00012 #define _H_CS_CRS
00013 
00014 #ifdef _DEBUG
00015 #include<iostream>
00016 #endif
00017 
00021 template< typename T >
00022 class CS_CRS {
00023 
00024   private:
00025 
00027         typedef unsigned long int ULI;
00028 
00029   protected:
00030 
00032         CS_ELEMENT< ULI >* nnz;
00033 
00035         CS_ELEMENT< ULI >* nor;
00036 
00038         CS_ELEMENT< ULI >* noc;
00039 
00041         CS_ARRAY< ULI >* row_start;
00042 
00044         CS_ARRAY< T >* nzs;
00045 
00047         CS_ARRAY< ULI >* col_ind;
00048 
00050         CS_ELEMENT< T >* zero_element;
00051 
00057         static int compareTriplets( const void * left, const void * right ) {
00058                 const Triplet< T > one = **( (Triplet< T > **)left );
00059                 const Triplet< T > two = **( (Triplet< T > **)right );
00060                 if ( one.j() < two.j() )
00061                         return -1;
00062                 if ( one.j() > two.j() )
00063                         return 1;
00064                 return 0;
00065         }       
00066 
00075         bool find( ULI col_index, ULI search_start, ULI search_end, ULI &ret ) {
00076                 for( ULI i=search_start; i<search_end; i++ )
00077                         if( (*col_ind)[ i ] == col_index ) {
00078                                 ret = i;
00079                                 return true;
00080                         }
00081                 return false;
00082         }
00083 
00084   public:
00085 
00095         CS_CRS( const ULI number_of_nonzeros, const ULI number_of_rows, const ULI number_of_cols, const T zero ):
00096                 nnz( number_of_nonzeros ), nor( number_of_rows ) {
00097                 nnz = new CS_ELEMENT< ULI >( number_of_nonzeros );
00098                 nor = new CS_ELEMENT< ULI >( number_of_rows );
00099                 noc = new CS_ELEMENT< ULI >( number_of_cols );
00100                 zero_element = new CS_ELEMENT< T >( zero );
00101                 row_start = new CS_ARRAY< ULI >( nor + 1, "Row start vector (CS_CRS.h)" );
00102                 nzs = new CS_ARRAY< T>( nnz, "Non-zero storage vector (CS_CRS.h" );
00103                 col_ind = new CS_ARRAY< ULI >( nnz, "Columm index vector (CS_CRS.h)" );
00104         }
00105 
00107         CS_CRS( CRS< T >& toCopy ) {
00108                 zero_element = new CS_ELEMENT< T >( toCopy.zero_element );      
00109                 nnz = new CS_ELEMENT< ULI >( toCopy.nnz );
00110                 nor = new CS_ELEMENT< ULI >( toCopy.nor );
00111                 noc = new CS_ELEMENT< ULI >( toCopy.noc );
00112                 row_start = new CS_ARRAY< ULI >( nor + 1, "Row start vector (CS_CRS.h)" );
00113                 nzs = new CS_ARRAY< T>( nnz, "Non-zero storage vector (CS_CRS.h" );
00114                 col_ind = new CS_ARRAY< ULI >( nnz, "Columm index vector (CS_CRS.h)" );
00115                 for( ULI i=0; i<(*nnz).const_access(); i++ ) {
00116                         nzs[ i ] = toCopy.nzs->const_access( i );
00117                         col_ind[ i ] = toCopy.col_ind->const_access( i );
00118                 }
00119                 for( ULI i=0; i<(*nor).const_access(); i++ )
00120                         row_start[ i ] = toCopy.row_start->const_access( i );
00121         }
00122 
00126         CS_CRS( CS_CRS< T >& toCopy ) {
00127                 zero_element = toCopy.zero_element;
00128                 nnz = toCopy.nnz;
00129                 nor = toCopy.nor;
00130                 noc = toCopy.noc;
00131                 row_start = new CS_ARRAY< ULI >( nor + 1, "Row start vector (CS_CRS.h)" );
00132                 nzs = new CS_ARRAY< T>( nnz, "Non-zero storage vector (CS_CRS.h" );
00133                 col_ind = new CS_ARRAY< ULI >( nnz, "Columm index vector (CS_CRS.h)" );
00134                 for( ULI i=0; i<(*nnz).const_access(); i++ ) {
00135                         nzs[ i ] = toCopy.nzs->const_access( i );
00136                         col_ind[ i ] = toCopy.col_ind->const_access( i );
00137                 }
00138                 for( ULI i=0; i<(*nor).const_access(); i++ )
00139                         row_start[ i ] = toCopy.row_start->const_access( i );
00140         }
00141 
00152         CS_CRS( std::vector< Triplet< T > > input, unsigned long int currow, unsigned long int curcol, T& zero ) {
00153                 zero_element = new CS_ELEMENT< T >( zero );
00154                 nnz = new CS_ELEMENT< ULI >( input.size() );
00155                 nor = new CS_ELEMENT< ULI >( static_cast< ULI >( currow ) );
00156                 noc = new CS_ELEMENT< ULI >( static_cast< ULI >( curcol ) );
00157         
00158                 //build better datastructure
00159                 std::vector< std::vector< Triplet< T >* > > ds( nor->const_access() );
00160                 
00161                 //move input there
00162                 typename std::vector< Triplet< T > >::iterator in_it;
00163                 in_it = input.begin();
00164                 for( ; in_it != input.end(); in_it++ ) {
00165                         Triplet< T >* cur = &(*in_it);
00166                         const ULI currow = cur->i();
00167                         const T value = cur->value;
00168                         if( value == zero_element->const_access() ) { nnz->access() -= 1; continue; }
00169                         ds.at( currow ).push_back( cur );
00170                 }
00171 
00172                 //allocate arrays
00173                 row_start = new CS_ARRAY< ULI >( nor->const_access() + 1, "Row start vector (CS_CRS.h)" );
00174                 nzs = new CS_ARRAY< T>( nnz->const_access(), "Non-zero storage vector (CS_CRS.h" );
00175                 col_ind = new CS_ARRAY< ULI >( nnz->const_access(), "Columm index vector (CS_CRS.h)" );
00176 
00177                 ULI index = 0;
00178                 for( ULI currow = 0; currow < nor->const_access(); currow++ ) {
00179                         row_start->access( currow ) = index;
00180                         if( ds.at( currow ).size() == 0 ) continue;
00181                         qsort( &( ds.at( currow )[ 0 ] ), ds.at( currow ).size(), sizeof( Triplet< T >* ), &compareTriplets );
00182                         typename std::vector< Triplet< T >* >::iterator row_it = ds.at( currow ).begin();
00183                         for( ; row_it!=ds.at( currow ).end(); row_it++ ) {
00184                                 const Triplet< T > cur = *(*row_it);
00185                                 nzs->access( index ) = cur.value;
00186                                 col_ind->access( index ) = cur.j();
00187                                 index++;
00188                         }
00189                 }
00190 
00191                 row_start->access( nor->const_access() ) = nnz->const_access();
00192         }
00193 
00202         CS_CRS( std::vector< Triplet< T > > input, T& zero ) {
00203                 zero_element = new CS_ELEMENT< T >( zero );
00204                 //find nnz
00205                 nnz = new CS_ELEMENT< ULI >( input.size() );
00206         
00207                 //find number of rows
00208                 nor = new CS_ELEMENT< ULI >( static_cast< ULI >( 0 ) );
00209                 noc = new CS_ELEMENT< ULI >( static_cast< ULI >( 0 ) );
00210                 typename std::vector< Triplet< T > >::iterator in_it;
00211                 in_it = input.begin();
00212                 for( ; in_it!=input.end(); in_it++ ) {
00213                         const ULI currow = ( *in_it ).i();
00214                         const ULI curcol = ( *in_it ).j();
00215                         if ( currow > nor->const_access() ) nor->access() = currow;
00216                         if ( curcol > noc->const_access() ) noc->access() = curcol;
00217                 }
00218 #ifdef _DEBUG
00219                 std::cout << "Max row index found: " << nor->const_access() << std::endl;
00220 #endif
00221                 nor->access() += 1;
00222                 noc->access() += 1;
00223 
00224                 //build better datastructure
00225                 std::vector< std::vector< Triplet< T >* > > ds( nor->const_access() );
00226                 
00227                 //move input there
00228                 in_it = input.begin();
00229                 for( ; in_it != input.end(); in_it++ ) {
00230                         Triplet< T >* cur = &(*in_it);
00231                         const ULI currow = cur->i();
00232                         const T value = cur->value;
00233                         if( value == zero_element->const_access() ) { nnz->access() -= 1; continue; }
00234                         ds.at( currow ).push_back( cur );
00235                 }
00236 
00237                 //allocate arrays
00238                 row_start = new CS_ARRAY< ULI >( nor->const_access() + 1, "Row start vector (CS_CRS.h)" );
00239                 nzs = new CS_ARRAY< T>( nnz->const_access(), "Non-zero storage vector (CS_CRS.h" );
00240                 col_ind = new CS_ARRAY< ULI >( nnz->const_access(), "Columm index vector (CS_CRS.h)" );
00241 
00242                 //make CRS
00243                 ULI index = 0;
00244                 for( ULI currow = 0; currow < nor->const_access(); currow++ ) {
00245 #ifdef _DEBUG
00246                         std::cout << "row_start[ " << currow << " ]=" << index << std::endl;
00247 #endif
00248                         row_start->access( currow ) = index;
00249                         if( ds.at( currow ).size() == 0 ) continue;
00250                         typename std::vector< Triplet< T >* >::iterator row_it = ds.at( currow ).begin();
00251                         for( ; row_it!=ds.at( currow ).end(); row_it++ ) {
00252                                 const Triplet< T > cur = *(*row_it);
00253                                 nzs->access( index ) = cur.value;
00254                                 col_ind->access( index ) = cur.j();
00255                                 index++;
00256                         }
00257                 }
00258                 row_start->access( nor->const_access() ) = nnz->const_access();
00259         }
00260 
00267         T& random_access( ULI i, ULI j ) {
00268                 const ULI found_index;
00269                 if ( find( j, row_start[ i ], row_start[ i+1 ], found_index ) )
00270                         return (*nzs)[ found_index ];
00271                 else
00272                         return zero_element;
00273         }
00274 
00284         T* zax( T* x_orig ) {
00285 
00286                 CS_ARRAY< T >* x = new CS_ARRAY< T >( noc->unrecorded_access(), "x-vector in CS_CRS::zax" );
00287                 x->unrecordedCopyFrom( x_orig );
00288 
00289 //              ULI index = 0;
00290 //              ULI row = 0;
00291 //              ULI row_i = row_start->const_access( 1 );
00292 #ifdef _DEBUG
00293                 std::cout << "Debug: nor=" << nor.unrecorded_access() << std::endl;
00294 #endif
00295                 CS_ARRAY< T >* ret = new CS_ARRAY< T >( nor->unrecorded_access(), "z-vector in CS_CRS::zax" );
00296                 for( ULI i=0; i < nor->unrecorded_access(); ++i )
00297                         ret->access( i ) = zero_element->unrecorded_access();
00298 //              for( ; index < nnz->unrecorded_access(); index++ ) {
00299 //                      while( index == row_i ) {
00300 //                              ++row;
00301 //                              row_i = row_start->const_access( row + 1 );
00302 //                      }
00303 // #ifdef _DEBUG
00304 //                      std::cout << "index: " << index << " nzs[ index ]: " << nzs->unrecorded_access( index ) << " col_ind[ index ]: " << col_ind->unrecorded_access( index ) << " x[ col_ind[ index ] ]: " << x->unrecorded_access( col_ind->unrecorded_access( index ) ) << std::endl;
00305 // #endif
00306 // 
00307 // #ifdef _DEBUG
00308 //                      double toAdd = nzs->const_access( index );
00309 //                      const unsigned long int temp = col_ind->const_access(index );
00310 //                      toAdd *= x->const_access( temp );
00311 // #else
00312 //                      const double toAdd = nzs->const_access( index ) * x->const_access( col_ind->const_access( index ) );
00313 // #endif
00314 //                      ret->access( row ) += toAdd;
00315 //              }
00316                 for( ULI i=0; i < nor->unrecorded_access(); i++ ) {
00317                         for( ULI j=row_start->const_access( i ); j<row_start->const_access( i+1 ); j++ )
00318                                 ret->access( i ) += nzs->const_access( j ) * x->const_access( col_ind->const_access( j ) );
00319                 }
00320                 T* real_ret = new T[ nor->unrecorded_access() ];
00321                 ret->unrecordedCopyTo( real_ret );
00322                 delete ret;
00323                 delete x;
00324                 return real_ret;
00325         }
00326 
00328         ~CS_CRS() {
00329                 delete row_start;
00330                 delete zero_element;
00331                 delete nzs;
00332                 delete col_ind;
00333                 delete nnz;
00334                 delete nor;
00335                 delete noc;
00336         }
00337 
00338 };
00339 
00340 #endif
00341 

Generated on Fri Aug 15 18:12:22 2008 for Run-timeCacheSimulator by  doxygen 1.3.9.1