「BJWC 2010」次小生成树

题目大意:求一张图的次小生成树。$n \leq 10^5, m \leq 3 \times 10^5$

枚举每一条非树边找两个顶点树上路径的最大边(如果最大边与非树边边权相同则找次大边)然后更新最小增量。

最大边和次大边可以通过树上倍增求出,分别为 $mx_1[N][18],mx_2[N][18]$。

#include <cstdio>
#include <algorithm>
#include <cstring>

const int N = 100005;

struct Edge1 {
    int u, v, w;
    bool operator < (const Edge1 &rhs) const {
        return w < rhs.w;
    }
} e[300005];

struct Edge2 {
    int to, cost;
} g[N<<1];

int head[N], nxt[N<<1], fa[N], dep[N], f[N][18], mx1[N][18], mx2[N][18], one, two, tot;

int read() {
    int x = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') {
        x = (x << 3) + (x << 1) + (c ^ 48);
        c = getchar();
    }
    return x;
}
void swap(int &x, int &y) {
    x ^= y, y ^= x, x ^= y;
}
int max(int x, int y) {
    return x > y ? x : y;
}
long long min(long long x, long long y) {
    return x < y ? x : y;
}
int find(int x) {
    if (x != fa[x]) fa[x] = find(fa[x]);
    return fa[x];
}
void add_edge(int u, int v, int w) {
    nxt[++tot] = head[u];
    head[u] = tot;
    g[tot] = (Edge2){v, w};
}
void dfs(int cur, int fa, int deep) {
    dep[cur] = deep, f[cur][0] = fa;
    for (int i = 1; (1 << i) < deep; ++i) {
        f[cur][i] = f[f[cur][i-1]][i-1];
        mx1[cur][i] = max(mx1[cur][i-1], mx1[f[cur][i-1]][i-1]);
        if (mx1[cur][i-1] < mx1[cur][i]) mx2[cur][i] = max(mx2[cur][i], mx1[cur][i-1]);
        else mx2[cur][i] = max(mx2[cur][i], mx2[cur][i-1]);
        if (mx1[f[cur][i-1]][i-1] < mx1[cur][i]) mx2[cur][i] = max(mx2[cur][i], mx1[f[cur][i-1]][i-1]);
        else mx2[cur][i] = max(mx2[cur][i], mx2[f[cur][i-1]][i-1]);
    }
    for (int i = head[cur]; i; i = nxt[i])
        if (g[i].to != fa) mx1[g[i].to][0] = g[i].cost, dfs(g[i].to, cur, deep + 1);
}
void update(int x, int y) {
    two = max(two, y);
    if (x > one) two = max(two, one), one = x;
    else if (x < one) two = max(two, x);
}
void lca(int x, int y) {
    one = two = -1;
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 17; i >= 0; --i)
        if (dep[f[x][i]] >= dep[y])
            update(mx1[x][i], mx2[x][i]), x = f[x][i];
    if (x == y) return;
    for (int i = 17; i >= 0; --i)
        if (f[x][i] != f[y][i]) {
            update(mx1[x][i], mx2[x][i]), x = f[x][i];
            update(mx1[y][i], mx2[y][i]), y = f[y][i];
        }
    update(mx1[x][0], -1), update(mx1[y][0], -1);
}

int main() {
    int n = read(), m = read(); long long sum = 0, ans = 1e18;
    for (int i = 1; i <= m; ++i)
        e[i].u = read(), e[i].v = read(), e[i].w = read();
    std::sort(e + 1, e + m + 1);
    for (int i = 1; i <= n; ++i) fa[i] = i;
    for (int i = 1; i <= m; ++i) {
        int fx = find(e[i].u), fy = find(e[i].v);
        if (fx != fy) {
            fa[fx] = fy, sum += e[i].w;
            add_edge(e[i].v, e[i].u, e[i].w);
            add_edge(e[i].u, e[i].v, e[i].w);
        }
    }
    memset(mx1, -1, sizeof mx1);
    memset(mx2, -1, sizeof mx2);
    dfs(1, 0, 1);
    for (int i = 1; i <= m; ++i) {
        lca(e[i].u, e[i].v);
        if (one < e[i].w) ans = min(ans, sum - one + e[i].w);
        else if (~two) ans = min(ans, sum - two + e[i].w);
    }
    printf("%lld\n", ans);
    return 0;
}

/*
3 3
1 2 2
2 3 1
1 3 2
*/

时间复杂度 $O(m \log (nm))$