Union Find 并查集

Catalogue
  1. 1. What is Union Find ?
    1. 1.1. 应用场景
    2. 1.2. Union Find vs DFS
  2. 2. Algorithm
    1. 2.1. Quick-Find
    2. 2.2. Quick-Union
    3. 2.3. Weighted Quick-Union
    4. 2.4. Path Compression
    5. 2.5. Weighted Quick-Union With Path Compression
  3. 3. Connected
    1. 3.1. LintCode 589. Connecting Graph
    2. 3.2. LintCode 590. Connecting Graph II
    3. 3.3. 130. Surrounded Regions
    4. 3.4. 737. Sentence Similarity II
  4. 4. 统计连通块的个数
    1. 4.1. LintCode 591. Connecting Graph III
    2. 4.2. 323. Number of Connected Components in an Undirected Graph
    3. 4.3. 305. Number of Islands II
    4. 4.4. 547. Friend Circles
  5. 5. Redundant Connection
    1. 5.1. 261. Graph Valid Tree
    2. 5.2. 684. Redundant Connection
    3. 5.3. 685. Redundant Connection II

What is Union Find ?

  • Union-Find算法(并查集算法)是解决动态连通性(Dynamic Conectivity)问题的一种算法,”人以类聚,物以群分”
  • 一种用来解决集合查询合并数据结构,支持 O(1)find, O(1)union
  1. 查询 Find
    • 确定某个元素x属于哪一个集合
  2. 合并 Union
    • 将两个集合合并

应用场景

  1. Computer Network
    • 两个网络节点是否联通
    • 最小的布线使得整个网络联通
  2. Social Network
    • Linkedin 两个用户可能认识的人
  3. 集合论

Union Find vs DFS

在对问题进行建模的时候,我们应该尽量想清楚需要解决的问题是什么!
因为模型中选择的数据结构和算法显然会根据问题的不同而不同!

  • Union Find - 给出两个节点,判断它们是否连通,如果连通,不需要给出具体的路径
  • DFS - 给出两个节点,判断它们是否连通,如果连通,需要给出具体的路径

Algorithm

Quick-Find

有点类似于染色的过程,每个节点一个颜色,然后相同的节点设置成相同的颜色。
quick-find算法十分直观符合简单的思考过程。

1
2
3
4
5
6
7
8
9
10
11
12
13
# Time : O(1)
def find(x):
return root[x]

# Time : O(n)
def union(x, y):
rootx = root[x]
rooty = root[y]
if rootx == rooty:
return
for i in xrange(len(root)):
if root[i] == rootx:
root[i] = rooty

每次添加新路径(Union)就是 “牵一发而动全身”,想要解决这个问题,关键就是要提高union方法的效率,让它不再需要遍历整个数组。

Quick-Union

  • 以树的思想,表示集合!!!
  • 这是UF算法里最关键的思路,以树的形式表示集合,这样组织正好可是很高效的实现find和union!
1
2
3
4
5
6
7
8
9
10
11
12
# Time : O(Tree Height), Worst Case O(n)
# Recursion
def find(x):
if root[x] == x:
return x
return find(root[x])

# Iteration
def find(x):
while root[x] != x:
x = root[x]
return x
1
2
3
4
5
6
# Time : O(Tree Height), Worst Case O(n)
def union(x, y):
rootx = find(x)
rooty = find(y)
if rootx != rooty: 判断两个Element在不在同一个集合当中
root[rootx] = rooty

Weighted Quick-Union

既然树的高度成为制约时间复杂度的瓶颈,我们就想办法让树平衡!

  • 以Quick union为基础,我们 额外利用一个size[]保存每一个联通集中对象的数量
  • 在调用union()的时候,我们总是把 对象数目较少的联通集连接到对象数目较多的联通集 中。
  • 通过这种方式,我们可以在一定程度上缓解树的高度太大的问题,从而改善Quick union的时间复杂度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Time : O(logn)
def find(x):
if root[x] == x:
return x
return find(root[x])

# Time : O(logn)
def union(x, y):
rootx = find(x)
rooty = find(y)
if size[rootx] >= size[rooty]:
root[rooty] = rootx
size[rootx] += size[rooty]
else:
root[rootx] = rooty
size[rooty] += size[rootx]

Path Compression

随着数据的增加,树的深度不断增加,性能会逐渐变差。这个时候,如果我们在计算一个node的root时,将node为根的树摘下来,挂在当前树的根结点上,会降低树的深度,也就是提高效率,降低时间复杂度。

1
2
3
4
5
6
7
8
9
# Path Compression 是在find的过程当中处理的
def find(x):
if root[x] == x:
return x

# make every other node in path point to its grandparent.
root[x] = find(root[x]) # Only one extra line

return root[x]

Weighted Quick-Union With Path Compression

Proof is very difficult, But the algorithm is still simple!

1
2
3
4
5
6
7
8
9
10
11
12
13
# Weighted 是体现在Union的过程当中
# Time : Very Near to O(1)
def union(x, y):
rootx = find(x)
rooty = find(y)
if rootx == rooty:
return
if size[rootx] >= size[rooty]:
root[rooty] = rootx
size[rootx] += size[rooty]
else:
root[rootx] = rooty
size[rooty] += size[rootx]

Connected

查询两个元素是否在同一个集合内。

LintCode 589. Connecting Graph

Given n nodes in a graph labeled from 1 to n. There is no edges in the graph at beginning.

You need to support the following method:

  1. connect(a, b), add an edge to connect node a and node b.
  2. query(a, b), check if two nodes are connected
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class ConnectingGraph:
def __init__(self, n):
self.root = range(n+1)

# Find the Root of node x
def find(self, x):
root = self.root # tip1 : root在类里面要加上"self."
if root[x] == x:
return x
root[x] = self.find(root[x])
return root[x]

def union(self, x, y):
root = self.root
rootx = self.find(x) # tip2 : root[x] vs find(x)
rooty = self.find(y)
if rootx != rooty:
root[rootx] = rooty

def connect(self, a, b):
self.union(a, b)

def query(self, a, b):
return self.find(a) == self.find(b)

LintCode 590. Connecting Graph II

  • 统计每个联通块的元素个数
  • query(a), Returns the number of connected component nodes which include node a.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class ConnectingGraph2:
def __init__(self, n):
self.root = range(n+1)
self.size = [1] * (n+1)

def find(self, x):
root = self.root
if root[x] == x:
return x
root[x] = self.find(root[x])
return root[x]

def connect(self, a, b):
root = self.root
size = self.size
roota = self.find(a)
rootb = self.find(b)
if roota != rootb:
root[roota] = rootb
size[rootb] += size[roota]

def query(self, a):
return self.size[self.find(a)]

130. Surrounded Regions

  • 解法1 DFS

从边缘的’O’出发,通过DFS,所有能够遍历的’O’都可以暂时被标记为’#’,那么剩下未能被标记的’O’说明被surrounded,需要在遍历结束之后全部转为’X’

  • 解法2 Union Find

将与边缘相连通的’O’全部union到一个dummy node(也可以用hasEdge[]来存储,不过内存占用更多,
最终将没有和这个dummy node是一个component的’O’点全部标记为’X

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class UnionFind(object):
def __init__(self, n):
self.root = range(n)

def find(self, x):
root = self.root
if root[x] == x:
return x
root[x] = self.find(root[x])
return root[x]

def union(self, x, y):
root = self.root
rootx = self.find(x)
rooty = self.find(y)
if rootx != rooty:
# tip : 为了总是以dummy node(total)为父节点
root[min(rootx, rooty)] = max(rootx, rooty)

class Solution(object):
def solve(self, board):
if not board:
return
m, n = len(board), len(board[0])
total = m*n
uf = UnionFind(total+1)
grid = board
for i in xrange(m):
for j in xrange(n):
if grid[i][j] == 'X':
continue
# Connect to "total" root
if i == 0 or j == 0 or i == m-1 or j == n-1:
uf.union(total, i*n+j)
else:
d = [(1, 0), (0, 1), (-1, 0), (0, -1)]
for k in xrange(4):
ni, nj = i + d[k][0], j + d[k][1]
if grid[ni][nj] == 'O':
uf.union(ni*n + nj, i*n + j)
for i in xrange(m):
for j in xrange(n):
if grid[i][j] == 'X':
continue
if uf.find(i*n + j) != total:
grid[i][j] = 'X'

737. Sentence Similarity II

典型的Union Find 应用题,两个单词是不是similarity其实就是两个单词在不在同一个集合内(connected 操作)!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class UnionFind(object):
def __init__(self, n):
self.root = range(n)

def find(self, x):
root = self.root
if root[x] == x:
return x
root[x] = self.find(root[x])
return root[x]

def union(self, x, y):
self.root[self.find(x)] = self.find(y)

class Solution(object):
def areSentencesSimilarTwo(self, words1, words2, pairs):
m, n = len(words1), len(words2)
if m != n: return False

# 建立words到index的映射关系!UnionFind 只支持数字的index!
uf = UnionFind(len(pairs)*2)
cnt = 0
pdic = {}
for w1, w2 in pairs:
if w1 not in pdic:
pdic[w1] = cnt
cnt += 1
if w2 not in pdic:
pdic[w2] = cnt
cnt += 1
uf.union(pdic[w1], pdic[w2])

for w1, w2 in zip(words1, words2):
if w1 == w2:
continue
if w1 not in pdic or w2 not in pdic:
return False
if uf.find(pdic[w1]) != uf.find(pdic[w2]):
return False
return True

统计连通块的个数

the number of connected components.

LintCode 591. Connecting Graph III

  • Query() - Returns the number of connected component in the graph
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class ConnectingGraph3:
def __init__(self, n):
self.root = range(n+1)
self.cnt = n

def find(self, x):
root = self.root
if root[x] == x:
return x
root[x] = self.find(root[x])
return root[x]

def connect(self, a, b):
root = self.root
roota = self.find(a)
rootb = self.find(b)
if roota != rootb:
root[roota] = rootb
self.cnt -= 1

def query(self):
return self.cnt

323. Number of Connected Components in an Undirected Graph

  • 解法1. DFS

将Graph原本的nodes和edges表达形式,改成hash做的邻接表,这个就可以查询从每个节点出发到的节点!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution(object):
def countComponents(self, n, edges):
visited = [0] * n
graph = [set() for _ in xrange(n)] # Adjacent Table
for i, j in edges:
graph[i].add(j)
graph[j].add(i)
res = 0
for i in xrange(n):
if visited[i] == 1:
continue
self.dfs(i, visited, graph)
res += 1
return res


def dfs(self, n, visited, graph):
if visited[n] == 1:
return
visited[n] = 1
for i in graph[n]:
self.dfs(i, visited, graph)

  • 解法2. Union Find
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Solution(object):
def countComponents(self, n, edges):
root = range(n)
self.cnt = n
def find(x):
if root[x] == x:
return x
root[x] = find(root[x])
return root[x]

def union(x, y):
rootx = find(x)
rooty = find(y)
if rootx != rooty:
root[rootx] = rooty
self.cnt -= 1

for i, j in edges:
union(i, j)
return self.cnt

305. Number of Islands II

实时放入island显示出联通块的个数,算是一个online的算法!

  • 原始UF算法是一维的,2D坐标和1D坐标的转化
  • 体现Union Find的Online特性,可以实时添加边!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Time : O(m * n + k)
class UnionFind(object):
def __init__(self, n):
self.root = [-1] * n
self.cnt = 0

def find(self, x):
root = self.root
if root[x] == x:
return x
root[x] = self.find(root[x])
return root[x]

def add(self, x):
self.root[x] = x
self.cnt += 1

def union(self, x, y):
root = self.root
rootx = self.find(x)
rooty = self.find(y)
if rootx != rooty:
root[rootx] = rooty
self.cnt -= 1

class Solution(object):
def numIslands2(self, m, n, positions):
uf = UnionFind(m * n)
res = []
d = [(1, 0), (-1, 0), (0, 1), (0, -1)]
for i, j in positions:
p = i*n + j
uf.add(p)
for k in range(4):
ni, nj = i + d[k][0], j + d[k][1]
q = ni * n + nj
if ( 0 <= ni <= m-1 and
0 <= nj <= n-1 and
uf.root[q] != -1):
uf.union(p, q)
res.append(uf.cnt)
return res

547. Friend Circles

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution(object):
def findCircleNum(self, M):
n = len(M)
root = range(n)
self.cnt = n

def find(x):
if root[x] == x:
return x
root[x] = find(root[x])
return root[x]

def union(x, y):
rootx = find(x)
rooty = find(y)
if rootx != rooty:
root[rootx] = rooty
self.cnt -= 1

for i in xrange(n):
for j in xrange(i+1, n):
if M[i][j]:
union(i, j)
return self.cnt

Redundant Connection

261. Graph Valid Tree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution(object):
def validTree(self, n, edges):
root = range(n)
self.cnt = n
def find(x):
if root[x] == x:
return x
root[x] = find(root[x])
return root[x]

def union(x, y):
rootx = find(x)
rooty = find(y)
if rootx != rooty:
root[rootx] = rooty
self.cnt -= 1
return True
return False

for i, j in edges:
if not union(i, j):
return False
return self.cnt == 1

684. Redundant Connection

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Solution(object):
def findRedundantConnection(self, edges):
root = range(1001)

def find(x):
if root[x] == x:
return x
root[x] = find(root[x])
return root[x]

for i, j in edges:
rooti = find(i)
rootj = find(j)
if rooti == rootj:
return [i, j]
root[rooti] = rootj

685. Redundant Connection II

  • Case1: There is a loop in the graph, and no vertex has more than 1 parent.
    • 有环,且没有入度大于1的node => Union Find
  • Case2: A vertex has more than 1 parent, but there isn’t a loop in the graph.
    • 无环,且有入度大于2的node => last node (indegree > 1)
  • Case3: A vertex has more than 1 parent, and is part of a loop.
    • 有环,且有入度大于2的node
    • 这种复杂的情况怎么筛选?
    • Delete the second edge!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Solution(object):
def findRedundantDirectedConnection(self, edges):
n = len(edges)
parent = [0] * (n+1)
ans = None
# Step1 : calculate indegree > 1 node
for i in xrange(n):
u, v = edges[i]
if parent[v] == 0:
parent[v] = u
else:
ans = [[parent[v], v], [u, v]]
# !!! Delete the second Edge
edges[i][1] = 0

# Step2 : Union Find detect cycle
root = range(n+1)
def find(x):
if root[x] == x:
return x
return find(root[x])

for u, v in edges:
rootu = find(u)
rootv = find(v)
if rootu == rootv: # Detect Cycle
if not ans:
return [u, v]
else:
return ans[0]
root[rootu] = rootv
return ans[1]
Share