树状数组
- 基本原理
- 扩展
- 例题: 哪些题型可以使用树状数组
题型:
一般遇见这两类题型,或者“二合一”的问法,就说明需要树状数组了 ⚠️
- 快速求前缀和
O(logN): \(\sum a[L,R]\)
- 修改某一个数
O(logN): a[x] += C
树状数组是一种支持 单点修改 和 区间查询 的,代码量小的数据结构
事实上,树状数组能解决的问题是线段树能解决的问题的子集:树状数组能做的,线段树一定能做;线段树能做的,树状数组不一定可以。然而,树状数组的代码要远比线段树短,时间效率常数也更小,因此仍有学习价值。
有时,在差分数组和辅助数组的帮助下,树状数组还可解决更强的 区间加单点值 和 区间加区间和 问题。

基本原理

过程1小时,代码1分钟
模板:
| C++ |
|---|
| // tr[i]: 以 i 结尾, 长度是 lowbit(i) 的一段区间和
// tr[i] = SUM ( a[i-lowbit(i)+1, i] )
// 修改操作: 子 --> 父1 --> 父2 --> ...
for (int i=x; i<=n; i+=lowbit(i)) tr[i] += c;
// 查询操作:
for (int i=x; i; i-=lowbit(i)) tr[i] += c;
|
(1) 前置知识lowbit(x)
定义: 非负整数x在二进制表示下 最低位1及其后面的0 构成的 数值
| C++ |
|---|
| int lowbit(int x) {
return x & -x;
}
|
(2) add(x, k) 表示将序列中第 x 个数加上 k

在整棵树上维护这个值,需要一层一层向上找到父结点,并将这些结点上的t[x]值都加上k
| C++ |
|---|
| void add(int x, int k) { // 在 第x个数 的位置加上 k
for(int i = x; i <= n; i += lowbit(i)) // 向上找“谁被影响” (上: 数值意义)
t[i] += k;
}
|
(3) ask(x) 表示将查询序列前 x 个数的和

查询这个点的前缀和,需要从这个点向左上找到上一个结点,将加上其结点的值。向左上找到上一个结点,只需要将下标 x -= lowbit(x),例如 7 - lowbit(7) = 6
| C++ |
|---|
| int ask(int x) { // 找到 前x个数 的和
int sum = 0;
for(int i = x; i; i -= lowbit(i)) // 向下找“组成部分” (下: 数值意义)
sum += t[i];
return sum;
}
|
(4) 让谁成为“树状数组”
本质上, 是想让 idx=i 的位置等于谁 ("我们的前缀和究竟是想维护谁") 🌟
| C++ |
|---|
| // (1) i -> a[i] 原数组:
add(i, a[i]);
// (2) i -> b[i] 差分数组:
add(i, a[i]-a[i-1]);
|
241 楼兰图腾

这题非常不适合入门,建议先跳过,看 242 and 243
(1) 朴素的想法 (O(N³))
最直接的想法是枚举所有可能的三个点 (i, y_i), (j, y_j), (k, y_k) 且 i < j < k,然后判断它们是否构成 V 或 ∧ 图腾。这需要三层循环,时间复杂度是 O(N³),对于 N = 200000 来说是绝对无法通过的。
(2) 优化:固定中间点 (O(N²))
设计思路:
与其枚举三个点,不如我们枚举中间那个点 j,然后统计它能和左边、右边的多少个点组成图腾
-
对于一个固定的中间点 (j, y_j):
-
要形成一个 V 图腾 (y_i > y_j < y_k),我们需要在它左边(i < j)找一个点 (i, y_i) 使得 y_i > y_j,同时在它右边(k > j)找一个点 (k, y_k) 使得 y_k > y_j
-
假设点 j 左边有 L_greater 个点的 y 值比 y_j 大
-
假设点 j 右边有 R_greater 个点的 y 值比 y_j 大
-
根据乘法原理,以 j 为谷底的 V 图腾就有 L_greater * R_greater 个
-
要形成一个 ∧ 图腾 (y_i < y_j > y_k),我们需要在它左边(i < j)找一个点 (i, y_i) 使得 y_i < y_j,同时在它右边(k > j)找一个点 (k, y_k) 使得 y_k < y_j
-
假设点 j 左边有 L_lower 个点的 y 值比 y_j 小
-
假设点 j 右边有 R_lower 个点的 y 值比 y_j 小
-
根据乘法原理,以 j 为山峰的 ∧ 图腾就有 L_lower * R_lower 个
-
总数:
算法的整体流程设计得非常巧妙:
-
第一次遍历(从左到右):
-
遍历 i 从 1 到 n。
-
对于每个 a[i],我们利用树状数组快速查询在它左边(也就是 a[1] 到 a[i-1] 中)有多少个数比 a[i] 小,有多少个数比 a[i] 大。
-
将这些结果存起来。Lower[i] 存左边比 a[i] 小的数的个数,Greater[i] 存左边比 a[i] 大的数的个数。
-
查询完后,把 a[i] 这个数“加入”到树状数组中,供后面的元素查询。
-
第二次遍历(从右到左):
-
清空树状数组。
-
遍历 i 从 n 到 1。
-
对于每个 a[i],我们利用树状数组快速查询在它右边(也就是 a[i+1] 到 a[n] 中,因为是逆序遍历,这些都是已经处理过的)有多少个数比 a[i] 小,有多少个数比 a[i] 大。
-
在这一步,我们不需要把右边的计数值存起来。因为当我们处理到 a[i] 时,我们已经从第一次遍历中知道了 Lower[i] 和 Greater[i](左边的计数值)。我们可以直接计算 a[i] 作为中间点的贡献:
-
将这些贡献累加到总和中。
-
查询和计算贡献后,把 a[i] 这个数“加入”到树状数组中,供更左边的元素查询。
| C++ |
|---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77 | #include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int N = 2000010;
typedef long long LL;
int n;
//t[i]表示树状数组i结点覆盖的范围和
int a[N], t[N];
//Lower[i]表示左边比第i个位置小的数的个数
//Greater[i]表示左边比第i个位置大的数的个数
int Lower[N], Greater[N];
//返回非负整数x在二进制表示下最低位1及其后面的0构成的数值
int lowbit(int x)
{
return x & -x;
}
//将序列中第x个数加上k
void add(int x, int k)
{
for(int i = x; i <= n; i += lowbit(i)) t[i] += k;
}
//查询序列前x个数的和
int ask(int x)
{
int sum = 0;
for(int i = x; i; i -= lowbit(i)) sum += t[i];
return sum;
}
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
//从左向右, 依次统计每个位置左边比第i个数y小的数的个数、以及大的数的个数
for(int i = 1; i <= n; i++)
{
int y = a[i]; //第i个数
//在前面已加入树状数组的所有数中统计在区间[1, y - 1]的数字的出现次数
Lower[i] = ask(y - 1);
//在前面已加入树状数组的所有数中统计在区间[y + 1, n]的数字的出现次数
Greater[i] = ask(n) - ask(y);
//将y加入树状数组, 即数字y出现1次, 供后面的元素查询
add(y, 1);
}
//清空树状数组, 从右往左统计每个位置右边比第i个数y小的数的个数、以及大的数的个数
memset(t, 0, sizeof t);
LL resA = 0, resV = 0;
//从右往左统计
for(int i = n; i >= 1; i--)
{
int y = a[i];
resA += (LL)Lower[i] * ask(y - 1);
resV += (LL)Greater[i] * (ask(n) - ask(y));
//将y加入树状数组, 即数字y出现1次
add(y, 1);
}
printf("%lld %lld\n", resV, resA);
return 0;
}
|
242 一个简单的整数问题
给定长度为 N 的数列 A,然后输入 M 行操作指令
第一类指令形如 C l r d,表示把数列中第 l∼r 个数都加 d
第二类指令形如 Q x,表示询问数列中第 x 个数的值
对于每个询问,输出一个整数表示答案
本质: 前缀和 -> 差分
Review: 差分数组
| C++ |
|---|
| // definition
b1 = a1
b2 = a2 - a1
b3 = a3 - a2
...
bn = an - a(n-1)
// a[l,r] ---add--> c
b[l]+=c
b[r+1]-=c
|
代码:
| C++ |
|---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64 | #include <iostream>
#include <string>
using namespace std;
typedef long long LL;
const int maxn = 1e5+10;
int n, m;
int a[maxn]; // "原数组"
LL tr[maxn]; // 维护的是 "差分数组"
// 树状数组三件套: lowbit(x) + add(x, d) + query(x)
int lowbit(int x)
{
return x & (-x);
}
void add(int x, int d) // b[x] + d
{
for (int i = x; i <= n; i += lowbit(i)) tr[i] += d;
}
LL query(int x) // a[x]
{
LL res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
int main()
{
cin >> n >> m;
for (int i=1; i<=n; i++) cin >> a[i];
// 初始化: 构建树状数组
// 重点: 让谁成为树状数组
for (int i=1; i<=n; i++)
{
// 为 "差分" 维护树状数组
add(i, a[i] - a[i-1]);
}
while (m--)
{
string op;
cin >> op;
if (op == "C") {
int l, r, d;
cin >> l >> r >> d;
// 回顾: 差分数组
add(l, d);
add(r+1, -d);
}
else {
int x;
cin >> x;
cout << query(x) <<endl;
}
}
return 0;
}
|
243 一个简单的整数问题 2
给定一个长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:
C l r d,表示把 A[l], A[l+1] , … , A[r] 都加上 d
Q l r,表示询问数列中第 l∼r 个数的和
对于每个询问,输出一个整数表示答案
分析:

| C++ |
|---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74 | #include <iostream>
#include <string>
using namespace std;
typedef long long LL;
const int maxn = 1e5 + 10;
LL n, m;
LL a[maxn];
LL tr1[maxn]; // 维护 b[i] 前缀和
LL tr2[maxn]; // 维护
LL lowbit(LL x)
{
return x & (-x);
}
void add(LL tr[], LL x, LL d)
{
for (LL i = x; i <= n; i += lowbit(i)) tr[i] += d;
}
LL query(LL tr[], LL x)
{
LL res = 0;
for (LL i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
LL sum(LL x)
{
return query(tr1, x) * (x + 1) - query(tr2, x);
}
int main()
{
cin >> n >> m;
for (LL i = 1; i <= n; i++) cin >> a[i];
// 初始化
for (LL i=1; i<=n; i++)
{
// 树状数组 维护 "bi" -> sigma(bi) -> ai
add(tr1, i, a[i] - a[i-1]);
// 树状数组 维护 "i*bi" -> sigma(i*bi)
add(tr2, i, i * (a[i] - a[i-1]));
}
while (m--)
{
string op;
cin >> op;
if (op == "C")
{
LL l, r, d;
cin >> l >> r >> d;
// 回顾: 差分数组 bi
add(tr1, l, d);
add(tr1, r+1, -d);
// 同理: 更新数组 i*bi
add(tr2, l, l*d);
add(tr2, r+1, (r+1)*(-d));
}
else
{
LL l, r;
cin >> l >> r;
cout << (sum(r) - sum(l-1)) <<endl;
}
}
return 0;
}
|
244 谜一样的牛
USACO, 略