diff --git a/src/atomic.h b/src/atomic.h index 5deb82c6d77522ad2d05a24fd937c5e239b3b91f..df6e5aaeed4db12653530d6a5dec8ee3042f02f7 100644 --- a/src/atomic.h +++ b/src/atomic.h @@ -23,3 +23,4 @@ #define atomic_add(v,i) __sync_fetch_and_add( v , i ) #define atomic_inc(v) atomic_add( v , 1 ) #define atomic_dec(v) atomic_add( v , -1 ) +#define atomic_cas(v,o,n) __sync_val_compare_and_swap( v , o , n ) diff --git a/src/space.c b/src/space.c index 9f10edd6cb4b82ebf11401f546c0913e28b7f9ea..5dc36b3c93a6a5cf9f8204ef0df37a950466c5a0 100644 --- a/src/space.c +++ b/src/space.c @@ -339,9 +339,9 @@ void space_rebuild ( struct space *s , double cell_max ) { void parts_sort ( struct part *parts , int *ind , int N , int min , int max ) { struct { - int i, j, min, max; + int i, j, min, max, ready; } qstack[space_qstack]; - int first, last, waiting; + volatile int first, last, waiting; int pivot; int i, ii, j, jj, temp_i, qid; @@ -352,7 +352,10 @@ void parts_sort ( struct part *parts , int *ind , int N , int min , int max ) { qstack[0].j = N-1; qstack[0].min = min; qstack[0].max = max; - first = 0; last = 0; waiting = 1; + qstack[0].ready = 1; + for ( i = 1 ; i < space_qstack ; i++ ) + qstack[i].ready = 0; + first = 0; last = 1; waiting = 1; /* Parallel bit. */ #pragma omp parallel default(shared) private(pivot,i,ii,j,jj,min,max,temp_i,qid,temp_p) @@ -361,26 +364,17 @@ void parts_sort ( struct part *parts , int *ind , int N , int min , int max ) { /* Main loop. */ while ( waiting > 0 ) { - /* Try to get an index off the stack. */ - qid = -1; - while ( waiting > 0 && qid < 0 ) { - #pragma omp critical (stack) - { - if ( last - first > space_qstack ) - error( "Sorting stack overflow." ); - if ( first <= last ) { - qid = first; - first += 1; - } - } - } + /* Grab an interval off the queue. */ + qid = atomic_inc( &first ) % space_qstack; - /* Did we get an index? */ - if ( qid < 0 ) - continue; - + /* Wait for the interval to be ready. */ + while ( waiting > 0 && atomic_cas( &qstack[qid].ready , 1 , 0 ) != 1 ); + + /* Broke loop for all the wrong reasons? */ + if ( waiting == 0 ) + break; + /* Get the stack entry. */ - qid %= space_qstack; i = qstack[qid].i; j = qstack[qid].j; min = qstack[qid].min; @@ -415,43 +409,36 @@ void parts_sort ( struct part *parts , int *ind , int N , int min , int max ) { /* Recurse on the left? */ if ( jj > i && pivot > min ) { - #pragma omp critical (stack) - { - last += 1; - qid = last % space_qstack; - qstack[qid].i = i; - qstack[qid].j = jj; - qstack[qid].min = min; - qstack[qid].max = pivot; - waiting += 1; - } + qid = atomic_inc( &last ) % space_qstack; + qstack[qid].i = i; + qstack[qid].j = jj; + qstack[qid].min = min; + qstack[qid].max = pivot; + qstack[qid].ready = 1; + atomic_inc( &waiting ); } /* Recurse on the right? */ if ( jj+1 < j && pivot+1 < max ) { - #pragma omp critical (stack) - { - last += 1; - qid = last % space_qstack; - qstack[qid].i = jj+1; - qstack[qid].j = j; - qstack[qid].min = pivot+1; - qstack[qid].max = max; - waiting += 1; - } + qid = atomic_inc( &last ) % space_qstack; + qstack[qid].i = jj+1; + qstack[qid].j = j; + qstack[qid].min = pivot+1; + qstack[qid].max = max; + qstack[qid].ready = 1; + atomic_inc( &waiting ); } - #pragma omp critical (stack) - waiting -= 1; + atomic_dec( &waiting ); } /* main loop. */ } /* parallel bit. */ /* Verify sort. */ - /* for ( i = 1 ; i < N ; i++ ) + for ( i = 1 ; i < N ; i++ ) if ( ind[i-1] > ind[i] ) - error( "Sorting failed!" ); */ + error( "Sorting failed!" ); } diff --git a/src/space.h b/src/space.h index 4e2667de1fe110d858569d0c1c8f2c987dbd750b..88b012a2279c1dcc679f482f46d039e53251af54 100644 --- a/src/space.h +++ b/src/space.h @@ -98,7 +98,6 @@ struct space { /* function prototypes. */ void parts_sort ( struct part *parts , int *ind , int N , int min , int max ); -void parts_sort_par ( struct part *parts , int *ind , int N , int min , int max ); struct cell *space_getcell ( struct space *s ); int space_getsid ( struct space *s , struct cell **ci , struct cell **cj , double *shift ); void space_init ( struct space *s , double dim[3] , struct part *parts , int N , int periodic , double h_max );