「Wannafly 023D」漂亮的公园

题目大意:给定一棵 $n$ 个节点的树,第 $i$ 个节点的颜色为 $c_i$,$q$ 次询问两种颜色的点的最大距离。$n,q \leq 10^5$

根据树直径定理可推出:若颜色 $A$ 的直径端点为 $a,b$,颜色 $B$ 直径端点为 $c,d$,那么 $A$ 到 $B$ 的最长距离只有可能是 $ac,ad,bc,bd$ 四种情况。

预处理每个颜色的直径时,不需要 $n$ 遍 $dfs$,倍增计算出当前节点到旧端点的距离并更新端点即可。

#include <cstdio>
#include <algorithm>
 
const int N = 100005;
 
struct Node {
    int a, b, c;
} a[N];
 
int siz[N], fa[N], son[N], dep[N], top[N], head[N], nxt[N<<1], pre[N<<1], b[N], c[N], tot, n;
 
int read() {
    int x = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') {
        x = (x << 3) + (x << 1) + (c ^ 48);
        c = getchar();
    }
    return x;
}
int max(int x, int y) {
    return x > y ? x : y;
}
void add_edge(int u, int v) {
    pre[++tot] = head[u];
    head[u] = tot, nxt[tot] = v;
}
int find(int x) {
    int l = 1, r = n;
    while (l <= r) {
        int mid = l + (r - l >> 1);
        if (b[mid] == x) return mid;
        if (b[mid] < x) l = mid + 1;
        else r = mid - 1;
    }
    return 0;
}
void dfs1(int cur, int father, int deep) {
    siz[cur] = 1, fa[cur] = father, dep[cur] = deep;
    for (int i = head[cur]; i; i = pre[i]) {
        if (nxt[i] == father) continue;
        dfs1(nxt[i], cur, deep + 1);
        siz[cur] += siz[nxt[i]];
        if (siz[nxt[i]] > siz[son[cur]]) son[cur] = nxt[i];
    }
}
void dfs2(int cur, int tp) {
    top[cur] = tp;
    if (son[cur]) dfs2(son[cur], tp);
    for (int i = head[cur]; i; i = pre[i])
        if (nxt[i] != fa[cur] && nxt[i] != son[cur]) dfs2(nxt[i], nxt[i]);
}
int calc(int x, int y) {
    if (!x || !y) return 0; //注意特判一个颜色只有一个点的情况
    int a = x, b = y;
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) x = fa[top[x]];
        else y = fa[top[y]];
    }
    int c = dep[x] < dep[y] ? x : y;
    return dep[a] - dep[c] + dep[b] - dep[c];
}
void dfs(int cur) { //预处理每个颜色的直径端点
    int x = find(c[cur]);
    if (!a[x].a) a[x].a = cur;
    else if (!a[x].b) a[x].b = cur, a[x].c = calc(a[x].a, cur);
    else {
        int y = calc(a[x].a, cur), z = calc(a[x].b, cur);
        if (y > z && y > a[x].c) a[x].b = cur, a[x].c = y;
        else if (z > a[x].c) a[x].a = cur, a[x].c = z;
    }
    for (int i = head[cur]; i; i = pre[i])
        if (nxt[i] != fa[cur]) dfs(nxt[i]);
}
 
int main() {
    n = read(); int q = read();
    for (int i = 1; i <= n; ++i) b[i] = c[i] = read();
    std::sort(b + 1, b + n + 1);
    for (int i = 1; i < n; ++i) {
        int u = read(), v = read();
        add_edge(u, v), add_edge(v, u);
    }
    dfs1(1, 0, 0), dfs2(1, 1), dfs(1);
    while (q--) {
        int x = read(), y = read();
        x = find(x), y = find(y);
        if (!x || !y) puts("0");
        else printf("%d\n", max(max(calc(a[x].a, a[y].a), calc(a[x].a, a[y].b)),
                                max(calc(a[x].b, a[y].a), calc(a[x].b, a[y].b))));
    }
    return 0;
}

时间复杂度 $O((n+q)\log n)$