HiveBrain v1.2.0
Get Started
← Back to all entries
patternMinor

Functional translation of dice simulation

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
functionaltranslationdicesimulation

Problem

I have tried my hands at "functionalising" a toy problem I had: find the expected number of throws of a six sided die until all sides have been seen (the answer is 14.7)

My starting point is my imperative solution

import scala.collection.mutable.SortedSet
import scala.util.Random

def dice(nsim:Int) : Double = {
    val nsides = 6
    val r = Random
    val throws = SortedSet[Int]()
    var nthrows = 0
    var res = 0.0

    for (i <- 1 to nsim) {
        throws.clear
        nthrows = 0
        while ( throws.size != nsides) {
            nthrows += 1
            throws += r.nextInt(nsides) + 1
        }
        res += nthrows
    }
    res / nsim
}

dice(10000) // ~14.7


I initialise an empty set, and keep adding my thrown dice until i have seen all sides (the set length equals six), while keeping track of how many die I have thrown. The expected number is just the sum of all these throws divided by the number of simulations (repetitions) I perform: approx 14.7

And here is my functional attempt (almost half a day went into this, including plenty of googling, outofmemory errors etc) - and I am not sure if this is considered good.

def throwdie(nsides:Int) : Int = {
    val r = Random
    r.nextInt(nsides) + 1 
}

def nthrows(seen:SortedSet[Int], count:Int, nsides:Int) : Int = {
    if (seen.size == nsides)
        return count
    return nthrows(seen + throwdie(nsides), count + 1, nsides)
}

def fdice(nsim:Int, nsides:Int) : Double = {
    val runs = Iterator.fill(nsim)(SortedSet[Int]())
    runs.map( x=> nthrows(x, 0, nsides) ).reduceLeft(_+_) * 1.0 / nsim
}

fdice(10000, 6) // ~14.7


explanation

So I start with a function throwdie that just returns the outcome of rolling a single die.

Next I have tried to rewrite the while loop using recursion. I came up with nthrows which takes an empty set, a zero count, and the number of die sides, and returns the number of throws until all sides have been seen for one "simulation". I use a mut

Solution

Let me first present my solution and then comment below: Let me point the code is more verbose than needed, in order to be more descriptive.

import scala.annotation.tailrec
import scala.collection.SortedSet
import scala.util.Random

//=========Neat reusable methods=============

//Infinite iterator of the results of a dice throw
def diceResultIterator(sidesCount: Int): Iterator[Int] = {

  case class Dice(rnd: Random, sidesCount: Int, upside: Int)
  //Gets a dice, returns a dice, functional "mutation"! For more, search State monad
  def roll(initDice: Dice) = {
    val newSide = initDice.rnd.nextInt(initDice.sidesCount) + 1
    Dice(initDice.rnd, initDice.sidesCount, newSide)
  }
  val someDice = Dice(new Random, sidesCount, 1)
  val diceIterator = Iterator.iterate(someDice)(roll)
  diceIterator.map(_.upside)
}

//Infinite iterator of the results of the given function
def simulateManyTimes[T](f: => T) = Iterator.iterate(f)(_ => f)

//A usual general function to calculate and average
def average(values: Iterable[Int]): Option[Double] = {
  val (total, len) = values.foldLeft((0, 0)) { case ((sum, cnt), value) => (sum + value, cnt + 1) }
  if (len > 0) Some(total.toDouble / len) else None
}
def averageResultsOfIntSimulations(f: => Int, repeat: Int) = {
  average(simulateManyTimes(f).take(repeat).toIterable)
}
//===================Problem specific part======================

//Similar to OP method, just encapsulated the recursion
def countThrowsUntilAllSidesSeen(diceIter: Iterator[Int], sidesCount: Int): Int = {
  @tailrec
  def recursiveHelper(seen: SortedSet[Int], count: Int): Int = {
    if (seen.size == sidesCount)
      count
    else
      recursiveHelper(seen + diceIter.next(), count + 1)
  }
  recursiveHelper(SortedSet[Int](), 0)
}
def simulationWrapper(sides: Int, simRepeat: Int) = {
  averageResultsOfIntSimulations(countThrowsUntilAllSidesSeen(diceResultIterator(sides), sides), simRepeat).get
}

simulationWrapper(6, 1000000)


Generally, creating a new Randomfor each throw is a bad practice since it may produce less random results than expected. My proposed solution reuses the same Random in a functional way for each simulation.

I am in favor of the school that says code should be self documenting, and your second method tries to fit too many things in one place, adding obscurity.

val runs = Iterator.fill(nsim)(SortedSet[Int]())
runs.map( x=> nthrows(x, 0, nsides) ).reduceLeft(_+_) * 1.0 / nsim


For example there is no good reason why you would fill an iterator with SortedSets then use then with map, instead of just runs.map( _ => nthrows(SortedSet[Int](), 0, nsides)) Counter-intuitive code is bad code. In production, the dev next to you would certainly tap your shoulder to ask you to explain your code, generally not too happy that he had to do so.

As general advice, tune the code to your audience, like with all forms of communication. For example, the diceResultIterator in my solution is purposefully verbose, to make the concept clearer. When you are comfortable with it, you can see that it is rather trivially refactored to ~half the lines. Scala has a trap to chain operations into obscurity, be mindful about this ;)

When you see recursiveHelper inside countThrowsUntilAllSidesSeenwith arguments seen: SortedSet[Int], count: Int, it's purpose is already quite clear, so the most important piece of information was to say that it is there just to encapsulate recursion.

Lastly, I added explicit return type is the functions that I felt it added clarity. Think of them like training wheels: Just remove them when they feel useless.

Code Snippets

import scala.annotation.tailrec
import scala.collection.SortedSet
import scala.util.Random

//=========Neat reusable methods=============

//Infinite iterator of the results of a dice throw
def diceResultIterator(sidesCount: Int): Iterator[Int] = {

  case class Dice(rnd: Random, sidesCount: Int, upside: Int)
  //Gets a dice, returns a dice, functional "mutation"! For more, search State monad
  def roll(initDice: Dice) = {
    val newSide = initDice.rnd.nextInt(initDice.sidesCount) + 1
    Dice(initDice.rnd, initDice.sidesCount, newSide)
  }
  val someDice = Dice(new Random, sidesCount, 1)
  val diceIterator = Iterator.iterate(someDice)(roll)
  diceIterator.map(_.upside)
}

//Infinite iterator of the results of the given function
def simulateManyTimes[T](f: => T) = Iterator.iterate(f)(_ => f)

//A usual general function to calculate and average
def average(values: Iterable[Int]): Option[Double] = {
  val (total, len) = values.foldLeft((0, 0)) { case ((sum, cnt), value) => (sum + value, cnt + 1) }
  if (len > 0) Some(total.toDouble / len) else None
}
def averageResultsOfIntSimulations(f: => Int, repeat: Int) = {
  average(simulateManyTimes(f).take(repeat).toIterable)
}
//===================Problem specific part======================

//Similar to OP method, just encapsulated the recursion
def countThrowsUntilAllSidesSeen(diceIter: Iterator[Int], sidesCount: Int): Int = {
  @tailrec
  def recursiveHelper(seen: SortedSet[Int], count: Int): Int = {
    if (seen.size == sidesCount)
      count
    else
      recursiveHelper(seen + diceIter.next(), count + 1)
  }
  recursiveHelper(SortedSet[Int](), 0)
}
def simulationWrapper(sides: Int, simRepeat: Int) = {
  averageResultsOfIntSimulations(countThrowsUntilAllSidesSeen(diceResultIterator(sides), sides), simRepeat).get
}

simulationWrapper(6, 1000000)
val runs = Iterator.fill(nsim)(SortedSet[Int]())
runs.map( x=> nthrows(x, 0, nsides) ).reduceLeft(_+_) * 1.0 / nsim

Context

StackExchange Code Review Q#118173, answer score: 3

Revisions (0)

No revisions yet.