「NowCoder」保护

题目大意:给出一棵 $n$ 个节点的树,每条边的长度为 $1$,有 $m$ 支军队,第 $i$ 支军队守卫节点 $x_i$ 到节点 $y_i$ 的最短路径,$q$ 次询问,第 $i$ 次询问给出 $v_i,k_i$,请你求出一个距 $1$ 号节点最近的点 $u_i$,使得有不少于 $k_i$ 支军队,每支军队完全覆盖 $v_i$ 到 $u_i$ 的最短路径。$n,m,q \leq 200000$

首先,将 $1$ 号节点设为根节点,显然 $u$ 在 $v \rightarrow root$ 的路径上选择是最优的。

假设一条路径的两个端点分别是 $s$ 和 $t$,其中 $deep[s] \geq deep[t]$,如果这条路径能完全覆盖 $v \rightarrow u$,则一定满足 $s$ 在 $v$ 的子树中,且 $deep[lca(s,t)] \geq deep[u]$。

离线,对于每一条路径 $s \rightarrow t$,我们在 $s$ 和 $t$ 上各开一个权值线段树,下标为 $deep[lca(s,t)]$,自底向上依次进行线段树合并,当遇到一个有询问标记的节点 $v$ 时,所有端点 $s$ 在 $v$ 的子树中的路径,都已经合并到当前线段树里了。

由于我们希望答案 $u$ 的深度尽量小,于是可以让当前线段树中 $lca(s,t)$ 的深度在前 $k$ 小的路径覆盖 $v \rightarrow u$,那么 $u$ 就是深度第 $k$ 小的 $lca(s,t)$。

实现需要注意以下几点:

  • 由于权值线段树的下标是 $1 \cdots n$,因此 $deep$ 应该从 $1$ 开始标号
  • 查询的时候,如果当前线段树中的元素不足 $k$ 个,应该直接判掉
  • 因为每条路径都要在 $s$ 和 $t$ 上各开一棵线段树,所以空间复杂度应该是 $2n (\log{(2n)}+1)$,而不是 $n (\log n+1)$
#include <cstdio>

const int N = 200005, M = 8000005; //2N(log(2N)+1)

struct Node {
	int k, id;
} nxt1[N];
int head[N], pre[N<<1], nxt[N<<1], tot, head1[N], pre1[N], tot1, cnt, ans[N], n;
int siz[N], fa[N], son[N], dep[N], top[N], sum[M], lson[M], rson[M], node[M], root[N];

int read() {
	int x = 0; char c = getchar();
	while (c < '0' || c > '9') c = getchar();
	while (c >= '0' && c <= '9') {
		x = (x << 3) + (x << 1) + (c ^ 48);
		c = getchar();
	}
	return x;
}
void swap(int &x, int &y) {
	x ^= y, y ^= x, x ^= y;
}
void add_edge(int u, int v) {
	pre[++tot] = head[u];
	head[u] = tot, nxt[tot] = v;
}
void add_tag1(int u, int k, int id) {
	pre1[++tot1] = head1[u];
	head1[u] = tot1, nxt1[tot1] = (Node){k, id};
}
void dfs1(int cur, int father, int deep) {
	siz[cur] = 1, fa[cur] = father, dep[cur] = deep;
	for (int i = head[cur]; i; i = pre[i])
		if (nxt[i] != father) {
			dfs1(nxt[i], cur, deep + 1);
			siz[cur] += siz[nxt[i]];
			if (siz[nxt[i]] > siz[son[cur]]) son[cur] = nxt[i];
		}
}
void dfs2(int cur, int tp) {
	top[cur] = tp;
	if (son[cur]) dfs2(son[cur], tp);
	for (int i = head[cur]; i; i = pre[i])
		if (nxt[i] != fa[cur] && nxt[i] != son[cur]) dfs2(nxt[i], nxt[i]);
}
int lca(int x, int y) {
	while (top[x] != top[y]) {
		if (dep[top[x]] > dep[top[y]]) x = fa[top[x]];
		else y = fa[top[y]];
	}
	if (dep[x] < dep[y]) return x; return y;
}
void insert(int &cur, int l, int r, int p, int x) {
	if (!cur) cur = ++cnt;
	++sum[cur];
	if (l == r) { node[cur] = x; return; }
	int mid = l + (r - l >> 1);
	if (p <= mid) insert(lson[cur], l, mid, p, x);
	else insert(rson[cur], mid + 1, r, p, x);
}
int merge(int &x, int &y) {
	if (!x) return y;
	if (!y) return x;
	sum[x] += sum[y];
	if (!node[x]) node[x] = node[y];
	lson[x] = merge(lson[x], lson[y]);
	rson[x] = merge(rson[x], rson[y]);
	return x;
}
int query(int cur, int l, int r, int k) {
	if (l == r) return node[cur];
	int mid = l + (r - l >> 1);
	if (sum[lson[cur]] >= k) return query(lson[cur], l, mid, k);
	else return query(rson[cur], mid + 1, r, k - sum[lson[cur]]);
}
void dfs(int cur) {
	for (int i = head[cur]; i; i = pre[i])
		if (nxt[i] != fa[cur]) {
			dfs(nxt[i]);
			root[cur] = merge(root[cur], root[nxt[i]]);
		}
	for (int i = head1[cur]; i; i = pre1[i]) {
		int x = sum[root[cur]] >= nxt1[i].k ? query(root[cur], 1, n, nxt1[i].k) : cur; //不足k个的情况直接判掉
		ans[nxt1[i].id] = dep[x] <= dep[cur] ? dep[cur] - dep[x] : 0;
	}
}

int main() {
	n = read(); int m = read();
	for (int i = 1; i < n; ++i) {
		int u = read(), v = read();
		add_edge(u, v), add_edge(v, u);
	}
	dfs1(1, 0, 1), dfs2(1, 1); //dep从1开始标号
	for (int i = 1; i <= m; ++i) {
		int s = read(), t = read(), x = lca(s,t);
		insert(root[s], 1, n, dep[x], x);
		insert(root[t], 1, n, dep[x], x);
	}
	int q = read();
	for (int i = 1; i <= q; ++i) {
		int v = read(), k = read();
		add_tag1(v, k, i);
	}
	dfs(1);
	for (int i = 1; i <= q; ++i) printf("%d\n", ans[i]);
	return 0;
}

时间复杂度 $O((n+q)\log n)$


一开始误把题意理解成只要 $v$ 到 $u$ 的最短路径上的所有节点都有不少于 $k$ 支军队覆盖就可以……

于是写了以下代码,是用启发式合并大根堆来离线做的,堆中存的是询问,关键字是 $k$,自底向上依次合并,每次把 $k$ 大于当前节点的军队数量的询问弹出。

#include <cstdio>
#include <queue>

const int N = 200005;

struct Node {
	int id, v, k;
	bool operator < (const Node &rhs) const {
		return k < rhs.k;
	}
} ask[N];
std::priority_queue<Node> Q[N];

int head[N], nxt[N<<1], to[N<<1], tot;
int head2[N], nxt2[N], tot2;
int siz[N], dep[N], top[N], son[N], fa[N];
int tag[N], ans[N], be[N], n;

int read() {
	int x = 0; char c = getchar();
	while (c < '0' || c > '9') c = getchar();
	while (c >= '0' && c <= '9') {
		x = (x << 3) + (x << 1) + (c ^ 48);
		c = getchar();
	}
	return x;
}
void add_edge(int u, int v) {
	nxt[++tot] = head[u];
	head[u] = tot, to[tot] = v;
}
void dfs1(int u, int f, int d) {
	siz[u] = 1, fa[u] = f, dep[u] = d;
	for (int i = head[u]; i; i = nxt[i]) {
		if (to[i] != f) {
			dfs1(to[i], u, d + 1);
			if (siz[to[i]] > siz[son[u]]) son[u] = to[i];
			siz[u] += siz[to[i]];
		}
	}
}
void dfs2(int u, int t) {
	top[u] = t;
	if (son[u]) dfs2(son[u], t);
	for (int i = head[u]; i; i = nxt[i])
		if (to[i] != fa[u] && to[i] != son[u]) dfs2(to[i], to[i]);
}
int lca(int x, int y) {
	while (top[x] != top[y]) {
		if (dep[top[x]] > dep[top[y]]) x = fa[top[x]];
		else y = fa[top[y]];
	}
	if (dep[x] < dep[y]) return x;
	return y;
}
void dfs3(int u) {
	for (int i = head[u]; i; i = nxt[i])
		if (to[i] != fa[u]) {
			dfs3(to[i]);
			tag[u] += tag[to[i]];
		}
}
void dfs4(int u) {
	for (int i = head[u]; i; i = nxt[i]) {
		if (to[i] != fa[u]) {
			dfs4(to[i]);
			if (Q[be[u]].size() < Q[be[to[i]]].size()) {
				while (!Q[be[u]].empty()) Q[be[to[i]]].push(Q[be[u]].top()), Q[be[u]].pop();
				be[u] = be[to[i]];
			} else {
				while (!Q[be[to[i]]].empty()) Q[be[u]].push(Q[be[to[i]]].top()), Q[be[to[i]]].pop();
				be[to[i]] = be[u];
			}
		}
	}
	for (int i = head2[u]; i; i = nxt2[i]) {
		if (ask[i].k <= tag[u]) Q[be[u]].push(ask[i]);
		else ans[ask[i].id] = 0;
	}
	while (!Q[be[u]].empty() && Q[be[u]].top().k > tag[fa[u]]) {
		ans[Q[be[u]].top().id] = dep[Q[be[u]].top().v] - dep[u];
		Q[be[u]].pop();
	}
}

int main() {
	n = read(); int m = read();
	for (int i = 1; i < n; ++i) {
		int u = read(), v = read();
		add_edge(u, v), add_edge(v, u);
	}
	dfs1(1, 0, 0), dfs2(1, 1);
	for (int i = 1; i <= m; ++i) {
		int u = read(), v = read(), c = lca(u, v);
		++tag[u], ++tag[v], --tag[c], --tag[fa[c]];
	}
	dfs3(1);
	int q = read();
	for (int i = 1; i <= q; ++i) {
		int v = read(), k = read();
		nxt2[++tot2] = head2[v], head2[v] = tot2, ask[tot2] = (Node){i, v, k};
	}
	for (int i = 1; i <= n; ++i) be[i] = i;
	dfs4(1);
	for (int i = 1; i <= q; ++i) printf("%d\n", ans[i]);
	return 0;
}

时间复杂度 $O((n+q)\log n)$