# Disjoint Set — Union & Find

Source: https://shawnlyu.com/algorithms/disjoint-set-union-find/

In computer science, a

disjoint-set data structure… is a data structure that stores a collection of disjoint (non-overlapping) sets. … It provides operations for adding new sets, merging sets (replacing them by their union), and finding a representative member of a set.Source: https://en.wikipedia.org/wiki/Disjoint-set_data_structure

Disjoint Set helps to group distinct elements into a collection of disjoint sets. There are two major functions associated with it: finding the set that a given element belongs to and merging two sets into one (Cormen, Thomas H., and Thomas H. Cormen. *Introduction to Algorithms*). This post will introduce the implementations of two functions union(u,v) and find(p), and provide more details using Leetcode 200. Number of Islands as an example.

`find(p)`

, `union(u,v)`

, and optimization

There are two optimizations in the two functions: path compression and merge by rank.

`find(p)`

and path compression

Given an element `p`

, `find(p)`

will return the representative of the set that `p`

belongs to. Initially, we have an array `root`

indicating the root of each element. Therefore, we can recursively or iteratively search for the root of `p`

through `root`

.

`root=[0,0,0,0,4,4,5,5,7]`

# recursively

def find(p):

if root[p]!=p:

return find(root[p])

return p

# iteratively

def find(p):

while root[p]!=p:

p = root[p]

return p

We can add **path compression** as optimization. While we are searching for the root of `p`

, we can assign the root to the elements along the path. Also there will be two ways of implementing this.

`# recursively`

def find(p):

if root[p]!=p:

root[p] = find(root[p])

return root[p]

# iteratively

def find(p):

node = p

while node!=root[node]:

node = root[node]

while p!=node:

par = root[p]

root[p] = node

p = par

return p

`union(u,v)`

and merge by rank

Given two elements `u`

and `v`

, `union(u,v)`

merges the sets that `u`

and `v`

belong to accordingly into one. To avoid the case shown below, we can add merge by rank as optimization.

We can have an array `rank`

indicating the height of each node and when we merge two sets, we would always seek to put the set with lower rank under the set with a higher rank.

`def union(u,v):`

u_root = find(u)

v_root = find(v)

if rank[u_root]>rank[v_root]:

root[v_root] = u_root

elif rank[u_root]<rank[v_root]:

root[u_root] = v_root

else:

root[v_root] = u_root

rank[u_root] += 1

# Complexities

Without path compression and **merge by rank**, the time complexity for find(p) could be O(n)*O*(*n*) and

# Leetcode 200. Number of Islands

Initially, we would assign all `'1'`

element as an isolated island. While we are iterating from top to bottom and from left to right, if we find its right neighbour or its neighbour below is also `'1'`

, we can conduct `union(u,v)`

. Remember to deduct `1`

from the total number of the island when we merge two sets.

`class Solution:`

def numIslands(self, grid: List[List[str]]) -> int:

if not grid or not grid[0]: return 0

row,col = len(grid),len(grid[0])

root = [i for i in range(row*col)]

ranks = [0]*(row*col)

cnt = 0

for r in range(row):

for c in range(col):

# count each '1' as an isolated island

if grid[r] == '1':

cnt += 1

def find(p):

# add path compression

if root[p]!=p:

root[p] = find(root[p])

return root[p]

def union(u,v):

# add merge by rank

nonlocal cnt

u_root = find(u)

v_root = find(v)

if u_root == v_root: return

if ranks[u_root] > ranks[v_root]: root[v_root] = u_root

elif ranks[u_root] < ranks[v_root]: root[u_root] = v_root

else:

root[v_root] = u_root

ranks[u_root] += 1

# remember to deduct 1 from the total number of islands

cnt -= 1

for r in range(row):

for c in range(col):

if grid[r] == '0': continue

# union connected '1's

if r+1<row and grid[r+1] == '1': union(r*col+c,(r+1)*col+c)

if c+1<col and grid[r] == '1': union(r*col+c,r*col+c+1)

return cnt