Friday, 31 December 2021
Lagrange Interpolation
There are some well-known formulas
$$
\sum_{i=1}^n i = 1 + 2 + \dots + n = \frac{n * (n + 1)}{2}
$$
$$
\sum_{i=1}^n i^2 = 1^2 + 2^2 + \dots + n^2 = \frac{n * (n + 1) * (2n + 1)}{6}
$$
$$
\sum_{i=1}^n i^3 = 1^3 + 2^3 + \dots + n^3 = (\frac{n * (n + 1)}{2}) ^ 2
$$
Then what is the value of the following sum of the k-th power?
$$
\sum_{i=1}^n i^k = 1^k + 2^k + \dots + n^k \mod 10^9 + 7
$$
given $1 <= n <= 10^9$ and $0 <= k <= 10^6$
The target sum will be a degree $ (k + 1) $ polynomial and we can interpolate the answer with $ (k + 2) $ data points, i.e. $degree(f) + 1$ points. In order to find the data points, we need to calculate $f(0) = 0$, $f(x) = f(x - 1) + x ^ k$.
If there is less than $k + 2$ data points, we can calculate the answer directly.
```
if (x <= k + 1) {
int s = 0;
for (int i = 1; i <= x; i++) {
s = (s + qpow(i, k)) % mod;
}
return s;
}
```
Otherwise, let's say $f(x_1) = y_1, f(x_2) = y_2, ..., f(x_n) = y_n$ and $f$ is the unique $(n - 1)$ degree polynomial and we are interested in $f(x) = \sum_{i=1}^n f_i(x)$ where $f_i(x) = y_i * \prod_{j=1, j!=i}^n \frac{x - x_j}{x_i - x_j} $. Therefore, we have our Lagrange interpolation as $ f(x) = \sum_{i=1}^n y^i \prod_{j=1, j!=i}^n \frac{x - x_j}{x_i - x_j} $.
However, we need $O(n^2)$ to calculate $f(x)$. Let's substitute $x_i = i$ and $y_i = f(i)$ and we got $ f(x) = \sum_{i=1}^n f(i) \frac{\prod_{j=1, j!=i}^n x - j}{\prod_{j=1, j!=i}^n i - j} $. What can we do for numerator and denonminator here? For numerator, we can per-calculate the prefix and suffix product of $x$, for each $i$, we can calculate the nubmerator in $O(1)$.
$$
\prod_{j=1, j!=i}^n x - x_j = [(x - 1)(x - 2)...(x-(i-1))] * [(x - (i + 1))*(x - (i + 2))...(x - n)]
$$
```
vector pre(k + 2), suf(k + 2);
pre[0] = x;
suf[k + 1] = x - (k + 1);
for (int i = 1; i <= k; i++) pre[i] = pre[i - 1] * (x - i) % mod;
for (int i = k; i >= 1; i--) suf[i] = suf[i + 1] * (x - i) % mod;
```
For denominator, we can precompute the factorials using their inverse in $O(1)$ also.
$$
\prod_{j=1, j!=i}^n i - j = [(i - 1)(i - 2)(i - 3)...(i - (i - 1))] * [i - (i + 1)(i - (i + 2)...(i - n)]
= (-1)^{n - i} (n - i)!(i - 1)!
$$
```
int qpow(int base, int exp) {
int res = 1;
while (exp) {
if (exp & 1) res = (res * base) % mod;
base = (base * base) % mod;
exp >>= 1;
}
return res;
}
unordered_map<int, int> rv_m;
int rv(int x) {
if (rv_m.count(x)) {
return rv_m[x];
}
return rv_m[x] = qpow(x, mod - 2);
}
vector inv(k + 2);
inv[0] = 1;
for (int i = 1; i <= k + 1; i++) inv[i] = inv[i - 1] * rv(i) % mod;
```
Overall, we can calculate $f(x)$ in $O(n)$.
Complete Code:
```
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 1e9 + 7;
int qpow(int base, int exp) {
int res = 1;
while (exp) {
if (exp & 1) res = (res * base) % mod;
base = (base * base) % mod;
exp >>= 1;
}
return res;
}
unordered_map<int, int> rv_m;
int rv(int x) {
if (rv_m.count(x)) {
return rv_m[x];
}
return rv_m[x] = qpow(x, mod - 2);
}
int lagrange_interpolate(int x, int k, bool bf = false) {
if (k == 0) return x;
// find 1 ^ k + 2 ^ k + ... + x ^ k
// (k + 1) degree polynomial -> (k + 2) points
if (x <= k + 1 || bf) {
int s = 0;
for (int i = 1; i <= x; i++) {
s = (s + qpow(i, k)) % mod;
}
return s;
}
vector<int> pre(k + 2), suf(k + 2), inv(k + 2);
inv[0] = 1, pre[0] = x;
suf[k + 1] = x - (k + 1);
for (int i = 1; i <= k; i++) pre[i] = pre[i - 1] * (x - i) % mod;
for (int i = k; i >= 1; i--) suf[i] = suf[i + 1] * (x - i) % mod;
for (int i = 1; i <= k + 1; i++) inv[i] = inv[i - 1] * rv(i) % mod;
int ans = 0;
int yi = 0; // 0 ^ k + ~ i ^ k
int num, denom;
for (int i = 0; i <= k + 1; i++) {
yi = (yi + qpow(i, k)) % mod; // interpolate point: (i, yi)
if (i == 0) num = suf[1];
else if (i == k + 1) num = pre[k];
else num = pre[i - 1] * suf[i + 1] % mod; // numerator
denom = inv[i] * inv[k + 1 - i] % mod; // denominator
if ((i + k) & 1) ans += (yi * num % mod) * denom % mod;
else ans -= (yi * num % mod) * denom % mod;
ans = (ans % mod + mod) % mod;
}
return ans;
}
void solve() {
int n, k;
cin >> n >> k;
cout << lagrange_interpolate(n, k) << endl;
}
int32_t main() {
ios_base::sync_with_stdio(false), cin.tie(nullptr);
// int T; cin >> T;
// while(T--) solve();
solve();
return 0;
}
```
Subscribe to:
Post Comments (Atom)
A Fun Problem - Math
# Problem Statement JATC's math teacher always gives the class some interesting math problems so that they don't get bored. Today t...
-
SHA stands for Secure Hashing Algorithm and 2 is just a version number. SHA-2 revises the construction and the big-length of the signature f...
-
Contest Link: [https://www.e-olymp.com/en/contests/19775](https://www.e-olymp.com/en/contests/19775) Full Solution: [https://github.com/...
No comments:
Post a Comment