Searching for RH Counterexamples — Adding a Database

In the last article we set up pytest for a simple application that computes divisor sums \sigma(n) and tries to disprove the Riemann Hypothesis. In this post we’ll show how to extend the application as we add a database dependency. The database stores the computed sums so we can analyze them after our application finishes.

As in the previous post, I’ll link to specific git commits in the final code repository to show how the project evolves. You can browse or checkout the repository at each commit to see how it works.

Interface before implementation

The approach we’ll take is one that highlights the principle of good testing and good software design: separate components by thin interfaces so that the implementations of those interfaces can change later without needing to update lots of client code.

We’ll take this to the extreme by implementing and testing the logic for our application before we ever decide what sort of database we plan to use! In other words, the choice of database will be our last choice, making it inherently flexible to change. That is, first we iron out a minimal interface that our application needs, and then choose the right database based on those needs. This is useful because software engineers often don’t understand how the choice of a dependency (especially a database dependency) will work out long term, particularly as a prototype starts to scale and hit application-specific bottlenecks. Couple this with the industry’s trend of chasing hot new fads, and eventually you realize no choice is sacred. Interface separation is the software engineer’s only defense, and their most potent tool for flexibility. As a side note, Tom Gamon summarizes this attitude well in a recent article, borrowing the analogy from a 1975 investment essay The Winner’s Game by Charles Ellis. Some of his other articles reinforce the idea that important decisions should be made as late as possible, since that is the only time you know enough to make those decisions well.

Our application has two parts so far: adding new divisor sums to the database, and loading divisor sums for analysis. Since we’ll be adding to this database over time, it may also be prudent to summarize the contents of the database, e.g. to say what’s the largest computed integer. This suggests the following first-pass interface, implemented in this commit.

class DivisorDb(ABC):
    @abstractmethod
    def load() -> List[RiemannDivisorSum]:
        '''Load the entire database.'''
        pass

    @abstractmethod
    def upsert(data: List[RiemannDivisorSum]) -> None:
        '''Insert or update data.'''
        pass

    @abstractmethod
    def summarize() -> SummaryStats:
        '''Summarize the contents of the database.'''
        pass

RiemannDivisorSum and SummaryStats are dataclasses. These are special classes that are intended to have restricted behavior: storing data and providing simple derivations on that data. For us this provides a stabler interface because the contents of the return values can change over time without interrupting other code. For example, we might want to eventually store the set of divisors alongside their sum. Compare this to returning a list or tuple, which is brittle when used with things like tuple assignment.

The other interesting tidbit about the commit is the use of abstract base classes (“ABC”, an awful name choice). Python has limited support for declaring an “interface” as many other languages do. The pythonic convention was always to use its “duck-typing” feature, which meant to just call whatever methods you want on an object, and then any object that supports has those methods can be used in that spot. The mantra was, “if it walks like a duck and talks like a duck, then it’s a duck.” However, there was no way to say “a duck is any object that has a waddle and quack method, and those are the only allowed duck functions.” As a result, I often saw folks tie their code to one particular duck implementation. That said, there were some mildly cumbersome third party libraries that enabled interface declarations. Better, recent versions of Python introduced the abstract base class as a means to enforce interfaces, and structural subtyping (typing.Protocol) to interact with type hints when subtyping directly is not feasible (e.g., when the source is in different codebases).

Moving on, we can implement an in-memory database that can be used for testing. This is done in this commit. One crucial aspect of these tests is that they do not rely on the knowledge that the in-memory database is secretly a dictionary. That is, the tests use only the DivisorDb interface and never inspect the underlying dict. This allows the same tests to run against all implementations, e.g., using pytest.parameterize. Also note it’s not thread safe or atomic, but for us this doesn’t really matter.

Injecting the Interface

With our first-pass database interface and implementation, we can write the part of the application that populates the database with data. A simple serial algorithm that computes divisor sums in batches of 100k until the user hits Ctrl-C is done in this commit.

def populate_db(db: DivisorDb, batch_size: int = 100000) -> None:
    '''Populate the db in batches.'''
    starting_n = (db.summarize().largest_computed_n or 5040) + 1
    while True:
        ending_n = starting_n + batch_size
        db.upsert(compute_riemann_divisor_sums(starting_n, ending_n))
        starting_n = ending_n + 1

I only tested this code manually. The reason is that line 13 (highlighted in the abridged snippet above) is the only significant behavior not already covered by the InMemoryDivisorDb tests. (Of course, that line had a bug later fixed in this commit). I’m also expecting to change it soon, and spending time testing vs implementing features is a tradeoff that should not always fall on the side of testing.

Next let’s swap in a SQL database. We’ll add sqlite3, which comes prepackaged with python, so needs no dependency management. The implementation in this commit uses the same interface as the in-memory database, but the implementation is full of SQL queries. With this, we can upgrade our tests to run identically on both implementations. The commit looks large, but really I just indented all the existing tests, and added the pytest parameterize annotation to the class definition (and corresponding method arguments). This avoids adding a parameterize annotation to every individual test function—which wouldn’t be all that bad, but each new test would require the writer to remember to include the annotation, and this way systematically requires the extra method argument.

And finally, we can switch the database population script to use the SQL database implementation. This is done in this commit. Notice how simple it is, and how it doesn’t require any extra testing.

After running it a few times and getting a database with about 20 million rows, we can apply the simplest possible analysis: showing the top few witness values.

sqlite> select n, witness_value from RiemannDivisorSums where witness_value > 1.7 order by witness_value desc limit 100;
10080|1.7558143389253
55440|1.75124651488749
27720|1.74253672381383
7560|1.73991651920276
15120|1.73855867428903
110880|1.73484901030336
720720|1.73306535623807
1441440|1.72774021157846
166320|1.7269287425473
2162160|1.72557022852613
4324320|1.72354665986337
65520|1.71788900114772
3603600|1.71646721405987
332640|1.71609697536058
10810800|1.71607328780293
7207200|1.71577914933961
30240|1.71395368739173
20160|1.71381061514181
25200|1.71248203640096
83160|1.71210965310318
360360|1.71187211014506
277200|1.71124375582698
2882880|1.7106690212765
12252240|1.70971873843453
12600|1.70953565488377
8648640|1.70941081706371
32760|1.708296575835
221760|1.70824623791406
14414400|1.70288499724944
131040|1.70269370474016
554400|1.70259313608473
1081080|1.70080265951221

We can also confirm John’s claim that “the winners are all multiples of 2520,” as the best non-multiple-of-2520 up to 20 million is 18480, whose witness value is only about 1.69.

This multiple-of-2520 pattern is probably because 2520 is a highly composite number, i.e., it has more divisors than all smaller numbers, so its sum-of-divisors will tend to be large. Digging in a bit further, it seems the smallest counterexample, if it exists, is necessarily a superabundant number. Such numbers have a nice structure described here that suggests a search strategy better than trying every number.

Next time, we can introduce the concept of a search strategy as a new component to the application, and experiment with different search strategies. Other paths forward include building a front-end component, and deploying the system on a server so that the database can be populated continuously.

Searching for RH Counterexamples — Setting up Pytest

Some mathy-programmy people tell me they want to test their code, but struggle to get set up with a testing framework. I suspect it’s due to a mix of:

  • There are too many choices with a blank slate.
  • Making slightly wrong choices early on causes things to fail in unexpected ways.

I suspect the same concerns apply to general project organization and architecture. Because Python is popular for mathy-programmies, I’ll build a Python project that shows how I organize my projects and and test my code, and how that shapes the design and evolution of my software. I will use Python 3.8 and pytest, and you can find the final code on Github.

For this project, we’ll take advice from John Baez and explore a question that glibly aims to disprove the Riemann Hypothesis:

A CHALLENGE:

Let σ(n) be the sum of divisors of n. There are infinitely many n with σ(n)/(n ln(ln(n)) > 1.781. Can you find one? If you can find n > 5040 with σ(n)/(n ln(ln(n)) > 1.782, you’ll have disproved the Riemann Hypothesis.

I don’t expect you can disprove the Riemann Hypothesis this way, but I’d like to see numbers that make σ(n)/(n ln(ln(n)) big. It seems the winners are all multiples of 2520, so try those. The best one between 5040 and a million is n = 10080, which only gives 1.755814.

https://twitter.com/johncarlosbaez/status/1149700802371608576

Initializing the Project

One of the hardest parts of software is setting up your coding environment. If you use an integrated development environment (IDE), project setup is bespoke to each IDE. I dislike this approach, because what you learn when using the IDE is not useful outside the IDE. When I first learned to program (Java), I was shackled to Eclipse for years because I didn’t know how to compile and run Java programs without it. Instead, we’ll do everything from scratch, using only the terminal/shell and standard Python tools. I will also ignore random extra steps and minutiae I’ve built up over the years to deal with minor issues. If you’re interested in that and why I do them, leave a comment and I might follow up with a second article.

This article assumes you are familiar with the basics of Python syntax, and know how to open a terminal and enter basic commands (like ls, cd, mkdir, rm). Along the way, I will link to specific git commits that show the changes, so that you can see how the project unfolds with each twist and turn.

I’ll start by creating a fresh Python project that does nothing. We set up the base directory riemann-divisor-sum, initialize git, create a readme, and track it in git (git add + git commit).

mkdir riemann-divisor-sum
cd riemann-divisor-sum
git init .
echo "# Divisor Sums for the Riemann Hypothesis" > README.md
git add README.md
git commit -m "add empty README.md"

Next I create a Github project at https://github.com/j2kun/riemann-divisor-sum (the name riemann-divisor-sum does not need to be the same, but I think it’s good), and push the project up to Github.

git remote add origin git@github.com:j2kun/riemann-divisor-sum.git
# instead of "master", my default branch is really "main"
git push -u origin master   

Note, if you’re a new Github user, the “default branch name” when creating a new project may be “master.” I like “main” because it’s shorter, clearer, and nicer. If you want to change your default branch name, you can update to git version 2.28 and add the following to your ~/.gitconfig file.

[init]
    defaultBranch = main

Here is what the project looks like on Github as of this single commit.

Pytest

Next I’ll install the pytest library which will run our project’s tests. First I’ll show what a failing test looks like, by setting up a trivial program with an un-implemented function, and a corresponding test. For ultimate simplicity, we’ll use Python’s built-in assert for the test lines. Here’s the commit.

# in the terminal
mkdir riemann
mkdir tests


# create riemann/divisor.py containing:
'''Compute the sum of divisors of a number.'''

def divisor_sum(n: int) -> int:
    raise ValueError("Not implemented.")


# create tests/divisor_test.py containing:
from riemann.divisor import divisor_sum

def test_sum_of_divisors_of_72():
    assert 195 == divisor_sum(72)

Next we install and configure Pytest. At this point, since we’re introducing a dependency, we need a project-specific place to store that dependency. All dependencies related to a project should be explicitly declared and isolated. This page helps explain why. Python’s standard tool is the virtual environment. When you “activate” the virtual environment, it temporarily (for the duration of the shell session or until you run deactivate) points all Python tools and libraries to the virtual environment.

virtualenv -p python3.8 venv
source venv/bin/activate

# shows the location of the overridden python binary path
which python
# outputs: /Users/jeremy/riemann-divisor-sum/venv/bin/python

Now we can use pip as normal and it will install to venv. To declare and isolate the dependency, we write the output of pip freeze to a file called requirements.txt, and it can be reinstalled using pip install -r requirements.txt. Try deleting your venv directory, recreating it, and reinstalling the dependencies this way.

pip install pytest
pip freeze > requirements.txt
git add requirements.txt
git commit -m "requirements: add pytest"

# example to wipe and reinstall
# deactivate
# rm -rf venv
# virtualenv -p python3.8 venv
# source venv/bin/activate
# pip install -r requirements.txt

As an aside, at this step you may notice git mentions venv is an untracked directory. You can ignore this, or add venv to a .gitignore file to tell git to ignore it, as in this commit. We will also have to configure pytest to ignore venv shortly.

When we run pytest (with no arguments) from the base directory, we see our first error:

    from riemann.divisor import divisor_sum
E   ModuleNotFoundError: No module named 'riemann'

Module import issues are a common stumbling block for new Python users. In order to make a directory into a Python module, it needs an __init__.py file, even if it’s empty. Any code in this file will be run the first time the module is imported in a Python runtime. We add one to both the code and test directories in this commit.

When we run pytest (with no arguments), it recursively searches the directory tree looking for files like *_test.py and test_*.py loads them, and treats every method inside those files that are prefixed with “test” as a test. Non-“test” methods can be defined and used as helpers to set up complex tests. Pytest then runs the tests, and reports the failures. For me this looks like

Our first test failure.

Our implementation is intentionally wrong for demonstration purposes. When a test passes, pytest will report it quietly as a “.” by default. See these docs for more info on different ways to run the pytest binary and configure its output report.

In this basic pytest setup, you can put test files wherever you want, name the files and test methods appropriately, and use assert to implement the tests themselves. As long as your modules are set up properly, as long as imports are absolute (see this page for gory details on absolute vs. relative imports), and as long as you run pytest from the base directory, pytest will find the tests and run them.

Since pytest searches all directories for tests, this includes venv and __pycache__, which magically appears when you create python modules (I add __pycache__ to gitignore). Sometimes package developers will include test code, and pytest will then run those tests, which often fail or clutter the output. A virtual environment also gets large as you install big dependencies (like numpy, scipy, pandas), so this makes pytest slow to search for tests to run. To alleviate, the --norecursedirs command line flag tells pytest to skip directories. Since it’s tedious to type --norecursedirs='venv __pycache__' every time you run pytest, you can make this the default behavior by storing the option in a configuration file recognized by pytest, such as setup.cfg. I did it in this commit.

Some other command line options that I use all the time:

  • pytest test/dir to test only files in that directory, or pytest test/dir/test_file.py to test only tests in that file.
  • pytest -k STR to only run tests whose name contains “STR”
  • pytest -s to see see any logs or print statements inside tested code
  • pytest -s to allow the pdb/ipdb debugger to function and step through a failing test.

Building up the project

Now let’s build up the project. My general flow is as follows:

  1. Decide what work to do next.
  2. Sketch out the interface for that work.
  3. Write some basic (failing, usually lightweight) tests that will pass when the work is done.
  4. Do the work.
  5. Add more nuanced tests if needed, based on what is learned during the work.
  6. Repeat until the work is done.

This strategy is sometimes called “the design recipe,” and I first heard about it from my undergraduate programming professor John Clements at Cal Poly, via the book “How to Design Programs.” Even if I don’t always use it, I find it’s a useful mental framework for getting things done.

For this project, I want to search through positive integers, and for each one I want to compute a divisor sum, do some other arithmetic, and compare that against some other number. I suspect divisor sum computations will be the hard/interesting part, but to start I will code up a slow/naive implementation with some working tests, confirm my understanding of the end-to-end problem, and then improve the pieces as needed.

In this commit, I implement the naive divisor sum code and tests. Note the commit also shows how to tell pytest to test for a raised exception. In this commit I implement the main search routine and confirm John’s claim about n=10080 (thanks for the test case!).

These tests already showcase a few testing best practices:

  • Test only one behavior at a time. Each test has exactly one assertion in it. This is good practice because when a test fails you won’t have to dig around to figure out exactly what went wrong.
  • Use the tests to help you define the interface, and then only test against that interface. The hard part about writing clean and clear software is defining clean and clear interfaces that work together well and hide details. Math does this very well, because definitions like \sigma(n) do not depend on how n is represented. In fact, math really doesn’t have “representations” of its objects—or more precisely, switching representations is basically free, so we don’t dwell on it. In software, we have to choose excruciatingly detailed representations for everything, and so we rely on the software to hide those details as much as possible. The easiest way to tell if you did it well is to try to use the interface and only the interface, and tests are an excuse to do that, which is not wasted effort by virtue of being run to check your work.

Improving Efficiency

Next, I want to confirm John’s claim that n=10080 is the best example between 5041 and a million. However, my existing code is too slow. Running the tests added in this commit seems to take forever.

We profile to confirm our suspected hotspot:

>>> import cProfile
>>> from riemann.counterexample_search import best_witness
>>> cProfile.run('best_witness(10000)')
ncalls  tottime  percall  cumtime  percall filename:lineno(function)
...
54826    3.669    0.000    3.669    0.000 divisor.py:10(<genexpr>)

As expected, computing divisor sums is the bottleneck. No surprise there because it makes the search take quadratic time. Before changing the implementation, I want to add a few more tests. I copied data for the first 50 integers from OEIS and used pytest’s parameterize feature since the test bodies are all the same. This commit does it.

Now I can work on improving the runtime of the divisor sum computation step. Originally, I thought I’d have to compute the prime factorization to use this trick that exploits the multiplicativity of \sigma(n), but then I found this approach due to Euler in 1751 that provides a recursive formula for the sum and skips the prime factorization. Since we’re searching over all integers, this allows us to trade off the runtime of each \sigma(n) computation against the storage cost of past \sigma(n) computations. I tried it in this commit, using python’s built-in LRU-cache wrapper to memoize the computation. The nice thing about this is that our tests are already there, and the interface for divisor_sum doesn’t change. This is on purpose, so that the caller of divisor_sum (in this case tests, also client code in real life) need not update when we improve the implementation. I also ran into a couple of stumbling blocks implementing the algorithm (I swapped the order of the if statements here), and the tests made it clear I messed up.

However, there are two major problems with that implementation.

  1. The code is still too slow. best_witness(100000) takes about 50 seconds to run, almost all of which is in divisor_sum.
  2. Python hits its recursion depth limit, and so the client code needs to eagerly populate the divisor_sum cache, which is violates encapsulation. The caller should not know anything about the implementation, nor need to act in a specific way to accommodate hidden implementation details.

I also realized after implementing it that despite the extra storage space, the runtime is still O(n^{3/2}), because each divisor-sum call requires O(n^{1/2}) iterations of the loop. This is just as slow as a naive loop that checks divisibility of integers up to \sqrt{n}. Also, a naive loop allows me to plug in a cool project called numba that automatically speeds up simple Python code by compiling it in place. Incidentally, numba is known to not work with lru_cache, so I can’t tack it on my existing implementation.

So I added numba as a dependency and drastically simplified the implementation. Now the tests run in 8 seconds, and in a few minutes I can upgrade John’s claim that n=10080 is the best example between 5041 and a million, to the best example between 5041 and ten million.

Next up

This should get you started with a solid pytest setup for your own project, but there is a lot more to say about how to organize and run tests, what kinds of tests to write, and how that all changes as your project evolves.

For this project, we now know that the divisor-sum computation is the bottleneck. We also know that the interesting parts of this project are yet to come. We want to explore the patterns in what makes these numbers large. One way we could go about this is to split the project into two components: one that builds/manages a database of divisor sums, and another that analyzes the divisor sums in various ways. The next article will show how the database set up works. When we identify relevant patterns, we can modify the search strategy to optimize for that. As far as testing goes, this would prompt us to have an interface layer between the two systems, and to add fakes or mocks to test the components in isolation.

After that, there’s the process of automating test running, adding tests for code quality/style, computing code coverage, adding a type-hint checker test, writing tests that generate other tests, etc.

If you’re interested, let me know which topics to continue with. I do feel a bit silly putting so much pomp and circumstance around such a simple computation, but hopefully the simplicity of the core logic makes the design and testing aspects of the project clearer and easier to understand.

Visualizing an Assassin Puzzle

Over at Math3ma, Tai-Danae Bradley shared the following puzzle, which she also featured in a fantastic (spoiler-free) YouTube video. If you’re seeing this for the first time, watch the video first.

Consider a square in the xy-plane, and let A (an “assassin”) and T (a “target”) be two arbitrary-but-fixed points within the square. Suppose that the square behaves like a billiard table, so that any ray (a.k.a “shot”) from the assassin will bounce off the sides of the square, with the angle of incidence equaling the angle of reflection.

Puzzle: Is it possible to block any possible shot from A to T by placing a finite number of points in the square?

This puzzle found its way to me through Tai-Danae’s video, via category theorist Emily Riehl, via a talk by the recently deceased Fields Medalist Maryam Mirzakhani, who studied the problem in more generality. I’m not familiar with her work, but knowing mathematicians it’s probably set in an arbitrary complex n-manifold.

See Tai-Danae’s post for a proof, which left such an impression on me I had to dig deeper. In this post I’ll discuss a visualization I made—now posted at the end of Tai-Danae’s article—as well as here and below (to avoid spoilers). In the visualization, mouse movement chooses the firing direction for the assassin, and the target is in green. Dragging the target with the mouse updates the position of the guards. The source code is on Github.

Outline

The visualization uses d3 library, which was made for visualizations that dynamically update with data. I use it because it can draw SVGs real nice.

The meat of the visualization is in two geometric functions.

  1. Decompose a ray into a series of line segments—its path as it bounces off the walls—stopping if it intersects any of the points in the plane.
  2. Compute the optimal position of the guards, given the boundary square and the positions of the assassin and target.

Both of these functions, along with all the geometry that supports them, is in geometry.js. The rest of the demo is defined in main.js, in which I oafishly trample over d3 best practices to arrive miraculously at a working product. Critiques welcome 🙂

As with most programming and software problems, the key to implementing these functions while maintaining your sanity is breaking it down into manageable pieces. Incrementalism is your friend.

Vectors, rays, rectangles, and ray splitting

We start at the bottom with a Vector class with helpful methods for adding, scaling, and computing norms and inner products.

function innerProduct(a, b) {
  return a.x * b.x + a.y * b.y;
}

class Vector {
  constructor(x, y) {
    this.x = x;
    this.y = y;
  }

  normalized() { ... }
  norm() { ... }
  add(vector) { ... }
  subtract(vector) { ... }
  scale(length) { ... }
  distance(vector) { ... }
  midpoint(b) { ... }
}

This allows one to compute the distance between two points, e.g., with vector.subtract(otherVector).norm().

Next we define a class for a ray, which is represented by its center (a vector) and a direction (a vector).

class Ray {
  constructor(center, direction, length=100000) {
    this.center = center;
    this.length = length;

    if (direction.x == 0 && direction.y == 0) {
      throw "Can't have zero direction";
    }
    this.direction = direction.normalized();
  }

  endpoint() {
    return this.center.add(this.direction.scale(this.length));
  }

  intersects(point) {
    let shiftedPoint = point.subtract(this.center);
    let signedLength = innerProduct(shiftedPoint, this.direction);
    let projectedVector = this.direction.scale(signedLength);
    let differenceVector = shiftedPoint.subtract(projectedVector);

    if (signedLength > 0
        && this.length > signedLength
        && differenceVector.norm() < intersectionRadius) {
      return projectedVector.add(this.center);
    } else {
      return null;
    }
  }
}

The ray must be finite for us to draw it, but the length we've chosen is so large that, as you can see in the visualization, it's effectively infinite. Feel free to scale it up even longer.

assassin-puzzle

The interesting bit is the intersection function. We want to compute whether a ray intersects a point. To do this, we use the inner product as a decision rule to compute the distance of a point from a line. If that distance is very small, we say they intersect.

In our demo points are not infinitesimal, but rather have a small radius described by intersectionRadius. For the sake of being able to see anything we set this to 3 pixels. If it’s too small the demo will look bad. The ray won’t stop when it should appear to stop, and it can appear to hit the target when it doesn’t.

Next up we have a class for a Rectangle, which is where the magic happens. The boilerplate and helper methods:

class Rectangle {
  constructor(bottomLeft, topRight) {
    this.bottomLeft = bottomLeft;
    this.topRight = topRight;
  }

  topLeft() { ... }
  center() { ... }
  width() { .. }
  height() { ... }
  contains(vector) { ... }

The function rayToPoints that splits a ray into line segments from bouncing depends on three helper functions:

  1. rayIntersection: Compute the intersection point of a ray with the rectangle.
  2. isOnVerticalWall: Determine if a point is on a vertical or horizontal wall of the rectangle, raising an error if neither.
  3. splitRay: Split a ray into a line segment and a shorter ray that’s “bounced” off the wall of the rectangle.

(2) is trivial, computing some x- and y-coordinate distances up to some error tolerance. (1) involves parameterizing the ray and checking one of four inequalities. If the bottom left of the rectangle is (x_1, y_1) and the top right is (x_2, y_2) and the ray is written as \{ (c_1 + t v_1, c_2 + t v_2) \mid t > 0 \}, then—with some elbow grease—the following four equations provide all possibilities, with some special cases for vertical or horizontal rays:

\displaystyle \begin{aligned} c_2 + t v_2 &= y_2 & \textup{ and } \hspace{2mm} & x_1 \leq c_1 + t v_1 \leq x_2 & \textup{ (intersects top)} \\ c_2 + t v_2 &= y_1 & \textup{ and } \hspace{2mm} & x_1 \leq c_1 + t v_1 \leq x_2 & \textup{ (intersects bottom)} \\ c_1 + t v_1 &= x_1 & \textup{ and } \hspace{2mm} & y_1 \leq c_2 + t v_2 \leq y_2 & \textup{ (intersects left)} \\ c_1 + t v_1 &= x_2 & \textup{ and } \hspace{2mm} & y_1 \leq c_2 + t v_2 \leq y_2 & \textup{ (intersects right)} \\ \end{aligned}

In code:

  rayIntersection(ray) {
    let c1 = ray.center.x;
    let c2 = ray.center.y;
    let v1 = ray.direction.x;
    let v2 = ray.direction.y;
    let x1 = this.bottomLeft.x;
    let y1 = this.bottomLeft.y;
    let x2 = this.topRight.x;
    let y2 = this.topRight.y;

    // ray is vertically up or down
    if (epsilon > Math.abs(v1)) {
      return new Vector(c1, (v2 > 0 ? y2 : y1));
    }

    // ray is horizontally left or right
    if (epsilon > Math.abs(v2)) {
      return new Vector((v1 > 0 ? x2 : x1), c2);
    }

    let tTop = (y2 - c2) / v2;
    let tBottom = (y1 - c2) / v2;
    let tLeft = (x1 - c1) / v1;
    let tRight = (x2 - c1) / v1;

    // Exactly one t value should be both positive and result in a point
    // within the rectangle

    let tValues = [tTop, tBottom, tLeft, tRight];
    for (let i = 0; i  epsilon && this.contains(intersection)) {
        return intersection;
      }
    } 

    throw "Unexpected error: ray never intersects rectangle!";
  }

Next, splitRay splits a ray into a single line segment and the “remaining” ray, by computing the ray’s intersection with the rectangle, and having the “remaining” ray mirror the direction of approach with a new center that lies on the wall of the rectangle. The new ray length is appropriately shorter. If we run out of ray length, we simply return a segment with a null ray.

  splitRay(ray) {
    let segment = [ray.center, this.rayIntersection(ray)];
    let segmentLength = segment[0].subtract(segment[1]).norm();
    let remainingLength = ray.length - segmentLength;

    if (remainingLength < 10) {
      return {
        segment: [ray.center, ray.endpoint()],
        ray: null
      };
    }

    let vertical = this.isOnVerticalWall(segment[1]);
    let newRayDirection = null;

    if (vertical) {
      newRayDirection = new Vector(-ray.direction.x, ray.direction.y);
    } else {
      newRayDirection = new Vector(ray.direction.x, -ray.direction.y);
    }

    let newRay = new Ray(segment[1], newRayDirection, length=remainingLength);
    return {
      segment: segment,
      ray: newRay
    };
  }

As you have probably guessed, rayToPoints simply calls splitRay over and over again until the ray hits an input “stopping point”—a guard, the target, or the assassin—or else our finite ray length has been exhausted. The output is a list of points, starting from the original ray’s center, for which adjacent pairs are interpreted as line segments to draw.

  rayToPoints(ray, stoppingPoints) {
    let points = [ray.center];
    let remainingRay = ray;

    while (remainingRay) {
      // check if the ray would hit any guards or the target
      if (stoppingPoints) {
        let hardStops = stoppingPoints.map(p => remainingRay.intersects(p))
          .filter(p => p != null);
        if (hardStops.length > 0) {
          // find first intersection and break
          let closestStop = remainingRay.closestToCenter(hardStops);
          points.push(closestStop);
          break;
        }
      }

      let rayPieces = this.splitRay(remainingRay);
      points.push(rayPieces.segment[1]);
      remainingRay = rayPieces.ray;
    } 

    return points;
  }

That’s sufficient to draw the shot emanating from the assassin. This method is called every time the mouse moves.

Optimal guards

The function to compute the optimal position of the guards takes as input the containing rectangle, the assassin, and the target, and produces as output a list of 16 points.

/*
 * Compute the 16 optimal guards to prevent the assassin from hitting the
 * target.
 */
function computeOptimalGuards(square, assassin, target) {
...
}

If you read Tai-Danae’s proof, you’ll know that this construction is to

  1. Compute mirrors of the target across the top, the right, and the top+right of the rectangle. Call this resulting thing the 4-mirrored-targets.
  2. Replicate the 4-mirrored-targets four times, by translating three of the copies left by the entire width of the 4-mirrored-targets shape, down by the entire height, and both left-and-down.
  3. Now you have 16 copies of the target, and one assassin. This gives 16 line segments from assassin-to-target-copy. Place a guard at the midpoint of each of these line segments.
  4. Finally, apply the reverse translation and reverse mirroring to return the guards to the original square.

Due to WordPress being a crappy blogging platform I need to migrate off of, the code snippets below have been magically disappearing. I’ve included links to github lines as well.

Step 1 (after adding simple helper functions on Rectangle to do the mirroring):

  // First compute the target copies in the 4 mirrors
  let target1 = target.copy();
  let target2 = square.mirrorTop(target);
  let target3 = square.mirrorRight(target);
  let target4 = square.mirrorTop(square.mirrorRight(target));
  target1.guardLabel = 1;
  target2.guardLabel = 2;
  target3.guardLabel = 3;
  target4.guardLabel = 4;

Step 2:

  // for each mirrored target, compute the four two-square-length translates
  let mirroredTargets = [target1, target2, target3, target4];
  let horizontalShift = 2 * square.width();
  let verticalShift = 2 * square.height();
  let translateLeft = new Vector(-horizontalShift, 0);
  let translateRight = new Vector(horizontalShift, 0);
  let translateUp = new Vector(0, verticalShift);
  let translateDown = new Vector(0, -verticalShift);

  let translatedTargets = [];
  for (let i = 0; i < mirroredTargets.length; i++) {
    let target = mirroredTargets[i];
    translatedTargets.push([
      target,
      target.add(translateLeft),
      target.add(translateDown),
      target.add(translateLeft).add(translateDown),
    ]);
  }

Step 3, computing the midpoints:

  // compute the midpoints between the assassin and each translate
  let translatedMidpoints = [];
  for (let i = 0; i  t.midpoint(assassin)));
  }

Step 4, returning the guards back to the original square, is harder than it seems, because the midpoint of an assassin-to-target-copy segment might not be in the same copy of the square as the target-copy being fired at. This means you have to detect which square copy the midpoint lands in, and use that to determine which operations are required to invert. This results in the final block of this massive function.

  // determine which of the four possible translates the midpoint is in
  // and reverse the translation. Since midpoints can end up in completely
  // different copies of the square, we have to check each one for all cases.
  function untranslate(point) {
    if (point.x  square.bottomLeft.y) {
      return point.add(translateRight);
    } else if (point.x >= square.bottomLeft.x && point.y <= square.bottomLeft.y) {
      return point.add(translateUp);
    } else if (point.x < square.bottomLeft.x && point.y <= square.bottomLeft.y) {
      return point.add(translateRight).add(translateUp);
    } else {
      return point;
    }
  }

  // undo the translations to get the midpoints back to the original 4-mirrored square.
  let untranslatedMidpoints = [];
  for (let i = 0; i  square.topRight.x && point.y > square.topRight.y) {
      return square.mirrorTop(square.mirrorRight(point));
    } else if (point.x > square.topRight.x && point.y <= square.topRight.y) {
      return square.mirrorRight(point);
    } else if (point.x  square.topRight.y) {
      return square.mirrorTop(point);
    } else {
      return point;
    }
  }

  return untranslatedMidpoints.map(unmirror);

And that’s all there is to it!

Improvements, if I only had the time

There are a few improvements I’d like to make to this puzzle, but haven’t made the time (I’m writing a book, after all!).

  1. Be able to drag the guards around.
  2. Create new guards from an empty set of guards, with a button to “reveal” the solution.
  3. Include a toggle that, when pressed, darkens the entire region of the square that can be hit by the assassin. For example, this would allow you to see if the target is in the only possible safe spot, or if there are multiple safe spots for a given configuration.
  4. Perhaps darken the vulnerable spots by the number of possible paths that hit it, up to some limit.
  5. The most complicated one: generalize to an arbitrary polygon (convex or not!), for which there may be no optional solution. The visualization would allow you to look for a solution using 2-4.

Pull requests are welcome if you attempt any of these improvements.

Until next time!

Binary Search on Graphs

Binary search is one of the most basic algorithms I know. Given a sorted list of comparable items and a target item being sought, binary search looks at the middle of the list, and compares it to the target. If the target is larger, we repeat on the smaller half of the list, and vice versa.

With each comparison the binary search algorithm cuts the search space in half. The result is a guarantee of no more than \log(n) comparisons, for a total runtime of O(\log n). Neat, efficient, useful.

There’s always another angle.

What if we tried to do binary search on a graph? Most graph search algorithms, like breadth- or depth-first search, take linear time, and they were invented by some pretty smart cookies. So if binary search on a graph is going to make any sense, it’ll have to use more information beyond what a normal search algorithm has access to.

For binary search on a list, it’s the fact that the list is sorted, and we can compare against the sought item to guide our search. But really, the key piece of information isn’t related to the comparability of the items. It’s that we can eliminate half of the search space at every step. The “compare against the target” step can be thought of a black box that replies to queries of the form, “Is this the thing I’m looking for?” with responses of the form, “Yes,” or, “No, but look over here instead.”

binarysearch1

As long as the answers to your queries are sufficiently helpful, meaning they allow you to cut out large portions of your search space at each step, then you probably have a good algorithm on your hands. Indeed, there’s a natural model for graphs, defined in a 2015 paper of Emamjomeh-Zadeh, Kempe, and Singhal that goes as follows.

You’re given as input an undirected, weighted graph G = (V,E), with weights w_e for e \in E. You can see the entire graph, and you may ask questions of the form, “Is vertex v the target?” Responses will be one of two things:

  • Yes (you win!)
  • No, but e = (v, w) is an edge out of v on a shortest path from v to the true target.

Your goal is to find the target vertex with the minimum number of queries.

Obviously this only works if G is connected, but slight variations of everything in this post work for disconnected graphs. (The same is not true in general for directed graphs)

When the graph is a line, this “reduces” to binary search in the sense that the same basic idea of binary search works: start in the middle of the graph, and the edge you get in response to a query will tell you in which half of the graph to continue.

binarysearch2.png

And if we make this example only slightly more complicated, the generalization should become obvious:

binarysearch3

Here, we again start at the “center vertex,” and the response to our query will eliminate one of the two halves. But then how should we pick the next vertex, now that we no longer have a linear order to rely on? It should be clear, choose the “center vertex” of whichever half we end up in. This choice can be formalized into a rule that works even when there’s not such obvious symmetry, and it turns out to always be the right choice.

Definition: median of a weighted graph G with respect to a subset of vertices S \subset V is a vertex v \in V (not necessarily in S) which minimizes the sum of distances to vertices in S. More formally, it minimizes

\Phi_S(v) = \sum_{u \in S} d(v, u),

where d(u,v) is the sum of the edge weights along a shortest path from v to u.

And so generalizing binary search to this query-model on a graph results in the following algorithm, which whittles down the search space by querying the median at every step.

Algorithm: Binary search on graphs. Input is a graph G = (V,E).

  • Start with a set of candidates S = V.
  • While we haven’t found the target and |S| > 1:
    • Query the median v of S, and stop if you’ve found the target.
    • Otherwise, let e = (v, w) be the response edge, and compute the set of all vertices x \in V for which e is on a shortest path from v to x. Call this set T.
    • Replace S with S \cap T.
  • Output the only remaining vertex in S

Indeed, as we’ll see momentarily, a python implementation is about as simple. The meat of the work is in computing the median and the set T, both of which are slight variants of Dijkstra’s algorithm for computing shortest paths.

The theorem, which is straightforward and well written by Emamjomeh-Zadeh et al. (only about a half page on page 5), is that this algorithm requires only O(\log(n)) queries, just like binary search.

Before we dive into an implementation, there’s a catch. Even though we are guaranteed only \log(n) many queries, because of our Dijkstra’s algorithm implementation, we’re definitely not going to get a logarithmic time algorithm. So in what situation would this be useful?

Here’s where we use the “theory” trick of making up a fanciful problem and only later finding applications for it (which, honestly, has been quite successful in computer science). In this scenario we’re treating the query mechanism as a black box. It’s natural to imagine that the queries are expensive, and a resource we want to optimize for. As an example the authors bring up in a followup paper, the graph might be the set of clusterings of a dataset, and the query involves a human looking at the data and responding that a cluster should be split, or that two clusters should be joined. Of course, for clustering the underlying graph is too large to process, so the median-finding algorithm needs to be implicit. But the essential point is clear: sometimes the query is the most expensive part of the algorithm.

Alright, now let’s implement it! The complete code is on Github as always.

Always be implementing

We start with a slight variation of Dijkstra’s algorithm. Here we’re given as input a single “starting” vertex, and we produce as output a list of all shortest paths from the start to all possible destination vertices.

We start with a bare-bones graph data structure.

from collections import defaultdict
from collections import namedtuple

Edge = namedtuple('Edge', ('source', 'target', 'weight'))

class Graph:
    # A bare-bones implementation of a weighted, undirected graph
    def __init__(self, vertices, edges=tuple()):
        self.vertices = vertices
        self.incident_edges = defaultdict(list)

        for edge in edges:
            self.add_edge(
                edge[0],
                edge[1],
                1 if len(edge) == 2 else edge[2]  # optional weight
            )

    def add_edge(self, u, v, weight=1):
        self.incident_edges[u].append(Edge(u, v, weight))
        self.incident_edges[v].append(Edge(v, u, weight))

    def edge(self, u, v):
        return [e for e in self.incident_edges[u] if e.target == v][0]

And then, since most of the work in Dijkstra’s algorithm is tracking information that you build up as you search the graph, we define the “output” data structure, a dictionary of edge weights paired with back-pointers for the discovered shortest paths.

class DijkstraOutput:
    def __init__(self, graph, start):
        self.start = start
        self.graph = graph

        # the smallest distance from the start to the destination v
        self.distance_from_start = {v: math.inf for v in graph.vertices}
        self.distance_from_start[start] = 0

        # a list of predecessor edges for each destination
        # to track a list of possibly many shortest paths
        self.predecessor_edges = {v: [] for v in graph.vertices}

    def found_shorter_path(self, vertex, edge, new_distance):
        # update the solution with a newly found shorter path
        self.distance_from_start[vertex] = new_distance

        if new_distance < self.distance_from_start[vertex]:
            self.predecessor_edges[vertex] = [edge]
        else:  # tie for multiple shortest paths
            self.predecessor_edges[vertex].append(edge)

    def path_to_destination_contains_edge(self, destination, edge):
        predecessors = self.predecessor_edges[destination]
        if edge in predecessors:
            return True
        return any(self.path_to_destination_contains_edge(e.source, edge)
                   for e in predecessors)

    def sum_of_distances(self, subset=None):
        subset = subset or self.graph.vertices
        return sum(self.distance_from_start[x] for x in subset)

The actual Dijkstra algorithm then just does a “breadth-first” (priority-queue-guided) search through G, updating the metadata as it finds shorter paths.

def single_source_shortest_paths(graph, start):
    '''
    Compute the shortest paths and distances from the start vertex to all
    possible destination vertices. Return an instance of DijkstraOutput.
    '''
    output = DijkstraOutput(graph, start)
    visit_queue = [(0, start)]

    while len(visit_queue) > 0:
        priority, current = heapq.heappop(visit_queue)

        for incident_edge in graph.incident_edges[current]:
            v = incident_edge.target
            weight = incident_edge.weight
            distance_from_current = output.distance_from_start[current] + weight

            if distance_from_current <= output.distance_from_start[v]:
                output.found_shorter_path(v, incident_edge, distance_from_current)
                heapq.heappush(visit_queue, (distance_from_current, v))

    return output

Finally, we implement the median-finding and T-computing subroutines:

def possible_targets(graph, start, edge):
    '''
    Given an undirected graph G = (V,E), an input vertex v in V, and an edge e
    incident to v, compute the set of vertices w such that e is on a shortest path from
    v to w.
    '''
    dijkstra_output = dijkstra.single_source_shortest_paths(graph, start)
    return set(v for v in graph.vertices
               if dijkstra_output.path_to_destination_contains_edge(v, edge))

def find_median(graph, vertices):
    '''
    Compute as output a vertex in the input graph which minimizes the sum of distances
    to the input set of vertices
    '''
    best_dijkstra_run = min(
         (single_source_shortest_paths(graph, v) for v in graph.vertices),
         key=lambda run: run.sum_of_distances(vertices)
    )
    return best_dijkstra_run.start

And then the core algorithm

QueryResult = namedtuple('QueryResult', ('found_target', 'feedback_edge'))

def binary_search(graph, query):
    '''
    Find a target node in a graph, with queries of the form "Is x the target?"
    and responses either "You found the target!" or "Here is an edge on a shortest
    path to the target."
    '''
    candidate_nodes = set(x for x in graph.vertices)  # copy

    while len(candidate_nodes) > 1:
        median = find_median(graph, candidate_nodes)
        query_result = query(median)

        if query_result.found_target:
            return median
        else:
            edge = query_result.feedback_edge
            legal_targets = possible_targets(graph, median, edge)
            candidate_nodes = candidate_nodes.intersection(legal_targets)

    return candidate_nodes.pop()

Here’s an example of running it on the example graph we used earlier in the post:

'''
Graph looks like this tree, with uniform weights

    a       k
     b     j
      cfghi
     d     l
    e       m
'''
G = Graph(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
           'j', 'k', 'l', 'm'],
          [
               ('a', 'b'),
               ('b', 'c'),
               ('c', 'd'),
               ('d', 'e'),
               ('c', 'f'),
               ('f', 'g'),
               ('g', 'h'),
               ('h', 'i'),
               ('i', 'j'),
               ('j', 'k'),
               ('i', 'l'),
               ('l', 'm'),
          ])

def simple_query(v):
    ans = input("is '%s' the target? [y/N] " % v)
    if ans and ans.lower()[0] == 'y':
        return QueryResult(True, None)
    else:
        print("Please input a vertex on the shortest path between"
              " '%s' and the target. The graph is: " % v)
        for w in G.incident_edges:
            print("%s: %s" % (w, G.incident_edges[w]))

        target = None
        while target not in G.vertices:
            target = input("Input neighboring vertex of '%s': " % v)

    return QueryResult(
        False,
        G.edge(v, target)
    )

output = binary_search(G, simple_query)
print("Found target: %s" % output)

The query function just prints out a reminder of the graph and asks the user to answer the query with a yes/no and a relevant edge if the answer is no.

An example run:

is 'g' the target? [y/N] n
Please input a vertex on the shortest path between 'g' and the target. The graph is:
e: [Edge(source='e', target='d', weight=1)]
i: [Edge(source='i', target='h', weight=1), Edge(source='i', target='j', weight=1), Edge(source='i', target='l', weight=1)]
g: [Edge(source='g', target='f', weight=1), Edge(source='g', target='h', weight=1)]
l: [Edge(source='l', target='i', weight=1), Edge(source='l', target='m', weight=1)]
k: [Edge(source='k', target='j', weight=1)]
j: [Edge(source='j', target='i', weight=1), Edge(source='j', target='k', weight=1)]
c: [Edge(source='c', target='b', weight=1), Edge(source='c', target='d', weight=1), Edge(source='c', target='f', weight=1)]
f: [Edge(source='f', target='c', weight=1), Edge(source='f', target='g', weight=1)]
m: [Edge(source='m', target='l', weight=1)]
d: [Edge(source='d', target='c', weight=1), Edge(source='d', target='e', weight=1)]
h: [Edge(source='h', target='g', weight=1), Edge(source='h', target='i', weight=1)]
b: [Edge(source='b', target='a', weight=1), Edge(source='b', target='c', weight=1)]
a: [Edge(source='a', target='b', weight=1)]
Input neighboring vertex of 'g': f
is 'c' the target? [y/N] n
Please input a vertex on the shortest path between 'c' and the target. The graph is:
[...]
Input neighboring vertex of 'c': d
is 'd' the target? [y/N] n
Please input a vertex on the shortest path between 'd' and the target. The graph is:
[...]
Input neighboring vertex of 'd': e
Found target: e

A likely story

The binary search we implemented in this post is pretty minimal. In fact, the more interesting part of the work of Emamjomeh-Zadeh et al. is the part where the response to the query can be wrong with some unknown probability.

In this case, there can be many shortest paths that are valid responses to a query, in addition to all the invalid responses. In particular, this rules out the strategy of asking the same query multiple times and taking the majority response. If the error rate is 1/3, and there are two shortest paths to the target, you can get into a situation in which you see three responses equally often and can’t choose which one is the liar.

Instead, the technique Emamjomeh-Zadeh et al. use is based on the Multiplicative Weights Update Algorithm (it strikes again!). Each query gives a multiplicative increase (or decrease) on the set of nodes that are consistent targets under the assumption that query response is correct. There are a few extra details and some postprocessing to avoid unlikely outcomes, but that’s the basic idea. Implementing it would be an excellent exercise for readers interested in diving deeper into a recent research paper (or to flex their math muscles).

But even deeper, this model of “query and get advice on how to improve” is a classic  learning model first formally studied by Dana Angluin (my academic grand-advisor). In her model, one wants to design an algorithm to learn a classifier. The allowed queries are membership and equivalence queries. A membership is essentially, “What’s its label of this element?” and an equivalence query has the form, “Is this the right classifier?” If the answer is no, a mislabeled example is provided.

This is different from the usual machine learning assumption, because the learning algorithm gets to construct an example it wants to get more information about, instead of simply relying on a randomly generated subset of data. The goal is to minimize the number of queries before the target hypothesis is learned exactly. And indeed, as we saw in this post, if you have a little extra time to analyze the problem space, you can craft queries that extract quite a lot of information.

Indeed, the model we presented here for binary search on graphs is the natural analogue of an equivalence query for a search problem: instead of a mislabeled counterexample, you get a nudge in the right direction toward the target. Pretty neat!

There are a few directions we could take from here: (1) implement the Multiplicative Weights version of the algorithm, (2) apply this technique to a problem like ranking or clustering, or (3) cover theoretical learning models like membership and equivalence queries in more detail. What interests you?

Until next time!