跳转至

树状数组

  1. 基本原理
  2. 扩展
    • 差分
    • 查分 + 推公式
  3. 例题: 哪些题型可以使用树状数组

题型:

一般遇见这两类题型,或者“二合一”的问法,就说明需要树状数组了 ⚠️

  1. 快速求前缀和 O(logN): \(\sum a[L,R]\)
  2. 修改某一个数 O(logN): a[x] += C

树状数组是一种支持 单点修改区间查询 的,代码量小的数据结构

事实上,树状数组能解决的问题是线段树能解决的问题的子集:树状数组能做的,线段树一定能做;线段树能做的,树状数组不一定可以。然而,树状数组的代码要远比线段树短,时间效率常数也更小,因此仍有学习价值。

有时,在差分数组和辅助数组的帮助下,树状数组还可解决更强的 区间加单点值区间加区间和 问题。

基本原理

过程1小时,代码1分钟

模板:

C++
1
2
3
4
5
6
7
8
// tr[i]: 以 i 结尾, 长度是 lowbit(i) 的一段区间和
// tr[i] = SUM ( a[i-lowbit(i)+1, i] )

// 修改操作: 子 --> 父1 --> 父2 --> ...
for (int i=x; i<=n; i+=lowbit(i)) tr[i] += c;

// 查询操作:
for (int i=x; i; i-=lowbit(i)) tr[i] += c;

(1) 前置知识lowbit(x)

定义: 非负整数x在二进制表示下 最低位1及其后面的0 构成的 数值

C++
1
2
3
int lowbit(int x) {
    return x & -x;
}

(2) add(x, k) 表示将序列中第 x 个数加上 k

在整棵树上维护这个值,需要一层一层向上找到父结点,并将这些结点上的t[x]值都加上k

C++
1
2
3
4
void add(int x, int k) { // 在 第x个数 的位置加上 k
    for(int i = x; i <= n; i += lowbit(i)) // 向上找“谁被影响” (上: 数值意义)
        t[i] += k;
}

(3) ask(x) 表示将查询序列前 x 个数的和

查询这个点的前缀和,需要从这个点向左上找到上一个结点,将加上其结点的值。向左上找到上一个结点,只需要将下标 x -= lowbit(x),例如 7 - lowbit(7) = 6

C++
1
2
3
4
5
6
int ask(int x) { // 找到 前x个数 的和
    int sum = 0;
    for(int i = x; i; i -= lowbit(i)) // 向下找“组成部分” (下: 数值意义)
        sum += t[i];
    return sum;
}

(4) 让谁成为“树状数组”

本质上, 是想让 idx=i 的位置等于谁 ("我们的前缀和究竟是想维护谁") 🌟

C++
1
2
3
4
5
// (1) i -> a[i] 原数组:
add(i, a[i]);

// (2) i -> b[i] 差分数组:
add(i, a[i]-a[i-1]);

241 楼兰图腾

这题非常不适合入门,建议先跳过,看 242 and 243

(1) 朴素的想法 (O(N³))

最直接的想法是枚举所有可能的三个点 (i, y_i), (j, y_j), (k, y_k)i < j < k,然后判断它们是否构成 V 或 ∧ 图腾。这需要三层循环,时间复杂度是 O(N³),对于 N = 200000 来说是绝对无法通过的。

(2) 优化:固定中间点 (O(N²))

设计思路:

与其枚举三个点,不如我们枚举中间那个点 j,然后统计它能和左边、右边的多少个点组成图腾

  • 对于一个固定的中间点 (j, y_j)

    • 要形成一个 V 图腾 (y_i > y_j < y_k),我们需要在它左边(i < j)找一个点 (i, y_i) 使得 y_i > y_j,同时在它右边(k > j)找一个点 (k, y_k) 使得 y_k > y_j

      • 假设点 j 左边有 L_greater 个点的 y 值比 y_j

      • 假设点 j 右边有 R_greater 个点的 y 值比 y_j

      • 根据乘法原理,以 j 为谷底的 V 图腾就有 L_greater * R_greater

    • 要形成一个 ∧ 图腾 (y_i < y_j > y_k),我们需要在它左边(i < j)找一个点 (i, y_i) 使得 y_i < y_j,同时在它右边(k > j)找一个点 (k, y_k) 使得 y_k < y_j

      • 假设点 j 左边有 L_lower 个点的 y 值比 y_j

      • 假设点 j 右边有 R_lower 个点的 y 值比 y_j

      • 根据乘法原理,以 j 为山峰的 ∧ 图腾就有 L_lower * R_lower

  • 总数:

    • 总的 V 图腾数 = ∑j=1n​(L_greaterj​×R_greaterj​)

    • 总的 ∧ 图腾数 = ∑j=1n​(L_lowerj​×R_lowerj​)

算法的整体流程设计得非常巧妙:

  1. 第一次遍历(从左到右)

    • 遍历 i 从 1 到 n

    • 对于每个 a[i],我们利用树状数组快速查询在它左边(也就是 a[1]a[i-1] 中)有多少个数比 a[i] 小,有多少个数比 a[i] 大。

    • 将这些结果存起来。Lower[i] 存左边比 a[i] 小的数的个数,Greater[i] 存左边比 a[i] 大的数的个数。

    • 查询完后,把 a[i] 这个数“加入”到树状数组中,供后面的元素查询。

  2. 第二次遍历(从右到左)

    • 清空树状数组。

    • 遍历 in 到 1。

    • 对于每个 a[i],我们利用树状数组快速查询在它右边(也就是 a[i+1]a[n] 中,因为是逆序遍历,这些都是已经处理过的)有多少个数比 a[i] 小,有多少个数比 a[i] 大。

    • 在这一步,我们不需要把右边的计数值存起来。因为当我们处理到 a[i] 时,我们已经从第一次遍历中知道了 Lower[i]Greater[i](左边的计数值)。我们可以直接计算 a[i] 作为中间点的贡献:

      • V 图腾贡献:Greater[i] * (右边比 a[i] 大的个数)

      • ∧ 图腾贡献:Lower[i] * (右边比 a[i] 小的个数)

    • 将这些贡献累加到总和中。

    • 查询和计算贡献后,把 a[i] 这个数“加入”到树状数组中,供更左边的元素查询。

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const int N = 2000010;

typedef long long LL;

int n;
//t[i]表示树状数组i结点覆盖的范围和
int a[N], t[N];
//Lower[i]表示左边比第i个位置小的数的个数
//Greater[i]表示左边比第i个位置大的数的个数
int Lower[N], Greater[N];

//返回非负整数x在二进制表示下最低位1及其后面的0构成的数值
int lowbit(int x)
{
    return x & -x;
}

//将序列中第x个数加上k
void add(int x, int k)
{
    for(int i = x; i <= n; i += lowbit(i)) t[i] += k;
}

//查询序列前x个数的和
int ask(int x)
{
    int sum = 0;
    for(int i = x; i; i -= lowbit(i)) sum += t[i];
    return sum;
}

int main()
{

    scanf("%d", &n);
    for(int i = 1; i <= n; i++) scanf("%d", &a[i]);

    //从左向右, 依次统计每个位置左边比第i个数y小的数的个数、以及大的数的个数
    for(int i = 1; i <= n; i++)
    {
        int y = a[i]; //第i个数

        //在前面已加入树状数组的所有数中统计在区间[1, y - 1]的数字的出现次数
        Lower[i] = ask(y - 1); 

        //在前面已加入树状数组的所有数中统计在区间[y + 1, n]的数字的出现次数
        Greater[i] = ask(n) - ask(y);

        //将y加入树状数组, 即数字y出现1次, 供后面的元素查询
        add(y, 1);
    }

    //清空树状数组, 从右往左统计每个位置右边比第i个数y小的数的个数、以及大的数的个数
    memset(t, 0, sizeof t);

    LL resA = 0, resV = 0;
    //从右往左统计
    for(int i = n; i >= 1; i--)
    {
        int y = a[i];
        resA += (LL)Lower[i] * ask(y - 1);
        resV += (LL)Greater[i] * (ask(n) - ask(y));

        //将y加入树状数组, 即数字y出现1次
        add(y, 1);
    }

    printf("%lld %lld\n", resV, resA);

    return 0;
}

242 一个简单的整数问题

给定长度为 N 的数列 A,然后输入 M 行操作指令

第一类指令形如 C l r d,表示把数列中第 l∼r 个数都加 d

第二类指令形如 Q x,表示询问数列中第 x 个数的值

对于每个询问,输出一个整数表示答案

本质: 前缀和 -> 差分

Review: 差分数组

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
// definition
b1 = a1
b2 = a2 - a1
b3 = a3 - a2
...
bn = an - a(n-1)

// a[l,r] ---add--> c
b[l]+=c
b[r+1]-=c

代码:

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <iostream>
#include <string>
using namespace std;

typedef long long LL;
const int maxn = 1e5+10;

int n, m;
int a[maxn];    // "原数组"
LL tr[maxn];    // 维护的是 "差分数组"

// 树状数组三件套: lowbit(x) + add(x, d) + query(x)

int lowbit(int x)
{
    return x & (-x);
}

void add(int x, int d) // b[x] + d
{
    for (int i = x; i <= n; i += lowbit(i)) tr[i] += d;
}

LL query(int x) // a[x]
{
    LL res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

int main()
{
    cin >> n >> m;
    for (int i=1; i<=n; i++) cin >> a[i];

    // 初始化: 构建树状数组
    // 重点: 让谁成为树状数组
    for (int i=1; i<=n; i++)
    {
        // 为 "差分" 维护树状数组
        add(i, a[i] - a[i-1]);
    }

    while (m--)
    {
        string op;
        cin >> op;

        if (op == "C") {
            int l, r, d;
            cin >> l >> r >> d;
            // 回顾: 差分数组
            add(l, d);
            add(r+1, -d);
        }
        else {
            int x;
            cin >> x;
            cout << query(x) <<endl;
        }
    }

    return 0;
}

243 一个简单的整数问题 2

给定一个长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:

  1. C l r d,表示把 A[l], A[l+1] , … , A[r] 都加上 d
  2. Q l r,表示询问数列中第 l∼r 个数的和

对于每个询问,输出一个整数表示答案

分析:

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <iostream>
#include <string>
using namespace std;

typedef long long LL;

const int maxn = 1e5 + 10;

LL n, m;
LL a[maxn];
LL tr1[maxn]; // 维护 b[i] 前缀和
LL tr2[maxn]; // 维护

LL lowbit(LL x)
{
    return x & (-x);
}

void add(LL tr[], LL x, LL d)
{
    for (LL i = x; i <= n; i += lowbit(i)) tr[i] += d;
}

LL query(LL tr[], LL x)
{
    LL res = 0;
    for (LL i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

LL sum(LL x)
{
    return query(tr1, x) * (x + 1) - query(tr2, x);
}

int main()
{
    cin >> n >> m;

    for (LL i = 1; i <= n; i++) cin >> a[i];

    // 初始化
    for (LL i=1; i<=n; i++)
    {
        // 树状数组 维护 "bi" -> sigma(bi) -> ai
        add(tr1, i, a[i] - a[i-1]);
        // 树状数组 维护 "i*bi"  ->  sigma(i*bi)
        add(tr2, i, i * (a[i] - a[i-1]));
    }

    while (m--)
    {
        string op;
        cin >> op;
        if (op == "C")
        {
            LL l, r, d;
            cin >> l >> r >> d;
            // 回顾: 差分数组 bi
            add(tr1, l, d);
            add(tr1, r+1, -d);
            // 同理: 更新数组 i*bi
            add(tr2, l, l*d);
            add(tr2, r+1, (r+1)*(-d));
        }
        else
        {
            LL l, r;
            cin >> l >> r;
            cout << (sum(r) - sum(l-1)) <<endl;
        }
    }
    return 0;
}

244 谜一样的牛

USACO, 略