P8253 [NOI Online 2022 提高组] 如何正确地排序 题解
KingPowers
·
2024-02-13 13:47:23
·
题解
前言
题目链接
114514 年前就听过的题了,现在终于是写上了。
巨大二/三维数点题,甚至还能带上 min-max 容斥。
正文
首先注意到 \min 和 \max 的做法本质上是相同的,下文只讨论 \max 该怎么做,\min 的部分偷懒的话可以直接将所有数取相反数再跑一边 \max。
由于 m 很小,且观察数据点 m 的分布,此题做法一定和 m 是强相关的,因此下文对 m=2,3,4 的情况分开讨论。
为了方便,提前约定 a_i,b_i,c_i,d_i 为原题中的 a_{1,i},a_{2,i},a_{3,i},a_{4,i}。
Case 1:m=2
这里有个重要的且下文会一直沿用的思想:考虑每对 i,j 会对答案贡献几次。
假如 a_i+a_j 会被算作最大值,那么有一条限制就是 a_i+a_j\ge b_i+b_j,移项得到 a_i-b_i\ge b_j-a_j。
贡献可以写成:
\sum_{i=1}^n\sum_{j=1}^n[a_i-b_i\ge b_j-a_j](a_i+a_j)
换个写法:
\sum_{i=1}^n(a_i\sum_{j=1}^n[a_i-b_i>b_j-a_j]+\sum_{j=1}^n[a_i-b_i>b_j-a_j]a_j)
每个 i 拆出属性为 a_i-b_i 和 b_i-a_i 的点,实际上是个一维数点,直接排序或者树状数组统计即可。
同理,如果 b_i+b_j 会被算作最大值,那么限制为 b_i-a_i>a_j-b_i(注意这里是大于号,用于防止算重),再跑一遍一维数点即可。
时间复杂度 O(n\log n)。
给出核心代码:
int lowbit(int x){return x & (-x);}
void add(int x, int y){
x += 200001;
while(x <= 400005){
sum[x] += y; cnt[x]++;
x += lowbit(x);
}
}
int query_sum(int x){
int res = 0; x += 200001;
for(; x; x -= lowbit(x)) res += sum[x];
return res;
}
int query_cnt(int x){
int res = 0; x += 200001;
for(; x; x -= lowbit(x)) res += cnt[x];
return res;
}
int solve2(int a[], int b[]){
int res = 0;
For(i, 1, n) add(b[i] - a[i], a[i]);
For(i, 1, n) res += query_cnt(a[i] - b[i]) * a[i] + query_sum(a[i] - b[i]);
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum);
For(i, 1, n) add(a[i] - b[i], b[i]);
For(i, 1, n) res += query_cnt(b[i] - a[i] - 1) * b[i] + query_sum(b[i] - a[i] - 1);
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum);
return res;
}
Case 2:m=3
考虑继续沿用 m=2 的做法,假如 a_i+a_j 被算作最大值,限制为:
\begin{cases}
a_i+a_j\ge b_i+b_j \\
a_i+a_j\ge c_i+c_j
\end{cases}
同样进行移项,可以得到:
\begin{cases}
a_i-b_i\ge b_j-a_j \\
a_i-c_i\ge c_j-a_j
\end{cases}
那么贡献写出来和 m=2 那个式子其实是一样的,只是方括号中的条件变成了满足上面这两个式子。
如果你真正理解了 m=2 的做法,那么不难发现这部分的贡献可以直接跑一个二维数点算出来,每个点的属性为 (a_i-b_i,a_i-c_i) 或 (b_i-a_i,c_i-a_i)。
同理,b_i+b_j 被算作最大值的限制就是:
\begin{cases}
b_i-a_i>a_j-b_j \\
b_i-c_i\ge c_j-b_j
\end{cases}
$$
\begin{cases}
c_i-a_i>a_j-c_j \\
c_i-b_i>b_j-c_j
\end{cases}
$$
一定要注意不等号,否则会算重或算漏。
那这部分直接跑三遍二维数点就做完了,时间复杂度 $O(n\log n)$。
同样放出核心代码:
```cpp
struct node{
int x, y, w, op;
}q[N];
bool cmp(const node &a, const node &b){
if(a.x != b.x) return a.x < b.x;
if(a.y != b.y) return a.y < b.y;
return a.op < b.op;
}
int lowbit(int x){return x & (-x);}
void add(int x, int y){
x += 200001;
while(x <= 400005){
sum[x] += y; cnt[x]++;
x += lowbit(x);
}
}
int query_sum(int x){
int res = 0; x += 200001;
for(; x; x -= lowbit(x)) res += sum[x];
return res;
}
int query_cnt(int x){
int res = 0; x += 200001;
for(; x; x -= lowbit(x)) res += cnt[x];
return res;
}
int solve3(int a[], int b[], int c[]){
int res = 0, tot = 0;
For(i, 1, n){
q[++tot] = {a[i] - b[i], a[i] - c[i], a[i], 1};
q[++tot] = {b[i] - a[i], c[i] - a[i], a[i], 0};
}
sort(q + 1, q + tot + 1, cmp);
For(i, 1, tot){
if(q[i].op == 1) res += query_cnt(q[i].y) * q[i].w + query_sum(q[i].y);
else add(q[i].y, q[i].w);
}
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum); tot = 0;
For(i, 1, n){
q[++tot] = {b[i] - a[i] - 1, b[i] - c[i], b[i], 1};
q[++tot] = {a[i] - b[i], c[i] - b[i], b[i], 0};
}
sort(q + 1, q + tot + 1, cmp);
For(i, 1, tot){
if(q[i].op == 1) res += query_cnt(q[i].y) * q[i].w + query_sum(q[i].y);
else add(q[i].y, q[i].w);
}
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum); tot = 0;
For(i, 1, n){
q[++tot] = {c[i] - a[i] - 1, c[i] - b[i] - 1, c[i], 1};
q[++tot] = {a[i] - c[i], b[i] - c[i], c[i], 0};
}
sort(q + 1, q + tot + 1, cmp);
For(i, 1, tot){
if(q[i].op == 1) res += query_cnt(q[i].y) * q[i].w + query_sum(q[i].y);
else add(q[i].y, q[i].w);
}
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum); tot = 0;
return res;
}
```
看着码量略大,其实重复的部分相当多,多写一个函数应该就能缩短很多。
### Case 3:$m=4
同样如果你也理解了 m=3 的做法,m=4 你或许也已经会了:继续沿用上面的做法,每行的限制写出来,直接跑三维数点,复杂度 O(n\log^2n),但是要跑四遍所以可能会比较慢。
事实上可以不用跑三维数点,注意到我们有 min-max 容斥:
\min(S)=\sum_{T\subseteq S}(-1)^{|T|+1}\max(T)
写出来我们要求的东西:
\sum_{i=1}^n\sum_{j=1}^n(\max_{k\in\{1,2,3,4\}}(a_{k,i}+a_{k,j})+\min_{k\in\{1,2,3,4\}}(a_{k,i}+a_{k,j}))
直接用 min-max 容斥把后面那个 \min 给替换掉:
\sum_{i=1}^n\sum_{j=1}^n(\max_{k\in\{1,2,3,4\}}(a_{k,i}+a_{k,j})+\sum_{k\in T,T\subseteq\{1,2,3,4\}}(-1)^{|T|+1}\max(a_{k,i}+a_{k,j}))
注意到当 T=\{1,2,3,4\} 时系数为 -1,与前面正好相抵消,所以实际上是求:
\sum_{i=1}^n\sum_{j=1}^n\sum_{k\in T,T\subsetneq\{1,2,3,4\}}(-1)^{|T|+1}\max(a_{k,i}+a_{k,j})
可以拆成若干个 m=1,2,3 的答案相加减得到。
复杂度 $O(n\log n)$,当然这个二维数点实际上要跑十遍左右,但应该还是比摁跑三维数点要快。
最后给出本题完整代码:
```cpp
#include
#define int long long
#define For(i, a, b) for(int i = (a); i <= (b); i++)
#define Rof(i, a, b) for(int i = (a); i >= (b); i--)
using namespace std;
const int N = 5e5 + 5;
int m, n, a[5][N], cnt[N], sum[N];
struct node{
int x, y, w, op;
}q[N];
bool cmp(const node &a, const node &b){
if(a.x != b.x) return a.x < b.x;
if(a.y != b.y) return a.y < b.y;
return a.op < b.op;
}
int lowbit(int x){return x & (-x);}
void add(int x, int y){
x += 200001;
while(x <= 400005){
sum[x] += y; cnt[x]++;
x += lowbit(x);
}
}
int query_sum(int x){
int res = 0; x += 200001;
for(; x; x -= lowbit(x)) res += sum[x];
return res;
}
int query_cnt(int x){
int res = 0; x += 200001;
for(; x; x -= lowbit(x)) res += cnt[x];
return res;
}
int solve1(int a[]){
int sum = 0;
For(i, 1, n) sum += a[i];
return sum * 2 * n;
}
int solve2(int a[], int b[]){
int res = 0;
For(i, 1, n) add(b[i] - a[i], a[i]);
For(i, 1, n) res += query_cnt(a[i] - b[i]) * a[i] + query_sum(a[i] - b[i]);
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum);
For(i, 1, n) add(a[i] - b[i], b[i]);
For(i, 1, n) res += query_cnt(b[i] - a[i] - 1) * b[i] + query_sum(b[i] - a[i] - 1);
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum);
return res;
}
int solve3(int a[], int b[], int c[]){
int res = 0, tot = 0;
For(i, 1, n){
q[++tot] = {a[i] - b[i], a[i] - c[i], a[i], 1};
q[++tot] = {b[i] - a[i], c[i] - a[i], a[i], 0};
}
sort(q + 1, q + tot + 1, cmp);
For(i, 1, tot){
if(q[i].op == 1) res += query_cnt(q[i].y) * q[i].w + query_sum(q[i].y);
else add(q[i].y, q[i].w);
}
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum); tot = 0;
For(i, 1, n){
q[++tot] = {b[i] - a[i] - 1, b[i] - c[i], b[i], 1};
q[++tot] = {a[i] - b[i], c[i] - b[i], b[i], 0};
}
sort(q + 1, q + tot + 1, cmp);
For(i, 1, tot){
if(q[i].op == 1) res += query_cnt(q[i].y) * q[i].w + query_sum(q[i].y);
else add(q[i].y, q[i].w);
}
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum); tot = 0;
For(i, 1, n){
q[++tot] = {c[i] - a[i] - 1, c[i] - b[i] - 1, c[i], 1};
q[++tot] = {a[i] - c[i], b[i] - c[i], c[i], 0};
}
sort(q + 1, q + tot + 1, cmp);
For(i, 1, tot){
if(q[i].op == 1) res += query_cnt(q[i].y) * q[i].w + query_sum(q[i].y);
else add(q[i].y, q[i].w);
}
memset(cnt, 0, sizeof cnt); memset(sum, 0, sizeof sum); tot = 0;
return res;
}
void Solve(){
cin >> m >> n;
For(i, 1, m) For(j, 1, n) cin >> a[i][j];
if(m == 2){
int ans = solve2(a[1], a[2]);
For(i, 1, m) For(j, 1, n) a[i][j] = -a[i][j];
ans -= solve2(a[1], a[2]);
cout << ans << '\n';
}
else if(m == 3){
int ans = solve3(a[1], a[2], a[3]);
For(i, 1, m) For(j, 1, n) a[i][j] = -a[i][j];
ans -= solve3(a[1], a[2], a[3]);
cout << ans << '\n';
}
else{
int ans = solve1(a[1]) + solve1(a[2]) + solve1(a[3]) + solve1(a[4]);
ans -= solve2(a[1], a[2]) + solve2(a[1], a[3]) + solve2(a[1], a[4]) + solve2(a[2], a[3]) + solve2(a[2], a[4]) + solve2(a[3], a[4]);
ans += solve3(a[1], a[2], a[3]) + solve3(a[1], a[2], a[4]) + solve3(a[1], a[3], a[4]) + solve3(a[2], a[3], a[4]);
cout << ans << '\n';
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int T = 1; //cin >> T;
while(T--) Solve();
return 0;
}
```