Mathias Brandewinder on .NET, F#, VSTO and Excel development, and quantitative analysis / machine learning.
by Mathias 5. August 2012 13:50

Today’s topic will be Chapter 3 of “Machine Learning in Action”, which covers Decision Trees.

Disclaimer: I am new to Machine Learning, and claim no expertise on the topic. I am currently reading“Machine Learning in Action”, and thought it would be a good learning exercise to convert the book’s samples from Python to F#.

The idea behind Decision Trees is similar to the Game of 20 Questions: construct a set of discrete Choices to identify the Class of an item. We will use the following dataset for illustration: imagine that we have 5 cards, each with a major masterpiece of contemporary cinema, classified by genre. Now I hide one – and you can ask 2 questions about the genre of the movie to identify the Thespian luminary in the lead role, in as few questions as possible:

  Action Sci-Fi Actor
Cliffhanger Yes No Stallone
Rocky Yes No Stallone
Twins No No Schwarzenegger
Terminator Yes Yes Schwarzenegger
Total Recall Yes Yes Schwarzenegger

The questions you would likely ask are:

  • Is this a Sci-Fi movie? If yes, Arnold is the answer, if no,
  • Is this an Action movie? if yes, go for Sylvester, otherwise Arnold it is.


That’s a Decision Tree in a nutshell: we traverse a Tree, asking about features, and depending on the answer, we draw a conclusion or recurse deeper into more questions. The goal today is to let the computer build the right tree from the dataset, and use that tree to classify “subjects”.

Defining a Tree

Let’s start with the end – the Tree. A common and convenient way to model Trees in F# is to use a discriminated union like this:

type Tree = 
    | Conclusion of string 
    | Choice of string * (string * Tree) []

A Tree is composed of either a Conclusion, described by a string, or a Choice, which is described by a string, and an Array of multiple options, each described by a string and its own Tree, “tupled”.

For instance, we can manually create a tree for our example like this:

let manualTree = 
               [|("Yes", Conclusion "Stallone");
                 ("No", Conclusion "Schwarzenegger")|]));
           ("Yes", Conclusion "Schwarzenegger")|])

Our tree starts with a Choice, labeled “Sci-Fi”, with 2 options in an Array, “No” or “Yes”. “Yes” leads to a Conclusion (a Leaf node), Arnold, while “No” opens another Choice, “Action”, with 2 Conclusions.

So how can we use this to Classify a “Subject”? We need to traverse down the Tree, check what branch corresponds to the Subject for the current Choice, and continue until we reach a Decision node, at what point we can return the contents of the Conclusion. To that effect, we’ll represent a “Subject” (the thing we are trying to classify) as an collection of Tuples, each Tuple being a key/value pair, representing a Feature and its value:

let test = [| ("Action", "Yes"); ("Sci-Fi", "Yes") |]

We are ready to write a classification function now:

let rec classify subject tree =
    match tree with
    | Conclusion(c) -> c
    | Choice(label, options) ->
        let subjectState =
            |> Seq.find(fun (key, value) -> key = label)
            |> snd
        |> Array.find (fun (option, tree) -> option = subjectState)
        |> snd
        |> classify subject

classify is a recursive function: given a subject and a tree, if the Tree is a Conclusion, we are done, otherwise, we retrieve the label of the next Choice, find the value of the Subject for that Choice, and use it to pick the next level of the Tree.

At that point, using the Tree to classify our subject is as simple as:

> let actor = classify test manualTree;;

val actor : string = "Schwarzenegger"

Not bad for 14 lines of code. The most painful part is the manual construction of the Tree – let’s see if we can get the computer to build that for us.



Comment RSS