Consistent Hashing for Dummies

Today I'll discuss about an interesting concept: consistent hashing. It's a widely employed technique to properly perform sharding in distributed storage systems. I'm not aiming at a rigorous explanation (please don't use the raw snippets I provide in production code!), but I hope I can make the concept simple enough.

The problem

What problem are we aiming to solve?
Let's suppose we need to handle more data than it's possible to store in a single server. What's the problem? Just create a number of shards and distribute the data!

But, then, clients need to know where to connect in order to store or retrieve a value. So we pick a primary key and feed it into a function, an hashing function in fact, since it maps our data to a fixed size output; such function will tell us where to connect, by returning a shard index so that 0 <= shard_index < shard_count, e.g:

def get_shard(primary_key, shard_count) -> int
    pass

Let's suppose our primary key is a 64-bit unsigned integer, and that we want to create 10 shards. Our first thought could be: let's split the keyspace in ten consecutive parts of (almost) identical size, check where our primary key belongs, and return the position of such part as the shard index, that is:

MAX_UINT64 = 2**64 - 1
def get_shard_linear(primary_key: int, shard_count: int):
    shard_size = -(-MAX_UINT64 // shard_count) # ceil division, we don't want the shard size to be accidentally too small
    return primary_key // shard_size

This would work. But there's a problem:
What happens if there's a lot of data with a primary key around a certain value, and very little elsewhere? We'll have one or two shards doing far too much work, while the others sit idle.

Modulo to the rescue! What if we just do primary_key modulo shard_count?

def get_shard_mod(primary_key: int, shard_count: int):
    return primary_key % shard_count

Looks nice! Unless the primary keys are sorted modulo-wise, it will probably distribute our load evenly.

But then, what happens if we discover that our system gets overloaded, and we'd like to add another shard? Our indexes will change, and we'll need to move some data around. But how much data, and which?.

With the linear sharding, the shard size will shrink, we'll need to recalculate the shard index, and probably move a lot of objects around for most shards; many objects that were towards the end of shard 0 will be shifted to shard 1, and so on; that's quite a lot of work:

ten_shard_size = -(-MAX_UINT64 // 10)
for n in range(ten_shard_size-3, ten_shard_size):
    print("10-shard: {} -> 11-shard: {}".format(get_shard_linear(n, 10), get_shard_linear(n, 11)))

for n in range(ten_shard_size*2-3, ten_shard_size*2):
    print("10-shard: {} -> 11-shard: {}".format(get_shard_linear(n, 10), get_shard_linear(n, 11)))

Remember, when using ten shards, indexes go from 0 to 9; when using eleven, indexes go from 0 to 10.

Output:

10-shard: 0 -> 11-shard: 1
10-shard: 0 -> 11-shard: 1
10-shard: 0 -> 11-shard: 1
10-shard: 1 -> 11-shard: 2
10-shard: 1 -> 11-shard: 2
10-shard: 1 -> 11-shard: 2

But even the modulo sharding won't help us: the shard index key will change for most values.

for n in range(20, 30):
    print("10-shard: {} -> 11-shard: {}".format(get_shard_mod(n, 10), get_shard_mod(n, 11)))

Output:

10-shard: 0 -> 11-shard: 9
10-shard: 1 -> 11-shard: 10
10-shard: 2 -> 11-shard: 0
10-shard: 3 -> 11-shard: 1
10-shard: 4 -> 11-shard: 2
10-shard: 5 -> 11-shard: 3
10-shard: 6 -> 11-shard: 4
10-shard: 7 -> 11-shard: 5
10-shard: 8 -> 11-shard: 6
10-shard: 9 -> 11-shard: 7

The solution

So, the properties we'd like to get from our ideal hashing function are balance - objects should be distributed evenly across our shards, and monotonicity - if a shard is added, objects should flow only from existing shards to the new one; there should be no need of internal reshuffling.

And, guess what? That's exactly what consistent hashing does!

A simple implementation of this algorithm goes like that:

# Python implementation by Peter Lithammer
def get_shard_lamping_veach(primary_key: int, shard_count: int):
    b, j = -1, 0.0

    if shard_count < 1:
        raise ValueError(
            f"'num_buckets' must be a positive number, got {shard_count}"
        )

    while j < shard_count:
        b = int(j)
        primary_key = ((primary_key * int(2862933555777941757)) + 1) & 0xFFFFFFFFFFFFFFFF
        j = float(b + 1) * (float(1 << 31) / float((primary_key >> 33) + 1))

    return int(b)

See it in action:

import random
random.seed(1024910)
moved_to_new = 0
for x in range(0, 10000):
    n = random.randrange(0, 2**64)
    ten_shard = get_shard_lamping_veach(n, 10)
    eleven_shard = get_shard_lamping_veach(n, 11)
    if ten_shard != eleven_shard:
        if eleven_shard != 10:
            raise ValueError("object flows to non-new index")
        moved_to_new += 1

print(f"{moved_to_new} objects changed index to 10")

Output:

898 objects changed index to 10

We picked 10,000 random indexes. As you can see, there were no objects whose shard index changed to a value different than 10, the newly-added shard index, and a reasonable amount of objects (quite close to 1/11 of 10,000 objects, in fact) were moved to the new shard index.

Explaining how this algorithm works is beyond the scope of this post; take a look at the last paper in the references if you're interested! But I hope you now understand what consistent hashing does: a consistent hashing function maps its input to evenly distributed outputs, and if the number of shards changes slightly, the output location changes only slightly. I haven't tested how the Lamping-Veach algorithm behaves if you wildly modify the number of shards (e.g. go from 10 to 20).

References