Atcoder abc321

https://atcoder.jp/contests/abc321

G - Electric Circuit

3种点

n个 part点

m个红点

m个蓝点

给定每个 颜色点连到一个part点

然后红点和蓝点的直接连线 需要单射且满射,那么有m!种

对于所有连法 求 Exp(连接后的连通块个数)

n 17

m 1e5

3s

1024mb

我的思路

n 很小想用bitmask

首先根据 n来整理,变成 有n个点,每个点有ri个红色,bi个蓝色

然后要把所有红蓝相连,求连通块的个数期望

1
2
3
4
5
6
7
8
9
f[mask] = mask内连成一个连通块 且无剩余蓝红的方案数
g[mask] = mask内无剩余蓝红的方案数
ans[s] = sum s的所有方案的连通块个数
ans[s] = for mask, mask 包含最小点
f[mask] * ans[s-mask] 两边互不影响计算右边贡献
+ f[mask] * g[mask] 计算左边贡献

g[mask] = sum红! , 满足sum红==sum蓝
f[mask] = g[mask] - sum f[submask]*g[mask-submask], submask包含mask中最小的

似乎就AC了

代码

https://atcoder.jp/contests/abc321/submissions/49762974

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
#include <atcoder/modint>
const int MOD=998244353;
using mint = atcoder::static_modint<MOD>;
using namespace std;
typedef long long ll;
#define rep(i,a,n) for (ll i=a;i<(ll)(n);i++)
ll read(){ll r;scanf("%lld",&r);return r;}
const int N=17;
mint f[1<<N]; // = mask内连成一个连通块 且无剩余蓝红的方案数
mint g[1<<N]={1}; // = mask内无剩余蓝红的方案数, 红=蓝, g[] = 个数!
mint ans[1<<N]; // sum s的所有方案的连通块个数
array<int,2> rb[1<<N];
array<int,2> rbm[1<<N]; // red blue mask
int LG[1<<N];
mint fac[100010]={1};

array<int,2> operator+(const array<int,2>&a0,const array<int,2>&a1){ return {a0[0]+a1[0],a0[1]+a1[1]}; }
int main(){
rep(i,0,N) LG[1<<i]=i;
rep(i,1,100000+1) fac[i]=fac[i-1]*i;
int n=read();
int m=read();
rep(i,0,m) rb[1<<(read()-1)][0]++;
rep(i,0,m) rb[1<<(read()-1)][1]++;
rep(msk,1,1<<n) rbm[msk] = rbm[msk&(msk-1)] + rb[msk&-msk]; // 统计msk 红蓝
rep(msk,1,1<<n) g[msk] = rbm[msk][0]==rbm[msk][1]?fac[rbm[msk][0]]:0; // 红==蓝,则 红!
rep(msk,1,1<<n) {
f[msk] = g[msk];
// submsk != msk, submsk \subset msk, submsk 包含msk最低位, submsk + revsubmsk = msk
int lowbit = msk&-msk;
int highbits = msk-lowbit;
if(highbits) for(int revsubmsk=highbits;revsubmsk != 0;revsubmsk=(revsubmsk-1)&highbits){
f[msk] -= f[lowbit + (highbits-revsubmsk)]*g[revsubmsk];
}
}
rep(msk,1,1<<n){
int lowbit = msk&-msk;
int highbits = msk-lowbit;
for(int revsubmsk=highbits;;revsubmsk=(revsubmsk-1)&highbits){
ans[msk] += f[lowbit + (highbits-revsubmsk)]*(ans[revsubmsk] + g[revsubmsk]);
if(revsubmsk==0)break;
}
}
printf("%d\n",(ans[(1<<n)-1]*fac[m].inv()).val());
return 0;
}

总结

G: 没啥难的