Earthmover Distance

Problem: Compute distance between points with uncertain locations (given by samples, or differing observations, or clusters).

For example, if I have the following three “points” in the plane, as indicated by their colors, which is closer, blue to green, or blue to red?

example-points.png

It’s not obvious, and there are multiple factors at work: the red points have fewer samples, but we can be more certain about the position; the blue points are less certain, but the closest non-blue point to a blue point is green; and the green points are equally plausibly “close to red” and “close to blue.” The centers of masses of the three sample sets are close to an equilateral triangle. In our example the “points” don’t overlap, but of course they could. And in particular, there should probably be a nonzero distance between two points whose sample sets have the same center of mass, as below. The distance quantifies the uncertainty.

same-centers.png

All this is to say that it’s not obvious how to define a distance measure that is consistent with perceptual ideas of what geometry and distance should be.

Solution (Earthmover distance): Treat each sample set A corresponding to a “point” as a discrete probability distribution, so that each sample x \ in A has probability mass p_x = 1 / |A|. The distance between A and B is the optional solution to the following linear program.

Each x \in A corresponds to a pile of dirt of height p_x, and each y \in B corresponds to a hole of depth p_y. The cost of moving a unit of dirt from x to y is the Euclidean distance d(x, y) between the points (or whatever hipster metric you want to use).

Let z_{x, y} be a real variable corresponding to an amount of dirt to move from x \in A to y \in B, with cost d(x, y). Then the constraints are:

  • Each z_{x, y} \geq 0, so dirt only moves from x to y.
  • Every pile x \in A must vanish, i.e. for each fixed x \in A, \sum_{y \in B} z_{x,y} = p_x.
  • Likewise, every hole y \in B must be completely filled, i.e. \sum_{y \in B} z_{x,y} = p_y.

The objective is to minimize the cost of doing this: \sum_{x, y \in A \times B} d(x, y) z_{x, y}.

In python, using the ortools library (and leaving out a few docstrings and standard import statements, full code on Github):

from ortools.linear_solver import pywraplp

def earthmover_distance(p1, p2):
    dist1 = {x: count / len(p1) for (x, count) in Counter(p1).items()}
    dist2 = {x: count / len(p2) for (x, count) in Counter(p2).items()}
    solver = pywraplp.Solver('earthmover_distance', pywraplp.Solver.GLOP_LINEAR_PROGRAMMING)

    variables = dict()

    # for each pile in dist1, the constraint that says all the dirt must leave this pile
    dirt_leaving_constraints = defaultdict(lambda: 0)

    # for each hole in dist2, the constraint that says this hole must be filled
    dirt_filling_constraints = defaultdict(lambda: 0)

    # the objective
    objective = solver.Objective()
    objective.SetMinimization()

    for (x, dirt_at_x) in dist1.items():
        for (y, capacity_of_y) in dist2.items():
            amount_to_move_x_y = solver.NumVar(0, solver.infinity(), 'z_{%s, %s}' % (x, y))
            variables[(x, y)] = amount_to_move_x_y
            dirt_leaving_constraints[x] += amount_to_move_x_y
            dirt_filling_constraints[y] += amount_to_move_x_y
            objective.SetCoefficient(amount_to_move_x_y, euclidean_distance(x, y))

    for x, linear_combination in dirt_leaving_constraints.items():
        solver.Add(linear_combination == dist1[x])

    for y, linear_combination in dirt_filling_constraints.items():
        solver.Add(linear_combination == dist2[y])

    status = solver.Solve()
    if status not in [solver.OPTIMAL, solver.FEASIBLE]:
        raise Exception('Unable to find feasible solution')

    return objective.Value()

Discussion: I’ve heard about this metric many times as a way to compare probability distributions. For example, it shows up in an influential paper about fairness in machine learning, and a few other CS theory papers related to distribution testing.

One might ask: why not use other measures of dissimilarity for probability distributions (Chi-squared statistic, Kullback-Leibler divergence, etc.)? One answer is that these other measures only give useful information for pairs of distributions with the same support. An example from a talk of Justin Solomon succinctly clarifies what Earthmover distance achieves

Screen Shot 2018-03-03 at 6.11.00 PM.png

Also, why not just model the samples using, say, a normal distribution, and then compute the distance based on the parameters of the distributions? That is possible, and in fact makes for a potentially more efficient technique, but you lose some information by doing this. Ignoring that your data might not be approximately normal (it might have some curvature), with Earthmover distance, you get point-by-point details about how each data point affects the outcome.

This kind of attention to detail can be very important in certain situations. One that I’ve been paying close attention to recently is the problem of studying gerrymandering from a mathematical perspective. Justin Solomon of MIT is a champion of the Earthmover distance (see his fascinating talk here for more, with slides) which is just one topic in a field called “optimal transport.”

This has the potential to be useful in redistricting because of the nature of the redistricting problem. As I wrote previously, discussions of redistricting are chock-full of geometry—or at least geometric-sounding language—and people are very concerned with the apparent “compactness” of a districting plan. But the underlying data used to perform redistricting isn’t very accurate. The people who build the maps don’t have precise data on voting habits, or even locations where people live. Census tracts might not be perfectly aligned, and data can just plain have errors and uncertainty in other respects. So the data that district-map-drawers care about is uncertain much like our point clouds. With a theory of geometry that accounts for uncertainty (and the Earthmover distance is the “distance” part of that), one can come up with more robust, better tools for redistricting.

Solomon’s website has a ton of resources about this, under the names of “optimal transport” and “Wasserstein metric,” and his work extends from computing distances to computing important geometric values like the barycenter, computational advantages like parallelism.

Others in the field have come up with transparency techniques to make it clearer how the Earthmover distance relates to the geometry of the underlying space. This one is particularly fun because the explanations result in a path traveled from the start to the finish, and by setting up the underlying metric in just such a way, you can watch the distribution navigate a maze to get to its target. I like to imagine tiny ants carrying all that dirt.

Screen Shot 2018-03-03 at 6.15.50 PM.png

Finally, work of Shirdhonkar and Jacobs provide approximation algorithms that allow linear-time computation, instead of the worst-case-cubic runtime of a linear solver.

Bayesian Ranking for Rated Items

Problem: You have a catalog of items with discrete ratings (thumbs up/thumbs down, or 5-star ratings, etc.), and you want to display them in the “right” order.

Solution: In Python

'''
  score: [int], [int], [float] -> float

  Return the expected value of the rating for an item with known
  ratings specified by `ratings`, prior belief specified by
  `rating_prior`, and a utility function specified by `rating_utility`,
  assuming the ratings are a multinomial distribution and the prior
  belief is a Dirichlet distribution.
'''
def score(self, ratings, rating_prior, rating_utility):
    ratings = [r + p for (r, p) in zip(ratings, rating_prior)]
    score = sum(r * u for (r, u) in zip(ratings, rating_utility))
    return score / sum(ratings)

Discussion: This deceptively short solution can lead you on a long and winding path into the depths of statistics. I will do my best to give a short, clear version of the story.

As a working example I chose merely because I recently listened to a related podcast, say you’re selling mass-market romance novels—which, by all accounts, is a predictable genre. You have a list of books, each of which has been rated on a scale of 0-5 stars by some number of users. You want to display the top books first, so that time-constrained readers can experience the most titillating novels first, and newbies to the genre can get the best first time experience and be incentivized to buy more.

The setup required to arrive at the above code is the following, which I’ll phrase as a story.

Users’ feelings about a book, and subsequent votes, are independent draws from a known distribution (with unknown parameters). I will just call these distributions “discrete” distributions. So given a book and user, there is some unknown list (p_0, p_1, p_2, p_3, p_4, p_5) of probabilities (\sum_i p_i = 1) for each possible rating a user could give for that book.

But how do users get these probabilities? In this story, the probabilities are the output of a randomized procedure that generates distributions. That modeling assumption is called a “Dirichlet prior,” with Dirichlet meaning it generates discrete distributions, and prior meaning it encodes domain-specific information (such as the fraction of 4-star ratings for a typical romance novel).

So the story is you have a book, and that book gets a Dirichlet distribution (unknown to us), and then when a user comes along they sample from the Dirichlet distribution to get a discrete distribution, which they then draw from to choose a rating. We observe the ratings, and we need to find the book’s underlying Dirichlet. We start by assigning it some default Dirichlet (the prior) and update that Dirichlet as we observe new ratings. Some other assumptions:

  1. Books are indistinguishable except in the parameters of their Dirichlet distribution.
  2. The parameters of a book’s Dirichlet distribution don’t change over time, and inherently reflect the book’s value.

So a Dirichlet distribution is a process that produces discrete distributions. For simplicity, in this post we will say a Dirichlet distribution is parameterized by a list of six integers (n_0, \dots, n_5), one for each possible star rating. These values represent our belief in the “typical” distribution of votes for a new book. We’ll discuss more about how to set the values later. Sampling a value (a book’s list of probabilities) from the Dirichlet distribution is not trivial, but we don’t need to do that for this program. Rather, we need to be able to interpret a fixed Dirichlet distribution, and update it given some observed votes.

The interpretation we use for a Dirichlet distribution is its expected value, which, recall, is the parameters of a discrete distribution. In particular if n = \sum_i n_i, then the expected value is a discrete distribution whose probabilities are

\displaystyle \left (  \frac{n_0}{n}, \frac{n_1}{n}, \dots, \frac{n_5}{n} \right )

So you can think of each integer in the specification of a Dirichlet as “ghost ratings,” sometimes called pseudocounts, and we’re saying the probability is proportional to the count.

This is great, because if we knew the true Dirichlet distribution for a book, we could compute its ranking without a second thought. The ranking would simply be the expected star rating:

def simple_score(distribution):
   return sum(i * p for (i, p) in enumerate(distribution))

Putting books with the highest score on top would maximize the expected happiness of a user visiting the site, provided that happiness matches the user’s voting behavior, since the simple_score is just the expected vote.

Also note that all the rating system needs to make this work is that the rating options are linearly ordered. So a thumbs up/down (heaving bosom/flaccid member?) would work, too. We don’t need to know how happy it makes them to see a 5-star vs 4-star book. However, because as we’ll see next we have to approximate the distribution, and hence have uncertainty for scores of books with only a few ratings, it helps to incorporate numerical utility values (we’ll see this at the end).

Next, to update a given Dirichlet distribution with the results of some observed ratings, we have to dig a bit deeper into Bayes rule and the formulas for sampling from a Dirichlet distribution. Rather than do that, I’ll point you to this nice writeup by Jonathan Huang, where the core of the derivation is in Section 2.3 (page 4), and remark that the rule for updating for a new observation is to just add it to the existing counts.

Theorem: Given a Dirichlet distribution with parameters (n_1, \dots, n_k) and a new observation of outcome i, the updated Dirichlet distribution has parameters (n_1, \dots, n_{i-1}, n_i + 1, n_{i+1}, \dots, n_k). That is, you just update the i-th entry by adding 1 to it.

This particular arithmetic to do the update is a mathematical consequence (derived in the link above) of the philosophical assumption that Bayes rule is how you should model your beliefs about uncertainty, coupled with the assumption that the Dirichlet process is how the users actually arrive at their votes.

The initial values (n_0, \dots, n_5) for star ratings should be picked so that they represent the average rating distribution among all prior books, since this is used as the default voting distribution for a new, unknown book. If you have more information about whether a book is likely to be popular, you can use a different prior. For example, if JK Rowling wrote a Harry Potter Romance novel that was part of the canon, you could pretty much guarantee it would be popular, and set n_5 high compared to n_0. Of course, if it were actually popular you could just wait for the good ratings to stream in, so tinkering with these values on a per-book basis might not help much. On the other hand, most books by unknown authors are bad, and n_5 should be close to zero. Selecting a prior dictates how influential ratings of new items are compared to ratings of items with many votes. The more pseudocounts you add to the prior, the less new votes count.

This gets us to the following code for star ratings.

def score(self, ratings, rating_prior):
    ratings = [r + p for (r, p) in zip(ratings, rating_prior)]
    score = sum(i * u for (i, u) in enumerate(ratings))
    return score / sum(ratings)

The only thing missing from the solution at the beginning is the utilities. The utilities are useful for two reasons. First, because books with few ratings encode a lot of uncertainty, having an idea about how extreme a feeling is implied by a specific rating allows one to give better rankings of new books.

Second, for many services, such as taxi rides on Lyft, the default star rating tends to be a 5-star, and 4-star or lower mean something went wrong. For books, 3-4 stars is a default while 5-star means you were very happy.

The utilities parameter allows you to weight rating outcomes appropriately. So if you are in a Lyft-like scenario, you might specify utilities like [-10, -5, -3, -2, 1] to denote that a 4-star rating has the same negative impact as two 5-star ratings would positively contribute. On the other hand, for books the gap between 4-star and 5-star is much less than the gap between 3-star and 4-star. The utilities simply allow you to calibrate how the votes should be valued in comparison to each other, instead of using their literal star counts.

Load Balancing and the Power of Hashing

Here’s a bit of folklore I often hear (and retell) that’s somewhere between a joke and deep wisdom: if you’re doing a software interview that involves some algorithms problem that seems hard, your best bet is to use hash tables.

More succinctly put: Google loves hash tables.

As someone with a passion for math and theoretical CS, it’s kind of silly and reductionist. But if you actually work with terabytes of data that can’t fit on a single machine, it also makes sense.

But to understand why hash tables are so applicable, you should have at least a fuzzy understanding of the math that goes into it, which is surprisingly unrelated to the actual act of hashing. Instead it’s the guarantees that a “random enough” hash provides that makes it so useful. The basic intuition is that if you have an algorithm that works well assuming the input data is completely random, then you can probably get a good guarantee by preprocessing the input by hashing.

In this post I’ll explain the details, and show the application to an important problem that one often faces in dealing with huge amounts of data: how to allocate resources efficiently (load balancing). As usual, all of the code used in the making of this post is available on Github.

Next week, I’ll follow this post up with another application of hashing to estimating the number of distinct items in a set that’s too large to store in memory.

Families of Hash Functions

To emphasize which specific properties of hash functions are important for a given application, we start by introducing an abstraction: a hash function is just some computable function that accepts strings as input and produces numbers between 1 and n as output. We call the set of allowed inputs U (for “Universe”). A family of hash functions is just a set of possible hash functions to choose from. We’ll use a scripty \mathscr{H} for our family, and so every hash function h in \mathscr{H} is a function h : U \to \{ 1, \dots, n \}.

You can use a single hash function h to maintain an unordered set of objects in a computer. The reason this is a problem that needs solving is because if you were to store items sequentially in a list, and if you want to determine if a specific item is already in the list, you need to potentially check every item in the list (or do something fancier). In any event, without hashing you have to spend some non-negligible amount of time searching. With hashing, you can choose the location of an element x \in U based on the value of its hash h(x). If you pick your hash function well, then you’ll have very few collisions and can deal with them efficiently. The relevant section on Wikipedia has more about the various techniques to deal with collisions in hash tables specifically, but we want to move beyond that in this post.

Here we have a family of random hash functions. So what’s the use of having many hash functions? You can pick a hash randomly from a “good” family of hash functions. While this doesn’t seem so magical, it has the informal property that it makes arbitrary data “random enough,” so that an algorithm which you designed to work with truly random data will also work with the hashes of arbitrary data. Moreover, even if an adversary knows \mathscr{H} and knows that you’re picking a hash function at random, there’s no way for the adversary to manufacture problems by feeding bad data. With overwhelming probability the worst-case scenario will not occur. Our first example of this is in load-balancing.

Load balancing and 2-uniformity

You can imagine load balancing in two ways, concretely and mathematically. In the concrete version you have a public-facing server that accepts requests from users, and forwards them to a back-end server which processes them and sends a response to the user. When you have a billion users and a million servers, you want to forward the requests in such a way that no server gets too many requests, or else the users will experience delays. Moreover, you’re worried that the League of Tanzanian Hackers is trying to take down your website by sending you requests in a carefully chosen order so as to screw up your load balancing algorithm.

The mathematical version of this problem usually goes with the metaphor of balls and bins. You have some collection of m balls and n bins in which to put the balls, and you want to put the balls into the bins. But there’s a twist: an adversary is throwing balls at you, and you have to put them into the bins before the next ball comes, so you don’t have time to remember (or count) how many balls are in each bin already. You only have time to do a small bit of mental arithmetic, sending ball i to bin f(i) where f is some simple function. Moreover, whatever rule you pick for distributing the balls in the bins, the adversary knows it and will throw balls at you in the worst order possible.

silk-balls.jpg

A young man applying his knowledge of balls and bins. That’s totally what he’s doing.

There is one obvious approach: why not just pick a uniformly random bin for each ball? The problem here is that we need the choice to be persistent. That is, if the adversary throws the same ball at us a second time, we need to put it in the same bin as the first time, and it doesn’t count toward the overall load. This is where the ball/bin metaphor breaks down. In the request/server picture, there is data specific to each user stored on the back-end server between requests (a session), and you need to make sure that data is not lost for some reasonable period of time. And if we were to save a uniform random choice after each request, we’d need to store a number for every request, which is too much. In short, we need the mapping to be persistent, but we also want it to be “like random” in effect.

So what do you do? The idea is to take a “good” family of hash functions \mathscr{H}, pick one h \in \mathscr{H} uniformly at random for the whole game, and when you get a request/ball x \in U send it to server/bin h(x). Note that in this case, the adversary knows your universal family \mathscr{H} ahead of time, and it knows your algorithm of committing to some single randomly chosen h \in \mathscr{H}, but the adversary does not know which particular h you chose.

The property of a family of hash functions that makes this strategy work is called 2-universality.

Definition: A family of functions \mathscr{H} from some universe U \to \{ 1, \dots, n \}. is called 2-universal if, for every two distinct x, y \in U, the probability over the random choice of a hash function h from \mathscr{H} that h(x) = h(y) is at most 1/n. In notation,

\displaystyle \Pr_{h \in \mathscr{H}}[h(x) = h(y)] \leq \frac{1}{n}

I’ll give an example of such a family shortly, but let’s apply this to our load balancing problem. Our load-balancing algorithm would fail if, with even some modest probability, there is some server that receives many more than its fair share (m/n) of the m requests. If \mathscr{H} is 2-universal, then we can compute an upper bound on the expected load of a given server, say server 1. Specifically, pick any element x which hashes to 1 under our randomly chosen h. Then we can compute an upper bound on the expected number of other elements that hash to 1. In this computation we’ll only use the fact that expectation splits over sums, and the definition of 2-universal. Call \mathbf{1}_{h(y) = 1} the random variable which is zero when h(y) \neq 1 and one when h(y) = 1, and call X = \sum_{y \in U} \mathbf{1}_{h(y) = 1}. In words, X simply represents the number of inputs that hash to 1. Then

exp-calc

So in expectation we can expect server 1 gets its fair share of requests. And clearly this doesn’t depend on the output hash being 1; it works for any server. There are two obvious questions.

  1. How do we measure the risk that, despite the expectation we computed above, some server is overloaded?
  2. If it seems like (1) is on track to happen, what can you do?

For 1 we’re asking to compute, for a given deviation t, the probability that X - \mathbb{E}[X] > t. This makes more sense if we jump to multiplicative factors, since it’s usually okay for a server to bear twice or three times its usual load, but not like \sqrt{n} times more than it’s usual load. (Industry experts, please correct me if I’m wrong! I’m far from an expert on the practical details of load balancing.)

So we want to know what is the probability that X - \mathbb{E}[X] > t \cdot \mathbb{E}[X] for some small number t, and we want this to get small quickly as t grows. This is where the Chebyshev inequality becomes useful. For those who don’t want to click the link, for our sitauation Chebyshev’s inequality is the statement that, for any random variable X

\displaystyle \Pr[|X - \mathbb{E}[X]| > t\mathbb{E}[X]] \leq \frac{\textup{Var}[X]}{t^2 \mathbb{E}^2[X]}.

So all we need to do is compute the variance of the load of a server. It’s a bit of a hairy calculation to write down, but rest assured it doesn’t use anything fancier than the linearity of expectation and 2-universality. Let’s dive in. We start by writing the definition of variance as an expectation, and then we split X up into its parts, expand the product and group the parts.

\displaystyle \textup{Var}[X] = \mathbb{E}[(X - \mathbb{E}[X])^2] = \mathbb{E}[X^2] - (\mathbb{E}[X])^2

The easy part is (\mathbb{E}[X])^2, it’s just (1 + (m-1)/n)^2, and the hard part is \mathbb{E}[X^2]. So let’s compute that

esquared-calcluation

In order to continue (and get a reasonable bound) we need an additional property of our hash family which is not immediately spelled out by 2-universality. Specifically, we need that for every h and i, \Pr_x[h(x) = i] = O(\frac{1}{n}). In other words, each hash function should evenly split the inputs across servers.

The reason this helps is because we can split \Pr[h(x) = h(y) = 1]  into \Pr[h(x) = h(y) \mid h(x) = 1] \cdot \Pr[h(x) = 1]. Using 2-universality to bound the left term, this quantity is at most 1/n^2, and since there are \binom{m}{2} total terms in the double sum above, the whole thing is at most O(m/n + m^2 / n^2) = O(m^2 / n^2). Note that in our big-O analysis we’re assuming m is much bigger than n.

Sweeping some of the details inside the big-O, this means that our variance is O(m^2/n^2), and so our bound on the deviation of X from its expectation by a multiplicative factor of t is at most O(1/t^2).

Now we computed a bound on the probability that a single server is not overloaded, but if we want to extend that to the worst-case server, the typical probability technique is to take the union bound over all servers. This means we just add up all the individual bounds and ignore how they relate. So the probability that some server has a load more than a multiplicative factor of t is bounded from above O(n/t^2). This is only less than one when t = \Omega(\sqrt{n}), so all we can say with this analysis is that (with some small constant probability) no server will have a load worse than \sqrt{n} times more than the expected load.

So we have this analysis that seems not so good. If we have a million servers then the worst load on one server could potentially be a thousand times higher than the expected load. This doesn’t scale, and the problem could be in any (or all) of three places:

  1. Our analysis is weak, and we should use tighter bounds because the true max load is actually much smaller.
  2. Our hash families don’t have strong enough properties, and we should beef those up to get tighter bounds.
  3. The whole algorithm sucks and needs to be improved.

It turns out all three are true. One heuristic solution is easy and avoids all math. Have some second server (which does not process requests) count hash collisions. When some server exceeds a factor of t more than the expected load, send a message to the load balancer to randomly pick a new hash function from \mathscr{H} and for any requests that don’t have existing sessions (this is included in the request data), use the new hash function. Once the old sessions expire, switch any new incoming requests from those IPs over to the new hash function.

But there are much better solutions out there. Unfortunately their analyses are too long for a blog post (they fill multiple research papers). Fortunately their descriptions and guarantees are easy to describe, and they’re easy to program. The basic idea goes by the name “the power of two choices,” which we explored on this blog in a completely different context of random graphs.

In more detail, the idea is that you start by picking two random hash functions h_1, h_2 \in \mathscr{H}, and when you get a new request, you compute both hashes, inspect the load of the two servers indexed by those hashes, and send the request to the server with the smaller load.

This has the disadvantage of requiring bidirectional talk between the load balancer and the server, rather than obliviously forwarding requests. But the advantage is an exponential decrease in the worst-case maximum load. In particular, the following theorem holds for the case where the hashes are fully random.

Theorem: Suppose one places m balls into n bins in order according to the following procedure: for each ball pick two uniformly random and independent integers 1 \leq i,j \leq n, and place the ball into the bin with the smallest current size. If there are ties pick the bin with the smaller index. Then with high probability the largest bin has no more than \Theta(m/n) + O(\log \log (n)) balls.

This theorem appears to have been proved in a few different forms, with the best analysis being by Berenbrink et al. You can improve the constant on the \log \log n by computing more than 2 hashes. How does this relate to a good family of hash functions, which is not quite fully random? Let’s explore the answer by implementing the algorithm in python.

An example of universal hash functions, and the load balancing algorithm

In order to implement the load balancer, we need to have some good hash functions under our belt. We’ll go with the simplest example of a hash function that’s easy to prove nice properties for. Specifically each hash in our family just performs some arithmetic modulo a random prime.

Definition: Pick any prime p > m, and for any 1 \leq a < p and 0 \leq b \leq n define h_{a,b}(x) = (ax + b \mod p) \mod m. Let \mathscr{H} = \{ h_{a,b} \mid 0 \leq b < p, 1 \leq a < p \}.

This family of hash functions is 2-universal.

Theorem: For every x \neq y \in \{0, \dots, p\},

\Pr_{h \in \mathscr{H}}[h(x) = h(y)] \leq 1/p

Proof. To say that h(x) = h(y) is to say that ax+b = ay+b + i \cdot m \mod p for some integer i. I.e., the two remainders of ax+b and ay+b are equivalent mod m. The b‘s cancel and we can solve for a

a = im (x-y)^{-1} \mod p

Since a \neq 0, there are p-1 possible choices for a. Moreover, there is no point to pick i bigger than p/m since we’re working modulo p. So there are (p-1)/m possible values for the right hand side of the above equation. So if we chose them uniformly at random, (remember, x-y is fixed ahead of time, so the only choice is a, i), then there is a (p-1)/m out of p-1 chance that the equality holds, which is at most 1/m. (To be exact you should account for taking a floor of (p-1)/m when m does not evenly divide p-1, but it only decreases the overall probability.)

\square

If m and p were equal then this would be even more trivial: it’s just the fact that there is a unique line passing through any two distinct points. While that’s obviously true from standard geometry, it is also true when you work with arithmetic modulo a prime. In fact, it works using arithmetic over any field.

Implementing these hash functions is easier than shooting fish in a barrel.

import random

def draw(p, m):
   a = random.randint(1, p-1)
   b = random.randint(0, p-1)

   return lambda x: ((a*x + b) % p) % m

To encapsulate the process a little bit we implemented a UniversalHashFamily class which computes a random probable prime to use as the modulus and stores m. The interested reader can see the Github repository for more.

If we try to run this and feed in a large range of inputs, we can see how the outputs are distributed. In this example m is a hundred thousand and n is a hundred (it’s not two terabytes, but give me some slack it’s a demo and I’ve only got my desktop!). So the expected bin size for any 2-universal family is just about 1,000.

>>> m = 100000
>>> n = 100
>>> H = UniversalHashFamily(numBins=n, primeBounds=[n, 2*n])
>>> results = []
>>> for simulation in range(100):
...    bins = [0] * n
...    h = H.draw()
...    for i in range(m):
...       bins[h(i)] += 1
...    results.append(max(bins))
...
>>> max(bins) # a single run
1228
>>> min(bins)
613
>>> max(results) # the max bin size over all runs
1228
>>> min(results)
1227

Indeed, the max is very close to the expected value.

But this example is misleading, because the point of this was that some adversary would try to screw us over by picking a worst-case input. If the adversary knew exactly which h was chosen (which it doesn’t) then the worst case input would be the set of all inputs that have the given hash output value. Let’s see it happen live.

>>> h = H.draw()
>>> badInputs = [i for i in range(m) if h(i) == 9]
>>> len(badInputs)
1227
>>> testInputs(n,m,badInputs,hashFunction=h)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1227, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

The expected size of a bin is 12, but as expected this is 100 times worse (linearly worse in n). But if we instead pick a random h after the bad inputs are chosen, the result is much better.

>>> testInputs(n,m,badInputs) # randomly picks a hash
[19, 20, 20, 19, 18, 18, 17, 16, 16, 16, 16, 17, 18, 18, 19, 20, 20, 19, 18, 17, 17, 16, 16, 16, 16, 17, 18, 18, 19, 20, 20, 19, 18, 17, 17, 16, 16, 16, 16, 8, 8, 9, 9, 10, 10, 10, 10, 9, 9, 8, 8, 8, 8, 8, 8, 9, 9, 10, 10, 10, 10, 9, 9, 8, 8, 8, 8, 8, 8, 9, 9, 10, 10, 10, 10, 9, 8, 8, 8, 8, 8, 8, 8, 9, 9, 10, 10, 10, 10, 9, 8, 8, 8, 8, 8, 8, 8, 9, 9, 10]

However, if you re-ran this test many times, you’d eventually get unlucky and draw the hash function for which this actually is the worst input, and get a single huge bin. Other times you can get a bad hash in which two or three bins have all the inputs.

An interesting question is, what is really the worst-case input for this algorithm? I suspect it’s characterized by some choice of hash output values, taking all inputs for the chosen outputs. If this is the case, then there’s a tradeoff between the number of inputs you pick and how egregious the worst bin is. As an exercise to the reader, empirically estimate this tradeoff and find the best worst-case input for the adversary. Also, for your choice of parameters, estimate by simulation the probability that the max bin is three times larger than the expected value.

Now that we’ve played around with the basic hashing algorithm and made a family of 2-universal hashes, let’s see the power of two choices. Recall, this algorithm picks two random hash functions and sends an input to the bin with the smallest size. This obviously generalizes to k choices, although the theoretical guarantee only improves by a constant factor, so let’s implement the more generic version.

class ChoiceHashFamily(object):
   def __init__(self, hashFamily, queryBinSize, numChoices=2):
      self.queryBinSize = queryBinSize
      self.hashFamily = hashFamily
      self.numChoices = numChoices

   def draw(self):
      hashes = [self.hashFamily.draw()
                   for _ in range(self.numChoices)]

      def h(x):
         indices = [h(x) for h in hashes]
         counts = [self.queryBinSize(i) for i in indices]
         count, index = min([(c,i) for (c,i) in zip(counts,indices)])
         return index

      return h

And if we test this with the bad inputs (as used previously, all the inputs that hash to 9), as a typical output we get

>>> bins
[15, 16, 15, 15, 16, 14, 16, 14, 16, 15, 16, 15, 15, 15, 17, 14, 16, 14, 16, 16, 15, 16, 15, 16, 15, 15, 17, 15, 16, 15, 15, 15, 15, 16, 15, 14, 16, 14, 16, 15, 15, 15, 14, 16, 15, 15, 15, 14, 17, 14, 15, 15, 14, 16, 13, 15, 14, 15, 15, 15, 14, 15, 13, 16, 14, 16, 15, 15, 15, 16, 15, 15, 13, 16, 14, 15, 15, 16, 14, 15, 15, 15, 11, 13, 11, 12, 13, 14, 13, 11, 11, 12, 14, 14, 13, 10, 16, 12, 14, 10]

And a typical list of bin maxima is

>>> results
[16, 16, 16, 18, 17, 365, 18, 16, 16, 365, 18, 17, 17, 17, 17, 16, 16, 17, 18, 16, 17, 18, 17, 16, 17, 17, 18, 16, 18, 17, 17, 17, 17, 18, 18, 17, 17, 16, 17, 365, 17, 18, 16, 16, 18, 17, 16, 18, 365, 16, 17, 17, 16, 16, 18, 17, 17, 17, 17, 17, 18, 16, 18, 16, 16, 18, 17, 17, 365, 16, 17, 17, 17, 17, 16, 17, 16, 17, 16, 16, 17, 17, 16, 365, 18, 16, 17, 17, 17, 17, 17, 18, 17, 17, 16, 18, 18, 17, 17, 17]

Those big bumps are the times when we picked an unlucky hash function, which is scarily large, although this bad event would be proportionally less likely as you scale up. But in the good case the load is clearly more even than the previous example, and the max load would get linearly smaller as you pick between a larger set of randomly chosen hashes (obviously).

Coupling this with the technique of switching hash functions when you start to observe a large deviation, and you have yourself an elegant solution.

In addition to load balancing, hashing has a ton of applications. Remember, the main key that you may want to use hashing is when you have an algorithm that works well when the input data is random. This comes up in streaming and sublinear algorithms, in data structure design and analysis, and many other places. We’ll be covering those applications in future posts on this blog.

Until then!

The Boosting Margin, or Why Boosting Doesn’t Overfit

There’s a well-understood phenomenon in machine learning called overfitting. The idea is best shown by a graph:

overfitting

Let me explain. The vertical axis represents the error of a hypothesis. The horizontal axis represents the complexity of the hypothesis. The blue curve represents the error of a machine learning algorithm’s output on its training data, and the red curve represents the generalization of that hypothesis to the real world. The overfitting phenomenon is marker in the middle of the graph, before which the training error and generalization error both go down, but after which the training error continues to fall while the generalization error rises.

The explanation is a sort of numerical version of Occam’s Razor that says more complex hypotheses can model a fixed data set better and better, but at some point a simpler hypothesis better models the underlying phenomenon that generates the data. To optimize a particular learning algorithm, one wants to set parameters of their model to hit the minimum of the red curve.

This is where things get juicy. Boosting, which we covered in gruesome detail previously, has a natural measure of complexity represented by the number of rounds you run the algorithm for. Each round adds one additional “weak learner” weighted vote. So running for a thousand rounds gives a vote of a thousand weak learners. Despite this, boosting doesn’t overfit on many datasets. In fact, and this is a shocking fact, researchers observed that Boosting would hit zero training error, they kept running it for more rounds, and the generalization error kept going down! It seemed like the complexity could grow arbitrarily without penalty.

Schapire, Freund, Bartlett, and Lee proposed a theoretical explanation for this based on the notion of a margin, and the goal of this post is to go through the details of their theorem and proof. Remember that the standard AdaBoost algorithm produces a set of weak hypotheses h_i(x) and a corresponding weight \alpha_i \in [-1,1] for each round i=1, \dots, T. The classifier at the end is a weighted majority vote of all the weak learners (roughly: weak learners with high error on “hard” data points get less weight).

Definition: The signed confidence of a labeled example (x,y) is the weighted sum:

\displaystyle \textup{conf}(x) = \sum_{i=1}^T \alpha_i h_i(x)

The margin of (x,y) is the quantity \textup{margin}(x,y) = y \textup{conf}(x). The notation implicitly depends on the outputs of the AdaBoost algorithm via “conf.”

We use the product of the label and the confidence for the observation that y \cdot \textup{conf}(x) \leq 0 if and only if the classifier is incorrect. The theorem we’ll prove in this post is

Theorem: With high probability over a random choice of training data, for any 0 < \theta < 1 generalization error of boosting is bounded from above by

\displaystyle \Pr_{\textup{train}}[\textup{margin}(x) \leq \theta] + O \left ( \frac{1}{\theta} (\textup{typical error terms}) \right )

In words, the generalization error of the boosting hypothesis is bounded by the distribution of margins observed on the training data. To state and prove the theorem more generally we have to return to the details of PAC-learning. Here and in the rest of this post, \Pr_D denotes \Pr_{x \sim D}, the probability over a random example drawn from the distribution D, and \Pr_S denotes the probability over a random (training) set of examples drawn from D.

Theorem: Let S be a set of m random examples chosen from the distribution D generating the data. Assume the weak learner corresponds to a finite hypothesis space H of size |H|, and let \delta > 0. Then with probability at least 1 - \delta (over the choice of S), every weighted-majority vote function f satisfies the following generalization bound for every \theta > 0.

\displaystyle \Pr_D[y f(x) \leq 0] \leq \Pr_S[y f(x) \leq \theta] + O \left ( \frac{1}{\sqrt{m}} \sqrt{\frac{\log m \log |H|}{\theta^2} + \log(1/\delta)} \right )

In other words, this phenomenon is a fact about voting schemes, not boosting in particular. From now on, a “majority vote” function f(x) will mean to take the sign of a sum of the form \sum_{i=1}^N a_i h_i(x), where a_i \geq 0 and \sum_i a_i = 1. This is the “convex hull” of the set of weak learners H. If H is infinite (in our proof it will be finite, but we’ll state a generalization afterward), then only finitely many of the a_i in the sum may be nonzero.

To prove the theorem, we’ll start by defining a class of functions corresponding to “unweighted majority votes with duplicates:”

Definition: Let C_N be the set of functions f(x) of the form \frac{1}{N} \sum_{i=1}^N h_i(x) where h_i \in H and the h_i may contain duplicates (some of the h_i may be equal to some other of the h_j).

Now every majority vote function f can be written as a weighted sum of h_i with weights a_i (I’m using a instead of \alpha to distinguish arbitrary weights from those weights arising from Boosting). So any such f(x) defines a natural distribution over H where you draw function h_i with probability a_i. I’ll call this distribution A_f. If we draw from this distribution N times and take an unweighted sum, we’ll get a function g(x) \in C_N. Call the random process (distribution) generating functions in this way Q_f. In diagram form, the logic goes

f \to weights a_i \to distribution over H \to function in C_N by drawing N times according to H.

The main fact about the relationship between f and Q_f is that each is completely determined by the other. Obviously Q_f is determined by f because we defined it that way, but f is also completely determined by Q_f as follows:

\displaystyle f(x) = \mathbb{E}_{g \sim Q_f}[g(x)]

Proving the equality is an exercise for the reader.

Proof of Theorem. First we’ll split the probability \Pr_D[y f(x) \leq 0] into two pieces, and then bound each piece.

First a probability reminder. If we have two events A and B (in what’s below, this will be yg(x) \leq \theta/2 and yf(x) \leq 0, we can split up \Pr[A] into \Pr[A \textup{ and } B] + \Pr[A \textup{ and } \overline{B}] (where \overline{B} is the opposite of B). This is called the law of total probability. Moreover, because \Pr[A \textup{ and } B] = \Pr[A | B] \Pr[B] and because these quantities are all at most 1, it’s true that \Pr[A \textup{ and } B] \leq \Pr[A \mid B] (the conditional probability) and that \Pr[A \textup{ and } B] \leq \Pr[B].

Back to the proof. Notice that for any g(x) \in C_N and any \theta > 0, we can write \Pr_D[y f(x) \leq 0] as a sum:

\displaystyle \Pr_D[y f(x) \leq 0] =\\ \Pr_D[yg(x) \leq \theta/2 \textup{ and } y f(x) \leq 0] + \Pr_D[yg(x) > \theta/2 \textup{ and } y f(x) \leq 0]

Now I’ll loosen the first term by removing the second event (that only makes the whole probability bigger) and loosen the second term by relaxing it to a conditional:

\displaystyle \Pr_D[y f(x) \leq 0] \leq \Pr_D[y g(x) \leq \theta / 2] + \Pr_D[yg(x) > \theta/2 \mid yf(x) \leq 0]

Now because the inequality is true for every g(x) \in C_N, it’s also true if we take an expectation of the RHS over any distribution we choose. We’ll choose the distribution Q_f to get

\displaystyle \Pr_D[yf(x) \leq 0] \leq T_1 + T_2

And T_1 (term 1) is

\displaystyle T_1 = \Pr_{x \sim D, g \sim Q_f} [yg(x) \leq \theta /2] = \mathbb{E}_{g \sim Q_f}[\Pr_D[yg(x) \leq \theta/2]]

And T_2 is

\displaystyle \Pr_{x \sim D, g \sim Q_f}[yg(x) > \theta/2 \mid yf(x) \leq 0] = \mathbb{E}_D[\Pr_{g \sim Q_f}[yg(x) > \theta/2 \mid yf(x) \leq 0]]

We can rewrite the probabilities using expectations because (1) the variables being drawn in the distributions are independent, and (2) the probability of an event is the expectation of the indicator function of the event.

Now we’ll bound the terms T_1, T_2 separately. We’ll start with T_2.

Fix (x,y) and look at the quantity inside the expectation of T_2.

\displaystyle \Pr_{g \sim Q_f}[yg(x) > \theta/2 \mid yf(x) \leq 0]

This should intuitively be very small for the following reason. We’re sampling g according to a distribution whose expectation is f, and we know that yf(x) \leq 0. Of course yg(x) is unlikely to be large.

Mathematically we can prove this by transforming the thing inside the probability to a form suitable for the Chernoff bound. Saying yg(x) > \theta / 2 is the same as saying |yg(x) - \mathbb{E}[yg(x)]| > \theta /2, i.e. that some random variable which is a sum of independent random variables (the h_i) deviates from its expectation by at least \theta/2. Since the y‘s are all \pm 1 and constant inside the expectation, they can be removed from the absolute value to get

\displaystyle \leq \Pr_{g \sim Q_f}[g(x) - \mathbb{E}[g(x)] > \theta/2]

The Chernoff bound allows us to bound this by an exponential in the number of random variables in the sum, i.e. N. It turns out the bound is e^{-N \theta^2 / 8}.

Now recall T_1

\displaystyle T_1 = \Pr_{x \sim D, g \sim Q_f} [yg(x) \leq \theta /2] = \mathbb{E}_{g \sim Q_f}[\Pr_D[yg(x) \leq \theta/2]]

For T_1, we don’t want to bound it absolutely like we did for T_2, because there is nothing stopping the classifier f from being a bad classifier and having lots of error. Rather, we want to bound it in terms of the probability that yf(x) \leq \theta. We’ll do this in two steps. In step 1, we’ll go from \Pr_D of the g‘s to \Pr_S of the g‘s.

Step 1: For any fixed g, \theta, if we take a sample S of size m, then consider the event in which the sample probability deviates from the true distribution by some value \varepsilon_N, i.e. the event

\displaystyle \Pr_D[yg(x) \leq \theta /2] > \Pr_{S, x \sim S}[yg(x) \leq \theta/2] + \varepsilon_N

The claim is this happens with probability at most e^{-2m\varepsilon_N^2}. This is again the Chernoff bound in disguise, because the expected value of \Pr_S is \Pr_D, and the probability over S is an average of random variables (it’s a slightly different form of the Chernoff bound; see this post for more). From now on we’ll drop the x \sim S when writing \Pr_S.

The bound above holds true for any fixed g,\theta, but we want a bound over all g and \theta. To do that we use the union bound. Note that there are only (N+1) possible choices for a nonnegative \theta because g(x) is a sum of N values each of which is either \pm1. And there are only |C_N| \leq |H|^N possibilities for g(x). So the union bound says the above event will occur with probability at most (N+1)|H|^N e^{-2m\varepsilon_N^2}.

If we want the event to occur with probability at most \delta_N, we can judiciously pick

\displaystyle \varepsilon_N = \sqrt{(1/2m) \log ((N+1)|H|^N / \delta_N)}

And since the bound holds in general, we can take expectation with respect to Q_f and nothing changes. This means that for any \delta_N, our chosen \varepsilon_N ensures that the following is true with probability at least 1-\delta_N:

\displaystyle \Pr_{D, g \sim Q_f}[yg(x) \leq \theta/2] \leq \Pr_{S, g \sim Q_f}[yg(x) \leq \theta/2] + \varepsilon_N

Now for step 2, we bound the probability that yg(x) \leq \theta/2 on a sample to the probability that yf(x) \leq \theta on a sample.

Step 2: The first claim is that

\displaystyle \Pr_{S, g \sim Q_f}[yg(x) \leq \theta / 2] \leq \Pr_{S} [yf(x) \leq \theta] + \mathbb{E}_{S}[\Pr_{g \sim Q_f}[yg(x) \leq \theta/2 \mid yf(x) \geq \theta]]

What we did was break up the LHS into two “and”s, when yf(x) > \theta and yf(x) \leq \theta (this was still an equality). Then we loosened the first term to \Pr_{S}[yf(x) \leq \theta] since that is only more likely than both yg(x) \leq \theta/2 and yf(x) \leq \theta. Then we loosened the second term again using the fact that a probability of an “and” is bounded by the conditional probability.

Now we have the probability of yg(x) \leq \theta / 2 bounded by the probability that yf(x) \leq 0 plus some stuff. We just need to bound the “plus some stuff” absolutely and then we’ll be done. The argument is the same as our previous use of the Chernoff bound: we assume yf(x) \geq \theta, and yet yg(x) \leq \theta / 2. So the deviation of yg(x) from its expectation is large, and the probability that happens is exponentially small in the amount of deviation. The bound you get is

\displaystyle \Pr_{g \sim Q}[yg(x) \leq \theta/2 \mid yf(x) > \theta] \leq e^{-N\theta^2 / 8}.

And again we use the union bound to ensure the failure of this bound for any N will be very small. Specifically, if we want the total failure probability to be at most \delta, then we need to pick some \delta_j‘s so that \delta = \sum_{j=0}^{\infty} \delta_j. Choosing \delta_N = \frac{\delta}{N(N+1)} works.

Putting everything together, we get that with probability at least 1-\delta for every \theta and every N, this bound on the failure probability of f(x):

\displaystyle \Pr_{x \sim D}[yf(x) \leq 0] \leq \Pr_{S, x \sim S}[yf(x) \leq \theta] + 2e^{-N \theta^2 / 8} + \sqrt{\frac{1}{2m} \log \left ( \frac{N(N+1)^2 |H|^N}{\delta} \right )}.

This claim is true for every N, so we can pick N that minimizes it. Doing a little bit of behind-the-scenes calculus that is left as an exercise to the reader, a tight choice of N is (4/ \theta)^2 \log(m/ \log |H|). And this gives the statement of the theorem.

\square

We proved this for finite hypothesis classes, and if you know what VC-dimension is, you’ll know that it’s a central tool for reasoning about the complexity of infinite hypothesis classes. An analogous theorem can be proved in terms of the VC dimension. In that case, calling d the VC-dimension of the weak learner’s output hypothesis class, the bound is

\displaystyle \Pr_D[yf(x) \leq 0] \leq \Pr_S[yf(x) \leq \theta] + O \left ( \frac{1}{\sqrt{m}} \sqrt{\frac{d \log^2(m/d)}{\theta^2} + \log(1/\delta)} \right )

How can we interpret these bounds with so many parameters floating around? That’s where asymptotic notation comes in handy. If we fix \theta \leq 1/2 and \delta = 0.01, then the big-O part of the theorem simplifies to \sqrt{(\log |H| \cdot \log m) / m}, which is easier to think about since (\log m)/m goes to zero very fast.

Now the theorem we just proved was about any weighted majority function. The question still remains: why is AdaBoost good? That follows from another theorem, which we’ll state and leave as an exercise (it essentially follows by unwrapping the definition of the AdaBoost algorithm from last time).

Theorem: Suppose that during AdaBoost the weak learners produce hypotheses with training errors \varepsilon_1, \dots , \varepsilon_T. Then for any \theta,

\displaystyle \Pr_{(x,y) \sim S} [yf(x) \leq \theta] \leq 2^T \prod_{t=1}^T \sqrt{\varepsilon_t^{(1-\theta)} (1-\varepsilon_t)^{(1+\theta)}}

Let’s interpret this for some concrete numbers. Say that \theta = 0 and \varepsilon_t is any fixed value less than 1/2. In this case the term inside product becomes \sqrt{\varepsilon (1-\varepsilon)} < 1/2 and the whole bound tends exponentially quickly to zero in the number of rounds T. On the other hand, if we raise \theta to about 1/3, then in order to maintain the LHS tending to zero we would need \varepsilon < \frac{1}{4} ( 3 - \sqrt{5} ) which is about 20% error.

If you’re interested in learning more about Boosting, there is an excellent book by Freund and Schapire (the inventors of boosting) called Boosting: Foundations and Algorithms. There they include a tighter analysis based on the idea of Rademacher complexity. The bound I presented in this post is nice because the proof doesn’t require any machinery past basic probability, but if you want to reach the cutting edge of knowledge about boosting you need to invest in the technical stuff.

Until next time!