Logo handezheng 的博客

博客

各种模板

...
handezheng
2025-12-01 12:50:48
崖谷涂足无人问,山巅独饮众生哗。

本文章由 WyOJ Shojo 从洛谷专栏拉取,原发布时间为 2024-02-19 09:18:43

板子大集

快读 和 快写

inline void read(int &x){\/\/快读
	int f = 1;
	x = 0;
	char s = getchar();
	while(s < '0' || s > '9') {
		if(s == '-') f = -1;
		s = getchar();
	}
	while(s >= '0' && s <= '9') {
		x = x * 10 + s - '0';
		s = getchar();
	}
	x *= f;
}

inline void write(int x){\/\/快写
	if(x < 0){
		putchar('-');
		x = -x;
	}
	if(x > 9) write(x \/ 10);
	putchar(x % 10 + '0');
	return ;
}

基础算法


高精度(感谢丁神的高精板子)

#include <bits\/stdc++.h>
using namespace std;
namespace CommonlyIO {
typedef long long ll;
typedef pair<int, int> PII;
#define x first
#define y second
#define sz(a) a.size()
#define lowbit(x) (x & -x)
#define debug(x) cout << #x << ':' << x << endl
#define VI vector<int>
#define all(a) a.begin(), a.end()
#define rep(i, c, n) for (int i = c; i <= n; ++i)
#define per(i, c, n) for (int i = c; i >= n; --i)
#define w(x) while (x--)
#define db double
#define pb(x) push_back(x)
#define gc() getchar()
}  \/\/ namespace CommonlyIO
using namespace CommonlyIO;
struct uinT : public vector<int> {
    const static int mod = (int)1e7;
    const static int wei = 7;
    uinT() {}
    uinT(int t) : vector<int>(1, t) {}
    uinT(vector<char> s) {
        for (int i = s.size() - 1; i >= 0; i -= wei) {
            int t = 0;
            for (int j = max(0, i - wei + 1); j <= i; ++j) {
                t = (t * 10 + (s[j] - '0'));
            }
            push_back(t);
        }
    }
    int const& operator[](const int n) const {
        return (n < size()) ? *(cbegin() + n) : (int const&)0;
    }
    int& operator[](const int n) { return *(begin() + n); }
    friend void input(uinT& a) {
        vector<char> s;
        char c = gc();
        while (!isdigit(c))
            c = gc();
        while (isdigit(c))
            s.push_back(c), c = gc();
        a = uinT(s);
    }
    friend void output(const uinT& a) {
        printf("%d", a.back());
        for (int i = a.size() - 2; i >= 0; --i) {
            printf("%0*d", wei, a[i]);
        }
    }
    friend bool operator==(uinT const& a, uinT const& b) {
        if (a.size() != b.size())
            return 0;
        for (int i = 0; i < a.size(); ++i)
            if (a[i] != b[i])
                return 0;
        return 1;
    }
    friend bool operator!=(uinT const& a, uinT const& b) { return !(a == b); }
    friend bool operator<(uinT const& a, uinT const& b) {
        if (a.size() != b.size())
            return a.size() < b.size();
        for (int i = a.size() - 1; i >= 0; --i) {
            if (a[i] != b[i])
                return a[i] < b[i];
        }
        return 0;
    }
    friend bool operator>(uinT const& a, uinT const& b) { return b < a; }
    friend bool operator<=(uinT const& a, uinT const& b) {
        if (a.size() != b.size())
            return a.size() < b.size();
        for (int i = a.size() - 1; i >= 0; --i) {
            if (a[i] != b[i])
                return a[i] < b[i];
        }
        return 1;
    }
    friend bool operator>=(uinT const& a, uinT const& b) { return b <= a; }
    friend uinT operator+(uinT const& a, uinT const& b) {
        uinT c = a;
        c.resize(max(a.size(), b.size()) + 1);
        for (int i = 0; i < b.size(); ++i) {
            c[i] += b[i];
            if (c[i] >= mod) {
                c[i] -= mod;
                c[i + 1] += 1;
            }
        }
        for (int i = b.size(); i < c.size() - 1; ++i) {
            if (c[i] >= mod) {
                c[i] -= mod;
                c[i + 1] += 1;
            }
        }
        if (c.back() == 0)
            c.pop_back();
        return c;
    }
    friend uinT operator-(uinT const& a, uinT const& b) {
        uinT c = a;
        for (int i = 0; i < b.size(); ++i) {
            c[i] -= b[i];
            if (c[i] < 0) {
                c[i] += mod;
                c[i + 1] -= 1;
            }
        }
        for (int i = b.size(); i < c.size(); ++i) {
            if (c[i] < 0) {
                c[i] += mod;
                c[i + 1] -= 1;
            } else
                break;
        }
        while (c.size() > 1 && c.back() == 0)
            c.pop_back();
        return c;
    }
    friend uinT operator*(uinT const& a, uinT const& b) {
        if (a == 0 || b == 0)
            return 0;  \/\/!
        vector<ll> t(a.size() + b.size());
        for (int i = 0; i < a.size(); ++i) {
            for (int j = 0; j < b.size(); ++j) {
                t[i + j] += 1ll * a[i] * b[j];
            }
        }
        for (int i = 0; i < t.size() - 1; ++i) {
            t[i + 1] += t[i] \/ mod;
            t[i] %= mod;
        }
        if (t.back() == 0)
            t.pop_back();
        uinT c;
        c.resize(t.size());
        for (int i = 0; i < t.size(); ++i) {
            c[i] = (int)t[i];
        }
        return c;
    }
    friend uinT operator\/(uinT const& a, uinT const& b) {
        if (a.size() < b.size())
            return 0;
        uinT c, d;
        c.assign(a.end() - b.size() + 1, a.end());
        for (int i = a.size() - b.size(); i >= 0; --i) {
            c.insert(c.begin(), a[i]);
            ll t = (c.size() > b.size())
                       ? (1ll * c.back() * mod + *(c.crbegin() + 1))
                       : (c.back());
            int l = (t \/ (b.back() + 1));
            int r = ((t + 1) \/ b.back());
            while (l < r) {
                int mid = (l + r + 1) >> 1;
                if (b * mid <= c)
                    l = mid;
                else
                    r = mid - 1;
            }
            c -= b * l;
            if (c.back() == 0)
                c.pop_back();
            d.push_back(l);
        }
        reverse(d.begin(), d.end());
        if (d.size() > 1 && d.back() == 0)
            d.pop_back();
        return d;
    }
    friend uinT operator%(uinT const& a, uinT const& b) {
        return a - a \/ b * b;
    }
    friend uinT const& operator+=(uinT& a, uinT const& b) { return a = a + b; }
    friend uinT const& operator-=(uinT& a, uinT const& b) { return a = a - b; }
    friend uinT const& operator*=(uinT& a, uinT const& b) { return a = a * b; }
    friend uinT const& operator\/=(uinT& a, uinT const& b) { return a = a \/ b; }
    friend uinT const& operator%=(uinT& a, uinT const& b) { return a = a % b; }
};

struct bigint {
    bool f;
    uinT t;
    bigint() : f(0) {}
    bigint(int t) : f(t < 0), t(uinT(abs(t))) {}
    bigint(bool f, uinT t) : f(f), t(t) {}
    friend void input(bigint& a) {
        a.f = 0;
        vector<char> s;
        char c = gc();
        for (; !isdigit(c); c = gc())
            if (c == '-')
                a.f = 1;
        while (isdigit(c))
            s.push_back(c), c = gc();
        a.t = uinT(s);
    }
    friend void output(const bigint& a) {
        if (a.f)
            putchar('-');
        output(a.t);
    }
    friend bigint abs(bigint const& a) { return bigint(0, a.t); }
    friend bool operator==(bigint const& a, bigint const& b) {
        return (a.f == b.f) && (a.t == b.t);
    }
    friend bool operator!=(bigint const& a, bigint const& b) {
        return !(a == b);
    }
    friend bool operator<(bigint const& a, bigint const& b) {
        if (a.f != b.f)
            return a.f;
        return a.f ? a.t > b.t : a.t < b.t;
    }
    friend bool operator>(bigint const& a, bigint const& b) { return b < a; }
    friend bool operator<=(bigint const& a, bigint const& b) {
        if (a.f != b.f)
            return a.f;
        return a.f ? a.t >= b.t : a.t <= b.t;
    }
    friend bool operator>=(bigint const& a, bigint const& b) { return b <= a; }
    friend bigint operator-(bigint const& a) { return bigint(!a.f, a.t); }
    friend bigint operator+(bigint const& a, bigint const& b) {
        if (a.f == b.f)
            return bigint(a.f, a.t + b.t);
        else if (a.t > b.t)
            return bigint(a.f, a.t - b.t);
        else if (a.t < b.t)
            return bigint(b.f, b.t - a.t);
        else
            return 0;
    }
    friend bigint operator-(bigint const& a, bigint const& b) {
        if (a.f == b.f) {
            if (a.t > b.t)
                return bigint(a.f, a.t - b.t);
            else if (a.t < b.t)
                return bigint(!a.f, b.t - a.t);
            else
                return 0;
        } else
            return bigint(a.f, a.t + b.t);
    }
    friend bigint operator*(bigint const& a, bigint const& b) {
        if (a == 0 || b == 0)
            return 0;
        return bigint(a.f ^ b.f, a.t * b.t);
    }
    friend bigint operator\/(bigint const& a, bigint const& b) {
        return bigint(a.f ^ b.f, a.t \/ b.t);
    }
    friend bigint operator%(bigint const& a, bigint const& b) {
        return bigint(a.f, a.t % b.t);
    }
    friend bigint const& operator+=(bigint& a, bigint const& b) {
        return a = a + b;
    }
    friend bigint const& operator-=(bigint& a, bigint const& b) {
        return a = a - b;
    }
    friend bigint const& operator*=(bigint& a, bigint const& b) {
        return a = a * b;
    }
    friend bigint const& operator\/=(bigint& a, bigint const& b) {
        return a = a \/ b;
    }
    friend bigint const& operator%=(bigint& a, bigint const& b) {
        return a = a % b;
    }
};
const int INF = 0x3f3f3f3f;
const ll INFll = 9223372036854775807;
const double PI = 3.1415926;
#define getchar()                                                          \
    (tt == ss && (tt = (ss = In) + fread(In, 1, 1 << 20, stdin), ss == tt) \
         ? EOF                                                             \
         : *ss++)
char In[1 << 20], *ss = In, *tt = In;
namespace Fastio {
struct Reader {
    template <typename T>
    Reader& operator>>(T& x) {
        x = 0;
        short f = 1;
        char c = getchar();
        while (c < '0' || c > '9') {
            if (c == '-')
                f *= -1;
            c = getchar();
        }
        while (c >= '0' && c <= '9')
            x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
        x *= f;
        return *this;
    }
    Reader& operator>>(double& x) {
        x = 0;
        double t = 0;
        short f = 1, s = 0;
        char c = getchar();
        while ((c < '0' || c > '9') && c != '.') {
            if (c == '-')
                f *= -1;
            c = getchar();
        }
        while (c >= '0' && c <= '9' && c != '.')
            x = x * 10 + (c ^ 48), c = getchar();
        if (c == '.')
            c = getchar();
        else {
            x *= f;
            return *this;
        }
        while (c >= '0' && c <= '9')
            t = t * 10 + (c ^ 48), s++, c = getchar();
        while (s--)
            t \/= 10.0;
        x = (x + t) * f;
        return *this;
    }
    Reader& operator>>(long double& x) {
        x = 0;
        long double t = 0;
        short f = 1, s = 0;
        char c = getchar();
        while ((c < '0' || c > '9') && c != '.') {
            if (c == '-')
                f *= -1;
            c = getchar();
        }
        while (c >= '0' && c <= '9' && c != '.')
            x = x * 10 + (c ^ 48), c = getchar();
        if (c == '.')
            c = getchar();
        else {
            x *= f;
            return *this;
        }
        while (c >= '0' && c <= '9')
            t = t * 10 + (c ^ 48), s++, c = getchar();
        while (s--)
            t \/= 10.0;
        x = (x + t) * f;
        return *this;
    }
    Reader& operator>>(char& c) {
        c = getchar();
        while (c == ' ' || c == '\n' || c == '\r')
            c = getchar();
        return *this;
    }
    Reader& operator>>(char* str) {
        int len = 0;
        char c = getchar();
        while (c == ' ' || c == '\n' || c == '\r')
            c = getchar();
        while (c != ' ' && c != '\n' && c != '\r')
            str[len++] = c, c = getchar();
        str[len] = '\';
        return *this;
    }
    Reader& operator>>(string& str) {
        str.clear();
        char c = getchar();
        while (c == ' ' || c == '\n' || c == '\r')
            c = getchar();
        while (c != ' ' && c != '\n' && c != '\r')
            str.push_back(c), c = getchar();
        return *this;
    }
    Reader& operator>>(bigint& a) {
        input(a);
        return *this;
    }
    Reader& operator>>(__float128& x) {
        x = 0;
        __float128 t = 0;
        short f = 1, s = 0;
        char c = getchar();
        while ((c < '0' || c > '9') && c != '.') {
            if (c == '-')
                f *= -1;
            c = getchar();
        }
        while (c >= '0' && c <= '9' && c != '.')
            x = x * 10 + (c ^ 48), c = getchar();
        if (c == '.')
            c = getchar();
        else {
            x *= f;
            return *this;
        }
        while (c >= '0' && c <= '9')
            t = t * 10 + (c ^ 48), s++, c = getchar();
        while (s--)
            t \/= 10.0;
        x = (x + t) * f;
        return *this;
    }
    Reader() {}
} cin;
const char endl = '\n';
struct Writer {
    const int Setprecision = 6;
    typedef int mxdouble;
    template <typename T>
    Writer& operator<<(T x) {
        if (x == 0) {
            putchar('0');
            return *this;
        }
        if (x < 0)
            putchar('-'), x = -x;
        static short sta[40];
        short top = 0;
        while (x > 0)
            sta[++top] = x % 10, x \/= 10;
        while (top > 0)
            putchar(sta[top] + '0'), top--;
        return *this;
    }
    Writer& operator<<(const bigint& n) {
        output(n);
        return *this;
    }
    Writer& operator<<(__float128 x) {
        if (x < 0)
            putchar('-'), x = -x;
        mxdouble _ = x;
        x -= (__float128)_;
        static short sta[40];
        short top = 0;
        while (_ > 0)
            sta[++top] = _ % 10, _ \/= 10;
        if (top == 0)
            putchar('0');
        while (top > 0)
            putchar(sta[top] + '0'), top--;
        putchar('.');
        for (int i = 0; i < Setprecision; i++)
            x *= 10;
        _ = x;
        while (_ > 0)
            sta[++top] = _ % 10, _ \/= 10;
        for (int i = 0; i < Setprecision - top; i++)
            putchar('0');
        while (top > 0)
            putchar(sta[top] + '0'), top--;
        return *this;
    }
    Writer& operator<<(double x) {
        if (x < 0)
            putchar('-'), x = -x;
        mxdouble _ = x;
        x -= (double)_;
        static short sta[40];
        short top = 0;
        while (_ > 0)
            sta[++top] = _ % 10, _ \/= 10;
        if (top == 0)
            putchar('0');
        while (top > 0)
            putchar(sta[top] + '0'), top--;
        putchar('.');
        for (int i = 0; i < Setprecision; i++)
            x *= 10;
        _ = x;
        while (_ > 0)
            sta[++top] = _ % 10, _ \/= 10;
        for (int i = 0; i < Setprecision - top; i++)
            putchar('0');
        while (top > 0)
            putchar(sta[top] + '0'), top--;
        return *this;
    }
    Writer& operator<<(long double x) {
        if (x < 0)
            putchar('-'), x = -x;
        mxdouble _ = x;
        x -= (long double)_;
        static short sta[40];
        short top = 0;
        while (_ > 0)
            sta[++top] = _ % 10, _ \/= 10;
        if (top == 0)
            putchar('0');
        while (top > 0)
            putchar(sta[top] + '0'), top--;
        putchar('.');
        for (int i = 0; i < Setprecision; i++)
            x *= 10;
        _ = x;
        while (_ > 0)
            sta[++top] = _ % 10, _ \/= 10;
        for (int i = 0; i < Setprecision - top; i++)
            putchar('0');
        while (top > 0)
            putchar(sta[top] + '0'), top--;
        return *this;
    }
    Writer& operator<<(char c) {
        putchar(c);
        return *this;
    }
    Writer& operator<<(char* str) {
        int cur = 0;
        while (str[cur])
            putchar(str[cur++]);
        return *this;
    }
    Writer& operator<<(const char* str) {
        int cur = 0;
        while (str[cur])
            putchar(str[cur++]);
        return *this;
    }
    Writer& operator<<(string str) {
        int st = 0, ed = str.size();
        while (st < ed)
            putchar(str[st++]);
        return *this;
    }
    Writer() {}
} cout;
}  \/\/ namespace Fastio
using namespace Fastio;
#define cin Fastio::cin
#define cout Fastio::cout
#define endl Fastio::endl
#define __int128 bigint

矩阵(取模)

int mod = ;\/\/模数自己设
struct Matrix{
	int n,m;
	int a[205][205];
	
	inline void init(){ memset(a,0,sizeof a); }
	inline void init_I(){
		init();
		F(i,1,min(n,m)) a[i][i]=1;
	}

    inline friend void operator = (Matrix &a,Matrix b){
		a.n=b.n,a.m=b.m;
		F(i,1,a.n) F(j,1,a.m) a.a[i][j] = b.a[i][j];
	}
	inline friend Matrix operator + (Matrix a,Matrix b){
		Matrix c; c.init();
		c.n=a.n, c.m=a.m;
		F(i,1,c.n) F(j,1,c.m) c.a[i][j] = (a.a[i][j] + b.a[i][j]) % mod;
		return c;
	}
	inline friend Matrix operator - (Matrix a,Matrix b){
		Matrix c; c.init();
		c.n=a.n, c.m=a.m;
		F(i,1,c.n) F(j,1,c.m) c.a[i][j] = (a.a[i][j] - b.a[i][j]) % mod;
		return c;
	}
	inline friend Matrix operator * (Matrix a,Matrix b){
		Matrix c; c.init();
		c.n = a.n, c.m = b.m;
		F(i,1,c.n) F(j,1,c.m) F(k,1,a.m)
			c.a[i][j] = (c.a[i][j] + a.a[i][k] * b.a[k][j]) % mod;
		return c;
	}
	inline friend bool operator == (Matrix a,Matrix b){
		if(a.n!=b.n || a.m!=b.m) return 0;
		F(i,1,a.n) F(j,1,a.m) if(a.a[i][j] != b.a[i][j]) return 0;
		return 1;
	}
	inline friend bool operator != (Matrix a,Matrix b){ return !(a==b); }
	inline friend void operator += (Matrix &a,Matrix b){ a=a+b; }
	inline friend void operator -= (Matrix &a,Matrix b){ a=a-b; }
};

线性筛(带了欧拉函数和莫比乌斯函数)

int prm[N], tot;
int phi[N], mu[N];
inline void ola(){
	F(i,2,N-50){
		if(!p[i]){
			prm[++tot] = i;
			phi[i] = i-1;
			mu[i] = -1;
		}
		for(int j=1; j<=tot && i*prm[j]<=N-50; j++){
			p[i] = true;
			if(i%prm[j]){
				phi[i*prm[j]] = phi[i] * phi[prm[j]];
				mu[i*prm[j]] = -mu[i];
			}else{
				phi[i*prm[j]] = phi[i] * prm[j];
				mu[i*prm[j]] = 0;
				break;
			}
		}
	}
}

快速幂

inline int fast_power(int x,int p){
  \/\/求x的p次方
	int ans = 1;
	while(p){
		if(p % 2) ans = ans * x % mod;
		x = x * x % mod;
		p \/= 2;
	}
	return ans;
}

建图——邻接表的连边

无边权

int tot,head[N],ver[N],nxt[N];

inline void add(int x,int y){
	ver[++tot] = y;
	nxt[tot] = head[x];
	head[x] = tot;
}

有边权

int tot,head[N],ver[N],nxt[N],d[N];

inline void add(int x,int y,int z){
	ver[++tot] = y;
	nxt[tot] = head[x];
	head[x] = tot;
	d[tot] = z;
}

图论

DSU 并查集

int dsu[N];

inline int find(int x){\/\/查找根节点 
	return dsu[x] == x ? x : dsu[x] = find(dsu[x])
}

inline void merge(int x,int y){\/\/合并两棵树 
	int xx = find(x), yy = find(y);
	dsu[xx] = yy;
}

inline bool union_find(int x,int y){\/\/查找两个节点是否在同一棵树上,即根节点是否相同 
	return find(x) == find(y);
}

MST 最小生成树

Kruskal

int n,m;
int dsu[M];
struct edge{
	int u,v,val;
}a[N];
inline int find(int x){
	return (x == dsu[x]) ? x : dsu[x] = find(dsu[x]);
}
inline void merge(int x,int y){
	int xx = find(x),yy = find(y);
	dsu[xx] = yy;
}
inline bool cmp(edge a,edge b){
	return a.val < b.val;
}
signed main(){
	sort(a + 1,a + m + 1,cmp);
	int cnt = 0,ans = 0;
	F(1,i,m){
		if(find(a[i].u) != find(a[i].v)){
			cnt ++;
			ans += a[i].val;
			merge(a[i].u,a[i].v);
		}
		if(cnt == n - 1) break;
	}
	cout << ans;
	return 0;
}

数论

拓展欧几里得

inline int exgcd(int a,int b,int &x,int &y){
	if(b == 0){
		x = 1,y = 0;
		return a;
	}
	int x2,y2;
	int d = exgcd(b,a % b,x2,y2);
	x = y2,y = x2 - a \/ b * y2;
  return d; 
}

字符串算法

Trie 字典树

int idx,tr[N][70],cnt[N];
char c[N];

inline int gt(char c){
	if(c>='A'&&c<='Z') return c-'A';
	if(c>='a'&&c<='z') return c-'a'+26;
	else return c-'0'+52;
}
inline void insert(char c[]){
	int len=strlen(c),p=0;
	F(i,0,len-1){
		int t=gt(c[i]);
		if(!tr[p][t]) tr[p][t]=++idx;
		p=tr[p][t];
		cnt[p]++;
	}
}
inline int find(char c[]){
	int len=strlen(c),p=0;
	F(i,0,len-1){
		int t=gt(c[i]);
		if(!tr[p][t]) return 0;
		p=tr[p][t];
	}
	return cnt[p];
}

最小表示法

inline string get_min_expression(string s){
	int k=0,i=0,j=1,n=s.size();
	while(i<n && j<n && k<n){
		if(s[(i+k)%n] == s[(j+k)%n]) k++;
		else{
			s[(i+k)%n]>s[(j+k)%n] ? i+=k+1 : j+=k+1;
			if(i==j) i++;
			k=0;
		}
	}
	i=min(i,j);
	return s.substr(i)+s.substr(0,i);
}

树形数据结构

Treap 无旋平衡树

struct Treap{
	#define ls (tr[rt].l)
	#define rs (tr[rt].r)
	struct treap{
		int l, r;\/\/表示其左右儿子的结点编号
		int val, rnd;\/\/BST的权值与heap的值
		int siz;\/\/以其为根的子树的结点个数
	}tr[N];
	int root = 0, n = 0;
	\/\/分别代表根和数的个数
	
	inline void newnode(int v){
		tr[++n].val = v;
		tr[n].rnd = rng();
		tr[n].siz = 1;
	}\/\/新增一个结点
	
	inline void pushup(int rt){
		tr[rt].siz = tr[ls].siz + tr[rs].siz +1;
	}
	inline void split(int rt, int v, int &x, int &y){
		if(rt == 0){
			x = y = 0;
			return ;
		}
		
		if(tr[rt].val <= v){
			x = rt;
			split(rs, v, rs, y);
		}else{
			y = rt;
			split(ls, v, x, ls);
		}\/\/极其重要!极其巧妙! 
		
		pushup(rt);
	}
	inline int merge(int x, int y){
		if(!x || !y) return x+y;
		if(tr[x].rnd < tr[y].rnd){
			tr[x].r = merge(tr[x].r, y);
			pushup(x);
			return x;
		}else{
			tr[y].l = merge(x, tr[y].l);
			pushup(y);
			return y;
		}
	}\/\/基本操作:分裂与合并
	
	inline void insert(int v){
		int x, y;
		split(root, v, x, y);
		newnode(v);
		int z = n;
		root = merge(merge(x, z), y);
	}\/\/插入结点
	inline void del(int v){
		int x, y, z;
		split(root, v, x, z);
		split(x, v-1, x, y);
		y = merge(tr[y].l, tr[y].r);\/\/若y有子树,合并(保留)其子树 
		root = merge(merge(x, y), z);
	}\/\/删除结点 
	inline int rank(int v){
		int x, y;
		split(root, v-1, x, y);
		int ans = tr[x].siz + 1;
		root = merge(x, y);
		return ans;
	}\/\/查找排名 
	inline int topk(int rt, int k){
		int lsz = tr[ls].siz;
		if(k == lsz+1) return tr[rt].val;
		if(k <= lsz) return topk(ls, k);
		return topk(rs, k - lsz - 1);
	}\/\/查询第k小的数 
	inline int get_pre(int v){
		int x, y;
		split(root, v-1, x, y);
		int ans = topk(x, tr[x].siz);
		root = merge(x, y);
		return ans;
	}\/\/查找前驱 
	inline int get_suc(int v){
		int x, y;
		split(root, v, x, y);
		int ans = topk(y, 1);
		root = merge(x, y);
		return ans;
	}\/\/查找后继 
}T;

Splay 平衡树

struct Splay{
	int rt=0,tot=0;
	struct tree{
		int s[2];	\/\/s[0]表示左儿子,s[1]表示右儿子 
		int siz,val;
		int fa;
	}tr[N<<1];
	#define fa(x) tr[x].fa
	#define ls(x) tr[x].s[0]
	#define rs(x) tr[x].s[1]
	inline void pushup(int x){ tr[x].siz = tr[ls(x)].siz+tr[rs(x)].siz+1; }
	inline void clear(int x){ tr[x].siz=fa(x)=ls(x)=rs(x)=tr[x].val=0; }
	inline bool get(int x){
		return rs(fa(x))==x;
	}\/\/返回 0 表示左儿子,返回 1 表示右儿子 
	
	inline void newnode(int val){
		clear(++tot);
		tr[tot].val = val;
		tr[tot].siz=1;
	}
	inline void ronate(int x){
		int c = get(x), y = fa(x), z = fa(y);
		if(tr[x].s[!c]) fa(tr[x].s[!c]) = y;	\/\/若有与自身左右不一样的儿子,则"过继"至原父亲
		tr[y].s[c] = tr[x].s[!c];
		tr[x].s[!c] = y; fa(y) = x;		\/\/原父亲成为儿子
		fa(x) = z;			\/\/原儿子成为原先父亲的父亲的儿子 
		if(z) tr[z].s[(y==rs(z))] = x;
		pushup(y), pushup(x);
	}
	inline void splay(int x){
		while(fa(x)){
			int y = fa(x);
			if(fa(y)) ronate(get(x)==get(y) ? y : x);	\/\/若自身与父亲同向,则先旋转父亲以维护平衡性
			ronate(x);
		}
		rt=x;
	}\/\/将某个非根节点旋转至根上 
	
	inline void insert(int val){
		int now=rt, f=0;
		while(now){
			f = now;
			now = tr[now].s[val > tr[now].val];
		}\/\/在二叉搜索树上寻找当前权值的节点应在的位置 
		newnode(val);		\/\/创建新节点 
		fa(tot) = f;
		tr[f].s[val > tr[f].val] = tot;		\/\/将新创建的节点与其父亲连边 
		splay(tot);		\/\/将新节点旋转至根 
	}\/\/插入新节点 
	inline void del(int val){
		int now=rt,f=0;
		while(now && tr[now].val != val){
			f=now;
			now = tr[now].s[val>tr[now].val];
		}\/\/根据权值,在二叉搜索树上寻找应删除节点的位置 
		if(!now){
			splay(f);
			return ;
		}\/\/若没有则旋转其父亲,并直接返回 
		splay(now);
		int cur = ls(now);
		if(!cur){		\/\/若要删除的点没有左儿子 
			rt = rs(now);	\/\/右儿子为根 
			if(rt) fa(rt) = 0;
			clear(now);	\/\/删除 
			return ;
		}
		while(rs(cur)) cur = rs(cur);	\/\/找到左子树中最右边的叶子 
		rs(cur) = rs(now);
		if(rs(now)) fa(rs(now)) = cur;	\/\/右子树的根成为该叶子的儿子以维护二叉搜索树的性质 
		fa(ls(now)) = 0;
		clear(now);			\/\/删除节点 
		pushup(cur);
		splay(cur);
	}\/\/删除一个节点 
	inline int rhk(int val){
		int res=1,now=rt,f=0;
		while(now){
			f = now;
			if(tr[now].val<val){
				res += tr[ls(now)].siz+1;	\/\/左子树上的点和当前点都比其小
				now = rs(now);
			}else now = ls(now);
		}
		splay(f);
		return res;
	}\/\/查询某个数的 rank 值
	inline int kth(int cnt){
		int now=rt;
		while(now){
			int t = tr[ls(now)].siz+1;
			if(t < cnt) cnt -= t, now=rs(now);
			else if(t==cnt) break;
			else now=ls(now);
		}
		splay(now);
		return tr[now].val;
	}\/\/查询第 k 小的值 
	inline int pre(int val){
		int res=0,now=rt,f=0;
		while(now){
			f=now;
			if(tr[now].val>=val) now=ls(now);
			else res=tr[now].val,now=rs(now);
		}
		splay(f);
		return res;
	}\/\/查询前驱 
	inline int nxt(int val){
		int res=0,now=rt,f=0;
		while(now){
			f=now;
			if(tr[now].val<=val) now=rs(now);
			else res=tr[now].val,now=ls(now);
		}
		splay(f);
		return res;
	}\/\/查询后继 
}T;

评论

暂无评论

发表评论

可以用@mike来提到mike这个用户,mike会被高亮显示。如果你真的想打“@”这个字符,请用“@@”。