patternMinor
Hamiltonian Monte Carlo in Scala
Viewed 0 times
montehamiltoniancarloscala
Problem
I'm writing a program in Scala to perform Hamiltonian Monte Carlo (HMC), coupled with Gibbs sampling of some variables. The algorithm, with the modifications such as perturbing epsilon and l and performing Gibbs updates on some variables may be found in the easily available book chapter "MCMC using Hamiltonian Dynamics" by R Neal. The program works by alternating an HMC update of a set of variables q via the
I'm very new to Scala (last Friday) so I would like some feedback on the idiomacity of the code, as well as advice in general on improving readability. I've tried to write this in the form of an abstract class so that myself and my other group members can use this in their own code by supplying the unspecified methods themselves - is this the most appropriate structure to use?
```
import breeze.linalg._
import breeze.stats.distributions.Uniform
import breeze.stats.distributions.Gaussian
import breeze.stats.distributions.Gamma
import math.log, math.sqrt
import scala.annotation.tailrec
abstract class HMC(ul: Int, ue: Double) {
val uDev = new Uniform(0.0,1.0)
val nDev = new Gaussian(0.0,1.0)
val l = ul
val ep = ue
final def metCheck(ho: Double, hn: Double): Boolean = log(uDev.sample()) < ho - hn
def calcU(q: DenseVector[Double], sigs: List[Double]): Double
def calcUgrad(q: DenseVector[Double], sigs: List[Double]): DenseVector[Double]
def gibbsUpdate(q: DenseVector[Double]): List[Double]
final def hmcUpdate(qo: DenseVector[Double], sigs: List[Double], acc: Int): (DenseVector[Double], Int) = {
val uep = ep min(sigs) min(sigs) (0.9 + 0.2 uDev.sample()) // perturb ep, l pm 10% to avoid cyclic orbits
val ul = (l.toDouble (0.9 + 0.2 uDev.sample())).toInt
val po = DenseVector(nDev.sample(qo.length).toArr
hmcUpdate() method with Gibbs updates of the variables gibbsVar via some unspecified method, then putting the two sets together in a tuple. The hmcRunner() method generates these tuples and prepends them onto a list.I'm very new to Scala (last Friday) so I would like some feedback on the idiomacity of the code, as well as advice in general on improving readability. I've tried to write this in the form of an abstract class so that myself and my other group members can use this in their own code by supplying the unspecified methods themselves - is this the most appropriate structure to use?
```
import breeze.linalg._
import breeze.stats.distributions.Uniform
import breeze.stats.distributions.Gaussian
import breeze.stats.distributions.Gamma
import math.log, math.sqrt
import scala.annotation.tailrec
abstract class HMC(ul: Int, ue: Double) {
val uDev = new Uniform(0.0,1.0)
val nDev = new Gaussian(0.0,1.0)
val l = ul
val ep = ue
final def metCheck(ho: Double, hn: Double): Boolean = log(uDev.sample()) < ho - hn
def calcU(q: DenseVector[Double], sigs: List[Double]): Double
def calcUgrad(q: DenseVector[Double], sigs: List[Double]): DenseVector[Double]
def gibbsUpdate(q: DenseVector[Double]): List[Double]
final def hmcUpdate(qo: DenseVector[Double], sigs: List[Double], acc: Int): (DenseVector[Double], Int) = {
val uep = ep min(sigs) min(sigs) (0.9 + 0.2 uDev.sample()) // perturb ep, l pm 10% to avoid cyclic orbits
val ul = (l.toDouble (0.9 + 0.2 uDev.sample())).toInt
val po = DenseVector(nDev.sample(qo.length).toArr
Solution
I rewrote parts of your code (see below). I did not have enough time to complete the refactoring, so it is far from I would really want to see, but at least it is not Fortran anymore.
-
You can use
-
The names, pretty much all names, are very uninformative. Why
-
You should separate the printing of reports from the actual algorithm (separation of concerns). This advice applies to all languages, not just Scala.
-
Avoid using recursion. It is like the assembly language of functional programming. Use
-
It looks like you don't know about "string" + " " + "concantetation" or, even better, interpolation:
-
In an
-
You call
-
You can shorten the imports:
-
Terminating a recursive loop on
-
I'm usually a big fan of breaking everything in small functions, but
-
Don't pass
-
The OO/functional design was pretty much inexistent. I did some, but it is currently very poor due to lack of time. A good design greatly improves readability. You should spend some time to properly break everything down in classes. And then think some more about it and do it, etc.
After some partial refactoring:
-
You can use
trait instead of abstract class. It is more flexible.-
The names, pretty much all names, are very uninformative. Why
uDev for a distribution?-
You should separate the printing of reports from the actual algorithm (separation of concerns). This advice applies to all languages, not just Scala.
-
Avoid using recursion. It is like the assembly language of functional programming. Use
reduce and such instead.-
It looks like you don't know about "string" + " " + "concantetation" or, even better, interpolation:
s"Hello, ${name}".-
In an
if expression, when the code at the start of both if and else is the same, you should pull that code out.-
You call
min(sigs) twice in a row. You should set a variable with that value and use it twice.-
You can shorten the imports:
import math.{log, sqrt}. You should also explicitly write out all imports instead of doing a bulk import (import breeze.linalg._). I was confused by where min was coming from.-
Terminating a recursive loop on
l == 1 looks odd in Scala; it's usually on 0.-
I'm usually a big fan of breaking everything in small functions, but
metCheck is so important that I think the computation should be left as code where it is called.-
Don't pass
acc to hmcUpdate. That function has no business dealing with that accumulator (separation of concern).-
The OO/functional design was pretty much inexistent. I did some, but it is currently very poor due to lack of time. A good design greatly improves readability. You should spend some time to properly break everything down in classes. And then think some more about it and do it, etc.
After some partial refactoring:
val unifDist = new Uniform(0.0, 1.0)
val normDist = new Gaussian(0.0, 1.0)
// Poor names. I don't know what they are.
val l = ul
val ep = ue
type Position = DenseVector[Double]
type Momentum = DenseVector[Double]
case class HamState(q: Position, p: Momentum)
def leapFrogStep(potentialEnergy: PotentialEnergy, uep: Double)(hamState: HamState): HamState = {
val q = hamState.q + (hamState.p :* uep)
val p = hamState.p - (potentialEnergy.calcGradU(hamState.q) :* uep)
HamState(q, p)
}
def leapFrogMomemtumHalfStep(potentialEnergy: PotentialEnergy, uep: Double)(hamState: HamState): HamState = {
val p = hamState.p - (potentialEnergy.calcGradU(hamState.q) :* uep) / 2.0
HamState(hamState.q, p)
}
trait PotentialEnergy {
def calcU(q: Position): Double
def calcGradU(q: Position): DenseVector[Double]
def computeTotalEnergy(hamState: HamState) = calcU(hamState.q) + (hamState.p dot hamState.p) / 2.0
// I don't have the slightest idea what uep is...
// I put it in this class, but it certainly does not belong here.
def uep(): Double
}
case class WhateverSigsIsPotentialEnergy(sigs: List[Double]) extends PotentialEnergy {
override def calcU(q: Position) = calcU_(q, sigs)
override def calcGradU(q: Position) = calcUgrad_(q, sigs)
// uep does not belong here. It should not access "global" ep or uDev.
override def uep = ep * min(sigs) * min(sigs) * (0.9 + 0.2 * unifDist.sample()) // perturb ep, l pm 10% to avoid cyclic orbits
}
def hmcStep(potentialEnergy: PotentialEnergy)(q: Position): Option[Position] = {
val uep = potentialEnergy.uep
val nSteps = (l.toDouble * (0.9 + 0.2 * unifDist.sample())).toInt
val p = DenseVector(normDist.sample(q.length).toArray)
val h = potentialEnergy.calcU(q) + (p dot p) / 2.0
val initState = HamState(q, p)
def halfStepper = leapFrogMomemtumHalfStep(potentialEnergy, uep)(_)
def fullStepper = leapFrogStep(potentialEnergy, uep)(_)
def fullPathTransform = (halfStepper :: List.fill(nSteps - 1)(fullStepper)) :+ halfStepper
val state = fullPathTransform.foldLeft(initState) { (state, stepper) => stepper(state) }
val energyChange =
potentialEnergy.computeTotalEnergy(state) - potentialEnergy.computeTotalEnergy(initState)
if (unifDist.sample()
posAndG :: visitedStates
}
visitedStates
// To print some report every "iter" iterations:
// 1) change the foldLeft above so it also counts the "isNew"
// 2) instead of calling once ".take(maxIt)", write a function that
// only takes "iter" at a time an then prints out some report.
// Call that method until the total number of iterations reaches "maxIt".
}Code Snippets
val unifDist = new Uniform(0.0, 1.0)
val normDist = new Gaussian(0.0, 1.0)
// Poor names. I don't know what they are.
val l = ul
val ep = ue
type Position = DenseVector[Double]
type Momentum = DenseVector[Double]
case class HamState(q: Position, p: Momentum)
def leapFrogStep(potentialEnergy: PotentialEnergy, uep: Double)(hamState: HamState): HamState = {
val q = hamState.q + (hamState.p :* uep)
val p = hamState.p - (potentialEnergy.calcGradU(hamState.q) :* uep)
HamState(q, p)
}
def leapFrogMomemtumHalfStep(potentialEnergy: PotentialEnergy, uep: Double)(hamState: HamState): HamState = {
val p = hamState.p - (potentialEnergy.calcGradU(hamState.q) :* uep) / 2.0
HamState(hamState.q, p)
}
trait PotentialEnergy {
def calcU(q: Position): Double
def calcGradU(q: Position): DenseVector[Double]
def computeTotalEnergy(hamState: HamState) = calcU(hamState.q) + (hamState.p dot hamState.p) / 2.0
// I don't have the slightest idea what uep is...
// I put it in this class, but it certainly does not belong here.
def uep(): Double
}
case class WhateverSigsIsPotentialEnergy(sigs: List[Double]) extends PotentialEnergy {
override def calcU(q: Position) = calcU_(q, sigs)
override def calcGradU(q: Position) = calcUgrad_(q, sigs)
// uep does not belong here. It should not access "global" ep or uDev.
override def uep = ep * min(sigs) * min(sigs) * (0.9 + 0.2 * unifDist.sample()) // perturb ep, l pm 10% to avoid cyclic orbits
}
def hmcStep(potentialEnergy: PotentialEnergy)(q: Position): Option[Position] = {
val uep = potentialEnergy.uep
val nSteps = (l.toDouble * (0.9 + 0.2 * unifDist.sample())).toInt
val p = DenseVector(normDist.sample(q.length).toArray)
val h = potentialEnergy.calcU(q) + (p dot p) / 2.0
val initState = HamState(q, p)
def halfStepper = leapFrogMomemtumHalfStep(potentialEnergy, uep)(_)
def fullStepper = leapFrogStep(potentialEnergy, uep)(_)
def fullPathTransform = (halfStepper :: List.fill(nSteps - 1)(fullStepper)) :+ halfStepper
val state = fullPathTransform.foldLeft(initState) { (state, stepper) => stepper(state) }
val energyChange =
potentialEnergy.computeTotalEnergy(state) - potentialEnergy.computeTotalEnergy(initState)
if (unifDist.sample() < math.exp(-energyChange))
Some(state.q)
else
None
}
type GibbsSomethings = List[Double]
case class PositionAndGibbs(pos: Position, gibbs: GibbsSomethings)
/**
* @return Boolean true if the generated position is different than the previous one.
*/
def hmcAndGibbsStep(potentialEnergy: PotentialEnergy)(positionAndGibbs: PositionAndGibbs): (PositionAndGibbs, Boolean) = {
var posOption = hmcStep(potentialEnergy)(positionAndGibbs.pos)
var pos = posOption.getOrElse(positionAndGibbs.pos)
var gibbs = gibbsUpdate(pos)
(PositionAndGibbs(pos, gibbs), (pos != None))
}
def hmcRunner(potentialEnergy: PotentialEnergy)(maxIContext
StackExchange Code Review Q#61668, answer score: 2
Revisions (0)
No revisions yet.