题目大意:求满足 $ax+by=c$,且 $x\in[l_1,r_1],y\in[l_2,r_2]$ 的整数解有多少对?
等式两边同时除以 $(a,b)$ 得
$exgcd$ 求解
可知 $x,y$ 的所有整数解为
然后计算边界
根据 $b’$ 的正负得出 $t$,其他边界同理,可以得出四组卡边界的解,选取其中两组符合情况的即可,注意如果有三组解符合条件,那么必有两组重复。
#include <cstdio>
#include <cmath>
typedef long long LL;
LL a, b, c, x, y, l1, r1, l2, r2, t1, t2;
LL ax, ay, bx, by, cx, cy, dx, dy, lx = 1e18, ly, rx = 1e18, ry;
void swap(LL &x, LL &y) {
x ^= y, y ^= x, x ^= y;
}
LL gcd(LL a, LL b) {
return b ? gcd(b, a % b) : a;
}
void exgcd(LL a, LL b, LL &x, LL &y) {
if (b == 0) x = 1, y = 0;
else exgcd(b, a % b, y, x), y -= a / b * x;
}
bool check(int x, int y) {
if (l1 <= x && x <= r1 && l2 <= y && y <= r2) return true;
return false;
}
void update(int x, int y) {
if (lx == 1e18) lx = x, ly = y;
else if (rx == 1e18) rx = x, ry = y;
else if (lx == rx && ly == ry) rx = x, ry = y;
}
int main() {
scanf("%lld%lld%lld%lld%lld%lld%lld", &a, &b, &c, &l1, &r1, &l2, &r2);
c = -c;
if (a == 0 && b == 0) {
if (c) puts("0");
else printf("%lld\n", (r1 - l1 + 1) * (r2 - l2 + 1));
return 0;
}
if (a == 0) {
if (c % b) puts("0");
else if (l2 <= c / b && c / b <= r2) printf("%lld\n", r1 - l1 + 1);
else puts("0");
return 0;
}
if (b == 0) {
if (c % a) puts("0");
else if (l1 <= c / a && c / a <= r1) printf("%lld\n", r2 - l2 + 1);
else puts("0");
return 0;
}
t1 = gcd(a, b);
if (c % t1) {
puts("0"); return 0;
}
a /= t1, b /= t1, c /= t1;
exgcd(a, b, x, y);
x *= c, y *= c;
if (b > 0) {
t1 = ceil((l1 - x) * 1.0 / b);
t2 = floor((r1 - x) * 1.0 / b);
} else {
t1 = floor((l1 - x) * 1.0 / b);
t2 = ceil((r1 - x) * 1.0 / b);
}
ax = x + b * t1, ay = y - a * t1;
bx = x + b * t2, by = y - a * t2;
if (a > 0) {
t1 = ceil((l2 - y) * 1.0 / a);
t2 = floor((r2 - y) * 1.0 / a);
} else {
t1 = floor((l2 - y) * 1.0 / a);
t2 = ceil((r2 - y) * 1.0 / a);
}
cx = x - b * t1, cy = y + a * t1;
dx = x - b * t2, dy = y + a * t2;
if (check(ax, ay)) update(ax, ay);
if (check(bx, by)) update(bx, by);
if (check(cx, cy)) update(cx, cy);
if (check(dx, dy)) update(dx, dy);
if (rx == 1e18) { puts("0"); return 0; }
if (lx > rx) swap(lx, rx);
if (ly > ry) swap(ly, ry);
if (b < 0) b = -b;
LL ans = lx == rx ? 1 : (rx - lx - 1) / b + 1; //这种向上取整法只适用于rx>lx
if ((x - lx) % b != 0 && (x - rx) % b != 0) --ans;
if ((x - lx) % b == 0 && (x - rx) % b == 0 && lx < rx) ++ans;
printf("%lld\n", ans);
return 0;
}
时间复杂度 $O()$