1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
| #include <bits/stdc++.h> #include <atcoder/modint> const int MOD = 998244353; using mint = atcoder::static_modint<MOD>; using namespace std;
typedef long long ll; #define rep(i, a, n) for (ll i = a; i < (ll)(n); i++) #define per(i, a, n) for (ll i = n; i-- > (ll)(a);)
ll read() { ll r; scanf("%lld", &r); return r; }
#define N 200010 int a[N]; vector<int> e[N];
int sz[N],fa[N],dep[N]; void dfs1(int u,int F) { per(i,0,size(e[u])) if(e[u][i]==F) { swap(e[u][i],e[u].back()); e[u].pop_back(); } sz[u]=1; dep[u]=dep[fa[u]=F]+1; rep(i,0,size(e[u])) { int v=e[u][i]; dfs1(v,u); sz[u]+=sz[v]; if(sz[v]>sz[e[u][0]]) swap(e[u][0],e[u][i]); } }
int dfn[N],root[N]; void dfs2(int u,int R,int &cur) { dfn[u]=cur++; root[u]=R; if(sz[u] == 1) return; dfs2(e[u][0],R,cur); rep(i,1,size(e[u])) dfs2(e[u][i],e[u][i],cur); } int lca(int u,int v) { for(;root[u]!=root[v];u=fa[root[u]]) if(dep[root[u]]<dep[root[v]]) swap(u,v); return dep[u]<dep[v]?u:v; } vector<int> eg[N]; vector<int> fake(vector<int> v) { auto cmp=[&](int x,int y){return dfn[x]<dfn[y];}; vector<int> ans=v; sort(v.begin(),v.end(),cmp); rep(i,1,size(v)) ans.push_back(lca(v[i-1],v[i])); ans.push_back(1); sort(ans.begin(),ans.end(),cmp); ans.resize(unique(ans.begin(),ans.end())-ans.begin()); rep(i,1,size(ans)) eg[lca(ans[i-1],ans[i])].push_back(ans[i]); return ans; } mint f[N]; void dfs(int x,int C,mint&ans) { f[x]=1; mint tot=0; for(auto i:eg[x]) { dfs(i,C,ans); tot+=f[i]; f[x]*=(f[i]+1); } if(a[x]!=C){ f[x]-=1; ans+=f[x]-tot; }else{ ans+=f[x]; } } vector<int> c2i[N]; int main() { int n=read(); rep(i,1,n+1) { a[i]=read(); c2i[a[i]].push_back(i); } rep(i,1,n) { int u=read(); int v=read(); e[u].push_back(v); e[v].push_back(u); } dfs1(1,0); int cur=0; dfs2(1,1,cur); mint ans=0; rep(i,1,n+1) if(c2i[i].size()) { vector<int> v=fake(c2i[i]); dfs(1,i,ans); for(auto j:v) eg[j].clear(); } printf("%d",ans.val()); return 0; }
|