HiveBrain v1.2.0
Get Started
← Back to all entries
patternMinor

Tail call optimization via translating to CPS

Submitted by: @import:stackexchange-cs··
0
Viewed 0 times
translatingcpscalloptimizationtailvia

Problem

I am struggling to wrap my head around this compiler technique, so let's say here's my factorial function
def factorial(value: int) -> int:
if value == 0:
return 1
else:
return factorial(value-1) * value


It is recursive, but not TCO friendly yet, so, as the theory goes, the first thing to try here is translate it to CPS:
def factorial_cont(value: int, cont: typing.Callable[[int], T]) -> T:
if value == 0:
return cont(1)
else:
return factorial_cont(value-1, lambda result: cont(value * result))


Now, as the function is tail call recursive, I can do the usual trick with the while loop:
def factorial_while(value: int, cont: typing.Callable[[int], T]) -> T:
current_cont = cont
current_value = value
while True:
if current_value == 0:
return current_cont(1)
else:
current_cont = lambda result: current_cont(current_value * result)
# note: in actual python that would look like
# current_cont = lambda result, c=current_cont, v=current_value: c(v * result)
current_value = current_value - 1


This current_cont thing effectively becomes a huge composition chain, in haskell terms for the value == 3 that would be let resulting_cont = ((initial_cont . (3)) . (2)) . (1*), where initial_cont is safe to default to id, and surely enough resulting_cont value == value!.

But I also know the trick with "accumulator" value:
def factorial_acc(value: int, acc: int = 1) -> int:
current_acc = acc
current_value = value
while True:
if current_value == 1:
return current_acc
else:
current_acc = current_acc * current_value
current_value = current_value - 1


which looks pretty much identical to the CPS version after the introduction of while loop.

The question is, how exactly do I massage the continuation let resulting_cont = ((initial_cont . (3)) . (2)) . (1*) into t

Solution

Somehow turn the opaque lambdas

next_cont = lambda result: current_cont(current_value * result)


into transparent data
next_cont = Compose( current_cont, Mult(current_value))


then use the fact that
Compose( Compose( a_cont, Mult(val)), Mult(current_value))
==
Compose( a_cont, Compose( Mult(val), Mult(current_value)))
==
Compose( a_cont, Mult( val*current_value ))


You end up with Compose( init_cont, Mult( result )) which you then call with
call( Compose( init_cont, Mult( result )), 1 )
=
init_cont( call( Mult(result), 1))
=
init_cont( result*1 )


Look up "reifying continuations" and "defunctionalization".

Context

StackExchange Computer Science Q#153619, answer score: 2

Revisions (0)

No revisions yet.