Welcome to day 1 of the OVO tech blog advent calendar 2018!
GADT (Generalized Algebraic Data Type) is a term that crops up quite frequently in discussions about Haskell, OCaml and the more functional end of the Scala spectrum, so it's worth understanding what it's all about.
GADTs are not as scary as they sound, and they are a useful tool for making your programs more type-safe. They allow you to encode invariants in the type system, meaning you can let the compiler take care of checking more stuff so you don't have to handle as many cases at runtime.
I'll start with a (slightly contrived) real-world-ish example of how GADTs can be useful, then explain what exactly a GADT is.
All the code in this post will be in Scala, but the ideas apply equally well to any language with GADT support.
Kafka integration tests
At OVO we use Kafka a lot, so integration tests for our microservices often involve producing and consuming Kafka messages. For example, a test might hit a service's API endpoint, verify that the API response is as expected, and finally check that the service sent the expected message to a certain Kafka topic.
In our hypothetical integration test, we want to test a service which sends events to multiple Kafka topics. For any given topic, the service will always send events of a single type: it sends CreatedWidget
events to the created_widget
topic, UpdatedWidget
events to the updated_widget
topic and so on.
In our tests we'd like to use that information - we want to say "I expect the service to send a 'created widget' event", meaning that we should poll the created_widget
topic for a message, and then deserialise that message as a CreatedWidget
event.
Let's say we have a KafkaConsumer
for each topic, which knows how to deserialise messages to the appropriate type:
class KafkaConsumer[A](topicName: String) {
def consumeEvent(timeoutMs: Long): Option[A] = {
// - poll the Kafka topic
// - if an event shows up before the timeout, deserialize it
None
}
}
val createdWidgetConsumer =
new KafkaConsumer[CreatedWidget](topicName = "created_widget")
val updatedWidgetConsumer =
new KafkaConsumer[UpdatedWidget](topicName = "updated_widget")
A naïve way to write the "expect the service to send an event" function might look like this:
def expectKafkaEvent(eventType: String): Option[AnyRef] = eventType match {
case "created" => createdWidgetConsumer.consumeEvent(timeoutMs)
case "updated" => updatedWidgetConsumer.consumeEvent(timeoutMs)
case _ => throw new IllegalArgumentException("huh?")
}
Here we take the expected event type as a string, which causes a couple of problems:
- The compiler doesn't give us any exhaustivity checks, so we need a catch-all case that just blows up
- Each consumer returns a different type, so we have to return their common supertype, which is
AnyRef
in this case
A test that uses the expectKafkaEvent
function might look like this:
def myAwesomeTest(): Unit = {
val createWidgetResponse = sendCreateWidgetRequestToAPI()
assert(createWidgetResponse.hasStatusCode(201))
val firstEvent = expectKafkaEvent("created")
assert(firstEvent == Some(CreatedWidget(id = 1)))
val updateWidgetResponse = sendUpdateWidgetRequestToAPI()
assert(updateWidgetResponse.hasStatusCode(200))
val secondEvent = expectKafkaEvent("updated")
assert(secondEvent == Some(UpdatedWidget(id = 1)))
}
It works, but it's a bit unpleasant. The type of firstEvent
and secondEvent
is Option[AnyRef]
, and we could easily make a silly typo when specifying the expected event type and end up with a RuntimeException
.
Let's see if we can refactor this code to solve the problems described above.
First let's solve the problem of exhaustivity checking by introducing an ADT (Algebraic Data Type) to represent the different types of event we care about:
sealed trait EventType
case object Created extends EventType
case object Updated extends EventType
If we rewrite the expectKafkaEvent
function to take an EventType
instead of a string, we can remove our catch-all case:
def expectKafkaEvent(eventType: EventType): Option[AnyRef] = eventType match {
case Created => createdWidgetConsumer.consumeEvent(timeoutMs)
case Updated => updatedWidgetConsumer.consumeEvent(timeoutMs)
}
That solves our first problem, but the second problem is still present. We are still returning a result of type Option[AnyRef]
. This is a shame, as anybody calling expectKafkaEvent(Created)
knows that the result should be Option[WidgetCreated]
, but that type information is being lost.
Let's fix that by turning our plain old ADT into a Generalized ADT:
sealed trait EventType[A]
case object Created extends EventType[CreatedWidget]
case object Updated extends EventType[UpdatedWidget]
EventType
is now polymorphic in A
, where A
represents the type of Kafka event that we expect. And each case object
sets A
to a different type: CreatedWidget
or UpdatedWidget
respectively.
Now that EventType
is a GADT, our expectKafkaEvent
function becomes polymorphic in the event type A
:
def expectKafkaEvent[A](eventType: EventType[A]): Option[A] = eventType match {
case Created => createdWidgetConsumer.consumeEvent(timeoutMs)
case Updated => updatedWidgetConsumer.consumeEvent(timeoutMs)
}
With this change, the return type of expectKafkaEvent
will change appropriately depending on which case it matches:
scala> expectKafkaEvent(Created)
res0: Option[blog.kafka.events.CreatedWidget] = None
scala> expectKafkaEvent(Updated)
res1: Option[blog.kafka.events.UpdatedWidget] = None
If you've ever used path-dependent types in Scala, this might strike you as a similar concept: the return type of the function depends on the input type.
Thanks to the GADT, our integration test is now beautifully type-safe:
def myAwesomeTest(): Unit = {
val createWidgetResponse = sendCreateWidgetRequestToAPI()
assert(createWidgetResponse.hasStatusCode(201))
val firstEvent: Option[CreatedWidget] = expectKafkaEvent(Created)
assert(firstEvent == Some(CreatedWidget(id = 1)))
val updateWidgetResponse = sendUpdateWidgetRequestToAPI()
assert(updateWidgetResponse.hasStatusCode(200))
val secondEvent: Option[UpdatedWidget] = expectKafkaEvent(Updated)
assert(secondEvent == Some(UpdatedWidget(id = 1)))
}
Recap
Now you've seen an example of GADTs in action, let's look more closely at what exactly a GADT is, and how it differs from a normal ADT.
A GADT is an ADT that is parameterized by a so-called phantom type.
sealed trait Foo[A]
case object FirstCase extends Foo[String]
case class SecondCase(x: Int) extends Foo[Boolean]
Because of this, when we pattern match on the GADT, in each case the compiler is able to infer what the concrete type of A
(the phantom type) will be:
def example[A](foo: Foo[A]): A = foo match {
case FirstCase => "hello" // compiler knows here that A == String
case SecondCase(x) => true // compiler knows here that A == Boolean
}
Another example
As I'm writing a blog post about GADTs, I'm bound by law to introduce an example of using a GADT to build a type-safe DSL for mathematical expressions.
sealed trait Expr[A]
case class Num(x: Int) extends Expr[Int]
case class Bool(x: Boolean) extends Expr[Boolean]
case class Add(x: Expr[Int], y: Expr[Int]) extends Expr[Int]
case class Equals[A](x: Expr[A], y: Expr[A]) extends Expr[Boolean]
This GADT allows us to build a tree representing an expression. The leaf nodes are integers and booleans, and there are branch nodes for adding two integers and for checking equality of two subexpressions.
For example, the expression ((3 == 1 + 2) == true)
can be written using our DSL as:
val expression: Expr[Boolean] =
Equals(
Equals(
Num(3),
Add(
Num(1),
Num(2))),
Bool(true))
(Of course, we could add some syntax sugar to make this look a bit nicer.)
Note that an Equals(x, y)
expression doesn't care what the types of the subexpressions are, but it does specify that they have to be the same type. So you can compare a number to a number and a boolean to a boolean, but you can't compare numbers and booleans - it doesn't make sense. We've used the type system to encode the business rules of our domain.
Similarly, it's not possible to add a boolean to anything. You can only add numbers to other numbers.
Given this GADT, it's trivial to write a function to evaluate an expression. The result will be either an integer or a boolean, depending on the type of the input.
def eval[A](expr: Expr[A]): A = expr match {
case Num(x) => x
case Bool(x) => x
case Add(x, y) => eval(x) + eval(y)
case Equals(x, y) => eval(x) == eval(y)
}
scala> eval(expression)
res0: Boolean = true
Type-safe DSLs like this one are a really common use case for GADTs.
A nice example in OCaml
At Curry On! conference in Amsterdam I watched a really good talk by Andreas Garnæs about his GraphQL server implementation. It uses GADTs to guarantee at compile-time that the server implementation matches a given GraphQL schema.
The video is available here: https://www.youtube.com/watch?v=jaKcEGkItsY
Further reading
Here's a nice introduction to GADTs in Haskell.
Sample code by Paul Chiusano (@pchiusano) demonstrating GADTs in Scala, including some pitfalls.
A ridiculously detailed post about rank-n types in Haskell and how they relate to GADTs.
If, like me, you count reading about advanced Scala topics in Japanese as a hobby, you might find this post by Kenji Yoshida (@xuwei_k) interesting.