k-Means Clustering and Birth Rates

A common problem in machine learning is to take some kind of data and break it up into “clumps” that best reflect how the data is structured. A set of points which are all collectively close to each other should be in the same clump.

A simple picture will clarify any vagueness in this:

cluster-example

Here the data consists of points in the plane. There is an obvious clumping of the data into three pieces, and we want a way to automatically determine which points are in which clumps. The formal name for this problem is the clustering problem. That is, these clumps of points are called clusters, and there are various algorithms which find a “best” way to split the data into appropriate clusters.

The important applications of this are inherently similarity-based: if our data comes from, say, the shopping habits of users of some website, we may want to target a group of shoppers who buy similar products at similar times, and provide them with a coupon for a specific product which is valid during their usual shopping times. Determining exactly who is in that group of shoppers (and more generally, how many groups there are, and what the features the groups correspond to) if the main application of clustering.

This is something one can do quite easily as a human on small visualizable datasets, but the usual the digital representation (a list of numeric points with some number of dimensions) doesn’t yield any obvious insights. Moreover, as the data becomes more complicated (be it by dimension increase, data collection errors, or sheer volume) the “human method” can easily fail or become inconsistent. And so we turn to mathematics to formalize the question.

In this post we will derive one possible version of the clustering problem known as the k-means clustering or centroid clustering problem, see that it is a difficult problem to solve exactly, and implement a heuristic algorithm in place of an exact solution.

And as usual, all of the code used in this post is freely available on this blog’s Google code page.

Partitions and Squared Deviations

The process of clustering is really a process of choosing a good partition of the data. Let’s call our data set S, and formalize it as a list of points in space. To be completely transparent and mathematical, we let S be a finite subset of a metric space (X,d), where d is our distance metric.

Definition: We call a partition of a set S a choice of subsets A_1, \dots, A_n of S so that every element of S is in exactly one of the A_i.

A couple of important things to note about partitions is that the union of all the A_i is S, and that any two A_i, A_j intersect trivially. These are immediate consequences of the definition, and together provide an equivalent, alternative definition for a partition. As a simple example, the even and odd integers form a partition of the whole set of integers.

There are many different kinds of clustering problems, but every clustering problem seeks to partition a data set in some way depending on the precise formalization of the goal of the problem. We should note that while this section does give one of many possible versions of this problem, it culminates in the fact that this formalization is too hard to solve exactly. An impatient reader can safely skip to the following section where we discuss the primary heuristic algorithm used in place of an exact solution.

In order to properly define the clustering problem, we need to specify the desired features of a cluster, or a desired feature of the set of all clusters combined. Intuitively, we think of a cluster as a bunch of points which are all close to each other. We can measure this explicitly as follows. Let A be a fixed subset of the partition we’re interested in. Then we might want to optimize the sum of all of the distances of pairs of points within A to be a measure of it’s “clusterity.” In symbols, this would be

\displaystyle \sum_{x \neq y \in A} d(x, y)

If this quantity is small, then it says that all of the points in the cluster A are close to each other, and A is a good cluster. Of course, we want all clusters to be “good” simultaneously, so we’d want to minimize the sum of these sums over all subsets in the partition.

Note that if there are n points in A, then the above sum involves \choose{n}{2} \sim n^2 distance calculations, and so this could get quite inefficient with large data sets. One of the many alternatives is to pick a “center” for each of the clusters, and try to minimize the sum of the distances of each point in a cluster from its center. Using the same notation as above, this would be

\displaystyle \sum_{x \in A} d(x, c)

where c denotes the center of the cluster A. This only involves n distance calculations, and is perhaps a better measure of “clusterity.” Specifically, if we use the first option and one point in the cluster is far away from the others, we essentially record that single piece of information n - 1 times, whereas in the second we only record it once.

The method we will use to determine the center can be very general. We could use one of a variety of measures of center, like the arithmetic mean, or we could try to force one of the points in A to be considered the “center.” Fortunately, the arithmetic mean has the property that it minimizes the above sum for all possible choices of c. So we’ll stick with that for now.

And so the clustering problem is formalized.

Definition: Let (X,d) be a metric space with metric d, and let S \subset (X,d) be a finite subset. The centroid clustering problem is the problem of finding for any positive integer k a partition \left \{ A_1 ,\dots A_k \right \} of S so that the following quantity is minimized:

\displaystyle \sum_{i=1}^k\sum_{x \in A_i} d(x, c(A_i))

where c(A_i) denotes the center of a cluster, defined as the arithmetic mean of the points in A_i:

\displaystyle c(A) = \frac{1}{|A|} \sum_{x \in A} x

Before we continue, we have a confession to make: the centroid clustering problem is prohibitively difficult. In particular, it falls into a class of problems known as NP-hard problems. For the working programmer, NP-hard means that there is unlikely to be an exact solution to the problem which is better than trying all possible partitions.

We’ll touch more on this after we see some code, but the salient fact is that a heuristic algorithm is our best bet. That is, all of this preparation with partitions and squared deviations really won’t come into the algorithm design at all. Formalizing this particular problem in terms of sets and a function we want to optimize only allows us to rigorously prove it is difficult to solve exactly. And so, of course, we will develop a naive and intuitive heuristic algorithm to substitute for an exact solution, observing its quality in practice.

Lloyd’s Algorithm

The most common heuristic for the centroid clustering problem is Lloyd’s algorithm, more commonly known as the k-means clustering algorithm. It was named after its inventor Stuart Lloyd, a University of Chicago graduate and member of the Manhattan project who designed the algorithm in 1957 during his time at Bell Labs.

Heuristics tend to be on the simpler side, and Lloyd’s algorithm is no exception. We start by fixing a number of clusters k and choosing an arbitrary initial partition A = \left \{ A_1, \dots, A_k \right \}. The algorithm then proceeds as follows:

repeat:
   compute the arithmetic mean c[i] of each A[i]
   construct a new partition B:
      each subset B[i] is given a center c[i] computed from A
      x is assigned to the subset B[i] whose c[i] is closest
   stop if B is equal to the old partition A, else set A = B

Intuitively, we imagine the centers of the partitions being pulled toward the center of mass of the points in its currently assigned cluster, and then the points deciding selectively who to pull towards them. (Indeed, precisely because of this the algorithm may not always give sensible results, but more on that later.)

One who is in tune with their inner pseudocode will readily understand the above algorithm. But perhaps the simplest way to think about this algorithm is functionally. That is, we are constructing this partition-updating function f which accepts as input a partition of the data and produces as output a new partition as follows: first compute the mean of centers of the subsets in the old partition, and then create the new partition by gathering all the points closest to each center. These are the fourth and fifth lines of the pseudocode above.

Indeed, the rest of the pseudocode is merely pomp and scaffolding! What we are really after is a fixed point of the partition-updating function f. In other words, we want a partition P such that f(P) = P. We go about finding one in this algorithm by applying f to our initial partition A, and then recursively applying f to its own output until we no longer see a change.

Perhaps we should break away from traditional pseudocode illegibility and rewrite the algorithm as follows:

define updatePartition(A):
   let c[i] = center(A[i])
   return a new partition B:
      each B[i] is given the points which are closest to c[i]

compute a fixed point by recursively applying 
updatePartition to any initial partition.

Of course, the difference between these pseudocode snippets is just the difference between functional and imperative programming. Neither is superior, but the perspective of both is valuable in its own right.

And so we might as well implement Lloyd’s algorithm in two such languages! The first, weighing in at a whopping four lines, is our Mathematica implementation:

closest[x_, means_] :=
  means[[First[Ordering[Map[EuclideanDistance[x, #] &, means]]]]];

partition[points_, means_] := GatherBy[points, closest[#, means]&];
updatePartition[points_, old_] := partition[points, Map[Mean, old]];

kMeans[points_, initialMeans_] := FixedPoint[updatePartition[points, #]&, partition[points, initialMeans]];

While it’s a little bit messy (as nesting 5 function calls and currying by hand will inevitably be), the ideas are simple. The “closest” function computes the closest mean to a given point x. The “partition” function uses Mathematica’s built-in GatherBy function to partition the points by the closest mean; GatherBy[L, f] partitions its input list L by putting together all points which have the same value under f. The “updatePartition” function creates the new partition based on the centers of the old partition. And finally, the “kMeans” function uses Mathematica’s built-in FixedPoint function to iteratively apply updatePartition to the initial partition until there are no more changes in the output.

Indeed, this is as close as it gets to the “functional” pseudocode we had above. And applying it to some synthetic data (three randomly-sampled Gaussian clusters that are relatively far apart) gives a good clustering in a mere two iterations:

k-means-example

Indeed, we rarely see a large number of iterations, and we leave it as an exercise to the reader to test Lloyd’s algorithm on random noise to see just how bad it can get (remember, all of the code used in this post is available on this blog’s Google code page). One will likely see convergence on the order of tens of iterations. On the other hand, there are pathologically complicated sets of points (even in the plane) for which Lloyd’s algorithm takes exponentially long to converge to a fixed point. And even then, the solution is never guaranteed to be optimal. Indeed, having the possibility for terrible run time and a lack of convergence is one of the common features of heuristic algorithms; it is the trade-off we must make to overcome the infeasibility of NP-hard problems.

Our second implementation was in Python, and compared to the Mathematica implementation it looks like the lovechild of MUMPS and C++. Sparing the reader too many unnecessary details, here is the main function which loops the partition updating, a la the imperative pseudocode:

def kMeans(points, k, initialMeans, d=euclideanDistance):
   oldPartition = []
   newPartition = partition(points, k, initialMeans, d)

   while oldPartition != newPartition:
      oldPartition = newPartition
      newMeans = [mean(S) for S in oldPartition]
      newPartition = partition(points, k, newMeans, d)

   return newPartition

We added in the boilerplate functions for euclideanDistance, partition, and mean appropriately, and the reader is welcome to browse the source code for those.

Birth and Death Rates Clustering

To test our algorithm, let’s apply it to a small data set of real-world data. This data will consist of one data point for each country consisting of two features: birth rate and death rate, measured in annual number of births/deaths per 1,000 people in the population. Since the population is constantly changing, it is measured at some time in the middle of the year to act as a reasonable estimate to the median of all population values throughout the year.

The raw data comes directly from the CIA’s World Factbook data estimate for 2012. Formally, we’re collecting the “crude birth rate” and “crude death rate” of each country with known values for both (some minor self-governing principalities had unknown rates). The “crude rate” simply means that the data does not account for anything except pure numbers; there is no compensation for the age distribution and fertility rates. Of course, there are many many issues affecting the birth rate and death rate, but we don’t have the background nor the stamina to investigate their implications here. Indeed, part of the point of studying learning methods is that we want to extract useful information from the data without too much human intervention (in the form of ambient knowledge).

Here is a plot of the data with some interesting values labeled (click to enlarge):

countries-birth-deat-labeled

Specifically, we note that there is a distinct grouping of the data into two clusters (with a slanted line apparently separating the clusters). As a casual aside, it seems that the majority of the countries in the cluster on the right are countries with active conflict.

Applying Lloyd’s algorithm with k=2 to this data results in the following (not quite so good) partition:

countries-birth-death-unstandardized

Note how some of the points which we would expect to be in the “left” cluster are labeled as being in the right. This is unfortunate, but we’ve seen this issue before in our post on k-nearest-neighbors: the different axes are on different scales. That is, death rates just tend to vary more wildly than birth rates, and the two variables have different expected values.

Compensating for this is quite simple: we just need to standardize the data. That is, we need to replace each data point with its deviation from the mean (with respect to each coordinate) using the usual formula:

\displaystyle z = \frac{x - \mu}{\sigma}

where for a random variable X, its (sample) expected value is \mu and its (sample) standard deviation is \sigma. Doing this in Mathematica is quite easy:

Transpose[Map[Standardize, Transpose[L]]]

where L is a list containing our data. Re-running Lloyd’s algorithm on the standardized data gives a much better picture:

countries-birth-death-2cluster

Now the boundary separating one cluster from the other is in line with what our intuition dictates it should be.

Heuristics… The Air Tastes Bitter

We should note at this point that we really haven’t solved the centroid clustering problem yet. There is one glaring omission: the choice of k. This question is central to the problem of finding a good partition; a bad choice can yield bunk insights at best. Below we’ve calculated Lloyd’s algorithm for varying values of k again on the birth-rate data set.

Lloyd's algorithm processed on the birth-rate/death-rate data set with varying values of k between 2 and 7.

Lloyd’s algorithm processed on the birth-rate/death-rate data set with varying values of k between 2 and 7 (click to enlarge).

The problem of finding k has been addressed by many a researcher, and unfortunately the only methods to find a good value for k are heuristic in nature as well. In fact, many believe that to determine the correct value of k is a learning problem in of itself! We will try not to go into too much detail about parameter selection here, but needless to say it is an enormous topic.

And as we’ve already said, even if the correct choice of k is known, there is no guarantee that Lloyd’s algorithm (or any algorithm attempting to solve the centroid clustering problem) will converge to a global optimum solution. In the same fashion as our posts on cryptoanalysis and deck-stacking in Texas Hold ‘Em, the process of finding a minimum can converge to a local minimum.

Here is an example with four clusters, where each frame is a step, and the algorithm progresses from left to right (click to enlarge):

One way to alleviate the issues of local minima is the same here as in our other posts: simply start the algorithm over again from a different randomly chosen starting point. That is, as in our implementations above, our “initial means” are chosen uniformly at random from among the data set points. Alternatively, one may randomly partition the data (without respect to any center; each data point is assigned to one of the k clusters with probability 1/k). We encourage the reader to try both starting conditions as an exercise, and implement the repeated algorithm to return that output which minimizes the objective function (as detailed in the “Partitions and Squared Deviations” section).

And even if the algorithm will converge to a global minimum, it might not be the case that it does so efficiently. As we already mentioned, solving the problem of centroid clustering (even for a fixed k) is NP-hard. And so (assuming \textup{P} \neq \textup{NP}) any algorithm which converges to a global minimum will take exponentially long on some pathological inputs. The interested reader will see this exponentially slow convergence even in the case of k=2 for points in the plane (that is as simple as it gets).

These kinds of reasons make Lloyd’s algorithm and the centroid clustering problem a bit of a poster child of machine learning. In theory it’s difficult to solve exactly, but it has an efficient and widely employed heuristic used in practice which is often good enough. Moreover, since the exact solution is more or less hopeless, much of the focus has shifted to finding randomized algorithms which on average give solutions that are within some constant-factor approximation of the true minimum.

A Word on Expectation Maximization

This algorithm shares quite a bit of features with a very famous algorithm called the Expectation-Maximization algorithm. We plan to investigate this after we spend some more time on probability theory on this blog, but the (very rough) idea is that the algorithm operates in two steps. First, a measure of “center” is chosen for each of a number of statistical models based on given data. Then a maximization step occurs which chooses the optimal parameters for those statistical models, in the sense that the probability that the data was generated by statistical models with those parameters is maximized. These statistical models are then used as the “old” statistical models whose centers are computed in the next step.

Continuing the analogy with clustering, one feature of expectation-maximization that makes it nice is it allows the sizes of the “clusters” to have varying sizes, whereas Lloyd’s algorithm tends to make its clusters have equal size (as we saw with varying values of k in our birth-rates example above).

And so the ideas involved in this post are readily generalizable, and the applications extend to a variety of fields like image reconstruction, natural language processing, and computer vision. The reader who is interested in the full mathematical details can see this tutorial.

Until next time!

About these ads

Decision Trees and Political Party Classification

Last time we investigated the k-nearest-neighbors algorithm and the underlying idea that one can learn a classification rule by copying the known classification of nearby data points. This required that we view our data as sitting inside a metric space; that is, we imposed a kind of geometric structure on our data. One glaring problem is that there may be no reasonable way to do this. While we mentioned scaling issues and provided a number of possible metrics in our primer, a more common problem is that the data simply isn’t numeric.

For instance, a poll of US citizens might ask the respondent to select which of a number of issues he cares most about. There could be 50 choices, and there is no reasonable way to assign these numerical values so that all are equidistant in the resulting metric space.

Another issue is that the quality of the data could be bad. For instance, there may be missing values for some attributes (e.g., a respondent may neglect to answer one or more questions). Alternatively, the attributes or the classification label could be wrong; that is, the data could exhibit noise. For k-nearest-neighbors, missing an attribute means we can’t (or can’t accurately) compute the distance function. And noisy data can interfere with our choice of k. In particular, certain regions might be better with a smaller value of k, while regions with noisier data might require a larger k to achieve the same accuracy rate.  Without making the algorithm sufficiently more complicated to vary k when necessary, our classification accuracy will suffer.

In this post we’ll see how decision trees can alleviate these issues, and we’ll test the decision tree on an imperfect data set of congressional voting records. We’ll implement the algorithm in Python, and test it on the problem of predicting the US political party affiliation of members of Congress based on their votes for a number of resolutions. As usual, we post the entire code and data set on this blog’s Google code page.

Before going on, the reader is encouraged to read our primer on trees. We will assume familiarity with the terminology defined there.

Intuition

Imagine we have a data set where each record is a list of categorical weather conditions on a randomly selected number of days, and the labels correspond to whether a girl named Arya went for a horse ride on that day. Let’s also assume she would like to go for a ride every day, and the only thing that might prohibit her from doing so is adverse weather. In this case, the input variables will be the condition in the sky (sunny, cloudy, rainy, and snow), the temperature (cold, warm, and hot), the relative humidity (low, medium, and high), and the wind speed (low and high). The output variable will be whether Arya goes on a horse ride that day. Some entries in this data set might look like:

                 Arya's Riding Data
Sky     Temperature    Humidity    Wind    Horse Ride
Cloudy  Warm           Low         Low     Yes
Rainy   Cold           Medium      Low     No
Sunny   Warm           Medium      Low     Yes
Sunny   Hot            High        High    No
Snow    Cold           Low         High    No
Rainy   Warm           High        Low     Yes

In this case, one might reasonably guess that certain weather features are more important than others in determining whether Arya can go for a horse ride. For instance, the difference between sun and rain/snow should be a strong indicator, although it is not always correct in this data set. In other words, we’re looking for a weather feature that best separates the data into its respective classes. Of course, we’ll need a rigorous way to measure how good that separation is, but intuitively we can continue.

For example, we might split based on the wind speed feature. In this case, we have two smaller data sets corresponding to the entries where the wind is high and low. The corresponding table might look like:

     Arya's Riding Data, Wind = High
Sky     Temperature    Humidity    Horse Ride
Sunny   Hot            High        No
Snow    Cold           Low         No

     Arya's Riding Data, Wind = Low
Sky     Temperature    Humidity    Horse Ride
Cloudy  Warm           Low         Yes
Rainy   Cold           Medium      No
Sunny   Warm           Medium      Yes
Rainy   Warm           High        Yes

In this case, Arya is never known to ride a horse when the wind speed is high, and there is only one occasion when she doesn’t ride a horse and the wind speed is low. Taking this one step further, we can repeat the splitting process on the “Wind = Low” data in search of a complete split between the two output classes. We can see by visual inspection that the only “no” instance occurs when the temperature is cold. Hence, we should split on the temperature feature.

It is not useful to write out another set of tables (one feels the pain already when imagining a data set with a thousand entries), because in fact there is a better representation. The astute reader will have already recognized that our process of picking particular values for the weather features is just the process of traversing a tree.

Let’s investigate this idea closer. Imagine we have a tree where the root node corresponds to the Wind feature, and it has two edges connected to child nodes; one edge corresponds to the value “Low” and the other to “High.” That is, the process of traveling from the root to a child along an edge is the process of selecting only those data points whose “Wind” feature is that edge’s label. We can take the child corresponding to “Low” wind and have it represent the Temperature feature, further adding three child nodes with edges corresponding to the “Cold,” “Warm,” and “Hot” values.

We can stop this process once the choice of features completely splits our data set. Pictorially, our tree would look like this:

We reasonably decide to stop the traversal when all of the examples in the split are in the same class. More so, we would not want to include the option for the temperature to be Hot in the right subtree, because the data tells us nothing about such a scenario (as indicated by the “None” in the corresponding leaf).

Now that we have the data organized as a tree, we can try to classify new data with it. Suppose the new example is:

Sky     Temperature    Humidity    Wind    Horse Ride
Rainy   Cold           Low         Low     ?

We first inspect the wind speed feature, and seeing that it is “Low,” we follow the edge to the right subtree and repeat. Seeing that the temperature feature is “Cold,” we further descend down the “Cold” branch, reaching the “All No” leaf. Since this leaf corresponds to examples we’ve seen which are all in the “No” class, we should classify the new data as “No” as well.

Summarizing, given a new piece of data, we can traverse the tree according to the values of its features until we reach a leaf node. If the leaf node is “All No,” then we classify the new set of weather conditions as a “No,” and if it is “All Yes,” we classify as “Yes.”

Of course, this tree makes it clear that this toy data set is much too small to be useful, and the rules we’ve extrapolated from it are ridiculous. In particular, surely some people ride horses when the wind speed is high, and they would be unlikely to do so in a warm, low-wind thunderstorm. Nevertheless, we might expect a larger data set to yield a more sensible tree, as the data would more precisely reflect the true reasons one might refrain from riding a hose.

Before we generalize this example to any data set, we should note that there is yet another form for our classification rule. In particular, we can write the traversal from the root to the rightmost leaf as a boolean expression of the form:

\displaystyle \textup{Wind = ``Low''} \wedge \textup{Temp = ``Warm''}

An example will be classified as “Yes” if and only if the wind feature is “High” and the temperature feature is “Warm” (here the wedge symbol \wedge means “and,” and is called a conjunction). If we had multiple such routes leading to leaves labeled with “All Yes,” say a branch for wind being “High” and sky being “Sunny,” we could expand our expression as a disjunction (an “or,” denoted \vee) of the two corresponding conjunctions as follows:

 \displaystyle (\textup{Wind = ``Low''} \wedge \textup{Temp = ``Warm''}) \vee (\textup{Wind = ``High''} \wedge \textup{Sky = ``Sunny''})

In the parlance of formal logic, this kind of expression is called the disjunctive normal form, that is, a disjunction of conjunctions. It’s an easy exercise to prove that every propositional statement (in our case, using only and, or, and parentheses for grouping) can be put into disjunctive normal form. That is, any boolean condition that can be used to classify the data can be expressed in a disjunctive normal form, and hence as a decision tree.

Such a “boolean condition” is an example of a hypothesis, which is the formal term for the rule an algorithm uses to classify new data points. We call the set of all possible hypotheses expressible by our algorithm the hypothesis space of our algorithm. What we’ve just shown above is that decision trees have a large and well-defined hypothesis space. On the other hand, it is much more difficult to describe the hypothesis space for an algorithm like k-nearest-neighbors. This is one argument in favor of decision trees: they have a well-understood hypothesis space, and that makes them analytically tractable and interpretable.

Using Entropy to Find Optimal Splits

The real problem here is not in using a decision tree, but in constructing one from data alone. At any step in the process we outlined in the example above, we need to determine which feature is the right one to split the data on. That is, we need to choose the labels for the interior nodes in so that the resulting data subsets are as homogeneous as possible. In particular, it would be nice to have a quantitative way to measure the quality of a split. Then at each step we could simply choose the feature whose split yields the highest value under this measurement.

While we won’t derive such a measurement in this post, we will use one that has an extensive history of applications: Shannon entropy.

Definition: Let D be a discrete probability distribution (p_1, p_2, \dots, p_n). Then the Shannon entropy of D, denoted E(p_1, \dots, p_n) is

\displaystyle E(p_1, \dots , p_n) = - \sum_{i=0}^n p_i \log(p_i)

Where the logarithms are taken in base 2.

In English, there are n possible outcomes numbered 1 to n, and the probability that an instance drawn from D results in the outcome k is p_k. Then Shannon’s entropy function computes a numerical quantity describing how “dispersed” the outcomes are.

While there are many other useful interpretations of Shannon entropy, we only need it to describe how well the data is split into its classes. For our purposes, the probability distribution will simply be the observed proportions of data with respect to their class labels. In the case of Arya’s horse riding, the initial distribution would be (1/2, 1/2), giving an entropy of 1.

Let’s verify that Shannon’s entropy function makes sense for our problem. Specifically, the best scenario for splitting the data on a feature is a perfect split; that is, each subset only has data from one class. On the other hand, the worst case would be where each subset is uniformly distributed across all classes (if there are n classes, then each subset has 1/n of its data from each class).

Indeed, if we adopt the convention that \log(0) = 0, then the entropy of (1,0, \dots, 0) consists of a single term -1 \log(1) = 0. It is clear that this does not depend on the position of the 1 within the probability distribution. On the other hand, the entropy of (1/n, \dots, 1/n) is

\displaystyle -\sum_{i=1}^n\frac{1}{n}\log \left (\frac{1}{n} \right ) = -\log \left (\frac{1}{n} \right ) = -(0 - \log(n)) = \log(n)

A well-known property of the entropy function tells us that this is in fact the maximum value for this function.

Summarizing this, in the best case entropy is minimized after the split, and in the worst case entropy is maximized. But we can’t simply look at the entropy of each subset after splitting. We need a sensible way to combine these entropies and to compare them with the entropy of the data before splitting. In particular, we would quantify the “decrease” in entropy caused by a split, and maximize that quantity.

Definition: Let S be a data set and A a feature with values v \in V, and let E denote Shannon’s entropy function. Moreover, let S_v denote the subset of S for which the feature A has the value v. The gain of a split along the feature A, denoted G(S,A) is

\displaystyle G(S,A) = E(S) - \sum_{v \in V} \frac{|S_v|}{|S|} E(S_v)

That is, we are taking the difference of the entropy before the split, and subtracting off the entropies of each part after splitting, with an appropriate weight depending on the size of each piece. Indeed, if the entropy grows after the split (that is if the data becomes more mixed), then this number will be small. On the other hand if the split separates the classes nicely, each subset S_v will have small entropy, and hence the value will be large.

It requires a bit of mathematical tinkering to be completely comfortable that this function actually does what we want it to (for instance, it is not obvious that this function is non-negative; does it make sense to have a negative gain?). We won’t tarry in those details (this author has spent at least a day or two ironing them out), but we can rest assured that this function has been studied extensively, and nothing unexpected happens.

So now the algorithm for building trees is apparent: at each stage, simply pick the feature for which the gain function is maximized, and split the data on that feature. Create a child node for each of the subsets in the split, and connect them via edges with labels corresponding to the chosen feature value for that piece.

This algorithm is classically called ID3, and we’ll implement it in the next section.

Building a Decision Tree in Python

As with our primer on trees, we can use a quite simple data structure to represent the tree, but here we need a few extra pieces of data associated with each node.

class Tree:
   def __init__(self, parent=None):
      self.parent = parent
      self.children = []
      self.splitFeature = None
      self.splitFeatureValue = None
      self.label = None

In particular, now that features can have more than two possible values, we need to allow for an arbitrarily long list of child nodes. In addition, we add three pieces of data (with default values None): the splitFeature is the feature for which each of its children assumes a separate value; the splitFeatureValue is the feature assumed for its parent’s split; and the label (which is None for all interior nodes) is the final classification label for a leaf.

We also need to nail down our representations for the data. In particular, we will represent a data set as a list of pairs of the form (point, label), where the point is itself a list of the feature values, and the label is a string.

Now given a data set the first thing we need to do is compute its entropy. For that we can first convert it to a distribution (in the sense defined above, a list of probabilities which sum to 1):

def dataToDistribution(data):
   ''' Turn a dataset which has n possible classification labels into a
       probability distribution with n entries. '''
   allLabels = [label for (point, label) in data]
   numEntries = len(allLabels)
   possibleLabels = set(allLabels)

   dist = []
   for aLabel in possibleLabels:
      dist.append(float(allLabels.count(aLabel)) / numEntries)

   return dist

And we can compute the entropy of such a distribution in the obvious way:

def entropy(dist):
   ''' Compute the Shannon entropy of the given probability distribution. '''
   return -sum([p * math.log(p, 2) for p in dist])

Now in order to compute the gain of a data set by splitting on a particular value, we need to be able to split the data set. To do this, we identify features with their index in the list of feature values of a given data point, enumerate all possible values of that feature, and generate the needed subsets one at a time. In particular, we use a Python generator object:

def splitData(data, featureIndex):
   ''' Iterate over the subsets of data corresponding to each value
       of the feature at the index featureIndex. '''

   # get possible values of the given feature
   attrValues = [point[featureIndex] for (point, label) in data]

   for aValue in set(attrValues):
      dataSubset = [(point, label) for (point, label) in data
                    if point[featureIndex] == aValue]

      yield dataSubset

So to compute the gain, we simply need to iterate over the set of all splits, and compute the entropy of each split. In code:

def gain(data, featureIndex):
   ''' Compute the expected gain from splitting the data along all possible
       values of feature. '''

   entropyGain = entropy(dataToDistribution(data))

   for dataSubset in splitData(data, featureIndex):
      entropyGain -= entropy(dataToDistribution(dataSubset))

   return entropyGain

Of course, the best split (represented as the best feature to split on) is given by such a line of code as:

bestFeature = max(range(n), key=lambda index: gain(data, index))

We can’t quite use this line exactly though, because while we’re building up the decision tree (which will of course be a recursive process) we need to keep track of which features have been split on previously and which have not; this data is different for each possible traversal of the tree. In the end, our function to build a decision tree requires three pieces of data: the current subset of the data to investigate, the root of the current subtree that we are in the process of building, and the set of features we have yet to split on.

Of course, this raises the obvious question about the base cases. One base case will be when we run out of data to split; that is, when our input data all have the same classification label. To check for this we implement a function called “homogeneous”

def homogeneous(data):
   ''' Return True if the data have the same label, and False otherwise. '''
   return len(set([label for (point, label) in data])) <= 1

The other base case is when we run out of good features to split on. Of course, if the true classification function is actually a decision tree then this won’t be the case. But now that we’re in the real world, we can imagine there may be two data points with identical features but different classes. Perhaps the simplest way to remedy this situation is to terminate the tree at that point (when we run out of features to split on, or no split gives positive gain), and use a simple majority vote to label the new leaf. In a sense, this strategy is a sort of nearest-neighbors vote as a default. To implement this, we have a function which simply patches up the leaf appropriately:

def majorityVote(data, node):
   ''' Label node with the majority of the class labels in the given data set. '''
   labels = [label for (pt, label) in data]
   choice = max(set(labels), key=labels.count)
   node.label = choice
   return node

The base cases show up rather plainly in the code to follow, so let us instead focus on the inductive step. We declare our function to accept the data set in question, the root of the subtree to be built, and a list of the remaining allowable features to split on. The function begins with:

def buildDecisionTree(data, root, remainingFeatures):
   ''' Build a decision tree from the given data, appending the children
       to the given root node (which may be the root of a subtree). '''

   if homogeneous(data):
      root.label = data[0][1]
      return root

   if len(remainingFeatures) == 0:
      return majorityVote(data, root)

   # find the index of the best feature to split on
   bestFeature = max(remainingFeatures, key=lambda index: gain(data, index))

   if gain(data, bestFeature) == 0:
      return majorityVote(data, root)

   root.splitFeature = bestFeature

Here we see the base cases, and the selection of the best feature to split on. As a side remark, we observe this is not the most efficient implementation. We admittedly call the gain function and splitData functions more often than necessary, but we feel what is lost in runtime speed is gained in code legibility.

Once we bypass the three base cases, and we have determined the right split, we just do it:

def buildDecisionTree(data, root, remainingFeatures):
   ''' Build a decision tree from the given data, appending the children
       to the given root node (which may be the root of a subtree). '''

   ...

   # add child nodes and process recursively
   for dataSubset in splitData(data, bestFeature):
      aChild = Tree(parent=root)
      aChild.splitFeatureValue = dataSubset[0][0][bestFeature]
      root.children.append(aChild)

      buildDecisionTree(dataSubset, aChild, remainingFeatures - set([bestFeature]))

   return root

Here we iterate over the subsets of data after the split, and create a child node for each. We then assign the child its corresponding feature value in the splitFeatureValue variable, and append the child to the root’s list of children. Next is where we first see the remainingFeatures set come into play, and in particular we note the overloaded minus sign as an operation on sets. This is a feature of python sets, and in particular it behaves exactly like set exclusion in mathematics. The astute programmer will note that the minus operation generates a new set, so that further recursive calls to buildDecisionTree will not be affected by what happens in this recursive call.

Now the first call to this function requires some initial parameter setup, so we define a convenience function that only requires a single argument: the data.

def decisionTree(data):
   return buildDecisionTree(data, Tree(), set(range(len(data[0][0]))))

Classifying New Data

The last piece of the puzzle is to classify a new piece of data once we’ve constructed the decision tree. This is a considerably simpler recursive process. If the current node is a leaf, output its label. Otherwise, recursively search the subtree (the child of the current node) whose splitFeatureValue matches the new data’s choice of the feature being split. In code,

def classify(tree, point):
   ''' Classify a data point by traversing the given decision tree. '''

   if tree.children == []:
      return tree.label
   else:
      matchingChildren = [child for child in tree.children
         if child.splitFeatureValue == point[tree.splitFeature]]

      return classify(matchingChildren[0], point)

And we can use this function to naturally test a dataset:

def testClassification(data, tree):
   actualLabels = [label for point, label in data]
   predictedLabels = [classify(tree, point) for point, label in data]

   correctLabels = [(1 if a == b else 0) for a,b in zip(actualLabels, predictedLabels)]
   return float(sum(correctLabels)) / len(actualLabels)

But now we run into the issue of noisy data. What if one wants to classify a point where one of the feature values which is used in the tree is unknown? One can take many approaches to remedy this, and we choose a simple one: simply search both routes, and use a majority vote when reaching a leaf. This requires us to add one additional piece of information to the leaf nodes: the total number of labels in each class used to build that leaf (recall, one of our stopping conditions resulted in a leaf having heterogeneous data). We omit the details here, but the reader is invited to read them on this blog’s Google code page, where as usual we have provided all of the code used in this post.

Classifying Political Parties Based on Congressional Votes

We now move to a concrete application of decision trees. The data set we will work with comes from the UCI machine learning repository, and it records the votes cast by the US House of Representatives during a particular session of Congress in 1984. The data set has 16 features; that is, there were 16 different measures considered “key” measures that were vote upon during this session. So each point in the dataset represents the 16 votes of a single House member in that session. With a bit of reformatting, we provide the complete data set on this blog’s Google code page.

Our goal is to learn party membership based on the voting records. This data set is rife with missing values; roughly half of the members abstained from voting on some of these measures. So we constructed a decision tree from the clean portion of the data, and use that to classify the remainder of the data.

Indeed, this data fits precisely into the algorithm we designed above. The code to construct a tree is almost trivial:

   with open('house-votes-1984.txt', 'r') as inputFile:
      lines = inputFile.readlines()

   data = [line.strip().split(',') for line in lines]
   data = [(x[1:], x[0]) for x in data]

   cleanData = [x for x in data if '?' not in x[0]]
   noisyData = [x for x in data if '?' in x[0]]

   tree = decisionTree(cleanData)

Indeed, the classification accuracy in doing this is around 90%. In addition (though we will revisit the concept of overfitting later), this is stable despite variation in the size of the subset of data used to build the tree. The graph below shows this.

The size of the subset used to construct the tree versus its accuracy in classifying the remainder of the data. Note that the subsets were chosen uniformly at random without replacement. The x-axis is the number of points used to construct the tree, and the y-axis is the proportion of labels correctly classified.

Inspecting the trees generated in this process, it appears that the most prominent feature to split on is the adoption of a new budget resolution. Very few Demorats voted in favor of this, so for many of the random subsets of the data, a split on this feature left one side homogeneously Republican.

Overfitting, Pruning, and Other Issues

Now there are some obvious shortcomings to the method in general. If the data set used to build the decision tree is enormous (in dimension or in number of points), then the resulting decision tree can be arbitrarily large in size. In addition, there is the pitfall of overfitting to this particular data set. For the party classification problem above, the point is to extend the classification to any population of people who might vote on these issues (or, more narrowly, to any politician who might vote on these issues). If we make our decision tree very large, then the hypothesis may be overly specific to the people in the sample used, and hence will not generalize well.

This problem is called overfitting to the data, and it’s a prevalent concern among all machine learning algorithms. There are a number of ways to avoid it for decision trees. Perhaps the most common is the idea of pruning: one temporarily removes all possible proper subtrees and reevaluates the classification accuracy for that removal. Whichever subtree results in the greatest increase in accuracy is actually removed, and it is replaced with a single leaf whose label corresponds to the majority label of the data points used to create the entire subtree. This process is then repeated until there are no possible improvements, or the gain is sufficiently small.

From a statistical point of view one could say this process is that of ignoring outliers. Any points which do not generally agree with the whole trend of the data set (hence, create their own branches in the decision tree) are simply removed. From a theoretical point of a view, a smaller decision tree satisfies the principle of Occam’s razor: a simpler hypothesis is more accurate by virtue of being simple.

While we won’t implement a pruning method here (indeed, we didn’t detect any overfitting in the congressional voting example), but it would be a nice exercise for the reader to wet his feet with the code given above. Finally, there are other algorithms to build decision trees that we haven’t mentioned here. You can see a list of such algorithms on the relevant wikipedia page. Because of the success of ID3, there is a large body of research on such algorithms.

In any event, next time we’ll investigate yet another machine learning method: that of neural networks. We’ll also start to look at more general frameworks for computational learning theory. That is, we’ll exercise the full might of theoretical mathematics to reason about how hard certain problems are to learn (or whether they can be learned at all).

Until then!

Trees – A Primer

This post comes in preparation for a post on decision trees (a specific type of tree used for classification in machine learning). While most mathematicians and programmers are familiar with trees, we have yet to discuss them on this blog. For completeness, we’ll give a brief overview of the terminology and constructions associated with trees, and describe a few common algorithms on trees. We will assume the reader has read our first primer on graph theory, which is a light assumption. Furthermore, we will use the terms node and vertex interchangeably, as mathematicians use the latter and computer scientists the former.

Definitions

Mathematically, a tree can be described in a very simple way.

Definition: A path (v_1, e_1, v_2, e_2, \dots, v_n) in a graph G is called a cycle if v_1 = v_n. Here we assume no edge is repeated in a path (we use the term trail for a path which allows repeated edges).

Definition: A graph G is called connected if every pair of vertices has a path between them. Otherwise it is called disconnected.

Definition: A connected graph G is called a tree if it has no cycles. Equivalently, G is a tree if for any two vertices v,w there is a unique path connecting them.

The image at the beginning of this post gives an example of a simple tree. Although the edges need not be directed (as implied by the arrows on the edges), there is usually a sort of hierarchy associated with trees. One vertex is usually singled out as the root vertex, and the choice of a root depends on the problem. Below are three examples of trees, each drawn in a different perspective. People who work with trees like to joke that trees are supposed to grow upwards from the root, but in mathematics they’re usually drawn with the root on top.

We call a tree with a distinguished root vertex a rooted tree, and we denote it (T,r), where T is the tree and r is the root. The important thing about the hierarchy is that it breaks the tree into discrete “levels” of depth. That is, we call the depth of a vertex v the length of the shortest path from the root r to v. As you can see in the rightmost tree in the above picture, we often draw a tree so that its vertices are horizontally aligned by their depth. Continuing with nature-inspired names, the vertices at the bottom of the tree (more rigorously, vertices of degree 1) are called leaves. A vertex which is neither a leaf nor the root is called an internal node. Extending the metaphor to family trees, given a vertex v of depth n, the adjacent vertices of depth $late n+1$ (if there are any) are called the child nodes (or children) of v. Similarly, v is called the parent node of its children. Extrapolating, any node on the path from v to the root r is an ancestor of v, and v is a descendant of each of them.

As a side note, all of this naming is simply a fancy way of imposing a partial ordering on the vertices of a tree, in that the vertex $v \leq w$ if v is on the path from r to w. In this case, a chain in this partial order is simply a traversal down the tree from some stopping vertex. All of the names simply make this easier to talk about in English: v \leq w if and only if v is an ancestor of w. Of course, there are also useful total orderings on a tree, and we will describe some later in this post.

In applications, there is usually some data associated with the vertices and edges of a tree. For example, in our future post on decision trees, the vertices will represent attributes of the data, and the edges will represent particular values for those attributes. A traversal down the tree from root to a leaf will correspond to an evaluation of the classification function. The meat of the discussion will revolve around how to construct a sensible tree.

The important thing about depth in trees is that, given sufficient bounds on the degree of each vertex, the depth of a tree which is not egregiously unbalanced is logarithmic in the number of leaves. In fact, most trees in practice will satisfy this. Perhaps the most common kind is a so-called binary tree, in which each internal node has degree at most 3 (two children, one parent). To see that this satisfies the logarithmic claim, simply count nodes by depth: the k-th level of the tree can have at most 2^k vertices. And so if all of the levels are filled (the tree is not “unbalanced”) and the tree has depth n, then the number of nodes in the tree is \sum_{i=0}^n 2^i = 2^{n+1} - 1. Taking a logarithm recovers a term that is linear in n, and the same argument holds if we can fix a global bound on the degree of each internal node. The rightmost picture in the image above gives an example of a complete binary tree of 15 nodes.

In other words, if one can model their data in a binary tree, then searching through the data takes logarithmic time in the number of data points! For those readers unfamiliar with complexity theory, that is wicked fast. To put things into perspective, it’s commonly estimated that there are less than a billion websites on the internet. If one could search through all of these in logarithmic time, it would take roughly 30 steps to find the right site (and that’s using a base of 2; in base 10 it would take 9 steps).

As a result, much work has been invested in algorithms to construct and work with trees. Indeed the crux of many algorithms is simply in translating a problem into a tree. These data structures pop up in nearly every computational field in existence, from operating systems to artificial intelligence and many many more.

Representing a Tree in a Computer

The remainder of this post will be spent designing a tree data structure in Python and writing a few basic algorithms on it. We’re lucky to have chosen Python in that the class representation of a tree is particularly simple. The central compound data type will be called “Node,” and it will have three associated parts:

  1. A list of child nodes, or an empty list if there are none.
  2. A  parent node, or “None” if the node is the root.
  3. Some data associated with the node.

In many strongly-typed languages (like Java), one would need to be much more specific in number 3. That is, one would need to construct a special Tree class for each kind of data associated with a node, or use some clever polymorphism or template programming (in Java lingo, generics), but the end result is often still multiple versions of one class.

In Python we’re lucky, because we can add or remove data from any instance of any class on the fly. So, for instance, we could have our leaf nodes use different internal data as our internal nodes, or have our root contain additional information. In any case, Python will have the smallest amount of code while still being readable, so we think it’s a fine choice.

The node class is simply:

class Node:
   def __init__(self):
      self.parent = None
      self.children = []

That’s it! In particular, we will set up all of the adjacencies between nodes after initializing them, so we don’t need to put anything else in the constructor.

Here’s an example of using the class:

root = Node()
root.value = 10

leftChild = Node()
leftChild.value = 5

rightChild = Node()
rightChild.value = 12

root.children.append(leftChild)
root.children.append(rightChild)
leftChild.parent = root
rightChild.parent = root

We should note that even though we called the variables “leftChild” and “rightChild,” there is no distinguishing from left and right in this data structure; there is just a list of children. While in some applications the left child and right child have concrete meaning (e.g. in a binary search tree where the left subtree represents values that are less than the current node, and the right subtree is filled with larger elements), in our application to decision trees there is no need to order the children.

But for the examples we are about to give, we require a binary structure. To make this structure more obvious, we’ll ugly the code up a little bit as follows:

class Node:
   def __init__(self):
      self.parent = None
      self.leftChild = None
      self.rightChild = None

In-order, Pre-order, and Post-order Traversals

Now we’ll explore a simple class of algorithms that traverses a tree in a specified order. By “traverse,” we simply mean that it visits each vertex in turn, and performs some pre-specified action on the data associated with each. Those familiar with our post on functional programming can think of these as extensions of the “map” function to operate on trees instead of lists. As we foreshadowed earlier, these represent total orders on the set of nodes of a tree, and in particular they stand out by how they reflect the recursive structure of a tree.

The first is called an in-order traversal, and it is perhaps the most natural way to traverse a tree. The idea is to hit the leaves in left-to-right order as per the usual way to draw a tree, ignoring depth. It generalizes easily from a tree with only three nodes: first you visit the left child, then you visit the root, then you visit the right child. Now instead of using the word “child,” we simply say “subtree.” That is, first you recursively process the left subtree, then you process the current node, then you recursively process the right subtree. This translates easily enough into code:

def inorder(root, f):
   ''' traverse the tree "root" in-order calling f on the 
       associated node (i.e. f knows the name of the field to 
       access). '''
   if root.leftChild != None:
      inorder(root.leftChild, f)

   f(root)

   if root.rightChild != None:
      inorder(root.rightChild, f)

For instance, suppose we have a tree consisting of integers. Then we can use this function to check if the tree is a binary search tree. That is, we can check to see if the left subtree only contains elements smaller than the root, and if the right subtree only contains elements larger than the root.

 def isBinarySearchTree(root):
   numbers = []
   f = lambda node: numbers.append(node.value)

   inorder(root, f)

   for i in range(1, len(numbers)):
      if numbers[i-1] > numbers[i]:
         return False

   return True

As expected, this takes linear time in the number of nodes in the tree.

The next two examples are essentially the same as in-order; they are just a permutation of the lines of code of the in-order function given above. The first is pre-order, and it simply evaluates the root before either subtree:

def preorder(root, f):
   f(root)
   if root.leftChild != None:
      preorder(root.leftChild, f)

   if root.rightChild != None:
      preorder(root.rightChild, f)

And post-order, which evaluates the root after both subtrees:

def postorder(root, f):
   if root.leftChild != None:
      postorder(root.leftChild, f)

   if root.rightChild != None:
      postorder(root.rightChild, f)

   f(root)

Pre-order does have some nice applications. The first example requires us to have an arithmetical expression represented in a tree:

root = Node()
root.value = '*'

n1 = Node()
n1.value = '1'
n2 = Node()
n2.value = '3'
n3 = Node()
n3.value = '+'
n4 = Node()
n4.value = '3'
n5 = Node()
n5.value = '4'
n6 = Node()
n6.value = '-'

root.leftChild = n3
root.rightChild = n6
n3.leftChild = n1
n3.rightChild = n2
n6.leftChild = n4
n6.rightChild = n5

This is just the expression (1+3)*(3-4), and the tree structure specifies where the parentheses go. Using pre-order traversal in the exact same way we used in-order, we can convert this representation to another common one: Polish notation.

def polish(exprTree):
   exprString = []
   f = lambda node: exprString.append(node.value)

   preorder(exprTree, f)
   return ''.join(exprString)

One could also use a very similar function to create a copy of a binary tree, as one needs to have the root before one can attach any children, and this rule applies recursively to each subtree.

On the other hand, post-order traversal can represent mathematical expressions in post-fix notation (reverse-polish notation), and it can be useful for deleting a tree. This would come up if, say, each node had some specific cleanup actions required before it could be deleted, or alternatively if one is working with a dynamic memory allocation (e.g. in C) and must explicitly “free” each node to clear up memory.

So now we’ve seen a few examples of trees and mentioned how they can be represented in a program. Next time we’ll derive and implement a meatier application of trees in the context of machine learning, and in future primers we’ll cover minimum spanning trees and graph searching algorithms.

Until then!

Dynamic Time Warping for Sequence Comparison

Problem: Write a program that compares two sequences of differing lengths for similarity.

Solution: (In Python)

import math

def dynamicTimeWarp(seqA, seqB, d = lambda x,y: abs(x-y)):
    # create the cost matrix
    numRows, numCols = len(seqA), len(seqB)
    cost = [[0 for _ in range(numCols)] for _ in range(numRows)]

    # initialize the first row and column
    cost[0][0] = d(seqA[0], seqB[0])
    for i in xrange(1, numRows):
        cost[i][0] = cost[i-1][0] + d(seqA[i], seqB[0])

    for j in xrange(1, numCols):
        cost[0][j] = cost[0][j-1] + d(seqA[0], seqB[j])

    # fill in the rest of the matrix
    for i in xrange(1, numRows):
        for j in xrange(1, numCols):
            choices = cost[i-1][j], cost[i][j-1], cost[i-1][j-1]
            cost[i][j] = min(choices) + d(seqA[i], seqB[j])

    for row in cost:
       for entry in row:
          print "%03d" % entry,
       print ""
    return cost[-1][-1]

DiscussionComparing sequences of numbers can be tricky business. The simplest way to do so is simply component-wise. However, this will often disregard more abstract features of a sequence that we intuitively understand as “shape.”

For example, consider the following two sequences.

0 0 0 3 6 13 25 22 7 2 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4 5 12 24 23 8 3 1 0 0 0 0 0

They both have the same characteristics — a bump of height roughly 25 and length 8 — but comparing the two sequences entrywise would imply they are not similar. According to the standard Euclidean norm, they are 52 units apart. For motivation, according to the dynamic time warping function above, they are a mere 7 units apart. Indeed, if the two bumps consisted of the same numbers, the dynamic time warp distance between the entire sequences would be zero.

These kinds of sequences show up in many applications. For instance, in speech recognition software one often has many samples of the a single person speaking, but there is a difference in the instant during the sample at which the person begins speaking. Similarly, the rate at which the person speaks may be slightly different. In either event, we want the computed “distance” between two such samples to be small. That is, we’re looking for a measurement which is time-insensitive both in scaling and in position. Historically, the dynamic time warping solution was designed in 1978 to solve this problem.

The idea is similar to the Levenshtein metric we implemented for “edit distance” measurements in our Metrics on Words post. Instead of inserting or deleting sequence elements, we may optionally “pause” our usual elementwise comparison on one sequence while continuing on the other. The trick is in deciding whether to “pause,” and when to switch the “pausing” from one sequence to the other; this will require a dynamic program (as did the Levenshtein metric).

In order to compare two sequences, one needs to have some notion of a local comparison. That is, we need to be able to compare any entry from one sequence to any entry in the other. While there are many such options that depend on the data type used in the sequence, we will use the following simple numeric metric:

\displaystyle d(x,y) = |x-y|

This is just the Euclidean metric in \mathbb{R}. More generally, this assumes that two numbers are similar when they are close together (and this is in fact an important assumption; not all number systems are like this).

Now given two sequences a_i, b_j, we can compare them by comparing their local distance for a specially chosen set of indices given by m_k for a_i and n_k for b_j. That is, the dynamic time warping distance will end up being the quantity:

\displaystyle C(a_i, b_j) = \sum_{k=0}^{M} d(a_{m_k}, b_{n_k})

Of course, we should constrain the indices m_k, n_k so that the result is reasonable. A good way to do that is to describe the conditions we want it to satisfy, and then figure out how to compute such indices. In particular, let us assume that a_i has length M, b_j has length N. Then we have the following definition.

Definition: A warping path for a_i, b_j is a pair of sequences (m_k, n_k), both of some length L, satisfying the following conditions:

  1. 1 \leq m_k \leq M and 1 \leq n_k \leq N for all k.
  2. The sequences have endpoints (m_1, n_1) = (1,1), (m_L, n_L) = (M, N)
  3. The seqences m_k, n_k are monotonically increasing.
  4. (m_k - m_{k-1}, n_k - n_{k-1}) must be one of (1,0), (0,1), (1,1).

The first condition is obvious, or else the m_k, n_k could not be indexing a_i, b_j. The second condition ensures that we use all of the information in both sequences in our comparison. The third condition implies that we cannot “step backward” in time as we compare local sequence entries. We wonder offhand if anything interesting could be salvaged from the mess that would ensue if one left this condition out. And the fourth condition allows the index of one sequence to be stopped while the other continues. This condition creates the “time warping” effect, that some parts of the sequence can be squeezed or stretched in comparison with the other sequence.

We also note as an aside that the fourth condition listed implies the third.

Of course there are many valid warping paths for any two sequences, but we need to pick one which has our desired feature. That is, it minimizes the sum of the local comparisons (the formula for C(a_i, b_j) above). We denote this optimal value as DTW(a_i, b_j).

The fourth condition in the list above should give it away that to compute the optimal path requires a dynamic program. Specifically, the optimal path can be computed by solving the three sub-problems of finding the optimal warping path for

\displaystyle (a_{1 \dots M-1}, b_{1 \dots N}), (a_{1 \dots M}, b_{1 \dots N-1}), \textup{ and } (a_{1 \dots M-1}, b_{1 \dots N-1})

The clueless reader should refer to this blog’s primer on Python and dynamic programming. In any event, the program implementing this dynamic program is given in the solution above.

A visualization of the dynamic time warp cost matrix for two sequences. The algorithm attempts to find the least expensive path from the bottom left to the top right, where the darker regions correspond to low local cost, and the lighter regions to high local cost. The arrows point in the forward direction along each sequence, showing the monotonicity requirement of an optimal warping path.

The applications of this technique certainly go beyond speech recognition. Dynamic time warping can essentially be used to compare any data which can be represented as one-dimensional sequences. This includes video, graphics, financial data, and plenty of others.

We may also play around with which metric is used in the algorithm. When the elements of the lists are themselves points in Euclidean space, you can swap out the standard Euclidean metric with metrics like the Manhattan metric, the maximum metric, the discrete metric, or your mother’s favorite L^p norm.

While we use a metric for elementwise comparisons in the algorithm above, the reader must note that the dynamic time warping distance is not a metricIn fact, it’s quite far from a metric. The casual reader can easily come up with an example of two non-identical sequences x, y for which DTW(x,y) = 0, hence denying positive-definiteness. The more intrepid reader will come up with three sequences which give a counterexample to satisfy the triangle inequality (hint: using the discrete metric as the local comparison metric makes things easier).

In fact, the failure to satisfy positive-definiteness and the triangle inequality means that the dynamic time warping “distance” is not even a semimetric. To be completely pedantic, it would fit in as a symmetric premetric. Unfortunately, this means that we don’t get the benefits of geometry or topology to analyze dynamic time warping as a characteristic feature of the space of all numeric sequences. In any event, it’s still proven to be useful in applications, so it belongs here in the program gallery.