1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| #include <bits/stdc++.h> #include <atcoder/modint> const int MOD=998244353; using mint = atcoder::static_modint<MOD>; using namespace std; typedef long long ll; #define rep(i,a,n) for (ll i=a;i<(ll)(n);i++) ll read(){ll r;scanf("%lld",&r);return r;} const int N=17; mint f[1<<N]; mint g[1<<N]={1}; mint ans[1<<N]; array<int,2> rb[1<<N]; array<int,2> rbm[1<<N]; int LG[1<<N]; mint fac[100010]={1};
array<int,2> operator+(const array<int,2>&a0,const array<int,2>&a1){ return {a0[0]+a1[0],a0[1]+a1[1]}; } int main(){ rep(i,0,N) LG[1<<i]=i; rep(i,1,100000+1) fac[i]=fac[i-1]*i; int n=read(); int m=read(); rep(i,0,m) rb[1<<(read()-1)][0]++; rep(i,0,m) rb[1<<(read()-1)][1]++; rep(msk,1,1<<n) rbm[msk] = rbm[msk&(msk-1)] + rb[msk&-msk]; rep(msk,1,1<<n) g[msk] = rbm[msk][0]==rbm[msk][1]?fac[rbm[msk][0]]:0; rep(msk,1,1<<n) { f[msk] = g[msk]; int lowbit = msk&-msk; int highbits = msk-lowbit; if(highbits) for(int revsubmsk=highbits;revsubmsk != 0;revsubmsk=(revsubmsk-1)&highbits){ f[msk] -= f[lowbit + (highbits-revsubmsk)]*g[revsubmsk]; } } rep(msk,1,1<<n){ int lowbit = msk&-msk; int highbits = msk-lowbit; for(int revsubmsk=highbits;;revsubmsk=(revsubmsk-1)&highbits){ ans[msk] += f[lowbit + (highbits-revsubmsk)]*(ans[revsubmsk] + g[revsubmsk]); if(revsubmsk==0)break; } } printf("%d\n",(ans[(1<<n)-1]*fac[m].inv()).val()); return 0; }
|