Selezione ponderata casuale in Java

Voglio scegliere un object casuale da un set, ma la possibilità di scegliere qualsiasi object dovrebbe essere proporzionale al peso associato

Esempi di input:

item weight ---- ------ sword of misery 10 shield of happy 5 potion of dying 6 triple-edged sword 1 

Quindi, se ho 4 possibili articoli, la possibilità di ottenere qualsiasi articolo senza pesi sarebbe 1 su 4.

In questo caso, un utente dovrebbe essere 10 volte più probabilità di ottenere la spada della miseria rispetto alla spada a tre punte.

Come posso effettuare una selezione casuale ponderata in Java?

Vorrei usare una NavigableMap

 public class RandomCollection { private final NavigableMap map = new TreeMap(); private final Random random; private double total = 0; public RandomCollection() { this(new Random()); } public RandomCollection(Random random) { this.random = random; } public RandomCollection add(double weight, E result) { if (weight <= 0) return this; total += weight; map.put(total, result); return this; } public E next() { double value = random.nextDouble() * total; return map.higherEntry(value).getValue(); } } 

Diciamo che ho una lista di animali cane, gatto, cavallo con probabilità del 40%, 35%, 25% rispettivamente

 RandomCollection rc = new RandomCollection<>() .add(40, "dog").add(35, "cat").add(25, "horse"); for (int i = 0; i < 10; i++) { System.out.println(rc.next()); } 

Non troverai una struttura per questo tipo di problema, in quanto la funzionalità richiesta non è altro che una semplice funzione. Fai qualcosa del genere:

 interface Item { double getWeight(); } class RandomItemChooser { public Item chooseOnWeight(List items) { double completeWeight = 0.0; for (Item item : items) completeWeight += item.getWeight(); double r = Math.random() * completeWeight; double countWeight = 0.0; for (Item item : items) { countWeight += item.getWeight(); if (countWeight >= r) return item; } throw new RuntimeException("Should never be shown."); } } 

C’è ora una class per questo in Apache Commons: EnumeratedDistribution

 Item selectedItem = new EnumeratedDistribution(itemWeights).sample(); 

dove itemWeights è una List> , like (assumendo l’interfaccia Item nella risposta di Arne):

 List> itemWeights = Collections.newArrayList(); for (Item i : itemSet) { itemWeights.add(new Pair(i, i.getWeight())); } 

o in Java 8:

 itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList()); 

Nota: la Pair qui deve essere org.apache.commons.math3.util.Pair , non org.apache.commons.lang3.tuple.Pair .

Utilizza un metodo alias

Se farai rotolare un sacco di volte (come in un gioco), dovresti usare un metodo alias.

Il codice seguente è piuttosto lunga implementazione di un tale metodo di alias, anzi. Ma questo è dovuto alla parte di inizializzazione. Il recupero degli elementi è molto veloce (vedi i metodi applyAsInt che non eseguono il ciclo).

uso

 Set items = ... ; ToDoubleFunction weighter = ... ; Random random = new Random(); RandomSelector selector = RandomSelector.weighted(items, weighter); Item drop = selector.next(random); 

Implementazione

Questa implementazione:

  • usa Java 8 ;
  • è progettato per essere il più veloce ansible (beh, almeno, ho provato a farlo usando il micro-benchmarking);
  • è totalmente thread-safe (tieni un Random in ogni thread per le massime prestazioni, usa ThreadLocalRandom ?);
  • recupera elementi in O (1) , a differenza di ciò che si trova principalmente su Internet o su StackOverflow, dove le implementazioni ingenue eseguono in O (n) o O (log (n));
  • mantiene gli oggetti indipendenti dal loro peso , quindi a un object possono essere assegnati pesi diversi in diversi contesti.

Comunque, ecco il codice. (Nota che mantengo una versione aggiornata di questa class .)

 import static java.util.Objects.requireNonNull; import java.util.*; import java.util.function.*; public final class RandomSelector { public static  RandomSelector weighted(Set elements, ToDoubleFunction weighter) throws IllegalArgumentException { requireNonNull(elements, "elements must not be null"); requireNonNull(weighter, "weighter must not be null"); if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); } // Array is faster than anything. Use that. int size = elements.size(); T[] elementArray = elements.toArray((T[]) new Object[size]); double totalWeight = 0d; double[] discreteProbabilities = new double[size]; // Retrieve the probabilities for (int i = 0; i < size; i++) { double weight = weighter.applyAsDouble(elementArray[i]); if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); } discreteProbabilities[i] = weight; totalWeight += weight; } if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); } // Normalize the probabilities for (int i = 0; i < size; i++) { discreteProbabilities[i] /= totalWeight; } return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities)); } private final T[] elements; private final ToIntFunction selection; private RandomSelector(T[] elements, ToIntFunction selection) { this.elements = elements; this.selection = selection; } public T next(Random random) { return elements[selection.applyAsInt(random)]; } private static class RandomWeightedSelection implements ToIntFunction { // Alias method implementation O(1) // using Vose's algorithm to initialize O(n) private final double[] probabilities; private final int[] alias; RandomWeightedSelection(double[] probabilities) { int size = probabilities.length; double average = 1.0d / size; int[] small = new int[size]; int smallSize = 0; int[] large = new int[size]; int largeSize = 0; // Describe a column as either small (below average) or large (above average). for (int i = 0; i < size; i++) { if (probabilities[i] < average) { small[smallSize++] = i; } else { large[largeSize++] = i; } } // For each column, saturate a small probability to average with a large probability. while (largeSize != 0 && smallSize != 0) { int less = small[--smallSize]; int more = large[--largeSize]; probabilities[less] = probabilities[less] * size; alias[less] = more; probabilities[more] += probabilities[less] - average; if (probabilities[more] < average) { small[smallSize++] = more; } else { large[largeSize++] = more; } } // Flush unused columns. while (smallSize != 0) { probabilities[small[--smallSize]] = 1.0d; } while (largeSize != 0) { probabilities[large[--largeSize]] = 1.0d; } } @Override public int applyAsInt(Random random) { // Call random once to decide which column will be used. int column = random.nextInt(probabilities.length); // Call random a second time to decide which will be used: the column or the alias. if (random.nextDouble() < probabilities[column]) { return column; } else { return alias[column]; } } } } 

Se è necessario rimuovere elementi dopo aver scelto, è ansible utilizzare un’altra soluzione. Aggiungi tutti gli elementi in un ‘LinkedList’, ogni elemento deve essere aggiunto tutte le volte che il suo peso è, quindi usa Collections.shuffle() che, secondo JavaDoc

Permette a caso l’elenco specificato usando una fonte predefinita di casualità. Tutte le permutazioni si verificano con approssimativamente uguale probabilità.

Infine, ottieni e rimuovi elementi usando pop() o removeFirst()

 Map map = new HashMap() {{ put("Five", 5); put("Four", 4); put("Three", 3); put("Two", 2); put("One", 1); }}; LinkedList list = new LinkedList<>(); for (Map.Entry entry : map.entrySet()) { for (int i = 0; i < entry.getValue(); i++) { list.add(entry.getKey()); } } Collections.shuffle(list); int size = list.size(); for (int i = 0; i < size; i++) { System.out.println(list.pop()); } 
 public class RandomCollection { private final NavigableMap map = new TreeMap(); private double total = 0; public void add(double weight, E result) { if (weight <= 0 || map.containsValue(result)) return; total += weight; map.put(total, result); } public E next() { double value = ThreadLocalRandom.current().nextDouble() * total; return map.ceilingEntry(value).getValue(); } }