Friday 31 December 2021

Lagrange Interpolation

There are some well-known formulas $$ \sum_{i=1}^n i = 1 + 2 + \dots + n = \frac{n * (n + 1)}{2} $$ $$ \sum_{i=1}^n i^2 = 1^2 + 2^2 + \dots + n^2 = \frac{n * (n + 1) * (2n + 1)}{6} $$ $$ \sum_{i=1}^n i^3 = 1^3 + 2^3 + \dots + n^3 = (\frac{n * (n + 1)}{2}) ^ 2 $$ Then what is the value of the following sum of the k-th power? $$ \sum_{i=1}^n i^k = 1^k + 2^k + \dots + n^k \mod 10^9 + 7 $$ given $1 <= n <= 10^9$ and $0 <= k <= 10^6$ The target sum will be a degree $ (k + 1) $ polynomial and we can interpolate the answer with $ (k + 2) $ data points, i.e. $degree(f) + 1$ points. In order to find the data points, we need to calculate $f(0) = 0$, $f(x) = f(x - 1) + x ^ k$. If there is less than $k + 2$ data points, we can calculate the answer directly. ``` if (x <= k + 1) { int s = 0; for (int i = 1; i <= x; i++) { s = (s + qpow(i, k)) % mod; } return s; } ``` Otherwise, let's say $f(x_1) = y_1, f(x_2) = y_2, ..., f(x_n) = y_n$ and $f$ is the unique $(n - 1)$ degree polynomial and we are interested in $f(x) = \sum_{i=1}^n f_i(x)$ where $f_i(x) = y_i * \prod_{j=1, j!=i}^n \frac{x - x_j}{x_i - x_j} $. Therefore, we have our Lagrange interpolation as $ f(x) = \sum_{i=1}^n y^i \prod_{j=1, j!=i}^n \frac{x - x_j}{x_i - x_j} $. However, we need $O(n^2)$ to calculate $f(x)$. Let's substitute $x_i = i$ and $y_i = f(i)$ and we got $ f(x) = \sum_{i=1}^n f(i) \frac{\prod_{j=1, j!=i}^n x - j}{\prod_{j=1, j!=i}^n i - j} $. What can we do for numerator and denonminator here? For numerator, we can per-calculate the prefix and suffix product of $x$, for each $i$, we can calculate the nubmerator in $O(1)$. $$ \prod_{j=1, j!=i}^n x - x_j = [(x - 1)(x - 2)...(x-(i-1))] * [(x - (i + 1))*(x - (i + 2))...(x - n)] $$ ``` vector pre(k + 2), suf(k + 2); pre[0] = x; suf[k + 1] = x - (k + 1); for (int i = 1; i <= k; i++) pre[i] = pre[i - 1] * (x - i) % mod; for (int i = k; i >= 1; i--) suf[i] = suf[i + 1] * (x - i) % mod; ``` For denominator, we can precompute the factorials using their inverse in $O(1)$ also. $$ \prod_{j=1, j!=i}^n i - j = [(i - 1)(i - 2)(i - 3)...(i - (i - 1))] * [i - (i + 1)(i - (i + 2)...(i - n)] = (-1)^{n - i} (n - i)!(i - 1)! $$ ``` int qpow(int base, int exp) { int res = 1; while (exp) { if (exp & 1) res = (res * base) % mod; base = (base * base) % mod; exp >>= 1; } return res; } unordered_map<int, int> rv_m; int rv(int x) { if (rv_m.count(x)) { return rv_m[x]; } return rv_m[x] = qpow(x, mod - 2); } vector inv(k + 2); inv[0] = 1; for (int i = 1; i <= k + 1; i++) inv[i] = inv[i - 1] * rv(i) % mod; ``` Overall, we can calculate $f(x)$ in $O(n)$. Complete Code: ``` #include <bits/stdc++.h> using namespace std; #define int long long const int mod = 1e9 + 7; int qpow(int base, int exp) { int res = 1; while (exp) { if (exp & 1) res = (res * base) % mod; base = (base * base) % mod; exp >>= 1; } return res; } unordered_map<int, int> rv_m; int rv(int x) { if (rv_m.count(x)) { return rv_m[x]; } return rv_m[x] = qpow(x, mod - 2); } int lagrange_interpolate(int x, int k, bool bf = false) { if (k == 0) return x; // find 1 ^ k + 2 ^ k + ... + x ^ k // (k + 1) degree polynomial -> (k + 2) points if (x <= k + 1 || bf) { int s = 0; for (int i = 1; i <= x; i++) { s = (s + qpow(i, k)) % mod; } return s; } vector<int> pre(k + 2), suf(k + 2), inv(k + 2); inv[0] = 1, pre[0] = x; suf[k + 1] = x - (k + 1); for (int i = 1; i <= k; i++) pre[i] = pre[i - 1] * (x - i) % mod; for (int i = k; i >= 1; i--) suf[i] = suf[i + 1] * (x - i) % mod; for (int i = 1; i <= k + 1; i++) inv[i] = inv[i - 1] * rv(i) % mod; int ans = 0; int yi = 0; // 0 ^ k + ~ i ^ k int num, denom; for (int i = 0; i <= k + 1; i++) { yi = (yi + qpow(i, k)) % mod; // interpolate point: (i, yi) if (i == 0) num = suf[1]; else if (i == k + 1) num = pre[k]; else num = pre[i - 1] * suf[i + 1] % mod; // numerator denom = inv[i] * inv[k + 1 - i] % mod; // denominator if ((i + k) & 1) ans += (yi * num % mod) * denom % mod; else ans -= (yi * num % mod) * denom % mod; ans = (ans % mod + mod) % mod; } return ans; } void solve() { int n, k; cin >> n >> k; cout << lagrange_interpolate(n, k) << endl; } int32_t main() { ios_base::sync_with_stdio(false), cin.tie(nullptr); // int T; cin >> T; // while(T--) solve(); solve(); return 0; } ```

No comments:

Post a Comment

A Fun Problem - Math

# Problem Statement JATC's math teacher always gives the class some interesting math problems so that they don't get bored. Today t...