UniversalRandom Using George Marsaglia's elegant Xorshift (In Python and Java)
September 21, 2016 ยท View on GitHub
/* ---------------------------------------------------------------------
- Numenta Platform for Intelligent Computing (NuPIC)
- Copyright (C) 2016, Numenta, Inc. Unless you have an agreement
- with Numenta, Inc., for a separate license for this software code, the
- following terms and conditions apply:
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero Public License version 3 as
- published by the Free Software Foundation.
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- See the GNU Affero Public License for more details.
- You should have received a copy of the GNU Affero Public License
- along with this program. If not, see http://www.gnu.org/licenses.
- http://numenta.org/licenses/
-
*/ package org.numenta.nupic.util;
import java.math.BigDecimal; import java.math.BigInteger; import java.math.MathContext; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; import java.util.stream.Collectors; import java.util.stream.IntStream;
import gnu.trove.list.array.TIntArrayList; import gnu.trove.set.hash.TIntHashSet;
/**
-
-
This also has a Python version which is guaranteed to output the same random
-
numbers if given the same initial seed value.
-
-
Implementation of George Marsaglia's elegant Xorshift random generator
-
30% faster and better quality than the built-in java.util.random.
-
-
see http://www.javamex.com/tutorials/random_numbers/xorshift.shtml.
-
@author cogmission / public class UniversalRandom extends Random { /* serial version */ private static final long serialVersionUID = 1L;
private static final MathContext MATH_CONTEXT = new MathContext(9);
long seed;
static final String BadBound = "bound must be positive";
public UniversalRandom(long seed) { this.seed = seed; }
/**
- Sets the long value used as the initial seed
- @param seed the value with which to be initialized */ @Override public void setSeed(long seed) { this.seed = seed; }
/**
- Returns the long value used as the initial seed
- @return the initial seed value */ public long getSeed() { return seed; }
/*
- Internal method used for testing
*/
private int[] sampleWithPrintout(TIntArrayList choices, int[] selectedIndices, List
collectedRandoms) { TIntArrayList choiceSupply = new TIntArrayList(choices); int upperBound = choices.size(); for (int i = 0; i < selectedIndices.length; i++) { int randomIdx = nextInt(upperBound); //System.out.println("randomIdx: " + randomIdx); collectedRandoms.add(randomIdx); selectedIndices[i] = (choiceSupply.removeAt(randomIdx)); upperBound--; } Arrays.sort(selectedIndices); return selectedIndices; }
/**
- Returns a random, sorted, and unique list of the specified sample size of
- selections from the specified list of choices.
- @param choices
- @param selectedIndices
- @return an array containing a sampling of the specified choices */ public int[] sample(TIntArrayList choices, int[] selectedIndices) { TIntArrayList choiceSupply = new TIntArrayList(choices); int upperBound = choices.size(); for (int i = 0; i < selectedIndices.length; i++) { int randomIdx = nextInt(upperBound); selectedIndices[i] = (choiceSupply.removeAt(randomIdx)); upperBound--; } Arrays.sort(selectedIndices); //System.out.println("sample: " + Arrays.toString(selectedIndices)); return selectedIndices; }
/**
- Fisher-Yates implementation which shuffles the array contents.
- @param array the array of ints to shuffle.
- @return shuffled array */ public int[] shuffle(int[] array) { int index; for (int i = array.length - 1; i > 0; i--) { index = nextInt(i + 1); if (index != i) { array[index] ^= array[i]; array[i] ^= array[index]; array[index] ^= array[i]; } } return array; }
/**
- Returns an array of floating point values of the specified shape
- @param rows the number of rows
- @param cols the number of cols
- @return */ public double[][] rand(int rows, int cols) { double[][] retval = new double[rows][cols]; for(int i = 0;i < rows;i++) { for(int j = 0;j < cols;j++) { retval[i][j] = nextDouble(); } } return retval; }
/**
-
Returns an array of binary values of the specified shape whose
-
total number of "1's" will reflect the sparsity specified.
-
@param rows the number of rows
-
@param cols the number of cols
-
@param sparsity number between 0 and 1, indicating percentage
-
of "on" bits -
@return */ public int[][] binDistrib(int rows, int cols, double sparsity) { double[][] rand = rand(rows, cols);
for(int i = 0;i < rand.length;i++) { TIntArrayList sub = new TIntArrayList( ArrayUtils.where(rand[i], new Condition.Adapter
() { @Override public boolean eval(double d) { return d >= sparsity; } })); int sublen = sub.size(); int target = (int)(sparsity * cols); if(sublen < target) { int[] full = IntStream.range(0, cols).toArray(); TIntHashSet subSet = new TIntHashSet(sub); TIntArrayList toFill = new TIntArrayList( Arrays.stream(full) .filter(d -> !subSet.contains(d)) .toArray()); int cnt = toFill.size(); for(int x = 0;x < target - sublen;x++, cnt--) { int ind = nextInt(cnt); int item = toFill.removeAt(ind); rand[i][item] = sparsity; } }else if(sublen > target) { int cnt = sublen; for(int x = 0;x < sublen - target;x++, cnt--) { int ind = nextInt(cnt); int item = sub.removeAt(ind); rand[i][item] = 0.0; } }}
int[][] retval = Arrays.stream(rand) .map(da -> Arrays.stream(da).mapToInt(d -> d >= sparsity ? 1 : 0).toArray()) .toArray(int[][]::new); return retval; }
@Override public double nextDouble() { int nd = nextInt(10000); double retVal = new BigDecimal(nd * .0001d, MATH_CONTEXT).doubleValue(); //System.out.println("nextDouble: " + retVal); return retVal; }
@Override public int nextInt() { int retVal = nextInt(Integer.MAX_VALUE); //System.out.println("nextIntNB: " + retVal); return retVal; }
@Override public int nextInt(int bound) { if (bound <= 0) throw new IllegalArgumentException(BadBound);
int r = next(31); int m = bound - 1; if ((bound & m) == 0) // i.e., bound is a power of 2 r = (int)((bound * (long)r) >> 31); else { r = r % bound; /* THIS CODE IS COMMENTED TO WORK IDENTICALLY WITH THE PYTHON VERSION for (int u = r; u - (r = u % bound) + m < 0; u = next(31)) ; */ } //System.out.println("nextInt(" + bound + "): " + r); return r;}
/**
-
Implementation of George Marsaglia's elegant Xorshift random generator
-
30% faster and better quality than the built-in java.util.random see also
-
see http://www.javamex.com/tutorials/random_numbers/xorshift.shtml */ protected int next(int nbits) { long x = seed; x ^= (x << 21) & 0xffffffffffffffffL; x ^= (x >>> 35); x ^= (x << 4); seed = x; x &= ((1L << nbits) - 1);
return (int) x; }
BigInteger bigSeed; /**
-
PYTHON COMPATIBLE (Protected against overflows)
-
Implementation of George Marsaglia's elegant Xorshift random generator
-
30% faster and better quality than the built-in java.util.random see also
-
see http://www.javamex.com/tutorials/random_numbers/xorshift.shtml */ protected int nextX(int nbits) { long x = seed; BigInteger bigX = bigSeed == null ? BigInteger.valueOf(seed) : bigSeed; bigX = bigX.shiftLeft(21).xor(bigX).and(new BigInteger("ffffffffffffffff", 16)); bigX = bigX.shiftRight(35).xor(bigX).and(new BigInteger("ffffffffffffffff", 16)); bigX = bigX.shiftLeft(4).xor(bigX).and(new BigInteger("ffffffffffffffff", 16)); bigSeed = bigX; bigX = bigX.and(BigInteger.valueOf(1L).shiftLeft(nbits).subtract(BigInteger.valueOf(1))); x = bigX.intValue();
//System.out.println("x = " + x + ", seed = " + seed); return (int)x; }
public static void main(String[] args) { UniversalRandom random = new UniversalRandom(42);
long s = 2858730232218250L; long e = (s >>> 35); System.out.println("e = " + e); int x = random.nextInt(50); System.out.println("x = " + x); x = random.nextInt(50); System.out.println("x = " + x); x = random.nextInt(50); System.out.println("x = " + x); x = random.nextInt(50); System.out.println("x = " + x); x = random.nextInt(50); System.out.println("x = " + x); for(int i = 0;i < 10;i++) { int o = random.nextInt(50); System.out.println("x = " + o); } random = new UniversalRandom(42); for(int i = 0;i < 10;i++) { double o = random.nextDouble(); System.out.println("d = " + o); } /////////////////////////////////// // Values Seen in Python // /////////////////////////////////// /* * e = 83200 x = 0 x = 26 x = 14 x = 15 x = 38 x = 47 x = 13 x = 9 x = 15 x = 31 x = 6 x = 3 x = 0 x = 21 x = 45 d = 0.945 d = 0.2426 d = 0.5214 d = 0.0815 d = 0.0988 d = 0.5497 d = 0.4013 d = 0.4559 d = 0.5415 d = 0.2381 */ random = new UniversalRandom(42); TIntArrayList choices = new TIntArrayList(new int[] { 1,2,3,4,5,6,7,8,9 }); int sampleSize = 6; int[] selectedIndices = new int[sampleSize]; List<Integer> collectedRandoms = new ArrayList<>(); int[] expectedSample = {1,2,3,7,8,9}; List<Integer> expectedRandoms = Arrays.stream(new int[] {0,0,0,5,3,3}).boxed().collect(Collectors.toList()); random.sampleWithPrintout(choices, selectedIndices, collectedRandoms); System.out.println("samples are equal ? " + Arrays.equals(expectedSample, selectedIndices)); System.out.println("used randoms are equal ? " + collectedRandoms.equals(expectedRandoms)); random = new UniversalRandom(42); int[] coll = ArrayUtils.range(0, 10); int[] before = Arrays.copyOf(coll, coll.length); random.shuffle(coll); System.out.println("collection before: " + Arrays.toString(before)); System.out.println("collection shuffled: " + Arrays.toString(coll)); int[] expected = { 5, 1, 8, 6, 2, 4, 7, 3, 9, 0 }; System.out.println(Arrays.equals(expected, coll)); System.out.println(!Arrays.equals(expected, before)); // not equal}
}