Optimization Models for Subset Cover

In a recent newsletter article I complained about how researchers mislead about the applicability of their work. I gave SAT solvers as an example. People provided interesting examples in response, but what was new to me was the concept of SMT (Satisfiability Modulo Theories), an extension to SAT. SMT seems to have more practical uses than vanilla SAT (see the newsletter for details).

I wanted to take some time to explore SMT solvers, and I landed on Z3, an open-source SMT solver from Microsoft. In particular, I wanted to compare it to ILP (Integer Linear Programing) solvers, which I know relatively well. I picked a problem that I thought would work better for SAT-ish solvers than for ILPs: subset covering (explained in the next section). If ILP still wins against Z3, then that would be not so great for the claim that SMT is a production strength solver.

All the code used for this post is on Github.

Subset covering

A subset covering is a kind of combinatorial design, which can be explained in terms of magic rings.

An adventurer stumbles upon a chest full of magic rings. Each ring has a magical property, but some pairs of rings, when worn together on the same hand, produce a combined special magical effect distinct to that pair.

The adventurer would like to try all pairs of rings to catalogue the magical interactions. With only five fingers, how can we minimize the time spent trying on rings?

Mathematically, the rings can be described as a set X of size n. We want to choose a family F of subsets of X, with each subset having size 5 (five fingers), such that each subset of X of size 2 (pairs of rings) is contained in some subset of F. And we want F to be as small as possible.

Subset covering is not a “production worthy” problem. Rather, I could imagine it’s useful in some production settings, but I haven’t heard of one where it is actually used. I can imagine, for instance, that a cluster of machines has some bug occurring seemingly at random for some point-to-point RPCs, and in tracking down the problem, you want to deploy a test change to subsets of servers to observe the bug occurring. Something like an experiment design problem.

If you generalize the “5” in “5 fingers” to an arbitrary positive integer k, and the “2” in “2 rings” to l < k, then we have the general subset covering problem. Define M(n, k, l) to be the minimal number of subsets of size k needed to cover all subsets of size l. This problem was studied by Erdős, with a conjecture subsequently proved by Vojtěch Rödl, that asymptotically M(n,k,l) grows like \binom{n}{l} / \binom{k}{l}. Additional work by Joel Spencer showed that a greedy algorithm is essentially optimal.

However, all of the constructive algorithms in these proofs involve enumerating all \binom{n}{k} subsets of X. This wouldn’t scale very well. You can alternatively try a “random” method, incurring a typically \log(r) factor of additional sets required to cover a 1 - 1/r fraction of the needed subsets. This is practical, but imperfect.

To the best of my knowledge, there is no exact algorithm, that both achieves the minimum and is efficient in avoiding constructing all \binom{n}{k} subsets. So let’s try using an SMT solver. I’ll be using the Python library for Z3.

Baseline: brute force Z3

For a baseline, let’s start with a simple Z3 model that enumerates all the possible subsets that could be chosen. This leads to an exceedingly simple model to compare the complex models against.

Define boolean variables \textup{Choice}_S which is true if and only if the subset S is chosen (I call this a “choice set”). Define boolean variables \textup{Hit}_T which is true if the subset T (I call this a “hit set”) is contained in a chosen choice set. Then the subset cover problem can be defined by two sets of implications.

First, if \textup{Choice}_S is true, then so must all \textup{Hit}_T for T \subset S. E.g., for S = \{ 1, 2, 3 \} and l=2, we get

\displaystyle \begin{aligned} \textup{Choice}_{(1,2,3)} &\implies \textup{Hit}_{(1,2)} \\ \textup{Choice}_{(1,2,3)} &\implies \textup{Hit}_{(1,3)} \\ \textup{Choice}_{(1,2,3)} &\implies \textup{Hit}_{(2,3)} \end{aligned}

In Python this looks like the following (note this program has some previously created lookups and data structures containing the variables)

for choice_set in choice_sets:
  for hit_set_key in combinations(choice_set.elements, hit_set_size):
    hit_set = hit_set_lookup[hit_set_key]
    implications.append(
      z3.Implies(choice_set.variable, hit_set.variable))

Second, if \textup{Hit}_T is true, it must be that some \textup{Choice}_S is true for some S containing T as a subset. For example,

\displaystyle \begin{aligned} \textup{Hit}_{(1,2)} &\implies \\ & \textup{Choice}_{(1,2,3,4)} \textup{ OR} \\ & \textup{Choice}_{(1,2,3,5)} \textup{ OR} \\ & \textup{Choice}_{(1,2,4,5)} \textup{ OR } \cdots \\ \end{aligned}

In code,

for hit_set in hit_sets.values():
  relevant_choice_set_vars = [
    choice_set.variable
    for choice_set in hit_set_to_choice_set_lookup[hit_set]
  ]
  implications.append(
    z3.Implies(
      hit_set.variable, 
      z3.Or(*relevant_choice_set_vars)))

Next, in this experiment we’re allowing the caller to specify the number of choice sets to try, and the solver should either return SAT or UNSAT. From that, we can use a binary search to find the optimal number of sets to pick. Thus, we have to limit the number of \textup{Choice}_S that are allowed to be true and false. Z3 supports boolean cardinality constraints, apparently with a special solver to handle problems that have them. Otherwise, the process of encoding cardinality constraints as SAT formulas is not trivial (and the subject of active research). But the code is simple enough:

args = [cs.variable for cs in choice_sets] + [parameters.num_choice_sets]
choice_sets_at_most = z3.AtMost(*args)
choice_sets_at_least = z3.AtLeast(*args)

Finally, we must assert that every \textup{Hit}_T is true.

solver = z3.Solver()
for hit_set in hit_sets.values():
  solver.add(hit_set.variable)

for impl in implications:
  solver.add(impl)

solver.add(choice_sets_at_most)
solver.add(choice_sets_at_least)

Running it for n=7, k=3, l=2, and seven choice sets (which is optimal), we get

>>> SubsetCoverZ3BruteForce().solve(
  SubsetCoverParameters(
    num_elements=7,
    choice_set_size=3,
    hit_set_size=2,
    num_choice_sets=7))
[(0, 1, 3), (0, 2, 4), (0, 5, 6), (1, 2, 6), (1, 4, 5), (2, 3, 5), (3, 4, 6)]
SubsetCoverSolution(status=<SolveStatus.SOLVED: 1>, solve_time_seconds=0.018305063247680664)

Interestingly, Z3 refuses to solve marginally larger instances. For instance, I tried the following and Z3 times out around n=12, k=6 (about 8k choice sets):

from math import comb

for n in range(8, 16):
  k = int(n / 2)
  l = 3
  max_num_sets = int(2 * comb(n, l) / comb(k, l))
  params = SubsetCoverParameters(
    num_elements=n,
    choice_set_size=k,
    hit_set_size=l,                                   
    num_choice_sets=max_num_sets)

    print_table(
      params, 
      SubsetCoverZ3BruteForce().solve(params), 
      header=(n==8))

After taking a long time to generate the larger models, Z3 exceeds my 15 minute time limit, suggesting exponential growth:

status               solve_time_seconds  num_elements  choice_set_size  hit_set_size  num_choice_sets
SolveStatus.SOLVED   0.0271              8             4                3             28
SolveStatus.SOLVED   0.0346              9             4                3             42
SolveStatus.SOLVED   0.0735              10            5                3             24
SolveStatus.SOLVED   0.1725              11            5                3             33
SolveStatus.SOLVED   386.7376            12            6                3             22
SolveStatus.UNKNOWN  900.1419            13            6                3             28
SolveStatus.UNKNOWN  900.0160            14            7                3             20
SolveStatus.UNKNOWN  900.0794            15            7                3             26

An ILP model

Next we’ll see an ILP model for the sample problem. Note there are two reasons I expect the ILP model to fall short. First, the best solver I have access to is SCIP, which, despite being quite good is, in my experience, about an order of magnitude slower than commercial alternatives like Gurobi. Second, I think this sort of problem seems to not be very well suited to ILPs. It would take quite a bit longer to explain why (maybe another post, if you’re interested), but in short well-formed ILPs have easily found feasible solutions (this one does not), and the LP-relaxation of the problem should be as tight as possible. I don’t think my formulation is very tight, but it’s possible there is a better formulation.

Anyway, the primary difference in my ILP model from brute force is that the number of choice sets is fixed in advance, and the members of the choice sets are model variables. This allows us to avoid enumerating all choice sets in the model.

In particular, \textup{Member}_{S,i} \in \{ 0, 1 \} is a binary variable that is 1 if and only if element i is assigned to be in set S. And \textup{IsHit}_{T, S} \in \{0, 1\} is 1 if and only if the hit set T is a subset of S. Here “S” is an index over the subsets, rather than the set itself, because we don’t know what elements are in S while building the model.

For the constraints, each choice set S must have size k:

\displaystyle \sum_{i \in X} \textup{Member}_{S, i} = k

Each hit set T must be hit by at least one choice set:

\displaystyle \sum_{S} \textup{IsHit}_{T, S} \geq 1

Now the tricky constraint. If a hit set T is hit by a specific choice set S (i.e., \textup{IsHit}_{T, S} = 1) then all the elements in T must also be members of S.

\displaystyle \sum_{i \in T} \textup{Member}_{S, i} \geq l \cdot \textup{IsHit}_{T, S}

This one works by the fact that the left-hand side (LHS) is bounded from below by 0 and bounded from above by l = |T|. Then \textup{IsHit}_{T, S} acts as a switch. If it is 0, then the constraint is vacuous since the LHS is always non-negative. If \textup{IsHit}_{T, S} = 1, then the right-hand side (RHS) is l = |T| and this forces all variables on the LHS to be 1 to achieve it.

Because we fixed the number of choice sets as a parameter, the objective is 1, and all we’re doing is looking for a feasible solution. The full code is here.

On the same simple example as the brute force

>>> SubsetCoverILP().solve(
  SubsetCoverParameters(
    num_elements=7,
    choice_set_size=3,
    hit_set_size=2,
    num_choice_sets=7))
[(0, 1, 3), (0, 2, 6), (0, 4, 5), (1, 2, 4), (1, 5, 6), (2, 3, 5), (3, 4, 6)]
SubsetCoverSolution(status=<SolveStatus.SOLVED: 1>, solve_time_seconds=0.1065816879272461)

It finds the same solution in about 10x the runtime as the brute force Z3 model, though still well under one second.

On the “scaling” example, it fares much worse. With a timeout of 15 minutes, it solves n=8, decently fast, n=9,12 slowly, and times out on the rest.

status               solve_time_seconds  num_elements  choice_set_size  hit_set_size  num_choice_sets
SolveStatus.SOLVED   1.9969              8             4                3             28
SolveStatus.SOLVED   306.4089            9             4                3             42
SolveStatus.UNKNOWN  899.8842            10            5                3             24
SolveStatus.UNKNOWN  899.4849            11            5                3             33
SolveStatus.SOLVED   406.9502            12            6                3             22
SolveStatus.UNKNOWN  902.7807            13            6                3             28
SolveStatus.UNKNOWN  900.0826            14            7                3             20
SolveStatus.UNKNOWN  900.0731            15            7                3             26

A Z3 Boolean Cardinality Model

The next model uses Z3. It keeps the concept of Member and Hit variables, but they are boolean instead of integer. It also replaces the linear constraints with implications. The constraint that forces a Hit set’s variable to be true when some Choice set contains all its elements is (for each S, T)

\displaystyle \left ( \bigwedge_{i \in T} \textup{Member}_{S, i} \right ) \implies \textup{IsHit}_T

Conversely, A Hit set’s variable being true implies its members are in some choice set.

\displaystyle \textup{IsHit}_T \implies \bigvee_{S} \bigwedge_{i \in T} \textup{Member}_{S, i}

Finally, we again use boolean cardinality constraints AtMost and AtLeast so that each choice set has the right size.

The results are much better than the ILP: it solves all of the instances in under 3 seconds

status              solve_time_seconds  num_elements  choice_set_size  hit_set_size  num_choice_sets
SolveStatus.SOLVED  0.0874              8             4                3             28
SolveStatus.SOLVED  0.1861              9             4                3             42
SolveStatus.SOLVED  0.1393              10            5                3             24
SolveStatus.SOLVED  0.2845              11            5                3             33
SolveStatus.SOLVED  0.2032              12            6                3             22
SolveStatus.SOLVED  1.3661              13            6                3             28
SolveStatus.SOLVED  0.8639              14            7                3             20
SolveStatus.SOLVED  2.4877              15            7                3             26

A Z3 integer model

Z3 supports implications on integer equation equalities, so we can try a model that leverages this by essentially converting the boolean model to one where the variables are 0-1 integers, and the constraints are implications on equality of integer formulas (all of the form “variable = 1”).

I expect this to perform worse than the boolean model, even though the formulation is almost identical. The details of the model are here, and it’s so similar to the boolean model above that it needs no extra explanation.

The runtime is much worse, but surprisingly it still does better than the ILP model.

status              solve_time_seconds  num_elements  choice_set_size  hit_set_size  num_choice_sets
SolveStatus.SOLVED  2.1129              8             4                3             28
SolveStatus.SOLVED  14.8728             9             4                3             42
SolveStatus.SOLVED  7.6247              10            5                3             24
SolveStatus.SOLVED  25.0607             11            5                3             33
SolveStatus.SOLVED  30.5626             12            6                3             22
SolveStatus.SOLVED  63.2780             13            6                3             28
SolveStatus.SOLVED  57.0777             14            7                3             20
SolveStatus.SOLVED  394.5060            15            7                3             26

Harder instances

So far all the instances we’ve been giving the solvers are “easy” in a sense. In particular, we’ve guaranteed there’s a feasible solution, and it’s easy to find. We’re giving roughly twice as many sets as are needed. There are two ways to make this problem harder. One is to test on unsatisfiable instances, which can be harder because the solver has to prove it can’t work. Another is to test on satisfiable instances that are hard to find, such as those satisfiable instances where the true optimal number of choice sets is given as the input parameter. The hardest unsatisfiable instances are also the ones where the number of choice sets allowed is one less than optimal.

Let’s test those situations. Since M(7, 3, 2) = 7, we can try with 7 choice sets and 6 choice sets.

For 7 choice sets (the optimal value), all the solvers do relatively well

method                    status  solve_time_seconds  num_elements  choice_set_size  hit_set_size  num_choice_sets
SubsetCoverILP            SOLVED  0.0843              7             3                2             7
SubsetCoverZ3Integer      SOLVED  0.0938              7             3                2             7
SubsetCoverZ3BruteForce   SOLVED  0.0197              7             3                2             7
SubsetCoverZ3Cardinality  SOLVED  0.0208              7             3                2             7

For 6, the ILP struggles to prove it’s infeasible, and the others do comparatively much better (at least 17x better).

method                    status      solve_time_seconds  num_elements  choice_set_size  hit_set_size  num_choice_sets
SubsetCoverILP            INFEASIBLE  120.8593            7             3                2             6
SubsetCoverZ3Integer      INFEASIBLE  3.0792              7             3                2             6
SubsetCoverZ3BruteForce   INFEASIBLE  0.3384              7             3                2             6
SubsetCoverZ3Cardinality  INFEASIBLE  7.5781              7             3                2             6

This seems like hard evidence that Z3 is better than ILPs for this problem (and it is), but note that the same test on n=8 fails for all models. They can all quickly prove 8 < M(8, 3, 2) \leq 11, but time out after twenty minutes when trying to determine if M(8, 3, 2) \in \{ 9, 10 \}. Note that k=3, l=2 is the least complex choice for the other parameters, so it seems like there’s not much hope to find M(n, k, l) for any seriously large parameters, like, say, k=6.

Thoughts

These experiments suggest what SMT solvers can offer above and beyond ILP solvers. Disjunctions and implications are notoriously hard to model in an ILP. You often need to add additional special variables, or do tricks like the one I did that only work in some situations and which can mess with the efficiency of the solver. With SMT, implications are trivial to model, and natively supported by the solver.

Aside from reading everything I could find on Z3, there seems to be little advice on modeling to help the solver run faster. There is a ton of literature for this in ILP solvers, but if any readers see obvious problems with my SMT models, please chime in! I’d love to hear from you. Even without that, I am pretty impressed by how fast the solves finish for this subset cover problem (which this experiment has shown me is apparently a very hard problem).

However, there’s an elephant in the room. These models are all satisfiability/feasibility checks on a given solution. What is not tested here is optimization, in the sense of having the number of choice sets used be minimized directly by the solver. In a few experiments on even simpler models, z3 optimization is quite slow. And while I know how I’d model the ILP version of the optimization problem, given that it’s quite slow to find a feasible instance when the optimal number of sets is given as a parameter, it seems unlikely that it will be fast when asked to optimize. I will have to try that another time to be sure.

Also, I’d like to test the ILP models on Gurobi, but I don’t have a personal license. There’s also the possibility that I can come up with a much better ILP formulation, say, with a tighter LP relaxation. But these will have to wait for another time.

In the end, this experiment has given me some more food for thought, and concrete first-hand experience, on the use of SMT solvers.

Earthmover Distance

Problem: Compute distance between points with uncertain locations (given by samples, or differing observations, or clusters).

For example, if I have the following three “points” in the plane, as indicated by their colors, which is closer, blue to green, or blue to red?

example-points.png

It’s not obvious, and there are multiple factors at work: the red points have fewer samples, but we can be more certain about the position; the blue points are less certain, but the closest non-blue point to a blue point is green; and the green points are equally plausibly “close to red” and “close to blue.” The centers of masses of the three sample sets are close to an equilateral triangle. In our example the “points” don’t overlap, but of course they could. And in particular, there should probably be a nonzero distance between two points whose sample sets have the same center of mass, as below. The distance quantifies the uncertainty.

same-centers.png

All this is to say that it’s not obvious how to define a distance measure that is consistent with perceptual ideas of what geometry and distance should be.

Solution (Earthmover distance): Treat each sample set A corresponding to a “point” as a discrete probability distribution, so that each sample x \in A has probability mass p_x = 1 / |A|. The distance between A and B is the optional solution to the following linear program.

Each x \in A corresponds to a pile of dirt of height p_x, and each y \in B corresponds to a hole of depth p_y. The cost of moving a unit of dirt from x to y is the Euclidean distance d(x, y) between the points (or whatever hipster metric you want to use).

Let z_{x, y} be a real variable corresponding to an amount of dirt to move from x \in A to y \in B, with cost d(x, y). Then the constraints are:

  • Each z_{x, y} \geq 0, so dirt only moves from x to y.
  • Every pile x \in A must vanish, i.e. for each fixed x \in A, \sum_{y \in B} z_{x,y} = p_x.
  • Likewise, every hole y \in B must be completely filled, i.e. \sum_{y \in B} z_{x,y} = p_y.

The objective is to minimize the cost of doing this: \sum_{x, y \in A \times B} d(x, y) z_{x, y}.

In python, using the ortools library (and leaving out a few docstrings and standard import statements, full code on Github):

from ortools.linear_solver import pywraplp

def earthmover_distance(p1, p2):
    dist1 = {x: count / len(p1) for (x, count) in Counter(p1).items()}
    dist2 = {x: count / len(p2) for (x, count) in Counter(p2).items()}
    solver = pywraplp.Solver('earthmover_distance', pywraplp.Solver.GLOP_LINEAR_PROGRAMMING)

    variables = dict()

    # for each pile in dist1, the constraint that says all the dirt must leave this pile
    dirt_leaving_constraints = defaultdict(lambda: 0)

    # for each hole in dist2, the constraint that says this hole must be filled
    dirt_filling_constraints = defaultdict(lambda: 0)

    # the objective
    objective = solver.Objective()
    objective.SetMinimization()

    for (x, dirt_at_x) in dist1.items():
        for (y, capacity_of_y) in dist2.items():
            amount_to_move_x_y = solver.NumVar(0, solver.infinity(), 'z_{%s, %s}' % (x, y))
            variables[(x, y)] = amount_to_move_x_y
            dirt_leaving_constraints[x] += amount_to_move_x_y
            dirt_filling_constraints[y] += amount_to_move_x_y
            objective.SetCoefficient(amount_to_move_x_y, euclidean_distance(x, y))

    for x, linear_combination in dirt_leaving_constraints.items():
        solver.Add(linear_combination == dist1[x])

    for y, linear_combination in dirt_filling_constraints.items():
        solver.Add(linear_combination == dist2[y])

    status = solver.Solve()
    if status not in [solver.OPTIMAL, solver.FEASIBLE]:
        raise Exception('Unable to find feasible solution')

    return objective.Value()

Discussion: I’ve heard about this metric many times as a way to compare probability distributions. For example, it shows up in an influential paper about fairness in machine learning, and a few other CS theory papers related to distribution testing.

One might ask: why not use other measures of dissimilarity for probability distributions (Chi-squared statistic, Kullback-Leibler divergence, etc.)? One answer is that these other measures only give useful information for pairs of distributions with the same support. An example from a talk of Justin Solomon succinctly clarifies what Earthmover distance achieves

Screen Shot 2018-03-03 at 6.11.00 PM.png

Also, why not just model the samples using, say, a normal distribution, and then compute the distance based on the parameters of the distributions? That is possible, and in fact makes for a potentially more efficient technique, but you lose some information by doing this. Ignoring that your data might not be approximately normal (it might have some curvature), with Earthmover distance, you get point-by-point details about how each data point affects the outcome.

This kind of attention to detail can be very important in certain situations. One that I’ve been paying close attention to recently is the problem of studying gerrymandering from a mathematical perspective. Justin Solomon of MIT is a champion of the Earthmover distance (see his fascinating talk here for more, with slides) which is just one topic in a field called “optimal transport.”

This has the potential to be useful in redistricting because of the nature of the redistricting problem. As I wrote previously, discussions of redistricting are chock-full of geometry—or at least geometric-sounding language—and people are very concerned with the apparent “compactness” of a districting plan. But the underlying data used to perform redistricting isn’t very accurate. The people who build the maps don’t have precise data on voting habits, or even locations where people live. Census tracts might not be perfectly aligned, and data can just plain have errors and uncertainty in other respects. So the data that district-map-drawers care about is uncertain much like our point clouds. With a theory of geometry that accounts for uncertainty (and the Earthmover distance is the “distance” part of that), one can come up with more robust, better tools for redistricting.

Solomon’s website has a ton of resources about this, under the names of “optimal transport” and “Wasserstein metric,” and his work extends from computing distances to computing important geometric values like the barycenter, computational advantages like parallelism.

Others in the field have come up with transparency techniques to make it clearer how the Earthmover distance relates to the geometry of the underlying space. This one is particularly fun because the explanations result in a path traveled from the start to the finish, and by setting up the underlying metric in just such a way, you can watch the distribution navigate a maze to get to its target. I like to imagine tiny ants carrying all that dirt.

Screen Shot 2018-03-03 at 6.15.50 PM.png

Finally, work of Shirdhonkar and Jacobs provide approximation algorithms that allow linear-time computation, instead of the worst-case-cubic runtime of a linear solver.

Binary Search on Graphs

Binary search is one of the most basic algorithms I know. Given a sorted list of comparable items and a target item being sought, binary search looks at the middle of the list, and compares it to the target. If the target is larger, we repeat on the smaller half of the list, and vice versa.

With each comparison the binary search algorithm cuts the search space in half. The result is a guarantee of no more than \log(n) comparisons, for a total runtime of O(\log n). Neat, efficient, useful.

There’s always another angle.

What if we tried to do binary search on a graph? Most graph search algorithms, like breadth- or depth-first search, take linear time, and they were invented by some pretty smart cookies. So if binary search on a graph is going to make any sense, it’ll have to use more information beyond what a normal search algorithm has access to.

For binary search on a list, it’s the fact that the list is sorted, and we can compare against the sought item to guide our search. But really, the key piece of information isn’t related to the comparability of the items. It’s that we can eliminate half of the search space at every step. The “compare against the target” step can be thought of a black box that replies to queries of the form, “Is this the thing I’m looking for?” with responses of the form, “Yes,” or, “No, but look over here instead.”

binarysearch1

As long as the answers to your queries are sufficiently helpful, meaning they allow you to cut out large portions of your search space at each step, then you probably have a good algorithm on your hands. Indeed, there’s a natural model for graphs, defined in a 2015 paper of Emamjomeh-Zadeh, Kempe, and Singhal that goes as follows.

You’re given as input an undirected, weighted graph G = (V,E), with weights w_e for e \in E. You can see the entire graph, and you may ask questions of the form, “Is vertex v the target?” Responses will be one of two things:

  • Yes (you win!)
  • No, but e = (v, w) is an edge out of v on a shortest path from v to the true target.

Your goal is to find the target vertex with the minimum number of queries.

Obviously this only works if G is connected, but slight variations of everything in this post work for disconnected graphs. (The same is not true in general for directed graphs)

When the graph is a line, this “reduces” to binary search in the sense that the same basic idea of binary search works: start in the middle of the graph, and the edge you get in response to a query will tell you in which half of the graph to continue.

binarysearch2.png

And if we make this example only slightly more complicated, the generalization should become obvious:

binarysearch3

Here, we again start at the “center vertex,” and the response to our query will eliminate one of the two halves. But then how should we pick the next vertex, now that we no longer have a linear order to rely on? It should be clear, choose the “center vertex” of whichever half we end up in. This choice can be formalized into a rule that works even when there’s not such obvious symmetry, and it turns out to always be the right choice.

Definition: median of a weighted graph G with respect to a subset of vertices S \subset V is a vertex v \in V (not necessarily in S) which minimizes the sum of distances to vertices in S. More formally, it minimizes

\Phi_S(v) = \sum_{u \in S} d(v, u),

where d(u,v) is the sum of the edge weights along a shortest path from v to u.

And so generalizing binary search to this query-model on a graph results in the following algorithm, which whittles down the search space by querying the median at every step.

Algorithm: Binary search on graphs. Input is a graph G = (V,E).

  • Start with a set of candidates S = V.
  • While we haven’t found the target and |S| > 1:
    • Query the median v of S, and stop if you’ve found the target.
    • Otherwise, let e = (v, w) be the response edge, and compute the set of all vertices x \in V for which e is on a shortest path from v to x. Call this set T.
    • Replace S with S \cap T.
  • Output the only remaining vertex in S

Indeed, as we’ll see momentarily, a python implementation is about as simple. The meat of the work is in computing the median and the set T, both of which are slight variants of Dijkstra’s algorithm for computing shortest paths.

The theorem, which is straightforward and well written by Emamjomeh-Zadeh et al. (only about a half page on page 5), is that this algorithm requires only O(\log(n)) queries, just like binary search.

Before we dive into an implementation, there’s a catch. Even though we are guaranteed only \log(n) many queries, because of our Dijkstra’s algorithm implementation, we’re definitely not going to get a logarithmic time algorithm. So in what situation would this be useful?

Here’s where we use the “theory” trick of making up a fanciful problem and only later finding applications for it (which, honestly, has been quite successful in computer science). In this scenario we’re treating the query mechanism as a black box. It’s natural to imagine that the queries are expensive, and a resource we want to optimize for. As an example the authors bring up in a followup paper, the graph might be the set of clusterings of a dataset, and the query involves a human looking at the data and responding that a cluster should be split, or that two clusters should be joined. Of course, for clustering the underlying graph is too large to process, so the median-finding algorithm needs to be implicit. But the essential point is clear: sometimes the query is the most expensive part of the algorithm.

Alright, now let’s implement it! The complete code is on Github as always.

Always be implementing

We start with a slight variation of Dijkstra’s algorithm. Here we’re given as input a single “starting” vertex, and we produce as output a list of all shortest paths from the start to all possible destination vertices.

We start with a bare-bones graph data structure.

from collections import defaultdict
from collections import namedtuple

Edge = namedtuple('Edge', ('source', 'target', 'weight'))

class Graph:
    # A bare-bones implementation of a weighted, undirected graph
    def __init__(self, vertices, edges=tuple()):
        self.vertices = vertices
        self.incident_edges = defaultdict(list)

        for edge in edges:
            self.add_edge(
                edge[0],
                edge[1],
                1 if len(edge) == 2 else edge[2]  # optional weight
            )

    def add_edge(self, u, v, weight=1):
        self.incident_edges[u].append(Edge(u, v, weight))
        self.incident_edges[v].append(Edge(v, u, weight))

    def edge(self, u, v):
        return [e for e in self.incident_edges[u] if e.target == v][0]

And then, since most of the work in Dijkstra’s algorithm is tracking information that you build up as you search the graph, we define the “output” data structure, a dictionary of edge weights paired with back-pointers for the discovered shortest paths.

class DijkstraOutput:
    def __init__(self, graph, start):
        self.start = start
        self.graph = graph

        # the smallest distance from the start to the destination v
        self.distance_from_start = {v: math.inf for v in graph.vertices}
        self.distance_from_start[start] = 0

        # a list of predecessor edges for each destination
        # to track a list of possibly many shortest paths
        self.predecessor_edges = {v: [] for v in graph.vertices}

    def found_shorter_path(self, vertex, edge, new_distance):
        # update the solution with a newly found shorter path
        self.distance_from_start[vertex] = new_distance

        if new_distance < self.distance_from_start[vertex]:
            self.predecessor_edges[vertex] = [edge]
        else:  # tie for multiple shortest paths
            self.predecessor_edges[vertex].append(edge)

    def path_to_destination_contains_edge(self, destination, edge):
        predecessors = self.predecessor_edges[destination]
        if edge in predecessors:
            return True
        return any(self.path_to_destination_contains_edge(e.source, edge)
                   for e in predecessors)

    def sum_of_distances(self, subset=None):
        subset = subset or self.graph.vertices
        return sum(self.distance_from_start[x] for x in subset)

The actual Dijkstra algorithm then just does a “breadth-first” (priority-queue-guided) search through G, updating the metadata as it finds shorter paths.

def single_source_shortest_paths(graph, start):
    '''
    Compute the shortest paths and distances from the start vertex to all
    possible destination vertices. Return an instance of DijkstraOutput.
    '''
    output = DijkstraOutput(graph, start)
    visit_queue = [(0, start)]

    while len(visit_queue) > 0:
        priority, current = heapq.heappop(visit_queue)

        for incident_edge in graph.incident_edges[current]:
            v = incident_edge.target
            weight = incident_edge.weight
            distance_from_current = output.distance_from_start[current] + weight

            if distance_from_current <= output.distance_from_start[v]:
                output.found_shorter_path(v, incident_edge, distance_from_current)
                heapq.heappush(visit_queue, (distance_from_current, v))

    return output

Finally, we implement the median-finding and T-computing subroutines:

def possible_targets(graph, start, edge):
    '''
    Given an undirected graph G = (V,E), an input vertex v in V, and an edge e
    incident to v, compute the set of vertices w such that e is on a shortest path from
    v to w.
    '''
    dijkstra_output = dijkstra.single_source_shortest_paths(graph, start)
    return set(v for v in graph.vertices
               if dijkstra_output.path_to_destination_contains_edge(v, edge))

def find_median(graph, vertices):
    '''
    Compute as output a vertex in the input graph which minimizes the sum of distances
    to the input set of vertices
    '''
    best_dijkstra_run = min(
         (single_source_shortest_paths(graph, v) for v in graph.vertices),
         key=lambda run: run.sum_of_distances(vertices)
    )
    return best_dijkstra_run.start

And then the core algorithm

QueryResult = namedtuple('QueryResult', ('found_target', 'feedback_edge'))

def binary_search(graph, query):
    '''
    Find a target node in a graph, with queries of the form "Is x the target?"
    and responses either "You found the target!" or "Here is an edge on a shortest
    path to the target."
    '''
    candidate_nodes = set(x for x in graph.vertices)  # copy

    while len(candidate_nodes) > 1:
        median = find_median(graph, candidate_nodes)
        query_result = query(median)

        if query_result.found_target:
            return median
        else:
            edge = query_result.feedback_edge
            legal_targets = possible_targets(graph, median, edge)
            candidate_nodes = candidate_nodes.intersection(legal_targets)

    return candidate_nodes.pop()

Here’s an example of running it on the example graph we used earlier in the post:

'''
Graph looks like this tree, with uniform weights

    a       k
     b     j
      cfghi
     d     l
    e       m
'''
G = Graph(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
           'j', 'k', 'l', 'm'],
          [
               ('a', 'b'),
               ('b', 'c'),
               ('c', 'd'),
               ('d', 'e'),
               ('c', 'f'),
               ('f', 'g'),
               ('g', 'h'),
               ('h', 'i'),
               ('i', 'j'),
               ('j', 'k'),
               ('i', 'l'),
               ('l', 'm'),
          ])

def simple_query(v):
    ans = input("is '%s' the target? [y/N] " % v)
    if ans and ans.lower()[0] == 'y':
        return QueryResult(True, None)
    else:
        print("Please input a vertex on the shortest path between"
              " '%s' and the target. The graph is: " % v)
        for w in G.incident_edges:
            print("%s: %s" % (w, G.incident_edges[w]))

        target = None
        while target not in G.vertices:
            target = input("Input neighboring vertex of '%s': " % v)

    return QueryResult(
        False,
        G.edge(v, target)
    )

output = binary_search(G, simple_query)
print("Found target: %s" % output)

The query function just prints out a reminder of the graph and asks the user to answer the query with a yes/no and a relevant edge if the answer is no.

An example run:

is 'g' the target? [y/N] n
Please input a vertex on the shortest path between 'g' and the target. The graph is:
e: [Edge(source='e', target='d', weight=1)]
i: [Edge(source='i', target='h', weight=1), Edge(source='i', target='j', weight=1), Edge(source='i', target='l', weight=1)]
g: [Edge(source='g', target='f', weight=1), Edge(source='g', target='h', weight=1)]
l: [Edge(source='l', target='i', weight=1), Edge(source='l', target='m', weight=1)]
k: [Edge(source='k', target='j', weight=1)]
j: [Edge(source='j', target='i', weight=1), Edge(source='j', target='k', weight=1)]
c: [Edge(source='c', target='b', weight=1), Edge(source='c', target='d', weight=1), Edge(source='c', target='f', weight=1)]
f: [Edge(source='f', target='c', weight=1), Edge(source='f', target='g', weight=1)]
m: [Edge(source='m', target='l', weight=1)]
d: [Edge(source='d', target='c', weight=1), Edge(source='d', target='e', weight=1)]
h: [Edge(source='h', target='g', weight=1), Edge(source='h', target='i', weight=1)]
b: [Edge(source='b', target='a', weight=1), Edge(source='b', target='c', weight=1)]
a: [Edge(source='a', target='b', weight=1)]
Input neighboring vertex of 'g': f
is 'c' the target? [y/N] n
Please input a vertex on the shortest path between 'c' and the target. The graph is:
[...]
Input neighboring vertex of 'c': d
is 'd' the target? [y/N] n
Please input a vertex on the shortest path between 'd' and the target. The graph is:
[...]
Input neighboring vertex of 'd': e
Found target: e

A likely story

The binary search we implemented in this post is pretty minimal. In fact, the more interesting part of the work of Emamjomeh-Zadeh et al. is the part where the response to the query can be wrong with some unknown probability.

In this case, there can be many shortest paths that are valid responses to a query, in addition to all the invalid responses. In particular, this rules out the strategy of asking the same query multiple times and taking the majority response. If the error rate is 1/3, and there are two shortest paths to the target, you can get into a situation in which you see three responses equally often and can’t choose which one is the liar.

Instead, the technique Emamjomeh-Zadeh et al. use is based on the Multiplicative Weights Update Algorithm (it strikes again!). Each query gives a multiplicative increase (or decrease) on the set of nodes that are consistent targets under the assumption that query response is correct. There are a few extra details and some postprocessing to avoid unlikely outcomes, but that’s the basic idea. Implementing it would be an excellent exercise for readers interested in diving deeper into a recent research paper (or to flex their math muscles).

But even deeper, this model of “query and get advice on how to improve” is a classic  learning model first formally studied by Dana Angluin (my academic grand-advisor). In her model, one wants to design an algorithm to learn a classifier. The allowed queries are membership and equivalence queries. A membership is essentially, “What’s its label of this element?” and an equivalence query has the form, “Is this the right classifier?” If the answer is no, a mislabeled example is provided.

This is different from the usual machine learning assumption, because the learning algorithm gets to construct an example it wants to get more information about, instead of simply relying on a randomly generated subset of data. The goal is to minimize the number of queries before the target hypothesis is learned exactly. And indeed, as we saw in this post, if you have a little extra time to analyze the problem space, you can craft queries that extract quite a lot of information.

Indeed, the model we presented here for binary search on graphs is the natural analogue of an equivalence query for a search problem: instead of a mislabeled counterexample, you get a nudge in the right direction toward the target. Pretty neat!

There are a few directions we could take from here: (1) implement the Multiplicative Weights version of the algorithm, (2) apply this technique to a problem like ranking or clustering, or (3) cover theoretical learning models like membership and equivalence queries in more detail. What interests you?

Until next time!

Bayesian Ranking for Rated Items

Problem: You have a catalog of items with discrete ratings (thumbs up/thumbs down, or 5-star ratings, etc.), and you want to display them in the “right” order.

Solution: In Python

'''
  score: [int], [int], [float] -&gt; float

  Return the expected value of the rating for an item with known
  ratings specified by `ratings`, prior belief specified by
  `rating_prior`, and a utility function specified by `rating_utility`,
  assuming the ratings are a multinomial distribution and the prior
  belief is a Dirichlet distribution.
'''
def score(self, ratings, rating_prior, rating_utility):
    ratings = [r + p for (r, p) in zip(ratings, rating_prior)]
    score = sum(r * u for (r, u) in zip(ratings, rating_utility))
    return score / sum(ratings)

Discussion: This deceptively short solution can lead you on a long and winding path into the depths of statistics. I will do my best to give a short, clear version of the story.

As a working example I chose merely because I recently listened to a related podcast, say you’re selling mass-market romance novels—which, by all accounts, is a predictable genre. You have a list of books, each of which has been rated on a scale of 0-5 stars by some number of users. You want to display the top books first, so that time-constrained readers can experience the most titillating novels first, and newbies to the genre can get the best first time experience and be incentivized to buy more.

The setup required to arrive at the above code is the following, which I’ll phrase as a story.

Users’ feelings about a book, and subsequent votes, are independent draws from a known distribution (with unknown parameters). I will just call these distributions “discrete” distributions. So given a book and user, there is some unknown list (p_0, p_1, p_2, p_3, p_4, p_5) of probabilities (\sum_i p_i = 1) for each possible rating a user could give for that book.

But how do users get these probabilities? In this story, the probabilities are the output of a randomized procedure that generates distributions. That modeling assumption is called a “Dirichlet prior,” with Dirichlet meaning it generates discrete distributions, and prior meaning it encodes domain-specific information (such as the fraction of 4-star ratings for a typical romance novel).

So the story is you have a book, and that book gets a Dirichlet distribution (unknown to us), and then when a user comes along they sample from the Dirichlet distribution to get a discrete distribution, which they then draw from to choose a rating. We observe the ratings, and we need to find the book’s underlying Dirichlet. We start by assigning it some default Dirichlet (the prior) and update that Dirichlet as we observe new ratings. Some other assumptions:

  1. Books are indistinguishable except in the parameters of their Dirichlet distribution.
  2. The parameters of a book’s Dirichlet distribution don’t change over time, and inherently reflect the book’s value.

So a Dirichlet distribution is a process that produces discrete distributions. For simplicity, in this post we will say a Dirichlet distribution is parameterized by a list of six integers (n_0, \dots, n_5), one for each possible star rating. These values represent our belief in the “typical” distribution of votes for a new book. We’ll discuss more about how to set the values later. Sampling a value (a book’s list of probabilities) from the Dirichlet distribution is not trivial, but we don’t need to do that for this program. Rather, we need to be able to interpret a fixed Dirichlet distribution, and update it given some observed votes.

The interpretation we use for a Dirichlet distribution is its expected value, which, recall, is the parameters of a discrete distribution. In particular if n = \sum_i n_i, then the expected value is a discrete distribution whose probabilities are

\displaystyle \left (  \frac{n_0}{n}, \frac{n_1}{n}, \dots, \frac{n_5}{n} \right )

So you can think of each integer in the specification of a Dirichlet as “ghost ratings,” sometimes called pseudocounts, and we’re saying the probability is proportional to the count.

This is great, because if we knew the true Dirichlet distribution for a book, we could compute its ranking without a second thought. The ranking would simply be the expected star rating:

def simple_score(distribution):
   return sum(i * p for (i, p) in enumerate(distribution))

Putting books with the highest score on top would maximize the expected happiness of a user visiting the site, provided that happiness matches the user’s voting behavior, since the simple_score is just the expected vote.

Also note that all the rating system needs to make this work is that the rating options are linearly ordered. So a thumbs up/down (heaving bosom/flaccid member?) would work, too. We don’t need to know how happy it makes them to see a 5-star vs 4-star book. However, because as we’ll see next we have to approximate the distribution, and hence have uncertainty for scores of books with only a few ratings, it helps to incorporate numerical utility values (we’ll see this at the end).

Next, to update a given Dirichlet distribution with the results of some observed ratings, we have to dig a bit deeper into Bayes rule and the formulas for sampling from a Dirichlet distribution. Rather than do that, I’ll point you to this nice writeup by Jonathan Huang, where the core of the derivation is in Section 2.3 (page 4), and remark that the rule for updating for a new observation is to just add it to the existing counts.

Theorem: Given a Dirichlet distribution with parameters (n_1, \dots, n_k) and a new observation of outcome i, the updated Dirichlet distribution has parameters (n_1, \dots, n_{i-1}, n_i + 1, n_{i+1}, \dots, n_k). That is, you just update the i-th entry by adding 1 to it.

This particular arithmetic to do the update is a mathematical consequence (derived in the link above) of the philosophical assumption that Bayes rule is how you should model your beliefs about uncertainty, coupled with the assumption that the Dirichlet process is how the users actually arrive at their votes.

The initial values (n_0, \dots, n_5) for star ratings should be picked so that they represent the average rating distribution among all prior books, since this is used as the default voting distribution for a new, unknown book. If you have more information about whether a book is likely to be popular, you can use a different prior. For example, if JK Rowling wrote a Harry Potter Romance novel that was part of the canon, you could pretty much guarantee it would be popular, and set n_5 high compared to n_0. Of course, if it were actually popular you could just wait for the good ratings to stream in, so tinkering with these values on a per-book basis might not help much. On the other hand, most books by unknown authors are bad, and n_5 should be close to zero. Selecting a prior dictates how influential ratings of new items are compared to ratings of items with many votes. The more pseudocounts you add to the prior, the less new votes count.

This gets us to the following code for star ratings.

def score(self, ratings, rating_prior):
    ratings = [r + p for (r, p) in zip(ratings, rating_prior)]
    score = sum(i * u for (i, u) in enumerate(ratings))
    return score / sum(ratings)

The only thing missing from the solution at the beginning is the utilities. The utilities are useful for two reasons. First, because books with few ratings encode a lot of uncertainty, having an idea about how extreme a feeling is implied by a specific rating allows one to give better rankings of new books.

Second, for many services, such as taxi rides on Lyft, the default star rating tends to be a 5-star, and 4-star or lower mean something went wrong. For books, 3-4 stars is a default while 5-star means you were very happy.

The utilities parameter allows you to weight rating outcomes appropriately. So if you are in a Lyft-like scenario, you might specify utilities like [-10, -5, -3, -2, 1] to denote that a 4-star rating has the same negative impact as two 5-star ratings would positively contribute. On the other hand, for books the gap between 4-star and 5-star is much less than the gap between 3-star and 4-star. The utilities simply allow you to calibrate how the votes should be valued in comparison to each other, instead of using their literal star counts.