「BZOJ 4987」Tree

题目大意:给定一棵 $n$ 个点的树,找出 $k$ 个点并按任意顺序排成一个序列 $A_1,A_2,\cdots,A_k$,使得 $\sum dis(A_i,A_i+1)$ 最小。$n \leq 3000$

显然这 $k$ 个点要在一个连通块内,且一个连通块的答案为所有的边权和 $\times 2$ 减去这个连通块的直径。

设 $f[i][j][k]$ 表示在 $i$ 的子树上选择了 $j$ 个点的连通块(包含 $i$)且确定了 $k=0/1/2$ 个直径端点的最小值,$k=0$ 时 $f[i][j][k]$ 表示所有边权和 $\times 2$ 的最小值,$k=1$ 表示所有边权和 $\times 2$ 减去 $\max{dis(i,j)}$ 的最小值($j$ 为连通块中的点),$k=2$ 表示所有边权和 $\times 2$ 减去该连通块的直径的最小值。

复杂度分析:相当于在以 $i$ 为根的子树中枚举点对,且这两个点分别在不同儿子的子树中,那么对于任意点对 $(u,v)$,当且仅当dfs到 $lca(u,v)$ 时会被枚举到,因此时间复杂度为 $O(n^2)$。

注意树形背包要枚举当前子树外和当前子树,否则会很慢。

#include <cstdio>
#include <cstring>
 
const int N = 3005;
 
struct Edge {
    int to, cost;
} e[N<<1];
 
int f[N][N][3], siz[N], head[N], nxt[N<<1], ans = 0x7fffffff, tot, n, m;
 
int read() {
    int x = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') {
        x = (x << 3) + (x << 1) + (c ^ 48);
        c = getchar();
    }
    return x;
}
int min(int x, int y) {
    return x < y ? x : y;
}
void add_edge(int u, int v, int w) {
    nxt[++tot] = head[u];
    head[u] = tot;
    e[tot] = (Edge){v, w};
}
void dfs(int cur, int fa) {
    f[cur][1][0] = f[cur][1][1] = 0, siz[cur] = 1;
    for (int i = head[cur]; i; i = nxt[i]) {
        if (e[i].to == fa) continue;
        dfs(e[i].to, cur);
        for (int j = min(siz[cur], m); j >= 0; --j)
            for (int k = min(siz[e[i].to], m - j); k >= 0; --k) {
                f[cur][j+k][0] = min(f[cur][j+k][0], f[cur][j][0] + f[e[i].to][k][0] + (e[i].cost << 1));
                f[cur][j+k][1] = min(f[cur][j+k][1], f[cur][j][0] + f[e[i].to][k][1] + e[i].cost);
                f[cur][j+k][1] = min(f[cur][j+k][1], f[cur][j][1] + f[e[i].to][k][0] + (e[i].cost << 1));
                f[cur][j+k][2] = min(f[cur][j+k][2], f[cur][j][1] + f[e[i].to][k][1] + e[i].cost);
                f[cur][j+k][2] = min(f[cur][j+k][2], f[cur][j][0] + f[e[i].to][k][2] + (e[i].cost << 1));
                f[cur][j+k][2] = min(f[cur][j+k][2], f[cur][j][2] + f[e[i].to][k][0] + (e[i].cost << 1));
            }
        siz[cur] += siz[e[i].to];
    }
    ans = min(ans, f[cur][m][2]);
}
 
int main() {
    n = read(), m = read();
    for (int i = 1; i < n; ++i) {
        int u = read(), v = read(), w = read();
        add_edge(u, v, w), add_edge(v, u, w);
    }
    memset(f, 0x3f, sizeof f);
    if (m == 1) ans = 0; else dfs(1, 0); //特判m=1的情况
    printf("%d\n", ans);
    return 0;
}

时间复杂度 $O(n^2)$