1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
|
#include <bits/stdc++.h> #define all(vec) vec.begin(),vec.end() #define pb push_back #define SZ(a) ((int) a.size()) #define FOR(i, a, b) for (int i = (a); i <= (b); ++i) #define ROF(i, a, b) for (int i = (a); i >= (b); --i) #define debug(var) cerr << #var <<":"<<var<<"\n"; #define lson(var) (var<<1) #define rson(var) ((var<<1)+1)
using namespace std;
using ll = long long;using ull = unsigned long long; using DB=double;using LD=long double;
using pdd = pair<DB,DB>;using plb = pair<ll,bool>; using pll = pair<ll,ll>; using arr3 = array<ll,3> ;using arr2 = array<ll,2>; constexpr ll MAXN=static_cast<ll>(1e6)+10,INF=static_cast<ll>(1e17)+9; constexpr ll MAXM=(ll)1e6+10;constexpr ll MAXV=(ll)1e5+10; constexpr ll mod=static_cast<ll>(1e9)+7; constexpr double eps=1e-8;const double pi=acos(-1.0);
ll N,M,Q,X,K,T,lT,A[MAXN];
inline void solve(){ cin>>N; vector<vector<ll> > g(N+5); FOR(i,1,N){ cin>>A[i]; } FOR(i,1,N-1){ ll u,v; cin>>u>>v; g[u].pb(v);g[v].pb(u); } vector<ll> sz(N+5,0);ll mxTree=INF; vector<bool> vis(N+5,0);ll root=0; ll ans=0; FOR(i,1,N){ for(auto &v:g[i]){ if(A[v]==A[i]) ++ans; } } ans/=2; auto get=[&](auto &&self,ll u,ll p,ll tot)->void{ ll val=0;sz[u]=1; for(auto &v:g[u]){ if(v==p || vis[v]) continue; self(self,v,u,tot); sz[u]+=sz[v]; val=max(val,sz[v]); } val=max(val,tot-sz[u]); if(val<mxTree){ mxTree=val; root=u; } }; auto cal=[&](ll u)->void{ map<ll,ll> cnt; auto dfs=[&](auto &&self,ll u,ll p,ll mn)->void{ if(mn>A[u]){ if(cnt.count(A[u])){ ans+=cnt[A[u]]; } ++cnt[A[u]]; } mn=min(A[u],mn); for(auto &v:g[u]){ if(v==p || vis[v]) continue; self(self,v,u,mn); } }; dfs(dfs,u,u,INF); }; auto dfz=[&](auto &&self,ll u)->void{ vis[u]=true; cal(u); for(auto &v:g[u]){ if(vis[v]) continue; mxTree=INF; get(get,v,v,sz[v]); self(self,root); } }; get(get,1,1,N); dfz(dfz,root); ans*=2; cout<<ans<<"\n"; }
int main() { ios::sync_with_stdio(false); cin.tie(0);cout.tie(0); solve(); return 0; }
|