Binary Indexed Tree

Catalogue
  1. 1. What is Binary Indexed Tree?
    1. 1.1. Binary Indexed Tree vs Segmented Tree
  2. 2. Operation
    1. 2.1. 1 Build
  3. 3. 2 Update
  4. 4. 3 Range Sum
  5. 5. Application
    1. 5.1. 307. Range Sum Query - Mutable
      1. 5.1.1. Solution1. Update O(1), RangeSum O(n)
      2. 5.1.2. Solution2. Update O(n), RangeSum O(1)
      3. 5.1.3. Solution3. Update O(logn), RangeSum O(logn)
    2. 5.2. 308. Range Sum Query 2D - Mutable
      1. 5.2.1. Solution1. PrefixSum
      2. 5.2.2. Solution2. BIT
    3. 5.3. 315. Count of Smaller Numbers After Self

A Fenwick tree or binary indexed tree is a data structure that can efficiently update elements and calculate prefix sums in a table of numbers.

What is Binary Indexed Tree?

1
2
3
4
5
# PrefixSum
def update(idx, n):
# O(n), update idx-th num in the array
def rangeSum(idx1, idx2):
# O(1), calculate the sum from idx1-th to idx2-th in the arrayk

传统的数组单点修改的复杂度为 O(1),查询子段和的复杂度为 O(n)
前缀和数组单点修改的复杂度为 O(n),查询子段和的复杂度为 O(1)

  • Binary Indexed Tree 修改和查询子段和复杂度均为 O(logn)
  • 所以在多组查询或动态查询时,用树状数组可以有效减小耗时,提高程序效率。

Binary Indexed Tree vs Segmented Tree

  • 树状数组 容易实现,代码量小,时间复杂度低,并且经过数学处理后也可以实现成段更新。线段树也可以做到和树状数组一样的效果,但是代码要复杂得多。
  • 不过要注意,一般情况下 树状数组能解决的问题线段树都能解决,反之有些线段树能解决的问题树状数组却不行

Operation

1 Build

从已知数组构建树状数组就是把线性的数组变成一棵树。那么,树状数组是如何把线性结构的数组变成一棵树的呢?以下以一个长度为8的数组为例:

原始数组:

1
A[1], A[2], A[3], A[4], A[5], A[6], A[7], A[8]

在修改和查询子段和时,很容易想到一种类似二分的想法来构建一棵树状的数组来保存原数组的所有信息。

1
2
3
4
5
6
7
8
C1 = A1
C2 = C1 + A2 = A1 + A2
C3 = A3
C4 = C2 + C3 + A4 = A1 + A2 + A3 + A4
C5 = A5
C6 = C5 + A6 = A5 + A6
C7 = A7
C8 = C4 + C6 + C7 + A8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8

从中可以发现,若结点的标号为 n ,则该结点的求和区域长度为 2k ,此处的 k 为 n 的二进制表示的末尾 0 的个数。

1
2
# 只保留n的二进制里最低位的1
2^k = n & (n ^ (n-1)) = n & (-n)

  • 前n项和分别保存在n二进制表示的每个“1”表示
i 二进制 包含A的个数
1 0001 1
2 0010 2
3 0011 3
4 0100 4
5 0101 1
6 0110 2
7 0111 1
8 1000 8
1
2
3
4
5
6
7
8
9
10
# Time : O(nlogn)
def build(self, nums):
n = len(nums)
# BIT 数组比原数组多一位!
A, C = nums, [0] * (n + 1)
for i in range(n):
k = i + 1 # Start From i+1
while k <= n:
C[k] += A[i]
k += (k & -k) # Next Parent Node

2 Update

  • C[i]的父节点为C[i + i & (-i)]

当我们修改A[i]的值时,记录变化,可以 从C[i]往根节点一路上溯,调整这条路上的所有C[p]即可,这个操作的复杂度在最坏情况下就是树的高度即O(logn)。

1
2
3
4
5
6
def update(self, i, val):
diff, self.A[i] = val - self.A[i], val
i += 1 # Start From i+1
while i <= self.n:
self.C[i] += diff
i += (i & -i) # Next Parent Node

3 Range Sum

  • 而对于求数列的前n项和S[n],只需找到C[n]以前(包括C[n])的所有最大子树,把其根节点的C[c]加起来即可。
1
2
3
4
5
6
7
8
9
def sumRange(self, i, j):
res, j = 0, j + 1
while j: # 前j项和(j=j+1了,数组是从0开始index的!)
res += self.C[j]
j -= (j & -j) # Next Sum Node
while i: # 前i-1项和
res -= self.C[i]
i -= (i & -i)
return res

Application

307. Range Sum Query - Mutable

  • update & range sum

    Solution1. Update O(1), RangeSum O(n)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    class NumArray(object):

    def __init__(self, nums):
    self.nums = nums

    def update(self, i, val):
    self.nums[i] = val

    def sumRange(self, i, j):
    s = 0
    for k in range(i, j+1):
    s += self.nums[k]
    return s

Solution2. Update O(n), RangeSum O(1)

  • Prefix Sum array

Solution3. Update O(logn), RangeSum O(logn)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class NumArray(object):

def __init__(self, nums):
self.n = n = len(nums)
self.A, self.C = nums, [0] *(n + 1)
for i in xrange(n):
k = i + 1 # tip1 : BIT index from 1
while k <= n:
self.C[k] += nums[i]
k += k & (-k)

def update(self, i, val):
diff = val - self.A[i]
self.A[i] = val # tip2 : remember to update original array
i += 1
while i <= self.n:
self.C[i] += diff
i += i & (-i)


def sumRange(self, i, j):
res = 0
j += 1
while j:
res += self.C[j]
j -= j & (-j)
while i: # tip3 : excluding i, so i do not need to +1
res -= self.C[i]
i -= i & (-i)
return res

308. Range Sum Query 2D - Mutable

Solution1. PrefixSum

  • Build : O(mn)
  • Update : O(n)
  • Region Sum : O(m)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class NumMatrix(object):

def __init__(self, matrix):
if not matrix:
return

self.matrix = matrix
self.preSum = copy.deepcopy(matrix)

for row in self.preSum:
for j in range(1, len(matrix[0])):
row[j] += row[j-1]

def update(self, row, col, val):
diff = val - self.matrix[row][col]
self.matrix[row][col] = val

for j in range(col, len(self.matrix[0])):
self.preSum[row][j] += diff

def sumRegion(self, row1, col1, row2, col2):
s = 0
for i in range(row1, row2+1):
row_sum = self.preSum[i][col2] - (self.preSum[i][col1-1] if col1 > 0 else 0)
s += row_sum
return s

Solution2. BIT

  • Build : O(mn(logm)(logn))
  • Update : O(logm logn)
  • Region Sum : O(logm logn)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class NumMatrix(object):

def __init__(self, matrix):
if not matrix:
return

m, n = len(matrix), len(matrix[0])
self.m, self.n = m, n
self.matrix = matrix
self.bit = [[0] * (n+1) for _ in xrange(m+1)]
for i in xrange(m):
for j in xrange(n):
self.build(i, j)

def build(self, row, col):
val = self.matrix[row][col]
i = row+1
while i <= self.m:
j = col + 1
while j <= self.n:
self.bit[i][j] += val
j += j & (-j)
i += i & (-i)

def update(self, row, col, val):
diff = val - self.matrix[row][col]
self.matrix[row][col] = val
i = row+1
while i <= self.m:
j = col + 1
while j <= self.n:
self.bit[i][j] += diff
j += j & (-j)
i += i & (-i)

def getSum(self, row, col):
i = row+1
res = 0
while i:
j = col + 1
while j:
res += self.bit[i][j]
j -= j & (-j)
i -= i & (-i)
return res

def sumRegion(self, row1, col1, row2, col2):
return self.getSum(row2, col2) - self.getSum(row1-1, col2) - self.getSum(row2, col1-1) + self.getSum(row1-1, col1-1)

315. Count of Smaller Numbers After Self

Binary Indexed Tree & Fenwick Tree

  • 对原数组nums进行 离散化处理 排序+去重,将nums从实数范围映射到 [1, len(set(nums))],记得到的新数组为iNums
1
2
3
4
5
6
idxes = {}
for k, v in enumerate(sorted(set(nums))):
idxes[v] = k + 1
iNums = [idxes[x] for x in nums]
# iNums 相当于重新映射后的Array,其间数值的相对大小没有改变,
# 但是值总的范围映射到了[0, n]这样就可以作为BIT的index了!!
  • 从右向左遍历iNums,对树状数组的iNums[i]位置执行+1操作,然后统计(0, iNums[i])的区间和,也可以用线段树
  • 把计数问题转化成了求区间和的问题!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class FenwickTree(object):

def __init__(self, n):
self.n = n
self.BIT = [0] * (n+1)

def add(self, i, val):
while i <= self.n:
self.BIT[i] += val
i += i & -i

def sum(self, i):
res = 0
while i:
res += self.BIT[i]
i -= i & -i
return res

class Solution(object):
def countSmaller(self, nums):
if not nums: return []
idxs = {}
for k, v in enumerate(sorted(set(nums))):
idxs[v] = k + 1
n = len(nums)
ftree = FenwickTree(n)
res = []
for i in xrange(n-1, -1, -1):
res.append(ftree.sum(idxs[nums[i]]-1))
ftree.add(idxs[nums[i]], 1)
return res[::-1]
Share