0%

2025 河南省赛——Gym-105941-Problem B. 随机栈 II

题目大意

题目描述
维护一个初始为空的多重集。每次从中取出一个元素时,集合中每个元素被取出的概率均等,取出后该元素从集合中删除。每次取出事件相互独立。
给定 nn 次操作序列,操作分为两种:

  1. 向集合中插入一个给定的非负整数。

  2. 从集合中随机取出一个元素。
    保证每次取出时集合非空,且整个操作序列至少包含一次取出操作。
    求所有被取出的元素排成的序列满足单调不降(即序列中每一项都小于等于它的后一项)的概率。答案对 998244353 取模。

输入格式
第一行包含一个正整数 TT1T2.5×1031 \le T \le 2.5 \times 10^3),表示测试数据组数。
每组数据第一行包含一个整数 nn2n5×1032 \le n \le 5 \times 10^3),表示操作总数。
第二行包含 nn 个整数 a1,a2,,ana_1, a_2, \dots, a_n1ain-1 \le a_i \le n),表示操作序列:

  • ai0a_i \ge 0 表示向集合中插入 aia_i

  • ai=1a_i = -1 表示从集合中随机取出一个元素。
    保证所有数据的 n5×103\sum n \le 5 \times 10^3

输出格式
对于每组数据,输出一行一个整数,表示概率对 998244353 取模的结果。

样例输入 1

1
2
3
4
5
6
7
3
3
1 2 -1
5
1 2 -1 3 -1
7
1 2 3 4 -1 -1 -1

样例输出 1

1
2
3
1
249561089
166374059

样例输入 2

1
2
3
4
5
6
7
3
4
1 2 -1 -1
6
1 2 -1 -1 1 -1
8
1 -1 2 -1 3 -1 4 -1

样例输出 2

1
2
3
499122177
0
1

样例解释
对于样例一的第 1 组数据,由于总共只取出了一个元素,序列无论如何都是单调不降的,概率为 1。

对于样例一的第 2 组数据,操作过程和对应概率如下:

  • 加入 1,集合变为 {1}\{1\};加入 2,集合变为 {1,2}\{1, 2\}

  • 此时取出 1 的概率为 12\frac{1}{2}。若取出 1,接下来加入 3,集合变为 {2,3}\{2, 3\}。之后无论取出哪个数,得到的序列(1 然后是 2 或 1 然后是 3)均满足单调不降。此情况的合法概率为 12×1=12\frac{1}{2} \times 1 = \frac{1}{2}

  • 此时取出 2 的概率为 12\frac{1}{2}。若取出 2,接下来加入 3,集合变为 {1,3}\{1, 3\}。之后为了保持单调不降,必须取出 3,取出 3 的概率为 12\frac{1}{2}。此情况的合法概率为 12×12=14\frac{1}{2} \times \frac{1}{2} = \frac{1}{4}
    最终单调不降的总概率为 12+14=34\frac{1}{2} + \frac{1}{4} = \frac{3}{4}34\frac{3}{4} 对 998244353 取模的结果为 249561089。

思路讲解

image

这道题目其实不难,就是细节比较多。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
for (ll op = 1; op <= cnt_op; ++op) {
vector<ll> vals;
for (int val = 0; val <= N; ++val) {
if (op_val_cnt[op][val] >= 1) {
vals.push_back(val);
}
}
if (op == 1) {
for (auto val: vals) {
dp_val_cnt[val].resize(2);
// 注意,有多个数字可选的时候,有多种可能,答案不是 1
dp_val_cnt[val][1] = op_val_cnt[op][val];
dpacc_val[val] = op_val_cnt[op][val];
}
} else {
vector<vector<ll> > ndp_val_cnt(N + 2);
vector<ll> ndpacc_val(N + 2);
vector<ll> presum_dpacc(N + 2);
// 注意使用前缀和优化
partial_sum(all(dpacc_val), presum_dpacc.begin(), [](ll a, ll b) {
return (a + b) % mod;
});
for (auto val: vals) {
ll val_cnt = op_val_cnt[op][val];
ndp_val_cnt[val].resize(min(op, val_cnt) + 2);
for (int cnt = 1; cnt <= min(op, val_cnt); ++cnt) {
ll rem_val = val_cnt - cnt + 1;
if (cnt >= 2) {
if (cnt - 1 > SZ(dp_val_cnt[val]) - 1) {
break;
}
ndp_val_cnt[val][cnt] = rem_val * dp_val_cnt[val][cnt - 1];
ndp_val_cnt[val][cnt] %= mod;
} else {
// for (int s_val = 0; s_val < val; ++s_val) {
// ndp_val_cnt[val][cnt] += dpacc_val[s_val];
// ndp_val_cnt[val][cnt] %= mod;
// }
if (val - 1 >= 0) {
ndp_val_cnt[val][cnt] = presum_dpacc[val - 1];
}
ndp_val_cnt[val][cnt] *= rem_val;
ndp_val_cnt[val][cnt] %= mod;
}
ndpacc_val[val] += ndp_val_cnt[val][cnt];
ndpacc_val[val] %= mod;
}
}
swap(dp_val_cnt, ndp_val_cnt);
swap(dpacc_val, ndpacc_val);
}
}

AC代码

AC
https://codeforces.com/gym/105941/submission/367931586

心路历程(WA,TLE,MLE……)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
do {
// 注意,还是要把 numc 显式地写出来
ll opc = 0, numc = 0;
vector<ll> cnt(N + 2);
for (int i = 1; i <= N; ++i) {
if (A[i] == -1) {
tot_num *= (numc - opc);
tot_num %= mod;
opc++;
op_val_cnt[opc] = cnt;
} else {
cnt[A[i]]++;
numc++;
}
}
} while (false);