11package com .github .coderodde .util ;
22
3- import java .util .logging .Level ;
4- import java .util .logging .Logger ;
3+ import java .util .ArrayList ;
4+ import java .util .List ;
5+ import java .util .Random ;
56
67/**
78 *
@@ -199,6 +200,8 @@ private static void parallelRadixSortImpl(
199200 BucketSizeCounterThread [] bucketSizeCounterThreads =
200201 new BucketSizeCounterThread [threads ];
201202
203+ // Spawn all but the rightmost bucket size counter thread. The rightmost
204+ // thread will be run in this thread as a mild optimization:
202205 for (int i = 0 ; i != bucketSizeCounterThreads .length - 1 ; i ++) {
203206 BucketSizeCounterThread bucketSizeCounterThread =
204207 new BucketSizeCounterThread (
@@ -219,8 +222,11 @@ private static void parallelRadixSortImpl(
219222 sourceFromIndex + rangeLength ,
220223 recursionDepth );
221224
225+ // Run the last bucket size thread in this thread:
226+ lastBucketSizeCounterThread .run ();
222227 bucketSizeCounterThreads [threads - 1 ] = lastBucketSizeCounterThread ;
223228
229+ // Join all the spawned bucket size counter threads:
224230 for (int i = 0 ; i != threads - 1 ; i ++) {
225231 BucketSizeCounterThread bucketSizeCounterThread =
226232 bucketSizeCounterThreads [i ];
@@ -255,8 +261,207 @@ private static void parallelRadixSortImpl(
255261 }
256262
257263 int spawnDegree = Math .min (numberOfNonemptyBuckets , threads );
264+ int [] startIndexMap = new int [BUCKETS ];
265+
266+ for (int i = 1 ; i != BUCKETS ; i ++) {
267+ startIndexMap [i ] = startIndexMap [i - 1 ]
268+ + globalBucketSizeMap [i - 1 ];
269+ }
270+
271+ int [][] processedMaps = new int [spawnDegree ][BUCKETS ];
272+
273+ // Make the preprocessing map independent of each thread:
274+ for (int i = 1 ; i != spawnDegree ; i ++) {
275+ int [] partialBucketSizeMap =
276+ bucketSizeCounterThreads [i - 1 ].getLocalBucketSizeMap ();
277+
278+ for (int j = 0 ; j != BUCKETS ; j ++) {
279+ processedMaps [i ][j ] = processedMaps [i - 1 ][j ]
280+ + partialBucketSizeMap [j ];
281+ }
282+ }
283+
284+ int sourceStartIndex = sourceFromIndex ;
285+ int targetStartIndex = targetFromIndex ;
286+
287+ BucketInserterThread [] bucketInserterThreads =
288+ new BucketInserterThread [spawnDegree ];
289+
290+ // Spawn all but the rightmost bucket inserter thread. The rightmost
291+ // thread will be run in this thread as a mild optimization:
292+ for (int i = 0 ; i != spawnDegree - 1 ; i ++) {
293+ BucketInserterThread bucketInserterThread =
294+ new BucketInserterThread (
295+ source ,
296+ target ,
297+ sourceStartIndex += subrangeLength ,
298+ targetStartIndex += subrangeLength ,
299+ startIndexMap ,
300+ processedMaps [i ],
301+ subrangeLength ,
302+ recursionDepth );
303+
304+ bucketInserterThread .start ();
305+ bucketInserterThreads [i ] = bucketInserterThread ;
306+ }
307+
308+ BucketInserterThread lastBucketInserterThread =
309+ new BucketInserterThread (
310+ source ,
311+ target ,
312+ sourceStartIndex ,
313+ targetStartIndex ,
314+ startIndexMap ,
315+ processedMaps [spawnDegree - 1 ],
316+ rangeLength - (spawnDegree - 1 ) * subrangeLength ,
317+ recursionDepth );
318+
319+ // Run the last, rightmost bucket inserter thread in this thread:
320+ lastBucketInserterThread .run ();
321+ bucketInserterThreads [threads - 1 ] = lastBucketInserterThread ;
322+
323+ // Join all the spawned bucket inserter threads:
324+ for (int i = 0 ; i != threads - 1 ; i ++) {
325+ BucketInserterThread bucketInserterThread =
326+ bucketInserterThreads [i ];
327+
328+ try {
329+ bucketInserterThread .join ();
330+ } catch (InterruptedException ex ) {
331+ throw new RuntimeException (
332+ "Could not join a bucket inserter thread." ,
333+ ex );
334+ }
335+ }
336+
337+ if (recursionDepth == DEEPEST_RECURSION_DEPTH ) {
338+ // Nowhere to recur, all bytes are processed. Return.
339+ return ;
340+ }
341+
342+ ListOfBucketKeyLists listOfBucketKeyLists =
343+ new ListOfBucketKeyLists (spawnDegree );
344+
345+ for (int i = 0 ; i != spawnDegree ; i ++) {
346+ BucketKeyList bucketKeyList =
347+ new BucketKeyList (numberOfNonemptyBuckets );
348+
349+ listOfBucketKeyLists .addBucketKeyList (bucketKeyList );
350+ }
351+
352+ // Match each thread to the number of threads it may run in:
353+ int [] threadCountMap = new int [spawnDegree ];
354+
355+ // ... basic thread counts...
356+ for (int i = 0 ; i != spawnDegree ; i ++) {
357+ threadCountMap [i ] = threads / spawnDegree ;
358+ }
359+
360+ // ... make sure all threads are in use:
361+ for (int i = 0 ; i != threads % spawnDegree ; i ++) {
362+ threadCountMap [i ]++;
363+ }
364+
365+ // Contains all the keys of all the non-empty buckets:
366+ BucketKeyList nonEmptyBucketIndices =
367+ new BucketKeyList (numberOfNonemptyBuckets );
368+
369+ for (int bucketKey = 0 ; bucketKey != BUCKETS ; bucketKey ++) {
370+ if (globalBucketSizeMap [bucketKey ] != 0 ) {
371+ nonEmptyBucketIndices .addBucketKey (bucketKey );
372+ }
373+ }
374+
375+ // Shuffle the bucket keys:
376+ nonEmptyBucketIndices .shuffle (new Random ());
377+
378+ // Distributed the buckets over sorter task lists:
379+ int f = 0 ;
380+ int j = 0 ;
381+ int listIndex = 0 ;
382+ int optimalSubrangeLength = rangeLength / spawnDegree ;
383+ int packed = 0 ;
384+ int sz = nonEmptyBucketIndices .size ();
385+
386+ while (j != sz ) {
387+ int bucketKey = nonEmptyBucketIndices .getBucketKey (j ++);
388+ int tmp = globalBucketSizeMap [bucketKey ];
389+ packed += tmp ;
390+
391+ if (packed >= optimalSubrangeLength || j == sz ) {
392+ packed = 0 ;
393+
394+ for (int i = f ; i != j ; i ++) {
395+ int bucketKey2 = nonEmptyBucketIndices .getBucketKey (i );
396+
397+ BucketKeyList bucketKeyList =
398+ listOfBucketKeyLists .getBucketKeyList (listIndex );
399+
400+ bucketKeyList .addBucketKey (bucketKey2 );
401+ }
402+
403+ listIndex ++;
404+ f = j ;
405+ }
406+ }
258407
408+ List <List <SorterTask >> listOfSorterTaskLists =
409+ new ArrayList <>(spawnDegree );
259410
411+ for (int i = 0 ; i != spawnDegree ; i ++) {
412+ List <SorterTask > sorterTaskList =
413+ new ArrayList <>(BUCKETS );
414+
415+ int size = listOfBucketKeyLists .getBucketKeyList (i ).size ();
416+
417+ for (int idx = 0 ; idx != size ; idx ++) {
418+ int bucketKey =
419+ listOfBucketKeyLists
420+ .getBucketKeyList (i )
421+ .getBucketKey (idx );
422+
423+ SorterTask sorterTask =
424+ new SorterTask (
425+ target ,
426+ source ,
427+ targetFromIndex + startIndexMap [bucketKey ],
428+ sourceStartIndex + startIndexMap [bucketKey ],
429+ globalBucketSizeMap [bucketKey ],
430+ recursionDepth + 1 ,
431+ threadCountMap [i ]);
432+
433+ sorterTaskList .add (sorterTask );
434+ }
435+
436+ listOfSorterTaskLists .add (sorterTaskList );
437+ }
438+
439+ SorterThread [] sorterThreads = new SorterThread [spawnDegree - 1 ];
440+
441+ // Recur into deeper depth via multithreading:
442+ for (int i = 0 ; i != sorterThreads .length ; i ++) {
443+ SorterThread sorterThread =
444+ new SorterThread (
445+ listOfSorterTaskLists .get (i ));
446+
447+ sorterThread .start ();
448+ sorterThreads [i ] = sorterThread ;
449+ }
450+
451+ // Run the rightmost sorter thread in this thread:
452+ new SorterThread (
453+ listOfSorterTaskLists .get (spawnDegree - 1 )).run ();;
454+
455+ // Join all the actually spawned sorter threads:
456+ for (SorterThread sorterThread : sorterThreads ) {
457+ try {
458+ sorterThread .join ();
459+ } catch (InterruptedException ex ) {
460+ throw new RuntimeException (
461+ "Could not join a sorter thread." ,
462+ ex );
463+ }
464+ }
260465 }
261466
262467 private static void rangeCheck (
@@ -608,10 +813,112 @@ public void run() {
608813 }
609814
610815 private static final class SorterThread extends Thread {
816+
817+ private final List <SorterTask > sorterTasks ;
818+
819+ SorterThread (List <SorterTask > sorterTasks ) {
820+ this .sorterTasks = sorterTasks ;
821+ }
611822
612823 @ Override
613824 public void run () {
825+ for (SorterTask sorterTask : sorterTasks ) {
826+ if (sorterTask .threads > 1 ) {
827+ parallelRadixSortImpl (sorterTask .source ,
828+ sorterTask .target ,
829+ sorterTask .sourceStartOffset ,
830+ sorterTask .targetStartOffset ,
831+ sorterTask .rangeLength ,
832+ sorterTask .recursionDepth ,
833+ sorterTask .threads );
834+ } else {
835+ radixSortImpl (sorterTask .source ,
836+ sorterTask .target ,
837+ sorterTask .sourceStartOffset ,
838+ sorterTask .targetStartOffset ,
839+ sorterTask .rangeLength ,
840+ sorterTask .recursionDepth );
841+ }
842+ }
843+ }
844+ }
845+
846+ private static final class SorterTask {
847+
848+ final int [] source ;
849+ final int [] target ;
850+ final int sourceStartOffset ;
851+ final int targetStartOffset ;
852+ final int rangeLength ;
853+ final int recursionDepth ;
854+ final int threads ;
855+
856+ SorterTask (int [] source ,
857+ int [] target ,
858+ int sourceStartOffset ,
859+ int targetStartOffset ,
860+ int rangeLength ,
861+ int recursionDepth ,
862+ int threads ) {
614863
864+ this .source = source ;
865+ this .target = target ;
866+ this .sourceStartOffset = sourceStartOffset ;
867+ this .targetStartOffset = targetStartOffset ;
868+ this .rangeLength = rangeLength ;
869+ this .recursionDepth = recursionDepth ;
870+ this .threads = threads ;
871+ }
872+ }
873+
874+ private static final class BucketKeyList {
875+ private final int [] bucketKeys ;
876+ private int size ;
877+
878+ BucketKeyList (int capacity ) {
879+ this .bucketKeys = new int [capacity ];
880+ }
881+
882+ void addBucketKey (int bucketKey ) {
883+ this .bucketKeys [size ++] = bucketKey ;
884+ }
885+
886+ int getBucketKey (int index ) {
887+ return this .bucketKeys [index ];
888+ }
889+
890+ int size () {
891+ return size ;
892+ }
893+
894+ void shuffle (Random random ) {
895+ for (int i = 0 ; i != size - 1 ; i ++) {
896+ int j = i + random .nextInt (size - i );
897+ int temp = bucketKeys [i ];
898+ bucketKeys [i ] = bucketKeys [j ];
899+ bucketKeys [j ] = temp ;
900+ }
901+ }
902+ }
903+
904+ private static final class ListOfBucketKeyLists {
905+ private final BucketKeyList [] lists ;
906+ private int size ;
907+
908+ ListOfBucketKeyLists (int capacity ) {
909+ this .lists = new BucketKeyList [capacity ];
910+ }
911+
912+ void addBucketKeyList (BucketKeyList bucketKeyList ) {
913+ this .lists [this .size ++] = bucketKeyList ;
914+ }
915+
916+ BucketKeyList getBucketKeyList (int index ) {
917+ return this .lists [index ];
918+ }
919+
920+ int size () {
921+ return size ;
615922 }
616923 }
617924}
0 commit comments