25 Apr 2018

Binary Indexed Tree and Segment Tree

Binary Indexed Tree (BIT) 是一种数据结构, 看名字就能看出来是个 binary tree. 它可以用来解决数组区间的问题. 比如给定一个数组, 求范围 [i, j] 内所有值的和. 如果每次通过遍历求和, 那么时间复杂度是 ; 通过使用 BIT 可以把时间复杂度降低到 .

Segment Tree 和 BIT 比较类似, BIT 能解决的问题一般用 Segment Tree 也能解决.

Binary Indexed Tree

先给个例子, 比较好理解.

Range Sum Query - Mutable: 给定一个都是整型的数组, 计算出下标 i 到下标 j 这个区间内的所有数字的和, i <= j 并且包含 i 和 j. update(i, val) 函数把下标 i 所在位置的值更新为 val.

假设数组为 .

构建 BIT

那么通过这个给定的数组我们可以构造一个 BIT 的数据结构, 如下图所示.

上图的最下层是数组索引, 然后是原数组; 最上层是生成的 Binary Indexed Tree, 下方是表示该 BIT 的数组.

我们在原数组的头部新增加一个空值, 然后忽略索引 0 , 即从 1 开始表示原数组. 新数组 .

对于奇数索引(如上图中的 1, 3, 5), 在新的 BIT 数组中填入原数组对应的值, 如 . 对于偶数索引(如上图中的 0, 2, 4), 在新的 BIT 数组中的值, 是它左子树的和加上原数组对应的值得到的. 比如 .

计算区间和

这样 BIT 构建成功之后, 如果要计算原来的数组区间 [i, j] 的和, 就可以通过 tree[] 数组计算出 [0, i] 的和 sumi, 以及 [0, j + 1] 的和 sumj, 再用 sumj - sumi 得到结果.

那么 sumi 要怎么计算? 假设 j = 5, 需要计算 sum5 .即 j + 1 = 6, 6 用二进制表示为 110. 那么使用二进制表示: . 即 .

再假设 j = 2, 需要计算 sum2 . 即 j + 1 = 3, 3 用二进制表示为 011. 那么使用二进制表示: .

这里计算和 s 其实就是对于某个索引的二进制表示, 从低位到高位依次把 1 变成 0 得到值就是 bit 数组的索引, 直到所有位都变成 0. 再把这些值和该索引在 bit 的值相加. 比如上面的 011 -> 010 -> 000.

从低位到高位依次把 1 变成 0 可以通过补码实现. 对于数字 num 可以表示成 a1b , 其中 1 是最低位的 1, a 表示 1 这一位高位的其他位, b 表示 1 这一位低位的其他位. 因为 1 是最低位的 1, 所以 b 全部由 0 组成, b 的补码 b- 全部由 1 组成 . num 的补码 (a1b)- . 那么 -num = (a1b)- + 1 = a-0b- + 1 = a-0(1…1) + 1 = a-1(0…0) = a-1b . 即 num & -num = a1b & a-1b = (0…0)1(0…0) . 即 num - (num & -num) 就可以把低位的 1 变成 0.

以上计算 BIT 索引 i 的和, 用代码表示为:

1
2
3
4
5
6
7
8
public int getTreeSum(int i) {
    int sum = 0;
    while (i > 0) {
        sum += tree[i];
        i -= (i & -i);
    }
    return sum;
}

更新索引值

另一个需求是数组的在某个索引的值可以被更新. 这样我们生成的 Binary Indexed Tree 也要更新.

从上面的分析中可知, tree[] 只用更新用到这个节点计算和的节点, 即索引大于它的父节点们. 比如如果执行 update(4, 1), 那么原数组变为 [2, 3, 1, 4, 1, 2]. tree[5] 为奇数索引, 直接更新 tree[5] = 1, 另外因为 tree[6] 是由 tree[5] 求和得到, 所以 tree[6] 也需要更新, tree[6] = 3 .

即更新后如下所示(更新的节点用红框表示):

上面我们已经知道了 num & -num = a1b & a-1b = (0…0)1(0…0) , 那么 num + (num & - num) 就可以得到大于 num 的下一个待更新节点.

用代码表示为:

1
2
3
4
5
6
7
8
9
public void update(int i, int val) {
    int j = i + 1;
    int diff = val - nums[j];
    while (j < nums.length) {
        tree[j] += diff;
        j += (j & -j);
    }
    nums[i + 1] = val;
}

获取原始数组值

如果我们保存了一份数组 nums[], 那么可以直接获取. 但如果为了减少内存, 我们只保留了 Binary Indexed Tree, 那么要如何获取原始的数组值呢?

假设我们需要获取索引 i 的原始值.

一种方法是通过 getTreeSum(i + 1) - getTreeSum(i) 获取, 这种方法的时间复杂度是 .

另一种方法, 对于任意索引 x, 它的前序节点 y, 可以把 y 表示成 a0b, 其中 b 全部由 1 组成. 那么 x = a1b-, 通过之前的算法知道在把最低位的 1 变成 0 之后 x = a0b-, 记为 z. 对 y 同样从低位开始把 1 转成 0, 那么 y 会变成 a0b-</sub> 即 z.

代码表示:

1
2
3
4
5
6
7
8
9
10
11
12
public int getSingle(int i) {
    int idx = i + 1;
    int sum = tree[idx];
    if (i > 0) {
        int z = idx - (idx & -idx);
        idx--;
        if (idx != z) {
            sum -= tree[idx];
        }
    }
    return sum;
}

第二种方法, 对于 BIT 奇数索引, 时间复杂度是 ; 对于 BIT 偶数索引, 时间复杂度是 , c < 1.

其他用法

本文最后的参考链接里, topcoder 给了另一个用法.

假设有 n 张牌, 每张牌都是朝下放置的. 有两种操作:

  • T(i, j) : 把 [i, j] 区间内的牌翻面, 包含 i 和 j. 即朝上的牌翻面后朝下, 朝下的牌翻面后朝上
  • Q(i) : 如果第 i 张牌朝下那么返回 0, 朝上返回 1

最直接的做法是每次翻牌就遍历一遍. 但通过 Binary Indexed Tree 可以把时间复杂度控制到 .

新建一个数组 f[], f[i] 初始化为 0. 当执行 T(i, j) 的时候, 把 f[i]++ 并且 f[j+1]–. 当执行 Q(i) 的时候, 其实是否返回 f[0, i] 区间内的和 sum % 2.

Segment Tree

Segment Tree 也是一种 binary tree, 类似上面介绍的 Binary Indexed Tree, 用于解决区间问题. Segment Tree 的每一个节点都代表一个区间.

假设有数组 A[], 大小为 N, 那么对应的 Segment Tree 记为 T:

  1. T 的根节点代表整个数组区间 [0, N - 1]
  2. T 的每个叶子节点都是数组里的一个元素 A[i], 0 <= i < N
  3. T 的中间节点代表数组区间 A[i, j], 0 <= i < j < N

Segment Tree 的根节点代表整个数组区间 , 根节点的两个子节点分别代表区间 . 如此往复把区间折半, 直到叶子节点代表数组里某个具体的值. Segment Tree 的高度为 , 有 N 个叶子节点代表数组的 N 个元素, 有 N - 1 个内部节点, 所以总节点数为 2 * N - 1.

构建 Segment Trees

可以通过递归的方式从顶至下构建 Segment Tree. 相应代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void build(int node, int start, int end) {
    if (start == end) {
        // Leaf node will have a single element
        tree[node] = A[start];
    } else {
        int mid = (start + end) / 2;
        // Recurse on the left child
        build(2 * node, start, mid);
        // Recurse on the right child
        build(2 * node + 1, mid + 1, end);
        // Internal node will have the sum of both of its children
        tree[node] = tree[2 * node] + tree[2 * node + 1];
    }
}

这样对于之前使用的数组 , 构建出 Segment Tree 如下:

计算区间和

对于 Segment Tree 的某个节点表示的区间范围, 有几种情况:

  • 节点表示的区间 正好在请求计算的区间 范围内. 那么直接返回节点值
  • 节点表示的区间 完全不在请求计算的区间 范围内. 那么直接返回 0
  • 节点表示的区间 部分在请求计算的区间 范围内. 那么计算该节点两个子节点, 再返回和

时间复杂度为 . 相应代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int query(int node, int start, int end, int l, int r) {
    if (r < start || end < l) {
        // range represented by a node is completely outside the given range
        return 0;
    }
    if (l <= start && end <= r) {
        // range represented by a node is completely inside the given range
        return tree[node];
    }
    // range represented by a node is partially inside and partially outside the given range
    int mid = (start + end) / 2;
    int p1 = query(2 * node, start, mid, l, r);
    int p2 = query(2 * node + 1, mid + 1, end, l, r);
    return p1 + p2;
}

更新索引值

更新数组的某个值, 只用在 Segment Tree 中找到包含该值的区间并递归更新即可. 更新的时间复杂度为 .

相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void update(int node, int start, int end, int idx, int val) {
    if (start == end) {
        // Leaf node
        A[idx] += val;
        tree[node] += val;
    } else {
        int mid = (start + end) / 2;
        if (start <= idx && idx <= mid) {
            // If idx is in the left child, recurse on the left child
            update(2 * node, start, mid, idx, val);
        } else {
            // if idx is in the right child, recurse on the right child
            update(2 * node + 1, mid + 1, end, idx, val);
        }
        // Internal node will have the sum of both of its children
        tree[node] = tree[2 * node] + tree[2 * node + 1];
    }
}

Reference