Codeforces Round 791

E

https://codeforces.com/contest/1679/problem/E

长度n,前17个小写字母组成的字符串s,其中有?表示t中任意字符

q次询问,每次给17个字母中的一个子集,问把?所有填子集的填法中,原串的回文子串个数和, 答案mod 998244353

范围

n 1e3

t 17

q 2e5

2s

256MB

我的思路

首先看到q 2e5, 和 t 17, 因为2^17 == 131072,其实可以变成求所有方案的答案

统计所有回文子串 的效率最好就是选取中心,再枚举长度

那么这个题也是这样做, 枚举中心向两边搜索

其中注意到的是,当选定中心后,如果对称位置上有一个问号一个确定字符,那么这个问号必定填这个字符

如果两个都是问号,那么自由度+1, 它

如果对称都填了但是不一样,那么到此结束

也就是n方可以统计 cnt[必要字符bitmask][自由度] = 次数

那么 每次询问bitmask 是 M 的话

ans[M] = sum { cnt[m 是M的子集][i = 0..] * size(M)^i }

那么问题来了, 朴素计算我必定超时, 但是我不知道如何dp, 看Codeforces的群里有人说sosdp

SOSDP 子集和DP

$F[mask] =\sum_{i \in mask}A[i]$

暴力枚举 $O(4^n)$

1
2
3
4
5
rep(mask,0,1 << N){
rep(i,0,1 << N){ // 枚举 mask, 检测是否是它的子集
if((i&mask) == i) F[mask] += A[i];
}
}

枚举子集和$O(3^n)$

子集遍历

也可以dfs,当然dfs会多函数调用和堆栈 比如mask = 111, i依次为 111,110,101,100,011,010,001,000, 注意到的是mask中间穿插0对应的就是i对应位置穿插0,看起来就像每次-1

1
2
3
4
5
6
rep(mask,0,1 << N){
F[mask] = A[0]; // 所有集合包含空集 空集合也作为停止条件
for(int i = mask;i;i = (i-1)&mask){ // 这就是传说中的二进制枚举子集 ,
F[mask] += A[i];
}
}

SOSdp $O(n2^n)$

dp[mask][i] 表示 高位和mask一致, 低[0..i]位所有方案的和

dp[10110][2] = A[10000]+A[10010]+A[10100]+A[10110]

状态转移

第i位为0时,dp[mask][i] = dp[mask][i-1]

第i位为1时,dp[mask][i] = dp[mask][i-1] + dp[mask xor (1 << i)][i-1]

这样变成递推 或者记忆化搜索,可以 $O(n2^n)$ 完成

上面合并一下变成,dp[mask][i] = dp[mask][i-1] + (mask & (1 << i)?dp[mask xor (1 << i)][i-1]:0)

注意到i依赖于i-1还可以滚动数组降低空间

1
2
3
4
5
6
7
rep(mask,0,1<<N) f[mask] = A[mask];
rep(i,0,N){
// 这里不需要从大到小, 因为dp[mask]已经初始化了,只会更新1 << i上为1的,而更新的来源是1 << i上不是1的
rep(mask,0,1 << N){
if(mask & (1 << i)) f[mask]+=f[mask ^ (1 << i)];
}
}

题解

我的暴力的代码

1
2
3
4
5
6
7
8
ans[mask] = 0;
sz = bitcount(mask)
rep(qmark,0,501){ // 自由的问号个数
ti = pow(sz,qmark)
for(int i = mask;i;i = (i-1)&mask){
ans[mask] += cnt[i][qmark] * ti
}
}

这是$O(n 3^{|t|})$ 的复杂度 肯定过不了

通过sosdp降低到$O(n t 2^{|t|})$ 虽然降低了不少,但是依然过不了

这里dp 改一个方式设计, 变成贡献统计

先交换一下循环层级

1
2
3
4
5
6
7
ans[mask] = 0;
sz = bitcount(mask)
for(int i = mask;i;i = (i-1)&mask){
rep(qmark,0,501){
ans[mask] += cnt[i][qmark] * pow(sz,qmark)
}
}

因为{i,qmark}中i是mask的子集,而{i,qmark}对mask 的贡献来讲 只与bitcount(mask) 有关,与mask 具体无关

1
2
3
4
5
6
7
8
9
10
11
12
13
rep(i,0,(1 << N)){
rep(sz,1,17+1){
rep(qmark,0,501){
cost[i][sz] += cnt[i][qmark] * pow(sz,qmark);
}
}
}
ans[mask] = 0;
sz = bitcount(mask)
// sosdp 优化掉
for(int i = mask;i;i = (i-1)&mask){
ans[mask] += cost[i][sz];
}

下面得到优化了, 但上面看起来复杂度并没有变好

但既然都说道贡献统计了,就去看贡献来源

在我们初始找回文时[i...j] 如果是一个回文,它的必须字符集为mask,自由度为qmark

那么

1
2
3
rep(sz,1,17+1){
cost[mask][sz] += pow(sz,qmark);
}

这样 初始化就变成$O(|t| n^2)$


综上 $O(初始化幂次 |t|n + 初始化贡献 |t| n^2 + 初始化答案 |t|^2 2^{|t|} + 查询 |t|q )$

注意到 题目要统计 所有字符串个数,也就是说 同一个位置的回文串 在不同字符串中出现,要多次统计

所以

1
2
3
rep(sz,1,17+1){
cost[mask][sz] += pow(sz,qmark) * pow(sz,total qmark - used qmark);
}

代码

https://codeforces.com/contest/1679/submission/158244316

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define MOD 998244353
#define rep(i,a,n) for (ll i=a;i<(ll)n;i++)
#define per(i,a,n) for (ll i=n;i-->(ll)a;)
#define pb push_back

const int t = 17;
char s[1010]; // 读入
char s2[2010]; // 插入井号简化搜索
int n;

ll p[t+10][1010]; // 幂次预处理
ll cost[(1<<t)+10][t+10]; // 具体的值
ll ans[(1<<t)+10][t+10]; // 由具体值的子集和的答案
char query[t+10];

int main(){
rep(i,1,t+1){
p[i][0] = 1;
rep(pwr,1,1000+1){
p[i][pwr] = p[i][pwr-1]*i%MOD;
}
}
scanf("%d",&n);
scanf("%s",s);
int qtotal = 0;
rep(i,0,n) qtotal += s[i] == '?';
// 字符两两间插入井号方便处理奇偶
s2[0] = s[0];
rep(i,1,n){
s2[i*2-1] = '#';
s2[i*2] = s[i];
}
// 中心
rep(c,0,n*2-1){
int mask = 0;
int qmark = 0;
int qcnt = 0;
rep(l,0,n*2-1){ // 长度
int il = c-l;
int ir = c+l;
if(il < 0 || ir >= n*2-1)break;
qcnt += s2[il] == '?';
qcnt += il != ir && s2[ir] == '?';
if(s2[il] == '#')continue; // 不贡献的
if(s2[il] == s2[ir]){
if(s2[il] == '?'){
qmark++;
}
}else{ // 不等
if(s2[il] == '?'){
mask |= (1 << (s2[ir] - 'a'));
}else if(s2[ir] == '?'){
mask |= (1 << (s2[il] - 'a'));
}else{ // 不同的字符
break;
}
}
// 贡献统计
rep(sz,1,17+1){
// 不同字符串 同一个位置回文串的贡献要多次统计 所以要乘上 其余位置是问号的所有放入方案 让此处的贡献倍数 sz**(qtotal - qcnt)
(cost[mask][sz] += p[sz][qmark] * p[sz][qtotal - qcnt] % MOD )%=MOD;
}
}
}
// sosdp
rep(sz,1,t+1){
rep(mask,0,1 << t){
ans[mask][sz] = cost[mask][sz];
}
rep(i,0,t){
rep(mask,0,1 << t){
if(mask & (1 << i)) (ans[mask][sz] += ans[mask ^ (1 << i)][sz])%=MOD;
}
}
}
// query
int q; // 2e5
scanf("%d",&q);
while(q--){
// string to mask
scanf("%s",query);
int sz = strlen(query);
int mask = 0;
rep(i,0,sz){
mask |= (1<<(query[i]-'a'));
}
printf("%lld\n", ans[mask][sz]);
}
return 0;
}

总结

回文的奇偶处理,可以用两两间插入不会出现的字符,如井号,方便枚举中心

学了一手sosdp

看ssr的twitter说有快速zeta变换, 搜下来看转移方程似乎就是sosdp, https://zhuanlan.zhihu.com/p/33328788?ivk_sa=1024320u

顺便这里的dp转化过程中 通过贡献和来优化效率

参考

官方

F

https://codeforces.com/contest/1679/problem/F

$[0..10^n)$的所有n位整数,不足的补前导零

给m个 (ui,vi) 数对, (ui不等于vi)

x 表示成十进制的数字数组 [d1,d2,…,dn]

一次操作可以交换相邻的d, 但需要满足 这两个数的 (di,di+1)或(di+1,di) 出现在 (ui,vi)中

如果一个数x能够通过上述转换变成y,那么认为它们是相等的, x和x自身相等

问题$[0..10^n)$ 有多少个不同的数, 答案mod 998244353

范围

n 5e4

m 45

u,v [0,9]

3s

256MB

题解

我的思路

先提取一下有用没用的信息,首先m没啥用因为实际就是所有数对的上限

如果x 中有两个数字 c0,c1 但是这两个数字没有在d中出现过, 那么这两个数字的相对前后关系不会改变

换句话说,如果两个相等的 它们互相可以转化,那么它们一定属于某个集合,集合里两两可以转化, 而有限集合一定有最小的, 我们用每个集合中数值最小的来表示一整个集合

于是 这个最小值值可以表示成 [单调递增] [单调递增] [单调递增] , 每两个单调递增之间 的值是不在数对里的

所以如果我们可以得到[起始,结束,长度]的方案数就可以考虑转移方程了

f[i][j][len1] = sum{ f[i][k][len0] * inc[ < k][j][len1-len0] } + inc[i][j][len]

看似复杂度没法搞,而实际上连逻辑也不一定对 201, 它允许 (0,1),(2,1) 那么显然120和它相等且更小

题解

思路是类似的, 也是集合的代表元, 但是并不是靠单调递增划分

而是说[序列长度l] 在后面加上mask中的数,它不会被移动到前面

那么,d会移动到前面的条件就是,序列的尾部的一串数都大于d,且都可以和d交换

[x,d1,d2,d3,...,ds,y]

其中 x > y, 且 y 可以和x,d1,...,ds交换

dp[suff][mask] 表示, suff个digits, 且只有mask中的digit可以被移到最左


考虑长度为s 的一个具体的串 等价最小串 X=[d0,d1,d2,.....,ds]

最长的前缀[d0,d1,d2,....dt] 包含的digits 两两可换, 我们把这样的digits变成mask, X 可以表示成贡献到 dp[s][mask]

那么现在如果左边放一个d, 变成X1 = [d,d0,d1,d2,...,ds]

一旦d和 mask 中某个值可换, 记为e, 且d > e

显然,因为mask中两两可换,所以X1 可以变成 [d,e,d0,d1,d2,...,ds], 然后交换e,d 得到 [e,d,d0,d1,d2,...,ds]

因为e < d ,所以 这个值比X1

即是X前面不能插入d, 如果 mask 中存在比d小,且和d相连的任何一个e

这样可选的d的范围就出来了, 这部分可以预处理

从X而非mask的角度来看,就是说插入了d以后,得到的值依然是 集合表示的最小值(代表元)


mask的变化

如果d可以放,那么X1 = [d,d0,d1,d2,...,ds], 显然,它的 最长两两可换前缀中有d, 且剩余的部分从mask 中取, 因为mask本身就是两两可换,所以只用考虑和d是否有边

所以新mask = d | mask 中和d有边的点


注意到mask 的意义是 mask 中的值两两可换, 又是X的最大前缀

所以这里其实会计算不少无效的mask, 但因为算次数,这些次数一定是0, 不影响答案

代码

https://codeforces.com/contest/1679/submission/158638333

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define MOD 998244353
#define rep(i,a,n) for (ll i=a;i<(ll)n;i++)
#define per(i,a,n) for (ll i=n;i-->(ll)a;)
#define pb push_back

namespace X{
ll r(){ll r;scanf("%lld",&r);return r;} // read
void add(ll &v0,ll &v1){(v0+=v1%MOD)%=MOD;} // add with MOD
};
using namespace X;

const int S = 10;

// mask 意义, mask中两两可换, 是X的最大前缀
// camp[mask0] = mask1, mask1任意一个bit 和 mask0 中的比它小的bit都没有链接
// 在 mask1中的digit 才能加入到 mask0
vector<int> camp(1<<S,(1<<S)-1);
bool conn[S][S]; // 连接状态
ll dp[2][1 << S];
int trans[1 << S][S]; // [mask][digit] = newmask,

int main() {
int n = r();
int m = r();
rep(i,0,m) {
int u = r();
int v = r();
conn[u][v] = conn[v][u] = 1;
}
// 计算 camp
// O(S^2 * 2^S)
rep(mask,0,1 << S){
rep(c,0,S){ // c 在 mask 中
if (!(mask & (1 << c))) continue;
rep(j,c+1,S){ // j > c
if (conn[c][j]) {
camp[mask] &= ~(1 << j);
}
}
}
}
// 计算trans
// O(S^2 * 2^S)
rep(mask,0,1 << S) {
rep(c,0,S){
trans[mask][c] = 1 << c;
rep(j,0,S){
// j 出现在mask 中, (c,j) 可以交换
if ((mask & (1 << j)) && conn[c][j]) {// 和 mask 中存在相连
trans[mask][c] |= 1 << j;
}
}
}
}
// 滚动数组
int cur = 0;
dp[0][0] = 1;
// O(n * S * 2^S)
rep(i,0,n){
// clear
fill(dp[cur^1],dp[cur^1] + (1<<S),0);
rep(mask,0,1<<S){
if (dp[cur][mask] == 0) continue;
rep(c,0,S){
// 在camp[mask] 中的才能加
if (!(camp[mask] & (1 << c))) continue;
add(dp[cur ^ 1][trans[mask][c]] , dp[cur][mask]);
}
}
cur ^= 1;
}

ll ans = 0;
rep(mask,1,1<<S) add(ans,dp[cur][mask]);
printf("%lld\n", ans);
}

总结

我的思路里关于 排序,代表元都有了, 这是好事,但是

其实这里一个核心 在于怎么把 最小值元素X,抽象的表示到一个dp的state中

这里给出的state的设计方案是最长两两可换的前缀的bitmask, 和长度来表示

换句话说如果有人告诉我怎么设计state,那么转移方程还是随便写的

参考

官方