Coding pattern for random percentage branching?

JavaDesign PatternsRandom

Java Problem Overview


So let's say we have a code block that we want to execute 70% of times and another one 30% of times.

if(Math.random() < 0.7)
    70percentmethod();
else
    30percentmethod();

Simple enough. But what if we want it to be easily expandable to say, 30%/60%/10% etc.? Here it would require adding and changing all the if statements on change which isn't exactly great to use, slow and mistake inducing.

So far I've found large switches to be decently useful for this use case, for example:

switch(rand(0, 10)){
    case 0:
    case 1:
    case 2:
    case 3:
    case 4:
    case 5:
    case 6:
    case 7:70percentmethod();break;
    case 8:
    case 9:
    case 10:30percentmethod();break;
}

Which can be very easily changed to:

switch(rand(0, 10)){
    case 0:10percentmethod();break;
    case 1:
    case 2:
    case 3:
    case 4:
    case 5:
    case 6:
    case 7:60percentmethod();break;
    case 8:
    case 9:
    case 10:30percentmethod();break;
}

But these have their drawbacks as well, being cumbersome and split onto a predetermined amount of divisions.

Something ideal would be based on a "frequency number" system I guess, like so:

(1,a),(1,b),(2,c) -> 25% a, 25% b, 50% c

then if you added another one:

(1,a),(1,b),(2,c),(6,d) -> 10% a, 10% b, 20% c, 60% d

So simply adding up the numbers, making the sum equal 100% and then split that.

I suppose it wouldn't be that much trouble to make a handler for it with a customized hashmap or something, but I'm wondering if there's some established way/pattern or lambda for it before I go all spaghetti on this.

Java Solutions


Solution 1 - Java

EDIT: See edit at end for more elegant solution. I'll leave this in though.

You can use a NavigableMap to store these methods mapped to their percentages.

NavigableMap<Double, Runnable> runnables = new TreeMap<>();

runnables.put(0.3, this::30PercentMethod);
runnables.put(1.0, this::70PercentMethod);

public static void runRandomly(Map<Double, Runnable> runnables) {
    double percentage = Math.random();
    for (Map.Entry<Double, Runnable> entry : runnables){
        if (entry.getKey() < percentage) {
            entry.getValue().run();
            return; // make sure you only call one method
        }
    }
    throw new RuntimeException("map not filled properly for " + percentage);
}

// or, because I'm still practicing streams by using them for everything
public static void runRandomly(Map<Double, Runnable> runnables) {
    double percentage = Math.random();
    runnables.entrySet().stream()
        .filter(e -> e.getKey() < percentage)
        .findFirst().orElseThrow(() -> 
                new RuntimeException("map not filled properly for " + percentage))
        .run();
}

The NavigableMap is sorted (e.g. HashMap gives no guarantees of the entries) by keys, so you get the entries ordered by their percentages. This is relevant because if you have two items (3,r1),(7,r2), they result in the following entries: r1 = 0.3 and r2 = 1.0 and they need to be evaluated in this order (e.g. if they are evaluated in the reverse order the result would always be r2).

As for the splitting, it should go something like this: With a Tuple class like this

static class Pair<X, Y>
{
    public Pair(X f, Y s)
    {
        first = f;
        second = s;
    }

    public final X first;
    public final Y second;
}

You can create a map like this

// the parameter contains the (1,m1), (1,m2), (3,m3) pairs
private static Map<Double,Runnable> splitToPercentageMap(Collection<Pair<Integer,Runnable>> runnables)
{

    // this adds all Runnables to lists of same int value,
    // overall those lists are sorted by that int (so least probable first)
    double total = 0;
    Map<Integer,List<Runnable>> byNumber = new TreeMap<>();
    for (Pair<Integer,Runnable> e : runnables)
    {
        total += e.first;
        List<Runnable> list = byNumber.getOrDefault(e.first, new ArrayList<>());
        list.add(e.second);
        byNumber.put(e.first, list);
    }

    Map<Double,Runnable> targetList = new TreeMap<>();
    double current = 0;
    for (Map.Entry<Integer,List<Runnable>> e : byNumber.entrySet())
    {
        for (Runnable r : e.getValue())
        {
            double percentage = (double) e.getKey() / total;
            current += percentage;
            targetList.put(current, r);
        }
    }

    return targetList;
}

And all of this added to a class

class RandomRunner {
    private List<Integer, Runnable> runnables = new ArrayList<>();
    public void add(int value, Runnable toRun) {
        runnables.add(new Pair<>(value, toRun));
    }
    public void remove(Runnable toRemove) {
        for (Iterator<Pair<Integer, Runnable>> r = runnables.iterator();
            r.hasNext(); ) {
            if (toRemove == r.next().second) {
               r.remove();
               break;
            }
        }
    }
    public void runRandomly() {
        // split list, use code from above
    }
}

EDIT :
Actually, the above is what you get if you get an idea stuck in your head and don't question it properly. Keeping the RandomRunner class interface, this is much easier:

class RandomRunner {
    List<Runnable> runnables = new ArrayList<>();
    public void add(int value, Runnable toRun) {
        // add the methods as often as their weight indicates.
        // this should be fine for smaller numbers;
        // if you get lists with millions of entries, optimize
        for (int i = 0; i < value; i++) {
            runnables.add(toRun);
        }
    }
    public void remove(Runnable r) {
        Iterator<Runnable> myRunnables = runnables.iterator();
        while (myRunnables.hasNext()) {
            if (myRunnables.next() == r) {
                myRunnables.remove();
            }
    }
    public void runRandomly() {
        if (runnables.isEmpty()) return;
        // roll n-sided die
        int runIndex = ThreadLocalRandom.current().nextInt(0, runnables.size());
        runnables.get(runIndex).run();
    }
}

Solution 2 - Java

All these answers seem quite complicated, so I'll just post the keep-it-simple alternative:

double rnd = Math.random()
if((rnd -= 0.6) < 0)
    60percentmethod();
else if ((rnd -= 0.3) < 0)
    30percentmethod();
else
    10percentmethod();

Doesn't need changing other lines and one can quite easily see what happens, without digging into auxiliary classes. A small downside is that it doesn't enforce that percentages sum to 100%.

Solution 3 - Java

I am not sure if there is a common name to this, but I think I learned this as the wheel of fortune back in university.

It basically just works as you described: It receives a list of values and "frequency numbers" and one is chosen according to the weighted probabilities.

list = (1,a),(1,b),(2,c),(6,d)

total = list.sum()
rnd = random(0, total)
sum = 0
for i from 0 to list.size():
    sum += list[i]
    if sum >= rnd:
        return list[i]
return list.last()

The list can be a function parameter if you want to generalize this.

This also works with floating point numbers and the numbers don't have to be normalized. If you normalize (to sum up to 1 for example), you can skip the list.sum() part.

EDIT:

Due to demand here is an actual compiling java implementation and usage example:

import java.util.ArrayList;
import java.util.Random;

public class RandomWheel<T>
{
  private static final class RandomWheelSection<T>
  {
	public double weight;
	public T value;

	public RandomWheelSection(double weight, T value)
	{
	  this.weight = weight;
	  this.value = value;
	}
  }
  
  private ArrayList<RandomWheelSection<T>> sections = new ArrayList<>();
  private double totalWeight = 0;
  private Random random = new Random();
  
  public void addWheelSection(double weight, T value)
  {
	sections.add(new RandomWheelSection<T>(weight, value));
	totalWeight += weight;
  }
  
  public T draw()
  {
	double rnd = totalWeight * random.nextDouble();
	
	double sum = 0;
	for (int i = 0; i < sections.size(); i++)
	{
	  sum += sections.get(i).weight;
	  if (sum >= rnd)
		return sections.get(i).value;
	}
	return sections.get(sections.size() - 1).value;
  }
  
  public static void main(String[] args)
  {
	RandomWheel<String> wheel = new RandomWheel<String>();
	wheel.addWheelSection(1, "a");
	wheel.addWheelSection(1, "b");
	wheel.addWheelSection(2, "c");
	wheel.addWheelSection(6, "d");
	
	for (int i = 0; i < 100; i++)
		System.out.print(wheel.draw());
  }
}

Solution 4 - Java

While the selected answer works, it is unfortunately asymptotically slow for your use case. Instead of doing this, you could use something called Alias Sampling. Alias sampling (or alias method) is a technique used for selection of elements with a weighted distribution. If the weights of choosing those elements doesn't change you can do selection in O(1) time!. If this isn't the case, you can still get amortized O(1) time if the ratio between the number of selections you make and the changes you make to the alias table (changing the weights) is high. The current selected answer suggests an O(N) algorithm, the next best thing is O(log(N)) given sorted probabilities and binary search, but nothing is going to beat the O(1) time I suggested.

This site provides a good overview of Alias method that is mostly language agnostic. Essentially you create a table where each entry represents the outcome of two probabilities. There is a single threshold for each entry at the table, below the threshold you get one value, above you get another value. You spread larger probabilities across multiple table values in order to create a probability graph with an area of one for all probabilities combined.

Say you have the probabilities A, B, C, and D, which have the values 0.1, 0.1, 0.1 and 0.7 respectively. Alias method would spread the probability of 0.7 to all the others. One index would correspond to each probability, where you would have the 0.1 and 0.15 for ABC, and 0.25 for D's index. With this you normalize each probability so that you end up with 0.4 chance of getting A and 0.6 chance of getting D in A's index (0.1/(0.1 + 0.15) and 0.15/(0.1 + 0.15) respecively) as well as B and C's index, and 100% chance of getting D in D's index (0.25/0.25 is 1).

Given an unbiased uniform PRNG (Math.Random()) for indexing, you get an equal probability of choosing each index, but you also do a coin flip per index which provides the weighted probability. You have a 25% chance of landing on the A or D slot, but within that you only have a 40% chance of picking A, and 60% of D. .40 * .25 = 0.1, our original probability, and if you add up all of D's probabilities strewn through out the other indices, you would get .70 again.

So to do random selection, you need only to generate a random index from 0 to N, then do a coin flip, no matter how many items you add, this is very fast and constant cost. Making an alias table doesn't take that many lines of code either, my python version takes 80 lines including import statements and line breaks, and the version presented in the Pandas article is similarly sized (and it's C++)

For your java implementation one could map between probabilities and array list indices to your functions you must execute, creating an array of functions which are executed as you index to each, alternatively you could use function objects (functors) which have a method that you use to pass parameters in to execute.

ArrayList<(YourFunctionObject)> function_list;
// add functions
AliasSampler aliassampler = new AliasSampler(listOfProbabilities);
// somewhere later with some type T and some parameter values. 
int index = aliassampler.sampleIndex();
T result = function_list[index].apply(parameters);

EDIT:

I've created a version in java of the AliasSampler method, using classes, this uses the sample index method and should be able to be used like above.

import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;

public class AliasSampler {
	private ArrayList<Double> binaryProbabilityArray;
	private ArrayList<Integer> aliasIndexList;
	AliasSampler(ArrayList<Double> probabilities){
		// java 8 needed here
		assert(DoubleStream.of(probabilities).sum() == 1.0);
		int n = probabilities.size();
		// probabilityArray is the list of probabilities, this is the incoming probabilities scaled
		// by the number of probabilities.  This allows us to figure out which probabilities need to be spread 
		// to others since they are too large, ie [0.1 0.1 0.1 0.7] = [0.4 0.4 0.4 2.80]
		ArrayList<Double> probabilityArray;
		for(Double probability : probabilities){
			probabilityArray.add(probability);
		}
		binaryProbabilityArray = new ArrayList<Double>(Collections.nCopies(n, 0.0));
		aliasIndexList = new ArrayList<Integer>(Collections.nCopies(n, 0));
		ArrayList<Integer> lessThanOneIndexList = new ArrayList<Integer>();
		ArrayList<Integer> greaterThanOneIndexList = new ArrayList<Integer>();
		for(int index = 0; index < probabilityArray.size(); index++){
			double probability = probabilityArray.get(index);
			if(probability < 1.0){
				lessThanOneIndexList.add(index);
			}
			else{
				greaterThanOneIndexList.add(index);
			}
		}

		// while we still have indices to check for in each list, we attempt to spread the probability of those larger
		// what this ends up doing in our first example is taking greater than one elements (2.80) and removing 0.6, 
		// and spreading it to different indices, so (((2.80 - 0.6) - 0.6) - 0.6) will equal 1.0, and the rest will
		// be 0.4 + 0.6 = 1.0 as well. 
		while(lessThanOneIndexList.size() != 0 && greaterThanOneIndexList.size() != 0){
			//https://stackoverflow.com/questions/16987727/removing-last-object-of-arraylist-in-java
			// last element removal is equivalent to pop, java does this in constant time
			int lessThanOneIndex = lessThanOneIndexList.remove(lessThanOneIndexList.size() - 1);
			int greaterThanOneIndex = greaterThanOneIndexList.remove(greaterThanOneIndexList.size() - 1);
			double probabilityLessThanOne = probabilityArray.get(lessThanOneIndex);
			binaryProbabilityArray.set(lessThanOneIndex, probabilityLessThanOne);
			aliasIndexList.set(lessThanOneIndex, greaterThanOneIndex);
			probabilityArray.set(greaterThanOneIndex, probabilityArray.get(greaterThanOneIndex) + probabilityLessThanOne - 1);
			if(probabilityArray.get(greaterThanOneIndex) < 1){
				lessThanOneIndexList.add(greaterThanOneIndex);
			}
			else{
				greaterThanOneIndexList.add(greaterThanOneIndex);
			}
		}
		//if there are any probabilities left in either index list, they can't be spread across the other 
		//indicies, so they are set with probability 1.0. They still have the probabilities they should at this step, it works out mathematically.
		while(greaterThanOneIndexList.size() != 0){
			int greaterThanOneIndex = greaterThanOneIndexList.remove(greaterThanOneIndexList.size() - 1);
			binaryProbabilityArray.set(greaterThanOneIndex, 1.0);
		}
		while(lessThanOneIndexList.size() != 0){
			int lessThanOneIndex = lessThanOneIndexList.remove(lessThanOneIndexList.size() - 1);
			binaryProbabilityArray.set(lessThanOneIndex, 1.0);
		}
	}
	public int sampleIndex(){
		int index = new Random().nextInt(binaryProbabilityArray.size());
		double r = Math.random();
		if( r < binaryProbabilityArray.get(index)){
			return index;
		}
		else{
			return aliasIndexList.get(index);
		}
	}

}

Solution 5 - Java

You could compute the cumulative probability for each class, pick a random number from [0; 1) and see where that number falls.

class WeightedRandomPicker {
    
    private static Random random = new Random();
    
    public static int choose(double[] probabilties) {
    	double randomVal = random.nextDouble();
    	double cumulativeProbability = 0;
    	for (int i = 0; i < probabilties.length; ++i) {
    		cumulativeProbability += probabilties[i];
    		if (randomVal < cumulativeProbability) {
    			return i;
    		}
    	}
    	return probabilties.length - 1; // to account for numerical errors
    }
    	
    public static void main (String[] args) {
    	double[] probabilties = new double[]{0.1, 0.1, 0.2, 0.6}; // the final value is optional
    	for (int i = 0; i < 20; ++i) {
    		System.out.printf("%d\n", choose(probabilties));
    	}
    }
}

Solution 6 - Java

The following is a bit like @daniu answer but makes use of the methods provided by TreeMap:

private final NavigableMap<Double, Runnable> map = new TreeMap<>();
{
	map.put(0.3d, this::branch30Percent);
	map.put(1.0d, this::branch70Percent);
}
private final SecureRandom random = new SecureRandom();

private void branch30Percent() {}

private void branch70Percent() {}

public void runRandomly() {
	final Runnable value = map.tailMap(random.nextDouble(), true).firstEntry().getValue();
	value.run();
}

This way there is no need to iterate the whole map until the matching entry is found, but the capabilities of TreeSet in finding an entry with a key specifically comparing to another key is used. This however will only make a difference if the number of entries in the map is large. However it does save a few lines of code.

Solution 7 - Java

I'd do that something like this:

class RandomMethod {
    private final Runnable method;
    private final int probability;

    RandomMethod(Runnable method, int probability){
        this.method = method;
        this.probability = probability;
    }

    public int getProbability() { return probability; }
    public void run()      { method.run(); }
}

class MethodChooser {
    private final List<RandomMethod> methods;
    private final int total;

    MethodChooser(final List<RandomMethod> methods) {
        this.methods = methods;
        this.total = methods.stream().collect(
            Collectors.summingInt(RandomMethod::getProbability)
        );
    }

    public void chooseMethod() {
        final Random random = new Random();
        final int choice = random.nextInt(total);

        int count = 0;
        for (final RandomMethod method : methods)
        {
            count += method.getProbability();
            if (choice < count) {
                method.run();
                return;
            }
        }
    }
}

Sample usage:

MethodChooser chooser = new MethodChooser(Arrays.asList(
    new RandomMethod(Blah::aaa, 1),
    new RandomMethod(Blah::bbb, 3),
    new RandomMethod(Blah::ccc, 1)
));

IntStream.range(0, 100).forEach(
    i -> chooser.chooseMethod()
);

Run it here.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionMoff KalastView Question on Stackoverflow
Solution 1 - JavadaniuView Answer on Stackoverflow
Solution 2 - JavajpaView Answer on Stackoverflow
Solution 3 - JavaSteakOverflowView Answer on Stackoverflow
Solution 4 - JavaKrupipView Answer on Stackoverflow
Solution 5 - JavaNPEView Answer on Stackoverflow
Solution 6 - JavaSpaceTruckerView Answer on Stackoverflow
Solution 7 - JavaMichaelView Answer on Stackoverflow