Monday, April 18, 2016

Curried Functions in Scala

In scala, we define a function like this:
def product(a: Int, b: Int): Int = a * b
Then, we call it like:
product(2, 8)
Other than some minor syntax details, this is pretty much the way functions are defined in most of the mainstream languages, such as Java and C#. However, in Scala there is an alternative way of defining this same function:
def product(a: Int)(b: Int): Int = a * b
If we define it this way, we need to call it like:
product(2)(8)
I know, we've just replaced a comma and a space with a closing and an opening parenthesis. Harder to type, so what's the point? It turns out there are a number of reasons to use this alternate form. These include:
  • Partial function application
  • Type inference
  • Block syntax
  • Implicit parameters 
The first form of the function definition can be thought of as "a function that takes 2 integers". The parenthesis for precedence are just where they are shown in the definition.
def product(a: Int, b: Int): Int
The 2 parameters are bundled together and passed to the function. However, the second form is left-associative.
 def (product(a: Int))(b: Int): Int
This can be thought of as "a function that takes one integer, and returns a function that takes another integer".

According to Wikipedia, Currying is "The technique of translating the evaluation of a function that takes multiple arguments (or a tuple of arguments) into evaluating a sequence of functions, each with a single argument." So, the second form of the function definition is a curried function.

Partial Function Application

With the curried function in the above example, we can do the following:
val times3 = product(3)_
 times3 is a function that takes an integer and multiplies it by 7. In the REPL:
scala> times3(7)
res6: Int = 21
What we have now is a function that is poised to multiply any other integer by 3.

Note that in some languages, such as Haskell, all functions are curried by default, and partial application happens any time you call a function with fewer than the full number of parameters. However, in Scala, the trailing underscore is needed as a placeholder to allow for partial application.

A partially applied function can be useful in a number of circumstances. For example, if a higher order function accepts a function with fewer parameters than the function you have, you can partially apply it before passing it as a parameter. It can also be used to specialize more general functions, and is a mechanism for function reuse.

If you have an uncurried function, like def sum(a: Int, b: Int): Int = a + b, and need a curried version of it, you can get one like this:
val sumCurry = (sum _).curried
You can also partially apply an uncurried function by using placeholder syntax for the parameters:
val sumPartial = sum(_: Int, 4)
The upside of this technique is that you can partially apply any of the parameters. The downside is that you almost always need to specify the type of the missing parameters, even if they seem obvious.

Many times, the terms currying and partial application are used as though they are synonymous. This is not the case. Currying is the process described by the wikipedia definition above, while partial application refers to providing only a subset of the parameters to a function and getting a function that takes the remaining parameters. So, a curried function can be partially applied, but, at least in Scala, a curried function is not required for partial application. This means that in Scala the remaining points we will discuss with respect to currying are arguably more important than partial application.

Type Inference 

While partial application is something that is common to many languages using the functional paradigm, this topic is specific to Scala. Let's take this function:
def transmogrify[A,B](a: A, f: A=>B):B = f(a)
 Let's see what happens when we call this in the REPL:
scala> transmogrify(3, a => a * 4) 
<console>:9: error: missing parameter type
              transmogrify(3, a => _ * 4)
                              ^
What happened? Surely the Scala compiler can see that the parameter to the third argument must be an integer! But, it can't, so we need to help the compiler out and call it like this:
scala> transmogrify(3, (a: Int) => a * 4) 
res1: Int = 12
While this works, it is an unfortunate amount of boilerplate. The solution is to instead define the function like so:
def transmogrify[A,B](a: A)(f: A => B): B = f(a)
Now we can successfully call it like:
transmogrify(3)(a => a * 4) 
Or, even more succinctly using placeholder syntax:
transmogrify(3)(_ * 4) 
How does this work? Well, in languages like Haskell, the compiler is very good at inferring types - in essence, the type information flows in all directions. However, in Scala, type information flows from left to right across parameter groups (but not within parameter groups) and down into the body of the function. In the case of method calls on objects, the object is treated as it's own (and first) parameter group.

So, in the transmogrify example, Scala sees the parameter in the first parameter group and can determine the type of A, so it knows that the function in the secong parameter group must accept an Integer. I think of it in terms of partial application. Scala first applies the 3 to transmogrify, knows what A is, and is left with a function that accepts another function that accepts an A. So, it knows the type of the argument to that function without us explicitly stating it.

There are, however, instances where currying hurts type inference. For example:
scala> def fold[A,B](o: Option[A])(b: => B)(f: A=>B): B =
  o.fold(b)(f) 
defined function fold

scala> fold(Option(3))(Nil)(List(_)) 
Main.scala:2261: type mismatch;
 found   : List[Int]
 required: scala.collection.immutable.Nil.type
fold(Option(3))(Nil)(List(_))
                         ^
This code is trying to convert an Option to a List. If the Option was None, the list would be empty, otherwise it would be a List with a single item containing the value that was in the Option. However, Scala sees the Nil and sets the type of B to be Nil. When it hits the last parameter group, it sees the List(_), knows this is a List[Int], and gives an error. Why? Because of the way List is defined, Nil is a subtype of every type of list. However, List[Int] is not a Nil. We need to help out with type annotations here:
scala> fold(Option(3))(Nil: List[Int])(List(_)) 
res44: List[Int] = List(3)
However, if we combine the last 2 parameter groups like this:
scala> def fold[A,B](o: Option[A])(b: => B, f: A=>B): B =
  o.fold(b)(f) 
defined function fold

scala> fold(Option(3))(Nil, List(_)) 
res45: List[Int] = List(3)
In this case, Scala sees both Nil and List(_) at the same time, knows they are both supposed to evaluate to the same type (B), and widens Nil to List[Int]. (Thanks to Rob Norris for this example.) Note that this also shows that even with multiple parameter groups, the groups can have more than one parameter in them, which doesn't strictly fit the definition of currying but is similar in intent.

Block Syntax

Any time the last parameter group consists of a single function, block syntax can be used. Let's use the curried version of transmogrify from above as an example:
scala> transmogrify(2){ a =>
     | if (a < 2) "Small"
     | else "Large"
     | }
res7: String = Large
Curly braces are used instead of parenthesis, and the function can be split across multiple lines with no semicolons. Here is an alternate way using the syntactic shortcut for match statements:
scala> transmogrify(2) {
     | case a if a < 2 => "Small"
     | case _ => "Large"
     | }
res8: String = Large
In this case, the use of pattern matching is a bit of an overkill, but it serves as an illustration.

Block syntax, in combination with by-name parameters, is particularly useful for things like creating custom control structures.

Implicit Parameters

It is beyond the scope of this post to go into the details of implicit parameters, so we'll just touch on it briefly. The last parameter group of a function can be denoted as implicit, like so:
def binaryOp[A](a1: A, a2: A)(implicit f: (A,A) => A): A = f(a1, a2)
Then, if we have a value like this in scope:
implicit val tacit = (a1: Int, a2: Int) => a1 + a2 
We can call our function without including that implicit parameter group:
scala> binaryOp(3, 4) 
res6: Int = 7
I'll just note that implicits are very powerful, and as such come with great responsibility. They can be used by library authors to simplify user code, and for implementing such useful things as type classes. But, they can also be used to write extremely obfuscated code. Use with caution.

Note: This blog post is essentially the same information as a beginner's talk I gave at the PDX Scala Meetup on 19 April, 2016.