题目大意:给定一棵 $n$ 个节点的树,第 $i$ 个节点的颜色为 $c_i$,$q$ 次询问两种颜色的点的最大距离。$n,q \leq 10^5$
根据树直径定理可推出:若颜色 $A$ 的直径端点为 $a,b$,颜色 $B$ 直径端点为 $c,d$,那么 $A$ 到 $B$ 的最长距离只有可能是 $ac,ad,bc,bd$ 四种情况。
预处理每个颜色的直径时,不需要 $n$ 遍 $dfs$,倍增计算出当前节点到旧端点的距离并更新端点即可。
#include <cstdio>
#include <algorithm>
const int N = 100005;
struct Node {
int a, b, c;
} a[N];
int siz[N], fa[N], son[N], dep[N], top[N], head[N], nxt[N<<1], pre[N<<1], b[N], c[N], tot, n;
int read() {
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x;
}
int max(int x, int y) {
return x > y ? x : y;
}
void add_edge(int u, int v) {
pre[++tot] = head[u];
head[u] = tot, nxt[tot] = v;
}
int find(int x) {
int l = 1, r = n;
while (l <= r) {
int mid = l + (r - l >> 1);
if (b[mid] == x) return mid;
if (b[mid] < x) l = mid + 1;
else r = mid - 1;
}
return 0;
}
void dfs1(int cur, int father, int deep) {
siz[cur] = 1, fa[cur] = father, dep[cur] = deep;
for (int i = head[cur]; i; i = pre[i]) {
if (nxt[i] == father) continue;
dfs1(nxt[i], cur, deep + 1);
siz[cur] += siz[nxt[i]];
if (siz[nxt[i]] > siz[son[cur]]) son[cur] = nxt[i];
}
}
void dfs2(int cur, int tp) {
top[cur] = tp;
if (son[cur]) dfs2(son[cur], tp);
for (int i = head[cur]; i; i = pre[i])
if (nxt[i] != fa[cur] && nxt[i] != son[cur]) dfs2(nxt[i], nxt[i]);
}
int calc(int x, int y) {
if (!x || !y) return 0; //注意特判一个颜色只有一个点的情况
int a = x, b = y;
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) x = fa[top[x]];
else y = fa[top[y]];
}
int c = dep[x] < dep[y] ? x : y;
return dep[a] - dep[c] + dep[b] - dep[c];
}
void dfs(int cur) { //预处理每个颜色的直径端点
int x = find(c[cur]);
if (!a[x].a) a[x].a = cur;
else if (!a[x].b) a[x].b = cur, a[x].c = calc(a[x].a, cur);
else {
int y = calc(a[x].a, cur), z = calc(a[x].b, cur);
if (y > z && y > a[x].c) a[x].b = cur, a[x].c = y;
else if (z > a[x].c) a[x].a = cur, a[x].c = z;
}
for (int i = head[cur]; i; i = pre[i])
if (nxt[i] != fa[cur]) dfs(nxt[i]);
}
int main() {
n = read(); int q = read();
for (int i = 1; i <= n; ++i) b[i] = c[i] = read();
std::sort(b + 1, b + n + 1);
for (int i = 1; i < n; ++i) {
int u = read(), v = read();
add_edge(u, v), add_edge(v, u);
}
dfs1(1, 0, 0), dfs2(1, 1), dfs(1);
while (q--) {
int x = read(), y = read();
x = find(x), y = find(y);
if (!x || !y) puts("0");
else printf("%d\n", max(max(calc(a[x].a, a[y].a), calc(a[x].a, a[y].b)),
max(calc(a[x].b, a[y].a), calc(a[x].b, a[y].b))));
}
return 0;
}
时间复杂度 $O((n+q)\log n)$