跳转至

平衡树

平衡树: treap = tree + heap

  1. Binary Search Tree: BST
    1. 当前节点的左子树中的任何一个点的权值,都小于,当前点的权值
    2. 当前节点的右子树中的任何一个点的权值,都大于,当前点的权值
  2. Heap: 大根堆

中序遍历: 左

本质: 动态维护一个有序数列

(1) 二叉搜索树 BST

常见操作:

  1. 插入: insert()
  2. 删除:(叶子节点)erase()
  3. 找 前驱 / 后继 (在中序遍历中的 前一个 / 后一个 位置) ++/--
    1. 找前驱:
      1. 存在左子树: 从当前节点先向左走一步,然后一直向右走到底 ✅
      2. 不存在左子树: 右上爬父节点,寻找到首个“父节点作为右儿子”的位置即可 ✅
    2. 找后继: 同理
  4. 找全局 max / min: begin() / end()-1
    1. max: 从当前节点一直向右走
    2. min: 从当前节点一直向左走
  5. 求某个数的排名
  6. 求排名是k的数是哪个
  7. 比某个数小的最大值
  8. 比某个数大的最小值

前四个操作在#incldude<set>里都有, 只有后面三个需要自己实现

(2) Treap

代码实现的过程中,一定要特别注意,哪些是引用 &p !!!

(2.1) 节点的一般定义

C++
1
2
3
4
5
6
7
8
9
struct Node {
    int l, r;
    int key;   // BST中每个点的权值
    int value; // Heap(大根) 中每个点的权值, 这里取随机值
    int cnt;   // 该key的数目
    int size;  // 以该点为根的子树的大小 (子树下所有cnt相加)
}tr[N];
// Heap: 任一节点的 value 大于 left_son 和 right_son 的 value
// Heap: 只是为了让 tree 更加 "平衡", 最好 "val大" 的在上面

注意, 我们需要 两个“哨兵”, 用于初始化

一开始,我们需要建一棵“空”的平衡树. 建两个哨兵:

  • 一个编号为 1, key 为 -INF
  • 另一个编号为 2, key 为 INF
C++
1
2
3
4
5
6
7
8
void build() {
    get_node(-INF), get_node(INF); // 两个哨兵
    root = 1, tr[1].r = 2; // 建立初始态
    pushup(root); // 信息同步更新

    // 确保 "val大" 的在上面, 保持 "平衡"
    if (tr[1].val < tr[2].val) zag(root);
}

(2.2) zig && zag

左旋zag / 右旋zig 是 Treap 中最重要的两个辅助操作 💣

这里讲的 “一个点”(&p) 指代的都是这颗子树的 root

右旋 (zig): 把一个点整到原位置的右子树上去 (比如这里的4号node)

左旋 (zag): 把一个点整到原位置的左子树上去 (比如这里的2号node)

zig 或 zag 后,会自动确保“中序遍历”的顺序不变

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
// 右旋:
void zig(int &p) {
    int q = tr[p].l;
    tr[p].l = tr[q].r;
    tr[q].r = p;
    p = q;                    // 通过指针间接改变和原来的p关联的值
    pushup(tr[p].r);
    pushup(p);                // pushup就是更新size
}

// 左旋:
void zag(int &p) {
    int q = tr[p].r;
    tr[p].r = tr[q].l;
    tr[q].l = p;
    p = q;                    // 通过指针间接改变和原来的p关联的值
    pushup(tr[p].l);
    pushup(p);                // 更新size
}

模板直接看下面的 "253 普通平衡树" 即可

253 普通平衡树

平衡树 纯模板题

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入数值 x
  2. 删除数值 x (若有多个相同的数,应只删除一个)
  3. 查询数值 x 的排名 (若有多个相同的数,应输出最小的排名)
  4. 查询排名为 x 的数值
  5. 求数值 x 的前驱 (前驱定义为小于 x 的最大的数)
  6. 求数值 x 的后继 (后继定义为大于 x 的最小的数)
C++
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int N = 100010, INF = 1e8;

int n;
struct Node {
    int l, r;
    int key, val;
    int cnt, size; // cnt: 一个数出现的次数; size: 一个数的子树大小
}tr[N];

int root, idx; // idx: 当前做到哪个点了 (跟链表里的一样, 全局变量)

void pushup(int p) // 更新一个节点的信息
{
    tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
}

int get_node(int key) // 如何 "初始化" 一个 node
{
    tr[ ++ idx].key = key; // 全局标号++, key赋值
    tr[idx].val = rand();  // val 随机值 (确保平衡)
    tr[idx].cnt = tr[idx].size = 1;
    return idx;
}

void zig(int &p) // 右旋 (&p)
{
    int q = tr[p].l;
    tr[p].l = tr[q].r, tr[q].r = p, p = q;
    pushup(tr[p].r), pushup(p);
}

void zag(int &p) // 左旋 (&p)
{
    int q = tr[p].r;
    tr[p].r = tr[q].l, tr[q].l = p, p = q;
    pushup(tr[p].l), pushup(p);
}

void build() // 初始化 Tree
{
    get_node(-INF), get_node(INF);
    root = 1, tr[1].r = 2;
    pushup(root);

    if (tr[1].val < tr[2].val) zag(root);
}

void insert(int &p, int key) // 给 当前点p 插入 值key (&p)
{
    if (!p) p = get_node(key); // 该node不存在, 新建一个
    else if (tr[p].key == key) tr[p].cnt ++ ; // 正好就是自己. cnt++
    else if (tr[p].key > key) // 说明应该去到 left_son_tree
    {
        insert(tr[p].l, key); // "插入行为" 递归到 left_son_tree
        // 确保val大的在上面 
        if (tr[tr[p].l].val > tr[p].val) zig(p); // 左子树, 右旋
    }
    else // 说明应该去到 right_son_tree
    {
        insert(tr[p].r, key);
        // 确保val大的在上面
        if (tr[tr[p].r].val > tr[p].val) zag(p); // 右子树, 左旋
    }
    pushup(p); // 更新信息, 维护 size
}

void remove(int &p, int key)
{
    if (!p) return; // 不存在的话皆大欢喜,直接return
    if (tr[p].key == key) // 要删除的就在这里!
    {
        if (tr[p].cnt > 1) tr[p].cnt -- ; // 有多个, 删一个即可
        else if (tr[p].l || tr[p].r) // 有左子树或右子树
        {
            // 右子树不存在 OR 左子树更大, "删除" 去到 左子树
            // PS: 先删var大的(玄学理论)
            if (!tr[p].r || tr[tr[p].l].val > tr[tr[p].r].val)
            {
                // 左子树, 右旋
                zig(p);
                remove(tr[p].r, key); // "删除行为"去到现在的right_son_tree
            }
            else
            {
                zag(p);
                remove(tr[p].l, key);
            }
        }
        else p = 0; // 叶子节点
    }
    else if (tr[p].key > key) remove(tr[p].l, key); // "删除" 去到 左子树
    else remove(tr[p].r, key); // "删除" 去到 右子树

    pushup(p); // 更新信息, 维护 size
}

int get_rank_by_key(int p, int key)    // 通过数值找排名
{
    if (!p) return 0;   // node不存在 (本题中不会发生此情况)
    if (tr[p].key == key) return tr[tr[p].l].size + 1;
    if (tr[p].key > key) return get_rank_by_key(tr[p].l, key);
    return tr[tr[p].l].size + tr[p].cnt + get_rank_by_key(tr[p].r, key);
}

int get_key_by_rank(int p, int rank)   // 通过排名找数值
{
    if (!p) return INF;     // node不存在 (本题中不会发生此情况)
    if (tr[tr[p].l].size >= rank) { // 去到 left_son_tree
        return get_key_by_rank(tr[p].l, rank); // 习惯上来讲排名从1开始, 要加1
    }
    if (tr[tr[p].l].size + tr[p].cnt >= rank) { // 正好就是自己
        return tr[p].key;
    }
    // 去到 right_son_tree
    return get_key_by_rank(tr[p].r, rank - tr[tr[p].l].size - tr[p].cnt);
}

int get_prev(int p, int key)   // 找到严格小于key的最大数
{
    if (!p) return -INF; // 前驱不存在
    if (tr[p].key >= key) return get_prev(tr[p].l, key); // 去 左子树
    return max(tr[p].key, get_prev(tr[p].r, key)); // 右子树 或 本身(def)
}

int get_next(int p, int key)    // 找到严格大于key的最小数
{
    if (!p) return INF; // 后继不存在
    if (tr[p].key <= key) return get_next(tr[p].r, key); // 去 右子树
    return min(tr[p].key, get_next(tr[p].l, key)); // 左子树 或 本身(def)
}

int main()
{
    build();

    scanf("%d", &n);
    while (n -- )
    {
        int opt, x;
        scanf("%d%d", &opt, &x);
        if (opt == 1) insert(root, x); // 从root这棵树开始往下看
        else if (opt == 2) remove(root, x);
        else if (opt == 3) printf("%d\n", get_rank_by_key(root, x) - 1);
        else if (opt == 4) printf("%d\n", get_key_by_rank(root, x + 1));
        else if (opt == 5) printf("%d\n", get_prev(root, x));
        else printf("%d\n", get_next(root, x));
    }

    return 0;
}

265 营业额统计

找 前驱 和 后继

C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long LL;

const int N = 33010, INF = 1e7;

int n;
struct Node
{
    int l, r;
    int key, val;
}tr[N];

int root, idx;

int get_node(int key)
{
    tr[ ++ idx].key = key;
    tr[idx].val = rand();
    return idx;
}

void build()
{
    get_node(-INF), get_node(INF);
    root = 1, tr[1].r = 2;
}

void zig(int &p)
{
    int q = tr[p].l;
    tr[p].l = tr[q].r, tr[q].r = p, p = q;
}

void zag(int &p)
{
    int q = tr[p].r;
    tr[p].r = tr[q].l, tr[q].l = p, p = q;
}

void insert(int &p, int key)
{
    if (!p) p = get_node(key);
    else if (tr[p].key == key) return;
    else if (tr[p].key > key)
    {
        insert(tr[p].l, key);
        if (tr[tr[p].l].val > tr[p].val) zig(p);
    }
    else
    {
        insert(tr[p].r, key);
        if (tr[tr[p].r].val > tr[p].val) zag(p);
    }
}

int get_prev(int p, int key)    // 找到小于等于key的最大数
{
    if (!p) return -INF;
    if (tr[p].key > key) return get_prev(tr[p].l, key);
    return max(tr[p].key, get_prev(tr[p].r, key));
}

int get_next(int p, int key)    // 找到大于等于key的最小数
{
    if (!p) return INF;
    if (tr[p].key < key) return get_next(tr[p].r, key);
    return min(tr[p].key, get_next(tr[p].l, key));
}

int main()
{
    build();
    scanf("%d", &n);

    LL res = 0;
    for (int i = 1; i <= n; i ++ )
    {
        int x;
        scanf("%d", &x);
        if (i == 1) res += x;
        else res += min(x - get_prev(root, x), get_next(root, x) - x);

        insert(root, x);
    }

    printf("%lld\n", res);

    return 0;
}