How I solved a cool probability problem

2021-09

A couple of weeks ago I came across a very interesting math puzzle when a friend told me to check out the Jane Street current puzzle of the month. From a first look, it might seem like it's a simple problem, but it does get quite tricky. In the end, my solution turned out to be weird enough that I thought it could be cool to write about it.

The puzzle was the Robot Tug-of-War from August 2021. Here’s the problem statement:


The Robot Weightlifting World Championship was such a huge success that the organizers have hired you to help design its sequel: a Robot Tug-of-War Competition!

In each one-on-one matchup, two robots are tied together with a rope. The center of the rope has a marker that begins above position 0 on the ground. The robots then alternate pulling on the rope. The first robot pulls in the positive direction towards 1; the second robot pulls in the negative direction towards -1. Each pull moves the marker a uniformly random draw from [0,1] towards the pulling robot. If the marker first leaves the interval [‑½,½] past ½, the first robot wins. If instead it first leaves the interval past -½, the second robot wins.

However, the organizers quickly noticed that the robot going second is at a disadvantage. They want to handicap the first robot by changing the initial position of the marker on the rope to be at some negative real number. Your job is to compute the position of the marker that makes each matchup a 50-50 competition between the robots. Find this position to seven significant digits—the integrity of the Robot Tug-of-War Competition hangs in the balance!


If you like this kind of puzzle, maybe you should stop here if you want to try to solve it before I spoil it for you. 🙂

First Steps

When I first read the problem, it screamed recursive probability function definition to me: When it’s player1’s turn they have a certain chance to win. If they don’t, player2 plays and if they also don’t win, it is player1’s turn again and we’re back to the same initial situation (or sort of, because the marker position player1 is playing from now has likely changed).

In order to proceed, let’s describe a couple of functions that I’ll start referring to in my explanations:

p1(x): probability that player1 wins the game if it’s currently their turn and they’re 
       playing from marker position x.

p2(x): probability that player2 wins the game if it’s currently their turn and they’re
       playing from marker position x.

So basically we need to find a start marker position s such that p1(s) = 0.5.

I quickly found a simple (and wrong) solution for it. It sounded too easy, and it was. My mistake was to assume that the mean value of p1(x) (and p2(x) as well) in an interval [a, b] was p1((a + b) / 2). Once I realized that was not the case, the problem became much harder.

I kept trying to find an analytical definition for p1(x). And after thinking about it for some time, I was take to take the following steps:

p1(x) = “chance player1 wins immediately” + 
        “chance player1 wins later in another turn”

The first part is pretty simple, it’s just the part of player1’s movement distribution that’s after 0.5, so x + 1 - 0.5.

The second part is much trickier. Since player1 didn’t win immediately, it will leave the marker at some position y and it will be player2’s turn. The chance of player1 winning now, is the chance that player2 loses, which is 1 - p2(y). The problem is that y varies according to player1’s movement distribution, so we need to integrate 1 - p2(y) over the possible values of y, so from x to 0.5. The integral is unweighted since player1’s movement distribution is uniform in an interval of size 1, so x0.5 (1 - p2(y)) dy

We note that there’s a symmetry between p1(x) and p2(x): the chance of player1 winning when they start from a marker position x is the same chance of player2 winning when they start from position -x. So p1(x) = p2(-x). Using that, we can replace p2(y) in the integral expression by p1(-y).

Putting everything together:

p1(x) = (x + 1 - 0.5) +  ∫x0.5 (1 - p1(-y)) dy

p1(x) = (x + 0.5) +  ∫x0.5 1 dy - ∫x0.5 p1(-y) dy

p1(x) = (x + 0.5) + (0.5 - x) - ∫x0.5 p1(-y) dy

p1(x) = 1 - ∫x0.5 p1(-y)) dy

So yeah, we found a recursive function definition after all! Although one with an unexpected integral getting in the way.

At this point, my differential equation solving skills failed me (it’s been a long time since I’ve sharpened them in college) and I got stuck. This is also where my solution started to get more weird (or interesting depending on your point of view haha 😅). As I couldn’t find a way to continue solving the problem analytically, I started looking for ways to approximate this integral.

Solving p1(x) using discrete integral approximation

My idea came from numerical integral approximation, basically I would break the integration interval in blocks of size step, i.e. [x, x+step], and approximate the area of each block by the value of the function at it’s beginning times the step size, i.e. f(x) * step. As step gets smaller, the approximation gets better.

x0.5 p1(-y) dy ~= Σi=0...n [step * p1(-(x + i * step))], where n = (0.5 - x) / step

For illustration purposes, let’s approximate p1(x) using step = 0.1. For that we will evaluate p1(x) at the following set of points: {0.4, 0.3, 0.2, 0.1, 0, -0.1, -0.2, -0.3, -0.4}.

p1(x) = 1 - Σi=0...(0.5 - x)/0.1 [0.1 * p1(-(x + i * 0.1))]

p1(0.4)  = 1 - 0.1 * (p1(-0.4)) 
p1(0.3)  = 1 - 0.1 * (p1(-0.4) + p1(-0.3)) 
p1(0.2)  = 1 - 0.1 * (p1(-0.4) + p1(-0.3) + p1(-0.2))
p1(0.1)  = 1 - 0.1 * (p1(-0.4) + p1(-0.3) + p1(-0.2) + p1(-0.1))
p1(0)    = 1 - 0.1 * (p1(-0.4) + p1(-0.3) + p1(-0.2) + p1(-0.1) + p1(0))
p1(-0.1) = 1 - 0.1 * (p1(-0.4) + p1(-0.3) + p1(-0.2) + p1(-0.1) + p1(0) + p1(0.1))
p1(-0.2) = 1 - 0.1 * (p1(-0.4) + ... + p1(0.2))
p1(-0.3) = 1 - 0.1 * (p1(-0.4) + ... + p1(0.3))
p1(-0.4) = 1 - 0.1 * (p1(-0.4) + ... + p1(0.4))

Now we have a system of 9 equations and 9 variables, we can solve it, and use as an approximate answer the value among those nine for which p1(x) gets evaluated closest to 0.5! Then we can just keep reducing the step size to get a better and better approximation until we think we've got confidence enough on the seven significant digits that we need. 🙂

Initially that system of equations looked a little bit ugly and I was getting afraid it could be inefficient to solve (such that if I reduced the step a lot, I’d have too many variables and it would get so slow that it’s infeasible to compute), but then I got quite excited when I found a way to solve it in O(N)!

To see that linear way to solve it, from the second equation forward we subtract from each equation the one that is right above them (and we just keep the first equation as is):

p1(0.4) = 1 - 0.1 * (p1(-0.4))
p1(0.3)  - p1(0.4)  = - 0.1 * p1(-0.3)
p1(0.2)  - p1(0.3)  = - 0.1 * p1(-0.2)
p1(0.1)  - p1(0.2)  = - 0.1 * p1(-0.1)
p1(0)    - p1(0.1)  = - 0.1 * p1(0)
p1(-0.1) - p1(0)    = - 0.1 * p1(0.1)
p1(-0.2) - p1(-0.1) = - 0.1 * p1(0.2)
p1(-0.3) - p1(-0.2) = - 0.1 * p1(0.3)
p1(-0.4) - p1(-0.3) = - 0.1 * p1(0.4)

Starting at the equation in the middle, and then looking in sequence at one bellow, then the next above the block, and then the next bellow the block, and so on, we can quickly express all variables in terms of p1(0) (i.e. let’s represent p1(x) as c(x) * p1(0), and find c(x) for every x, starting with c(0) = 1). We do this for all equations except the first, which will be left out:

p1(0) - p1(0.1) = - 0.1 * p1(0)
=> p1(0.1) = (0.1 * c(0) + c(0)) * p1(0) 
=> c(0.1) = 0.1 * c(0) + c(0)
=> c(0.1) = 1.1 
  
p1(-0.1) - p1(0) = - 0.1 * p1(0.1)
=> p1(-0.1) = (-0.1 * c(0.1) + c(0)) * p1(0)
=> c(-0.1) = -0.1 * c(0.1) + c(0)
  
p1(0.1) - p1(0.2) = - 0.1 * p1(-0.1)
=> p1(0.2) = (0.1 * c(-0.1) + c(0.1)) * p1(0)
=> c(0.2) = 0.1 * c(-0.1) + c(0.1)
  
p1(-0.2) - p1(-0.1) = - 0.1 * p1(0.2)
=> p1(-0.2) = (-0.1 * c(0.2) + c(-0.1)) * p1(0)
=> c(-0.2) = -0.1 * c(0.2) + c(-0.1)
  
p1(0.2) - p1(0.3) = - 0.1 * p1(-0.2)
=> p1(0.3) = (0.1 * c(-0.2) + c(0.2)) * p1(0)
=> c(0.3) = 0.1 * c(-0.2) + c(0.2)
  
p1(-0.3) - p1(-0.2) = - 0.1 * p1(0.3)
=> p1(-0.3) = (-0.1 * c(0.3) + c(-0.2)) * p1(0)
=> c(-0.3) = -0.1 * c(0.3) + c(-0.2)
  
p1(0.3) - p1(0.4) = - 0.1 * p1(-0.3)
=> p1(0.4) = (0.1 * c(-0.3) + c(0.3)) * p1(0)
=> c(0.4) = 0.1 * c(-0.3) + c(0.3)
  
p1(-0.4) - p1(-0.3) = - 0.1 * p1(0.4)
=> p1(-0.4) = (-0.1 * c(0.4) + c(-0.3)) * p1(0)
=> c(-0.4) = -0.1 * c(0.4) + c(-0.3)

Now that we finally have the values for every c(x), we can just use the first equation to find p1(0):

p1(0.4)  = 1 - 0.1 * (p1(-0.4))
=> (c(0.4) + 0.1 * c(-0.4)) * p1(0) = 1
=> p1(0) = 1 / (c(0.4) + 0.1 * c(-0.4))

Then we use p1(0) combined with each c(x) to find every other p1(x) and pick the one that’s closest to 0.5.

Looking back at how we solved each c(x), we can see there’s a rather simple pattern to determine them at pairs iteratively. Starting from i = 0 and incrementing it by one every loop:

c((i + 1) * step) = step * c(-i * step) + c(i * step)
c(-(i + 1) * step) = -step * c((i + 1) * step) + c(-i * step)

As I’ve been playing with Rust lately, of course I went ahead to implement this solution using it haha. Initially, I thought I had to store every c(x) in memory to solve the problem. That would limit how much I could decrease the step: for example, using a step of 1.0e-9, I'd have to store 1.0e+9 64-bit floating point variables (one for each c(x)), which would amount to 8GB of RAM already.

It was then that I realized that my iteration method to find c(x) values only depended on the last two values computed! So the loop could actually run with constant memory (independent of the step size). I would just have to run it twice: once to find p1(0), and a second time to inspect every p1(x) to find the one that is closest to 0.5. This insight allowed me to use even smaller step sizes to build more confidence in my solution, like 1.0e-11 (which would require 800GB of RAM if I had to store each c(x) in memory). I even tried using a smaller step once (it starts to take a lot of time to run, 1.0e-11 already takes a few minutes), but it looked like I started hitting some underflow issues with 64-bit floating point representation.

Here's my Rust code that implements this:

fn approximate_fair_p1_start_position(n_iterations: usize) {
    let step: f64 = 1.0 / ((2 * n_iterations) as f64);

    // Find p1(0)
    let mut a: f64 = 1.0; // c(i * step)
    let mut b: f64 = 1.0; // c(-i * step)
    for _ in 0..n_iterations {
        a = a + b * step;
        b = b - a * step;
    }
    let p1_at_0: f64 = 1.0 / (a + b * step);

    // Find x such that p1(x) is the closest one to 0.5
    let mut best_x_index = 0;
    let mut best_p1_x_dist = (p1_at_0 - 0.5).abs();
    a = p1_at_0;
    b = p1_at_0;
    for i in 0..n_iterations {
        a = a + b * step;
        b = b - a * step;

        let b_dist = (b - 0.5).abs();
        if b_dist < best_p1_x_dist {
            best_x_index = i + 1;
            best_p1_x_dist = b_dist;
        }
    }

    let best_x = -(best_x_index as f64) * step;
    println!(
        "best_x: {:.12},  step: {:<8.2e},  n_iterations: {}",
        best_x, step, n_iterations
    );
}

fn main() {
    approximate_fair_p1_start_position(5);
    approximate_fair_p1_start_position(50);
    approximate_fair_p1_start_position(500);
    approximate_fair_p1_start_position(5000);
    approximate_fair_p1_start_position(50000);
    approximate_fair_p1_start_position(500000);
    approximate_fair_p1_start_position(5000000);
    approximate_fair_p1_start_position(50000000);
    approximate_fair_p1_start_position(500000000);
    approximate_fair_p1_start_position(5000000000);
    approximate_fair_p1_start_position(50000000000);
}

Results:

best_x: -0.200000000000,  step: 1.00e-1 ,  n_iterations: 5
best_x: -0.280000000000,  step: 1.00e-2 ,  n_iterations: 50
best_x: -0.285000000000,  step: 1.00e-3 ,  n_iterations: 500
best_x: -0.285000000000,  step: 1.00e-4 ,  n_iterations: 5000
best_x: -0.285000000000,  step: 1.00e-5 ,  n_iterations: 50000
best_x: -0.285000000000,  step: 1.00e-6 ,  n_iterations: 500000
best_x: -0.285000100000,  step: 1.00e-7 ,  n_iterations: 5000000
best_x: -0.285000120000,  step: 1.00e-8 ,  n_iterations: 50000000
best_x: -0.285000121000,  step: 1.00e-9 ,  n_iterations: 500000000
best_x: -0.285000121700,  step: 1.00e-10,  n_iterations: 5000000000
best_x: -0.285000121750,  step: 1.00e-11,  n_iterations: 50000000000

Although I didn't find a way to formally prove that my solution had at least seven correct significant digits, it looked like the first seven digits were getting stable enough (as I reduced the step size) to trust that they were the correct ones.

After a couple of hours struggling to make progress, I was quite excited and happy that I found such a simple algorithm to solve the problem! 🙂

Final Thoughts

After I was satisfied with my solution, I still had one question at the back of my mind: "Does this problem have an analytical closed-form solution?". It reminded me of the goat problem which got popular since it just got solved in closed-form recently.

Turns out, it does! When a solution was posted on the Jane Street website, I was quite delighted to see that they provided a closed-form solution:

p1(x) = (sin(x) + cos(x)) / (sin(0.5) + cos(0.5))

Solution for p1(s) = 0.5 :  s = arcsin(sin(0.5 + 𝜋/4) / 2) - 𝜋/4

Now look at that solution. It's full of trigonometry expressions! Imagine how puzzled someone who studied probability, but has never seen trigonometry before would have been if they were presented this problem and that solution. That's something that always keeps coming back to amaze me when I see a new example of it happening: how we can find so many connections between apparently unconnected areas of Mathematics.