Consistent Hashing — Distributed Cache

Sourav Das
11 min readMar 29, 2024

--

What is Consistent Hashing ?

Consistent hashing is a technique used in distributed computing to efficiently distribute and balance data across multiple servers or nodes in a consistent manner, even as the number of nodes changes. It ensures that when nodes are added or removed, only a fraction of the keys need to be remapped, minimizing the impact on the system. This is achieved by mapping each data item to a point on a ring, and each node is assigned to a range of points on the ring. When a data item is hashed, it is assigned to the nearest server in a anti/clockwise direction on the ring.

Binary Search Tree Implementation

A Binary Search Tree (BST) implementation of consistent hashing is an alternative approach to building a consistent hashing algorithm. It is a search tree of hash maps distributed across several nodes.

# bst implementation api
bst = TreeMap()

for i in range(0, 150, 5):
bst.put(i, f"Value {i}")

print(bst) # {0=Value 0, 5=Value 5, 10=Value 10, 15=Value 15 ... }
print(bst.get(50)) # Value 50
print(bst.lower_key(10)) # Value 5
print(bst.higher_key(10)) # Value 15

Hashing function

I used the hashlib hash function to hash the keys, take a string as key and it returns a hashed integer. Pythons in-built hash function doesn’t provide the same hash for the same key for separate process, thus cannot be used here.

def stable_hash(key: str) -> int:
str_bytes = bytes(key, "UTF-8")
m = hashlib.md5(str_bytes)
return int(m.hexdigest(), base=16)

Node

A Node is a object which will store the key and value pairs inside it, it is a LRU cache which will remove the Least recenly used key when the nodes size is exhausted, lets define the node api i.e. how a node should behave.

# Node is a cache instance, were all the key -> value data are present
# instance_no: uniquely identify a node in a node set
# cache_size: total data that can be present in the node in bytes
class Node:

def __init__(self, instance_no: int, cache_size: int):
self.instance_no = instance_no # identity
self.data = OrderedDict() # LRU cache
self.cache_size = cache_size # in bytes
self.keys_to_remove = set()
self.data_size = 0 # current size

def put(self, key, value):
self.data[key] = value
# move latest used data to end, LRU cache
self.data.move_to_end(key)
self.data_size += len(json.dumps(self.data[key])) + len(key)
if self.load() >= 1.0:
# remove Least Recently Used key, as total spaced is use up
key, value = self.data.popitem(last=False)
self.data_size -= len(json.dumps(value)) + len(key)

def get(self, key):
# move latest used data to end, LRU cache
self.data.move_to_end(key)
return self.data[key]

def has(self, key: str):
return key in self.data

def remove(self, key: str):
value = self.data.pop(key)
self.data_size -= len(json.dumps(value)) + len(key)
return value

We know data need to be moved from one node to another, this is helpful when we need to split data between nodes.

# moved keys from one node to another
def copy_keys(self, target_node, from_key, to_key):
for key in list(self.data.keys()):
key_hash = stable_hash(key)
if from_key <= key_hash:
target_node.put(key, self.data[key])
# if target node crashes in between copy data is not lost
self.keys_to_remove.add(key)
return True

# reduce memory usage if possible
def compact_keys(self):
for key in list(self.data.keys()):
if key in self.keys_to_remove:
self.remove(key)
self.keys_to_remove.clear()

Some additional methods to help in distributing the data between all the nodes.

# calculate median of all keys present
def calculate_mid_key(self):
if len(self.data) <= 0:
return 0
return sum(map(lambda x: stable_hash(x), self.data.keys())) // self.data.__len__()

# calculate amount of used cache size
def load(self):
return self.data_size / self.cache_size

# computes memory usage of current node
def metrics(self):
# used_memory = psutil.virtual_memory().used
# data_size is the current amount of data is present in node in bytes
# load is present of cache size used
return {
"used_memory": self.data_size,
"load": self.load()
}

Node Set

A node set is a collection of node where the same data is copied across multiple nodes to provide availibity, if one nodes gets destroyed, other nodes are available to back up that node again.

# create a nodeset,
# replication factor: no of node the data is replicated across
# instance no : unique id of node set
# cache size: maximum size of each node in bytes
class NodeSet:

def __init__(self, replication_factor: int, instance_no: int, cache_size: int):
self.replication_factor = replication_factor
self.instance_no = instance_no
self.cache_size = cache_size
self.nodes = []
for i in range(replication_factor):
ino = instance_no * replication_factor + i
self.nodes.append(Node(ino, cache_size))
pass

# get specific replicate from the set
def get_replica(self, index):
return self.nodes[index]

# put the same key in all replicas
def put(self, key: str, value):
# should be done in parallel O(1)
for node in self.nodes:
node.put(key, value)
pass

# get the value from any of the replicas
def get(self, key: str):
# should be done in parallel O(1)
for node in self.nodes:
if node.has(key):
return node.get(key)
return None

def has(self, key: str):
# should be done in parallel O(1)
for node in self.nodes:
if node.has(key):
return True
return False

def remove(self, key: str):
# should be done in parallel O(1)
for node in self.nodes:
return node.remove(key)
return None

The NodeSet API is almost as same as node, the only difference is that node’s are internal to node-set and do not expose node details to the hash ring. The Ring forwards the put, remove & get request to node-set it is the nodeset’s resposiblity to maintain the data change requests to its internal replicas.

Apart from the individual key request, the entire data can be moved across from one node-set to another node-set, some of the addition functionality are given below

# copy data from other node set to other node set
def copy_keys(self, target_node_set: "NodeSet", from_hash: int, to_hash: int):
# probably this should run in parallel
for i in range(self.replication_factor):
# get other node replica
target_node = target_node_set.get_replica(i)
# copy self replica to target replica
self.nodes[i].copy_keys(target_node, from_hash, to_hash)
return True

# reduce memory usage if possible
def compact_keys(self):
# can be done in parallel O(1)
for node in self.nodes:
node.compact_keys()
pass

# calculate median of all keys present
def calculate_mid_key(self):
return self.nodes[0].calculate_mid_key()

def load(self):
return self.nodes[0].load()

# computes memory usage of current node
def metrics(self):
# used_memory = psutil.virtual_memory().used
return {
"used_memory": self.nodes[0].data_size,
"load": self.load()
}

The Ring

Create a ring with binary search tree, start with a node at hash “0”, the base node, and depending on the load, this node will scale.

# this class contains the binary search tree where all the nodes are stored
# max_node: the maximum no of NodeSet allowed to create
# replication_factor: no of replica allowed inside a NodeSet
class HashRing:

def __init__(self, cache_size: int, max_node: int, replication_factor: int) -> None:
self.replication_factor = replication_factor
self.max_node = max_node
self.ring = TreeMap()
self.cache_size = cache_size
self.node_counter = 1
# base node at hash 0
self.ring[0] = NodeSet(replication_factor, 0, cache_size)

The ring only deals with the node set, ring is unaware of the internal nodes present inside the nodeset, this providing a good abstration to the hash ring simplyfing the hash ring operations.

Find Node

The below code shows how a node-set is resolved from a given key.

# get nearest node from key hash
def resolve_node(self, key: str) -> Node:
# hash the key
key_hash = stable_hash(key)
# get nearest lower node from key
key_hash = self.ring.floor_key(key_hash)
# if nearest lower node not present get last node
if key_hash is None:
key_hash = self.ring.last_key()
return self.ring.get(key_hash)

Up Scaling

When a node-set is almost filled up, the hash ring splits the data of the overloaded node with a new node which is empty, so that ring operation is not affected.

# split a node into two
def split_node(self, node_hash: int):
# get the current node
current_node = self.ring.get(node_hash)
LOGGER.info("Splitting {}", current_node)
# get next node hash
ahead_hash = self.ring.higher_key(node_hash)
# create a new node set
new_node = NodeSet(self.replication_factor, self.node_counter, self.cache_size)
# calculate the hash in-between the two nodes
mid_hash = current_node.calculate_mid_key()
LOGGER.info("Moving data {} -> {}", current_node, new_node)
# move half the data to new node
current_node.copy_keys(new_node, mid_hash, ahead_hash)
# place the new node in-between the two nodes
self.ring.put(mid_hash, new_node)
# remove copied keys from current node
current_node.compact_keys()
self.node_counter += 1
pass

In this the current node’s data is moved into the new free node and then once successful then the new node is placed into the ring, it is placed at a position where it exactly divide the existing keys into two. i.e. average key hash.

Down Scaling

When the node is almost empty, the ring wants to free up the node so that its space can be efficiently used.

# merge the data with previous node, and delete the current node
def merge_node(self, node_hash: int):
# should never remove the first node, at 0 hash
if node_hash == self.ring.first_key():
return
# get current node
current_node = self.ring.get(node_hash)
LOGGER.info("Merging Node {}", current_node)
# get the previous node
behind_hash = self.ring.lower_key(node_hash)
behind_node = self.ring.get(behind_hash)

LOGGER.info("Moving data {} -> {}", current_node, behind_node)
# copy the data from current node to previous node
current_node.copy_keys(behind_node, node_hash, None)
# remove node from the ring
self.ring.remove(node_hash)
pass

The underloaded node is down scaled/deleted by moving all its values to the previous node, then when the node is free, the ring deletes it.

Distribution Algorithms

The Ring runs this method after regular interval, and tried to distribute the data across the entire ring across all nodes, it identifies the overloaded & underloaded nodes to split & merge the nodes.

# call this method after a fixed interval
def balance(self):
LOGGER.info("Balance Started")
# get overloaded & underloaded nodes
overloaded_nodes, underloaded_nodes = self.check_nodes()
if len(self.ring) < self.max_node:
for node_hash in overloaded_nodes:
# split overloaded nodes
self.split_node(node_hash)

for node_hash in underloaded_nodes:
# merge underloaded nodes
self.merge_node(node_hash)
LOGGER.info("Balance Completed")
pass

With this, the ring tried balancing the load, this is one of the strategy from across across a lot of the methods, because of its simplicity.

Testing the Ring

The below code shows how to use the ring, the test case randomly generates key value pairs and sees how it is balanced.

# generate random n length string
def random_str(n: int) -> str:
m = n // len(string.ascii_uppercase + string.digits) + 1
return ''.join(random.sample((string.ascii_uppercase + string.digits) * m, n))

# cache with 128 bytes size
# maximum having 5 node sets
# with 2 replica per node set
ring = HashRing(128, 5, 2)

keys = []

for i in range(20):
# generate random string
key = random_str(3)
# generate random string of random length
value = random_str(random.randint(1, 100))
# save in ring
ring.put(key, value)
# store key from removal
keys.append(key)
LOGGER.info("Inserted Key : {} -> {}", key, value)
# distribute the data across the nodes
ring.balance()

for key in keys:
if ring.has(key):
# remove key from the ring
ring.remove(key)
LOGGER.info("Removed Key : {}", key)
# distribute the data & remove unused nodes
ring.balance()

Output

The below data shows how the ring distributes the data across the nodes.

# when the Ring is started
Balance Started
NodeSet-0: [node-0, node-1] metrics - {'used_memory': 0, 'load': 0.0}
Balance Completed

# after inserting all the keys, the distribution in the Ring
Balance Started
NodeSet-0: [node-0, node-1] metrics - {'used_memory': 112, 'load': 0.875}
NodeSet-4: [node-8, node-9] metrics - {'used_memory': 72, 'load': 0.5625}
NodeSet-2: [node-4, node-5] metrics - {'used_memory': 41, 'load': 0.3203125}
NodeSet-1: [node-2, node-3] metrics - {'used_memory': 90, 'load': 0.703125}
NodeSet-3: [node-6, node-7] metrics - {'used_memory': 137, 'load': 1.0703125}
Balance Completed

# after deleting the all the keys the distribution in the Ring
Balance Started
NodeSet-0: [node-0, node-1] metrics - {'used_memory': 0, 'load': 0.0}
NodeSet-5: [node-10, node-11] metrics - {'used_memory': 0, 'load': 0.0}
Balance Completed

Complete Implementation

# this class contains the bst where all the nodes
# max_node: the maximum no of Node set allowed to create
# replication_factor: no of replica allowed in a node set
class HashRing:

def __init__(self, cache_size: int, max_node: int, replication_factor: int) -> None:
self.replication_factor = replication_factor
self.max_node = max_node
self.ring = TreeMap()
self.cache_size = cache_size
self.node_counter = 1
# base node at hash 0
self.ring.put(0, NodeSet(replication_factor, 0, cache_size))

# find out the overloaded & underloaded nodes
def check_nodes(self):
overloaded_nodes = []
underloaded_nodes = []
for node_hash in self.ring.key_set():
node_set = self.ring.get(node_hash)
metrics = node_set.metrics()
load = metrics['load']
LOGGER.info("{} metrics - {}", node_set, metrics)
# if usage crosses above 80 %, split the data between to node
if load > 0.8:
overloaded_nodes.append(node_hash)
# if usage falls below 15 %, then merge the node with the previous node
if load < 0.15:
underloaded_nodes.append(node_hash)
return overloaded_nodes, underloaded_nodes

# split a node into two
def split_node(self, node_hash: int):
# get the current node
current_node = self.ring.get(node_hash)
LOGGER.info("Splitting {}", current_node)
# get next node hash
ahead_hash = self.ring.higher_key(node_hash)
# create a new node set
new_node = NodeSet(self.replication_factor, self.node_counter, self.cache_size)
# calculate the hash in-between the two nodes
mid_hash = current_node.calculate_mid_key()
LOGGER.info("Moving data {} -> {}", current_node, new_node)
# move half the data to new node
current_node.copy_keys(new_node, mid_hash, ahead_hash)
# place the new node in-between the two nodes
self.ring.put(mid_hash, new_node)
# remove copied keys from current node
current_node.compact_keys()
self.node_counter += 1
pass

# merge the data with previous node, and delete the current node
def merge_node(self, node_hash: int):
# should never remove the first node, at 0 hash
if node_hash == self.ring.first_key():
return
# get current node
current_node = self.ring.get(node_hash)
LOGGER.info("Merging Node {}", current_node)
# get the previous node
behind_hash = self.ring.lower_key(node_hash)
behind_node = self.ring.get(behind_hash)

LOGGER.info("Moving data {} -> {}", current_node, behind_node)
# copy the data from current node to previous node
current_node.copy_keys(behind_node, node_hash, None)
# remove node from the ring
self.ring.remove(node_hash)
pass

# call this method after a fixed interval
def balance(self):
LOGGER.info("Balance Started")
# get overloaded & underloaded nodes
overloaded_nodes, underloaded_nodes = self.check_nodes()
if len(self.ring) < self.max_node:
for node_hash in overloaded_nodes:
# split overloaded nodes
self.split_node(node_hash)

for node_hash in underloaded_nodes:
# merge underloaded nodes
self.merge_node(node_hash)
LOGGER.info("Balance Completed")
pass

# get nearest node from key hash
def resolve_node(self, key: str) -> Node:
# hash the key
key_hash = stable_hash(key)
# get nearest lower node from key
key_hash = self.ring.floor_key(key_hash)
# if nearest lower node not present get last node
if key_hash is None:
key_hash = self.ring.last_key()
return self.ring.get(key_hash)

def put(self, key, value):
node = self.resolve_node(key)
node.put(key, value)
return True

def get(self, key):
node = self.resolve_node(key)
return node.get(key)

def has(self, key):
node = self.resolve_node(key)
return node.has(key)

def remove(self, key):
node = self.resolve_node(key)
return node.remove(key)

Conclusion

Overall, while distributed caching offers many benefits in terms of scalability, performance, and availability, it also presents challenges related to complexity, data consistency, and network overhead that need to be carefully addressed during design and implementation.

Here is also a blog written by me to implement this concept in Kubernetes & Docker — Distributed Hash Table in k8s

Advantages:

  • Scalability: Distributed caching allows for horizontal scaling, meaning you can add or remove cache nodes dynamically to handle varying workloads.
  • High Availability: NodeSet ensures high availability. If one replica node fails, the data remains accessible from other nodes.

Disadvantages:

  • Complexity: Setting up and managing a distributed caching system is complex, requiring expertise in caching technologies.
  • Data Partitioning: Distributing data across multiple nodes introduces complexities related to data partitioning and distribution, which can impact cache efficiency and performance.
  • Data Eviction: Distributed caching systems typically have limited memory capacity per node, which may require implementing eviction policies to remove the least-used data from the cache.
  • Network Overhead: Distributed caching systems rely on network communication between nodes, which can introduce network overhead and latency, especially in geographically distributed deployments.

--

--