1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| #include <bits/stdc++.h> using namespace std; typedef long long ll; #define rep(i,a,n) for(ll i=(a);i<(ll)(n);i++)
#define SEG_ROOT 0,0,nn #define SEG_L (o*2+1) #define SEG_R (o*2+2) #define mid (l+r)/2 #define SEG_L_CHILD SEG_L,l,mid #define SEG_R_CHILD SEG_R,mid,r
ll read(){ll r;scanf("%lld",&r);return r;} vector<int> e[300010]; int c[300010]; vector<int> cstk[300010]; vector<int> childc[300010]; array<int,2> lr[300010]; pair<int,int> seg[4*300000+10]; void dfsc(int u,int &idx){ if(cstk[c[u]].size()) childc[cstk[c[u]].back()].push_back(u); lr[u][0] = idx; if(e[u].size() == 0) { idx++; }else{ cstk[c[u]].push_back(u); for(auto v:e[u]) dfsc(v, idx); cstk[c[u]].pop_back(); } lr[u][1] = idx; } void up(int o,int inc){ seg[o].first += inc; seg[o].second += inc; } void down(int o){ if(seg[o].second){ up(SEG_L, seg[o].second); up(SEG_R, seg[o].second); seg[o].second = 0; } } void add(int o,int l,int r,int ql,int qr,int inc){ if(ql <= l and r <= qr) { up(o,inc); return ; } down(o); if(ql < mid) add(SEG_L_CHILD, ql, qr, inc); if(qr > mid) add(SEG_R_CHILD, ql, qr, inc); seg[o].first = max(seg[SEG_L].first, seg[SEG_R].first); } void build(int o,int l,int r){ seg[o] = {0,0}; if(l+1==r) return ; build(SEG_L_CHILD); build(SEG_R_CHILD); } int query(int o,int l,int r,int ql,int qr){ if(ql<=l and r<=qr) return seg[o].first; down(o); int ret = 0; if(ql < mid) ret = max(ret, query(SEG_L_CHILD, ql, qr)); if(qr > mid) ret = max(ret, query(SEG_R_CHILD, ql, qr)); return ret; }
void dfs1(int u,int nn, ll &ans){ auto [l,r] = lr[u]; for(auto v: e[u]) dfs1(v,nn,ans); for(auto v: childc[u]) add(SEG_ROOT,lr[v][0],lr[v][1],-1); add(SEG_ROOT, l, r, 1); vector<ll> mx; for(auto v: e[u]) mx.push_back(query(SEG_ROOT,lr[v][0], lr[v][1])); sort(begin(mx),end(mx)); if(mx.size() > 0) ans = max(ans,mx[mx.size()-1]); if(mx.size() > 1) ans = max(ans,mx[mx.size()-1] * mx[mx.size()-2]); }
void w(){ int n = read(); rep(i,1,n+1) e[i] = {}; rep(i,1,n+1) childc[i] = {}; rep(i,2,n+1) e[read()].push_back(i); rep(i,1,n+1) c[i] = read(); int nn = 0; dfsc(1, nn); build(SEG_ROOT); ll ans = 1; dfs1(1, nn, ans); printf("%lld\n",ans); }
int main(){ int t = read(); while(t--) w(); return 0; }
|