「HAOI2012」Road

题目大意:给出一张图,对于每条道路,计算有多少条不同的最短路经过该道路。$n \leq 1500, m \leq 5000$

枚举起点,设 $f[u]$ 表示当前起点到 $u$ 之间有多少条不同的最短路,$g[u]$ 表示从 $u$ 开始只走最短路图上的边可以得到多少条不同的路径。记忆化搜索求出。

那么对于一条在当前起点的最短路图上的边 $(u,v)$,累计的答案就是 $f[u]*g[v]$。

#include <cstdio>
#include <cstring>
#include <queue>
#include <cstdlib>
 
const int N = 1505, M = 5005, XZY = 1e9 + 7;
struct Edge {
    int v, w, id;
} e[M];
struct Pair {
    int x, y;
    bool operator < (const Pair &rhs) const {
        return x > rhs.x;
    }
};
int n, m, tot, ans[M], f[N], g[N], dis[N], vis[N], head[N], nxt[M];
 
int read() {
    int x = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') {
        x = (x << 3) + (x << 1) + (c ^ 48);
        c = getchar();
    }
    return x;
}
void adde(int u, int v, int w, int id) {
    nxt[++tot] = head[u];
    head[u] = tot;
    e[tot] = (Edge){v, w, id};
}
void dij(int s) {
    std::priority_queue<Pair> Q;
    memset(dis, 0x3f, sizeof dis);
    dis[s] = 0, Q.push((Pair){0, s});
    int cnt = 0;
    while (!Q.empty()) {
        int u = Q.top().y; Q.pop();
        if (vis[u] == s) continue;
        if (++cnt == n) return;
        vis[u] = s;
        for (int i = head[u]; i; i = nxt[i])
            if (dis[e[i].v] > dis[u] + e[i].w) {
                dis[e[i].v] = dis[u] + e[i].w;
                Q.push((Pair){dis[e[i].v], e[i].v});
            }
    }
}
void bfs(int s) { //拓扑求从s走最短路到每个点的方案数
    int d[N] = {}, q[N] = {}, h = 0, t = -1;
    for (int u = 1; u <= n; ++u)
        for (int i = head[u]; i; i = nxt[i])
            if (dis[u] + e[i].w == dis[e[i].v]) ++d[e[i].v];
    for (int i = 1; i <= n; ++i) if (!d[i]) q[++t] = i;
    f[s] = 1;
    while (h <= t) {
        int u = q[h++];
        for (int i = head[u]; i; i = nxt[i]) {
            if (dis[u] + e[i].w > dis[e[i].v]) continue;
            f[e[i].v] += f[u];
            if (f[e[i].v] >= XZY) f[e[i].v] -= XZY;
            if (--d[e[i].v] == 0) q[++t] = e[i].v;
        }
    }
}
int dfs(int u) { //记搜求从每个点出发走最短路的路径方案数
    if (g[u] != -1) return g[u];
    g[u] = 1;
    for (int i = head[u]; i; i = nxt[i])
        if (dis[u] + e[i].w == dis[e[i].v]) {
            g[u] += dfs(e[i].v);
            if (g[u] >= XZY) g[u] -= XZY;
        }
    return g[u];
}
 
int main() {
    n = read(), m = read();
    for (int i = 1; i <= m; ++i) {
        int u = read(), v = read(), w = read();
        adde(u, v, w, i);
    }
    for (int i = 1; i <= n; ++i) { //以点i为起点
        dij(i);
        memset(f, 0, sizeof f);
        memset(g, -1, sizeof g);
        bfs(i);
        for (int j = 1; j <= n; ++j) dfs(j);
        for (int u = 1; u <= n; ++u)
            for (int i = head[u]; i; i = nxt[i])
                if (dis[u] + e[i].w == dis[e[i].v])
                    ans[e[i].id] = (1LL * f[u] * g[e[i].v] + ans[e[i].id]) % XZY;
    }
    for (int i = 1; i <= m; ++i) printf("%d\n", ans[i]);
    return 0;
}