Bonsai Grafting
問題ページ - B - Bonsai Grafting
概要
頂点の木と頂点の木が与えられる。木の辺は頂点と頂点を結び、木の辺は頂点と頂点を結ぶ。2つの木から頂点を1つずつ選んでそれらの頂点間に辺を張り、頂点の木をつくる。通りの頂点の選び方それぞれについて、新しい木の直径(同じ頂点を2回以上通らない最長のパスの長さ)を求め、その総和を出力せよ。
制約
- 与えられる2つのグラフはそれぞれ木である
- 入力はすべて整数である
解法
木の各頂点について、木の各頂点を選んだときの新しい木の直径の総和をで求めることができれば、この計算部分は時間でできる。まず、それぞれの木の直径と、2点間の最短距離が直径となるような2点を求める。2つの木の直径のうち大きい方をとする。次に、その2点から全頂点への最短距離を求める。各頂点について、2点からの最短距離のうち大きいほうが、その頂点を根としたときの葉までの最短距離の最大値である。木の頂点についてのこの値をとし、木の頂点についてのこの値をとする。ここで、木について、がとなる頂点の数を数え、前から累積和を取る。また、を昇順に並べ、これも前から累積和を取る。前者を配列とし、後者を配列とする。このとき、木の頂点について、木の各頂点を選んだときの新しい木の直径として、次の2つの場合が考えられる。
- と等しい場合
- 新しい木の直径が前者よりも大きい場合
新しい木の直径は、木の頂点について、のときに前者となり、そうでないときに後者となる。前者の個数はであり、後者の個数はである。後者の場合の新しい木の直径は、木の頂点を選んだとき、となる。よって、これらの和はとなる。この値をとしたとき、がこの問題の答えとなる。この解法では、全体の計算量はである。
提出コード
https://atcoder.jp/contests/yahoo-procon2019-final-open/submissions/4363192#include<bits/stdc++.h> using namespace std; const int IINF = 1e9 + 100; const long long LLINF = 2e18 + 129; const long long MOD = 1e9 + 7; const int dx4[] = {1, 0, -1, 0}, dy4[] = {0, 1, 0, -1}; const int dx8[] = {1, 1, 0, -1, -1, -1, 0, 1}, dy8[] = {0, -1, -1, -1, 0, 1, 1, 1}; const double EPS = 1e-8; template<typename T> bool chmax(T &a, T b){ if(a < b){ a = b; return true; }else{ return false; } } pair<int, int> dfs1(int v, int prev, vector< vector<int> > &g, int &maxi, int &d1, int &d2){ pair<int, int> m1 = {-1, -1}, m2 = {-1, -1}; for(auto x : g[v]){ if(x != prev){ pair<int, int> tmp = dfs1(x, v, g, maxi, d1, d2); ++tmp.first; if(m1 < tmp){ m2 = m1; m1 = tmp; }else if(m2 < tmp){ m2 = tmp; } } } if(m1 == make_pair(-1, -1)){ return {0, v}; }else if(m2 == make_pair(-1, -1)){ if(chmax(maxi, m1.first)){ d1 = v; d2 = m1.second; } }else{ if(chmax(maxi, m1.first + m2.first)){ d1 = m1.second; d2 = m2.second; } } return m1; } void dfs2(int v, int depth, vector< vector<int> > &g, vector<int> &dis){ dis[v] = depth; for(auto x : g[v]){ if(dis[x] > depth + 1){ dfs2(x, depth + 1, g, dis); } } return; } int main(){ cin.tie(0); ios::sync_with_stdio(false); int n; cin >> n; vector< vector<int> > ga(n); for(int i = 0 ; i < n - 1 ; ++i){ int p, q; cin >> p >> q; --p; --q; ga[p].push_back(q); ga[q].push_back(p); } int m; cin >> m; vector< vector<int> > gb(m); for(int i = 0 ; i < m - 1 ; ++i){ int p, q; cin >> p >> q; --p; --q; gb[p].push_back(q); gb[q].push_back(p); } int ma = 0, mb = 0, da1, da2, db1, db2; dfs1(0, -1, ga, ma, da1, da2); dfs1(0, -1, gb, mb, db1, db2); vector<int> disa1(n, IINF), disa2(n, IINF), disb1(m, IINF), disb2(m, IINF); dfs2(da1, 0, ga, disa1); dfs2(da2, 0, ga, disa2); dfs2(db1, 0, gb, disb1); dfs2(db2, 0, gb, disb2); long long maxi = max(disa1[da2], disb1[db2]); int N = max((int)maxi, m) + 5; vector<long long> cntb(N), sumb(m + 1); for(int i = 0 ; i < m ; ++i){ long long t = max(disb1[i], disb2[i]) + 1; ++cntb[t]; sumb[i] = t; } sort(sumb.begin(), sumb.end()); sumb.resize(N); for(int i = 1 ; i < N ; ++i){ cntb[i] += cntb[i - 1]; sumb[i] += sumb[i - 1]; } long long ans = 0; for(int i = 0 ; i < n ; ++i){ long long t = max(disa1[i], disa2[i]); ans += maxi * cntb[maxi - t] + t * (m - cntb[maxi - t]) + sumb[m] - sumb[cntb[maxi - t]]; } cout << ans << endl; return 0; }