题目大意:给出一张图,对于每条道路,计算有多少条不同的最短路经过该道路。$n \leq 1500, m \leq 5000$
枚举起点,设 $f[u]$ 表示当前起点到 $u$ 之间有多少条不同的最短路,$g[u]$ 表示从 $u$ 开始只走最短路图上的边可以得到多少条不同的路径。记忆化搜索求出。
那么对于一条在当前起点的最短路图上的边 $(u,v)$,累计的答案就是 $f[u]*g[v]$。
#include <cstdio>
#include <cstring>
#include <queue>
#include <cstdlib>
const int N = 1505, M = 5005, XZY = 1e9 + 7;
struct Edge {
int v, w, id;
} e[M];
struct Pair {
int x, y;
bool operator < (const Pair &rhs) const {
return x > rhs.x;
}
};
int n, m, tot, ans[M], f[N], g[N], dis[N], vis[N], head[N], nxt[M];
int read() {
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x;
}
void adde(int u, int v, int w, int id) {
nxt[++tot] = head[u];
head[u] = tot;
e[tot] = (Edge){v, w, id};
}
void dij(int s) {
std::priority_queue<Pair> Q;
memset(dis, 0x3f, sizeof dis);
dis[s] = 0, Q.push((Pair){0, s});
int cnt = 0;
while (!Q.empty()) {
int u = Q.top().y; Q.pop();
if (vis[u] == s) continue;
if (++cnt == n) return;
vis[u] = s;
for (int i = head[u]; i; i = nxt[i])
if (dis[e[i].v] > dis[u] + e[i].w) {
dis[e[i].v] = dis[u] + e[i].w;
Q.push((Pair){dis[e[i].v], e[i].v});
}
}
}
void bfs(int s) { //拓扑求从s走最短路到每个点的方案数
int d[N] = {}, q[N] = {}, h = 0, t = -1;
for (int u = 1; u <= n; ++u)
for (int i = head[u]; i; i = nxt[i])
if (dis[u] + e[i].w == dis[e[i].v]) ++d[e[i].v];
for (int i = 1; i <= n; ++i) if (!d[i]) q[++t] = i;
f[s] = 1;
while (h <= t) {
int u = q[h++];
for (int i = head[u]; i; i = nxt[i]) {
if (dis[u] + e[i].w > dis[e[i].v]) continue;
f[e[i].v] += f[u];
if (f[e[i].v] >= XZY) f[e[i].v] -= XZY;
if (--d[e[i].v] == 0) q[++t] = e[i].v;
}
}
}
int dfs(int u) { //记搜求从每个点出发走最短路的路径方案数
if (g[u] != -1) return g[u];
g[u] = 1;
for (int i = head[u]; i; i = nxt[i])
if (dis[u] + e[i].w == dis[e[i].v]) {
g[u] += dfs(e[i].v);
if (g[u] >= XZY) g[u] -= XZY;
}
return g[u];
}
int main() {
n = read(), m = read();
for (int i = 1; i <= m; ++i) {
int u = read(), v = read(), w = read();
adde(u, v, w, i);
}
for (int i = 1; i <= n; ++i) { //以点i为起点
dij(i);
memset(f, 0, sizeof f);
memset(g, -1, sizeof g);
bfs(i);
for (int j = 1; j <= n; ++j) dfs(j);
for (int u = 1; u <= n; ++u)
for (int i = head[u]; i; i = nxt[i])
if (dis[u] + e[i].w == dis[e[i].v])
ans[e[i].id] = (1LL * f[u] * g[e[i].v] + ans[e[i].id]) % XZY;
}
for (int i = 1; i <= m; ++i) printf("%d\n", ans[i]);
return 0;
}