洛谷P3369 普通平衡树
题目描述:
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
- 插入x数
- 删除x数(若有多个相同的数,因只删除一个)
- 查询x数的排名(排名定义为比当前数小的数的个数+1。若有多个相同的数,因输出最小的排名)
- 查询排名为x的数
- 求x的前驱(前驱定义为小于x,且最大的数)
- 求x的后继(后继定义为大于x,且最小的数)
解题思路:
由于数据范围很大,所以我们需要高效的数据结构来完成上述操作。能够胜任这一要求并且易于操作的当然首选splay树,所以我们利用splay树来实现这六个操作。为了完成上述六个操作,我们定义六个函数:
insert(x): 向树中插入数x
delete(x): 删除树中的数x
getRankByVal(x): 查找数值对应的排名
getValByRank(k): 查找排名对应的数值
getPre(x): 查找前驱
getNext(x): 查找后继
而为了实现六个函数,我们还需要一些辅助函数,它们是:
join(x,y): 将树x和树y合并为一棵新树,并返回新树根节点编号。
split(x,y,a): 将树x按照数值a分为两棵子树x和y,一棵子树中所有元素小于等于a,另一棵大于a
splay(x): 将节点x旋转到根节点
rotate(x): 旋转x节点
update(x): 更新x节点的size信息
为了实现上述函数,我们需要一些变量,我们定义fa[x]为x父节点编号,ch[x][0]为x左子节点编号,ch[x][1]为x右子节点编号。type[x]为x的类型(左子节点还是右子节点);定义size[x]为以x为根的子树的大小。我们采用数组模拟指针实现splay树。实现思想见笔记:伸展树笔记
代码示例:
#include<cstdio>
const int SZ = 1e5+10;
const int INF = 0x3f3f3f3f;
struct SplayTree{
int cnt,rt;
int size[SZ],ch[SZ][2],type[SZ],fa[SZ],val[SZ],num[SZ];
void init(){
cnt = 0; //初始化节点计数器从0开始
rt = ++cnt; //根的编号为1,大小为无穷小
type[cnt] = 2;
val[cnt] = -INF;
}
//新建一棵仅有一个节点的树
int Malloc(int x){
type[++cnt] = 2;
val[cnt] = x;
num[cnt] = 1;
size[cnt] = 1;
return cnt;
}
//将数x插入树
void insert(int x){
int y = getIdByVal(rt,x);
if(val[y] == x) num[y]++,size[y]++,update(fa[y]);
else{
split(rt,y,x); //先将原树分为值小于x和大于x的
int t = Malloc(x);
join(rt,t); //将小于x的树和树t合并
join(rt,y);//再将rt和大于x的树y合并
}
}
//从树中删除数x,假设x存在
void Delete(int x){
int y = getIdByVal(rt,x);
num[y]--;
if(!num[y]){
rt = ch[y][0];
join(rt,ch[y][1]);
type[rt] = 2;
}
}
//求包含元素y的节点编号,若y不存在,则求小于y的值最大的节点的编号
int getIdByVal(int x,int y){
while(true){
if(val[x] == y) break;
else if(val[x] < y){
if(ch[x][1]) x = ch[x][1];
else break;
}else{
if(ch[x][0]) x = ch[x][0];
else break;
}
}
splay(x);rt = x;//这里把x旋转到根了,别忘了更新根
if(val[x] > y){
x = ch[x][0];
while(ch[x][1]) x = ch[x][1];
splay(x);rt = x;//这里把x旋转到根了,别忘了更新根
}
return x;
}
//求排名为k的值的值
int getValByRank(int k){
int x = rt;
while(true){
if(k <= size[ch[x][0]]) x = ch[x][0];
else if(k - size[ch[x][0]] <= num[x]) break;
else{
k -= size[ch[x][0]] + num[x];
x = ch[x][1];
}
}
return val[x];
}
//求x的前驱
int getPre(int x){
int y = getIdByVal(rt,x);
if(val[y] < x) return val[y];
y = ch[y][0];
while(ch[y][1]) y = ch[y][1];
return val[y];
}
//求x的后继
int getNext(int x){
int y = getIdByVal(rt,x);
while(val[y] <= x) y = ch[y][1];
while(ch[y][0]) y = ch[y][0];
return val[y];
}
//更新x的size
void update(int x){
size[x] = size[ch[x][0]] + num[x] + size[ch[x][1]];
}
//将节点x旋转至根节点
void splay(int x){
while(type[x] != 2){
int y = fa[x];
if(type[x] == type[y]) rotate(y);
else rotate(x);
if(type[x] == 2) break;
rotate(x);
}
update(x);
}
//旋转x和其父节点的连线
void rotate(int x){
int t = type[x],y = fa[x],z = ch[x][1-t];
type[x] = type[y];
fa[x] = fa[y];
if(type[x] != 2) ch[fa[x]][type[x]] = x;
type[y] = 1-t;
fa[y] = x;
ch[x][1-t] = y;
if(z) type[z] = t,fa[z] = y;
ch[y][t] = z;
update(y);
}
//返回树x中最大节点的编号
int getMaxId(int x){
while(ch[x][1]) x = ch[x][1];
return x;
}
//将树x和y合并,新根赋值给x
void join(int& x,int y){
x = getMaxId(x);
splay(x);
ch[x][1] = y,type[y] = 1,fa[y] = x;
update(x);
}
//将树x按照a分为子树x和y
void split(int& x,int& y,int a){
x = getIdByVal(x,a);
y = ch[x][1],type[y] = 2,fa[y] = 0;
ch[x][1] = 0;
update(x);
}
int getRank(int x){
int y = getIdByVal(rt,x);
return size[ch[y][0]] + 1;
}
}sp;
int main(){
int n,opt,x;
scanf("%d",&n);
sp.init();
for(int i = 1;i <= n;i++){
scanf("%d%d",&opt,&x);
if(opt == 1) sp.insert(x);
else if (opt == 2) sp.Delete(x);
else if (opt == 3) printf("%d\n",sp.getRank(x));
else if (opt == 4) printf("%d\n",sp.getValByRank(x));
else if (opt == 5) printf("%d\n",sp.getPre(x));
else printf("%d\n",sp.getNext(x));
}
return 0;
}