题目大意:给出一棵 $n$ 个节点的树,每条边的长度为 $1$,有 $m$ 支军队,第 $i$ 支军队守卫节点 $x_i$ 到节点 $y_i$ 的最短路径,$q$ 次询问,第 $i$ 次询问给出 $v_i,k_i$,请你求出一个距 $1$ 号节点最近的点 $u_i$,使得有不少于 $k_i$ 支军队,每支军队完全覆盖 $v_i$ 到 $u_i$ 的最短路径。$n,m,q \leq 200000$
首先,将 $1$ 号节点设为根节点,显然 $u$ 在 $v \rightarrow root$ 的路径上选择是最优的。
假设一条路径的两个端点分别是 $s$ 和 $t$,其中 $deep[s] \geq deep[t]$,如果这条路径能完全覆盖 $v \rightarrow u$,则一定满足 $s$ 在 $v$ 的子树中,且 $deep[lca(s,t)] \geq deep[u]$。
离线,对于每一条路径 $s \rightarrow t$,我们在 $s$ 和 $t$ 上各开一个权值线段树,下标为 $deep[lca(s,t)]$,自底向上依次进行线段树合并,当遇到一个有询问标记的节点 $v$ 时,所有端点 $s$ 在 $v$ 的子树中的路径,都已经合并到当前线段树里了。
由于我们希望答案 $u$ 的深度尽量小,于是可以让当前线段树中 $lca(s,t)$ 的深度在前 $k$ 小的路径覆盖 $v \rightarrow u$,那么 $u$ 就是深度第 $k$ 小的 $lca(s,t)$。
实现需要注意以下几点:
- 由于权值线段树的下标是 $1 \cdots n$,因此 $deep$ 应该从 $1$ 开始标号
- 查询的时候,如果当前线段树中的元素不足 $k$ 个,应该直接判掉
- 因为每条路径都要在 $s$ 和 $t$ 上各开一棵线段树,所以空间复杂度应该是 $2n (\log{(2n)}+1)$,而不是 $n (\log n+1)$
#include <cstdio>
const int N = 200005, M = 8000005; //2N(log(2N)+1)
struct Node {
int k, id;
} nxt1[N];
int head[N], pre[N<<1], nxt[N<<1], tot, head1[N], pre1[N], tot1, cnt, ans[N], n;
int siz[N], fa[N], son[N], dep[N], top[N], sum[M], lson[M], rson[M], node[M], root[N];
int read() {
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x;
}
void swap(int &x, int &y) {
x ^= y, y ^= x, x ^= y;
}
void add_edge(int u, int v) {
pre[++tot] = head[u];
head[u] = tot, nxt[tot] = v;
}
void add_tag1(int u, int k, int id) {
pre1[++tot1] = head1[u];
head1[u] = tot1, nxt1[tot1] = (Node){k, id};
}
void dfs1(int cur, int father, int deep) {
siz[cur] = 1, fa[cur] = father, dep[cur] = deep;
for (int i = head[cur]; i; i = pre[i])
if (nxt[i] != father) {
dfs1(nxt[i], cur, deep + 1);
siz[cur] += siz[nxt[i]];
if (siz[nxt[i]] > siz[son[cur]]) son[cur] = nxt[i];
}
}
void dfs2(int cur, int tp) {
top[cur] = tp;
if (son[cur]) dfs2(son[cur], tp);
for (int i = head[cur]; i; i = pre[i])
if (nxt[i] != fa[cur] && nxt[i] != son[cur]) dfs2(nxt[i], nxt[i]);
}
int lca(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) x = fa[top[x]];
else y = fa[top[y]];
}
if (dep[x] < dep[y]) return x; return y;
}
void insert(int &cur, int l, int r, int p, int x) {
if (!cur) cur = ++cnt;
++sum[cur];
if (l == r) { node[cur] = x; return; }
int mid = l + (r - l >> 1);
if (p <= mid) insert(lson[cur], l, mid, p, x);
else insert(rson[cur], mid + 1, r, p, x);
}
int merge(int &x, int &y) {
if (!x) return y;
if (!y) return x;
sum[x] += sum[y];
if (!node[x]) node[x] = node[y];
lson[x] = merge(lson[x], lson[y]);
rson[x] = merge(rson[x], rson[y]);
return x;
}
int query(int cur, int l, int r, int k) {
if (l == r) return node[cur];
int mid = l + (r - l >> 1);
if (sum[lson[cur]] >= k) return query(lson[cur], l, mid, k);
else return query(rson[cur], mid + 1, r, k - sum[lson[cur]]);
}
void dfs(int cur) {
for (int i = head[cur]; i; i = pre[i])
if (nxt[i] != fa[cur]) {
dfs(nxt[i]);
root[cur] = merge(root[cur], root[nxt[i]]);
}
for (int i = head1[cur]; i; i = pre1[i]) {
int x = sum[root[cur]] >= nxt1[i].k ? query(root[cur], 1, n, nxt1[i].k) : cur; //不足k个的情况直接判掉
ans[nxt1[i].id] = dep[x] <= dep[cur] ? dep[cur] - dep[x] : 0;
}
}
int main() {
n = read(); int m = read();
for (int i = 1; i < n; ++i) {
int u = read(), v = read();
add_edge(u, v), add_edge(v, u);
}
dfs1(1, 0, 1), dfs2(1, 1); //dep从1开始标号
for (int i = 1; i <= m; ++i) {
int s = read(), t = read(), x = lca(s,t);
insert(root[s], 1, n, dep[x], x);
insert(root[t], 1, n, dep[x], x);
}
int q = read();
for (int i = 1; i <= q; ++i) {
int v = read(), k = read();
add_tag1(v, k, i);
}
dfs(1);
for (int i = 1; i <= q; ++i) printf("%d\n", ans[i]);
return 0;
}
时间复杂度 $O((n+q)\log n)$
一开始误把题意理解成只要 $v$ 到 $u$ 的最短路径上的所有节点都有不少于 $k$ 支军队覆盖就可以……
于是写了以下代码,是用启发式合并大根堆来离线做的,堆中存的是询问,关键字是 $k$,自底向上依次合并,每次把 $k$ 大于当前节点的军队数量的询问弹出。
#include <cstdio>
#include <queue>
const int N = 200005;
struct Node {
int id, v, k;
bool operator < (const Node &rhs) const {
return k < rhs.k;
}
} ask[N];
std::priority_queue<Node> Q[N];
int head[N], nxt[N<<1], to[N<<1], tot;
int head2[N], nxt2[N], tot2;
int siz[N], dep[N], top[N], son[N], fa[N];
int tag[N], ans[N], be[N], n;
int read() {
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x;
}
void add_edge(int u, int v) {
nxt[++tot] = head[u];
head[u] = tot, to[tot] = v;
}
void dfs1(int u, int f, int d) {
siz[u] = 1, fa[u] = f, dep[u] = d;
for (int i = head[u]; i; i = nxt[i]) {
if (to[i] != f) {
dfs1(to[i], u, d + 1);
if (siz[to[i]] > siz[son[u]]) son[u] = to[i];
siz[u] += siz[to[i]];
}
}
}
void dfs2(int u, int t) {
top[u] = t;
if (son[u]) dfs2(son[u], t);
for (int i = head[u]; i; i = nxt[i])
if (to[i] != fa[u] && to[i] != son[u]) dfs2(to[i], to[i]);
}
int lca(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) x = fa[top[x]];
else y = fa[top[y]];
}
if (dep[x] < dep[y]) return x;
return y;
}
void dfs3(int u) {
for (int i = head[u]; i; i = nxt[i])
if (to[i] != fa[u]) {
dfs3(to[i]);
tag[u] += tag[to[i]];
}
}
void dfs4(int u) {
for (int i = head[u]; i; i = nxt[i]) {
if (to[i] != fa[u]) {
dfs4(to[i]);
if (Q[be[u]].size() < Q[be[to[i]]].size()) {
while (!Q[be[u]].empty()) Q[be[to[i]]].push(Q[be[u]].top()), Q[be[u]].pop();
be[u] = be[to[i]];
} else {
while (!Q[be[to[i]]].empty()) Q[be[u]].push(Q[be[to[i]]].top()), Q[be[to[i]]].pop();
be[to[i]] = be[u];
}
}
}
for (int i = head2[u]; i; i = nxt2[i]) {
if (ask[i].k <= tag[u]) Q[be[u]].push(ask[i]);
else ans[ask[i].id] = 0;
}
while (!Q[be[u]].empty() && Q[be[u]].top().k > tag[fa[u]]) {
ans[Q[be[u]].top().id] = dep[Q[be[u]].top().v] - dep[u];
Q[be[u]].pop();
}
}
int main() {
n = read(); int m = read();
for (int i = 1; i < n; ++i) {
int u = read(), v = read();
add_edge(u, v), add_edge(v, u);
}
dfs1(1, 0, 0), dfs2(1, 1);
for (int i = 1; i <= m; ++i) {
int u = read(), v = read(), c = lca(u, v);
++tag[u], ++tag[v], --tag[c], --tag[fa[c]];
}
dfs3(1);
int q = read();
for (int i = 1; i <= q; ++i) {
int v = read(), k = read();
nxt2[++tot2] = head2[v], head2[v] = tot2, ask[tot2] = (Node){i, v, k};
}
for (int i = 1; i <= n; ++i) be[i] = i;
dfs4(1);
for (int i = 1; i <= q; ++i) printf("%d\n", ans[i]);
return 0;
}
时间复杂度 $O((n+q)\log n)$