洛谷P3369 普通平衡树

洛谷P3369 普通平衡树

题目描述:

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  • 插入x数
  • 删除x数(若有多个相同的数,因只删除一个)
  • 查询x数的排名(排名定义为比当前数小的数的个数+1。若有多个相同的数,因输出最小的排名)
  • 查询排名为x的数
  • 求x的前驱(前驱定义为小于x,且最大的数)
  • 求x的后继(后继定义为大于x,且最小的数)

解题思路:

由于数据范围很大,所以我们需要高效的数据结构来完成上述操作。能够胜任这一要求并且易于操作的当然首选splay树,所以我们利用splay树来实现这六个操作。为了完成上述六个操作,我们定义六个函数:
insert(x): 向树中插入数x
delete(x): 删除树中的数x
getRankByVal(x): 查找数值对应的排名
getValByRank(k): 查找排名对应的数值
getPre(x): 查找前驱
getNext(x): 查找后继

而为了实现六个函数,我们还需要一些辅助函数,它们是:

join(x,y): 将树x和树y合并为一棵新树,并返回新树根节点编号。
split(x,y,a): 将树x按照数值a分为两棵子树x和y,一棵子树中所有元素小于等于a,另一棵大于a
splay(x): 将节点x旋转到根节点
rotate(x): 旋转x节点
update(x): 更新x节点的size信息

为了实现上述函数,我们需要一些变量,我们定义fa[x]为x父节点编号,ch[x][0]为x左子节点编号,ch[x][1]为x右子节点编号。type[x]为x的类型(左子节点还是右子节点);定义size[x]为以x为根的子树的大小。我们采用数组模拟指针实现splay树。实现思想见笔记:伸展树笔记

代码示例:

#include<cstdio>
const int SZ = 1e5+10;
const int INF = 0x3f3f3f3f;

struct SplayTree{
    int cnt,rt;
    int size[SZ],ch[SZ][2],type[SZ],fa[SZ],val[SZ],num[SZ];
    void init(){
        cnt = 0;    //初始化节点计数器从0开始
        rt = ++cnt; //根的编号为1,大小为无穷小
        type[cnt] = 2;
        val[cnt] = -INF;
    }
    //新建一棵仅有一个节点的树
    int Malloc(int x){
        type[++cnt] = 2;
        val[cnt] = x;
        num[cnt] = 1;
        size[cnt] = 1;
        return cnt;
    }
    //将数x插入树
    void insert(int x){
        int y = getIdByVal(rt,x);
        if(val[y] == x) num[y]++,size[y]++,update(fa[y]);
        else{
            split(rt,y,x); //先将原树分为值小于x和大于x的
            int t = Malloc(x);
            join(rt,t); //将小于x的树和树t合并
            join(rt,y);//再将rt和大于x的树y合并
        }
    }
    //从树中删除数x,假设x存在
    void Delete(int x){
        int y = getIdByVal(rt,x);
        num[y]--;
        if(!num[y]){
            rt = ch[y][0];
            join(rt,ch[y][1]);
            type[rt] = 2;
        }
    }
    //求包含元素y的节点编号,若y不存在,则求小于y的值最大的节点的编号
    int getIdByVal(int x,int y){
        while(true){
            if(val[x] == y) break;
            else if(val[x] < y){
                if(ch[x][1]) x = ch[x][1];
                else break;
            }else{
                if(ch[x][0]) x = ch[x][0];
                else break;
            } 
        }
        splay(x);rt = x;//这里把x旋转到根了,别忘了更新根
        if(val[x] > y){
            x = ch[x][0];
            while(ch[x][1]) x = ch[x][1];
            splay(x);rt = x;//这里把x旋转到根了,别忘了更新根
        }
        return x;
    }
    //求排名为k的值的值
    int getValByRank(int k){
        int x = rt;
        while(true){
            if(k <= size[ch[x][0]]) x = ch[x][0];
            else if(k - size[ch[x][0]] <= num[x]) break;
            else{
                k -= size[ch[x][0]] + num[x];
                x = ch[x][1];
            }
        }
        return val[x];
    }
    //求x的前驱
    int getPre(int x){
        int y = getIdByVal(rt,x);
        if(val[y] < x) return val[y];
        y = ch[y][0];
        while(ch[y][1]) y = ch[y][1];
        return val[y];
    }
    //求x的后继
    int getNext(int x){
        int y = getIdByVal(rt,x);
        while(val[y] <= x) y = ch[y][1];
        while(ch[y][0]) y = ch[y][0];
        return val[y];
    }
    //更新x的size
    void update(int x){
        size[x] = size[ch[x][0]] + num[x] + size[ch[x][1]];
    }
    //将节点x旋转至根节点
    void splay(int x){
        while(type[x] != 2){
            int y = fa[x];
            if(type[x] == type[y]) rotate(y);
            else rotate(x);
            if(type[x] == 2) break;
            rotate(x);
        }
        update(x);
    }
    //旋转x和其父节点的连线
    void rotate(int x){
        int t = type[x],y = fa[x],z = ch[x][1-t];
        type[x] = type[y];
        fa[x] = fa[y];
        if(type[x] != 2) ch[fa[x]][type[x]] = x;
        type[y] = 1-t;
        fa[y] = x;
        ch[x][1-t] = y;
        if(z) type[z] = t,fa[z] = y;
        ch[y][t] = z;
        update(y);
    }
    //返回树x中最大节点的编号
    int getMaxId(int x){
        while(ch[x][1])   x = ch[x][1];
        return x;
    }
    //将树x和y合并,新根赋值给x
    void join(int& x,int y){
        x = getMaxId(x);
        splay(x);
        ch[x][1] = y,type[y] = 1,fa[y] = x;
        update(x);
    }
    //将树x按照a分为子树x和y
    void split(int& x,int& y,int a){
        x = getIdByVal(x,a);
        y = ch[x][1],type[y] = 2,fa[y] = 0;
        ch[x][1] = 0;
        update(x);
    }
    int getRank(int x){
        int y = getIdByVal(rt,x);
        return size[ch[y][0]] + 1;
    }
}sp;
int main(){
    int n,opt,x;
    scanf("%d",&n);
    sp.init();
    for(int i = 1;i <= n;i++){
        scanf("%d%d",&opt,&x);
        if(opt == 1) sp.insert(x);
        else if (opt == 2) sp.Delete(x);
        else if (opt == 3) printf("%d\n",sp.getRank(x));
        else if (opt == 4) printf("%d\n",sp.getValByRank(x));
        else if (opt == 5) printf("%d\n",sp.getPre(x));
        else printf("%d\n",sp.getNext(x));
    }
    return 0;
}
已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 程序猿惹谁了 设计师:白松林 返回首页