橋梁管理日誌

日誌と言いながら日次での更新はされない模様

Bonsai Grafting

問題ページ - B - Bonsai Grafting

概要

N頂点の木AM頂点の木Bが与えられる。木Aの辺は頂点p_{A_i}と頂点q_{A_i} (1 \leq i \leq N - 1)を結び、木Bの辺は頂点p_{B_i}と頂点q_{B_i} (1 \leq i \leq M - 1)を結ぶ。2つの木から頂点を1つずつ選んでそれらの頂点間に辺を張り、N + M頂点の木をつくる。NM通りの頂点の選び方それぞれについて、新しい木の直径(同じ頂点を2回以上通らない最長のパスの長さ)を求め、その総和を出力せよ。

制約
  • 2 \leq N, M \leq 10^5
  •  1 \leq p_{A_i}, q_{A_i} \leq N
  •  1 \leq p_{B_i}, q_{B_i} \leq M
  • 与えられる2つのグラフはそれぞれ木である
  • 入力はすべて整数である
解法

Aの各頂点について、木Bの各頂点を選んだときの新しい木の直径の総和をO(1)で求めることができれば、この計算部分はO(N)時間でできる。まず、それぞれの木の直径と、2点間の最短距離が直径となるような2点を求める。2つの木の直径のうち大きい方をmaxiとする。次に、その2点から全頂点への最短距離を求める。各頂点について、2点からの最短距離のうち大きいほうが、その頂点を根としたときの葉までの最短距離の最大値である。木Aの頂点iについてのこの値をt_{A_i}とし、木Bの頂点iについてのこの値をt_{B_i}とする。ここで、木Bについて、t + 1iとなる頂点の数を数え、前から累積和を取る。また、 t + 1を昇順に並べ、これも前から累積和を取る。前者を配列cntとし、後者を配列sumとする。このとき、木Aの頂点iについて、木Bの各頂点を選んだときの新しい木の直径として、次の2つの場合が考えられる。

  • maxiと等しい場合
  • 新しい木の直径が前者よりも大きい場合

新しい木の直径は、木Bの頂点jについて、t_{A_i} + t_{B_j} + 1 \leq maxiのときに前者となり、そうでないときに後者となる。前者の個数はcnt_{maxi - t_{A_i}}であり、後者の個数はM - cnt_{maxi - t_{A_i}}である。後者の場合の新しい木の直径は、木Bの頂点jを選んだとき、t_{A_i} + t_{B_j} + 1となる。よって、これらの和はmaxi \times cnt_{maxi - t_{A_i}} + t_i \times (M - cnt_{maxi - t_{A_i}}) + sum_M - sum_{cnt_{maxi - t_{A_i}}}となる。この値をS_iとしたとき、\sum_{k = 1}^{N} S_kがこの問題の答えとなる。この解法では、全体の計算量はO(N + MlogM)である。

提出コード
https://atcoder.jp/contests/yahoo-procon2019-final-open/submissions/4363192

#include<bits/stdc++.h>
using namespace std;

const int IINF = 1e9 + 100;
const long long LLINF = 2e18 + 129;
const long long MOD = 1e9 + 7;
const int dx4[] = {1, 0, -1, 0}, dy4[] = {0, 1, 0, -1};
const int dx8[] = {1, 1, 0, -1, -1, -1, 0, 1}, dy8[] = {0, -1, -1, -1, 0, 1, 1, 1};
const double EPS = 1e-8;

template<typename T>
bool chmax(T &a, T b){
    if(a < b){
        a = b;
        return true;
    }else{
        return false;
    }
}

pair<int, int> dfs1(int v, int prev, vector< vector<int> > &g, int &maxi, int &d1, int &d2){
    pair<int, int> m1 = {-1, -1}, m2 = {-1, -1};
    for(auto x : g[v]){
        if(x != prev){
            pair<int, int> tmp = dfs1(x, v, g, maxi, d1, d2);
            ++tmp.first;
            if(m1 < tmp){
                m2 = m1;
                m1 = tmp;
            }else if(m2 < tmp){
                m2 = tmp;
            }
        }
    }

    if(m1 == make_pair(-1, -1)){
        return {0, v};
    }else if(m2 == make_pair(-1, -1)){
        if(chmax(maxi, m1.first)){
            d1 = v;
            d2 = m1.second;
        }
    }else{
        if(chmax(maxi, m1.first + m2.first)){
            d1 = m1.second;
            d2 = m2.second;
        }
    }

    return m1;
}

void dfs2(int v, int depth, vector< vector<int> > &g, vector<int> &dis){
    dis[v] = depth;
    for(auto x : g[v]){
        if(dis[x] > depth + 1){
            dfs2(x, depth + 1, g, dis);
        }
    }
    return;
}

int main(){
    cin.tie(0);
    ios::sync_with_stdio(false);

    int n;
    cin >> n;

    vector< vector<int> > ga(n);
    for(int i = 0 ; i < n - 1 ; ++i){
        int p, q;
        cin >> p >> q;
        --p;
        --q;
        ga[p].push_back(q);
        ga[q].push_back(p);
    }

    int m;
    cin >> m;

    vector< vector<int> > gb(m);
    for(int i = 0 ; i < m - 1 ; ++i){
        int p, q;
        cin >> p >> q;
        --p;
        --q;
        gb[p].push_back(q);
        gb[q].push_back(p);
    }

    int ma = 0, mb = 0, da1, da2, db1, db2;
    dfs1(0, -1, ga, ma, da1, da2);
    dfs1(0, -1, gb, mb, db1, db2);

    vector<int> disa1(n, IINF), disa2(n, IINF), disb1(m, IINF), disb2(m, IINF);
    dfs2(da1, 0, ga, disa1);
    dfs2(da2, 0, ga, disa2);
    dfs2(db1, 0, gb, disb1);
    dfs2(db2, 0, gb, disb2);

    long long maxi = max(disa1[da2], disb1[db2]);
    int N = max((int)maxi, m) + 5;
    vector<long long> cntb(N), sumb(m + 1);
    for(int i = 0 ; i < m ; ++i){
        long long t = max(disb1[i], disb2[i]) + 1;
        ++cntb[t];
        sumb[i] = t;
    }
    sort(sumb.begin(), sumb.end());
    sumb.resize(N);
    for(int i = 1 ; i < N ; ++i){
        cntb[i] += cntb[i - 1];
        sumb[i] += sumb[i - 1];
    }

    long long ans = 0;
    for(int i = 0 ; i < n ; ++i){
        long long t = max(disa1[i], disa2[i]);
        ans += maxi * cntb[maxi - t] + t * (m - cntb[maxi - t]) + sumb[m] - sumb[cntb[maxi - t]];
    }

    cout << ans << endl;

    return 0;
}

感想
700点自力AC嬉しい