There are some well-known formulas
n∑i=1i=1+2+⋯+n=n∗(n+1)2
n∑i=1i2=12+22+⋯+n2=n∗(n+1)∗(2n+1)6
n∑i=1i3=13+23+⋯+n3=(n∗(n+1)2)2
Then what is the value of the following sum of the k-th power?
n∑i=1ik=1k+2k+⋯+nkmod109+7
given 1<=n<=109 and 0<=k<=106
The target sum will be a degree (k+1) polynomial and we can interpolate the answer with (k+2) data points, i.e. degree(f)+1 points. In order to find the data points, we need to calculate f(0)=0, f(x)=f(x−1)+xk.
If there is less than k+2 data points, we can calculate the answer directly.
if (x <= k + 1) {
int s = 0;
for (int i = 1; i <= x; i++) {
s = (s + qpow(i, k)) % mod;
}
return s;
}
Otherwise, let's say f(x1)=y1,f(x2)=y2,...,f(xn)=yn and f is the unique (n−1) degree polynomial and we are interested in f(x)=∑ni=1fi(x) where fi(x)=yi∗∏nj=1,j!=ix−xjxi−xj. Therefore, we have our Lagrange interpolation as f(x)=∑ni=1yi∏nj=1,j!=ix−xjxi−xj.
However, we need O(n2) to calculate f(x). Let's substitute xi=i and yi=f(i) and we got f(x)=∑ni=1f(i)∏nj=1,j!=ix−j∏nj=1,j!=ii−j. What can we do for numerator and denonminator here? For numerator, we can per-calculate the prefix and suffix product of x, for each i, we can calculate the nubmerator in O(1).
n∏j=1,j!=ix−xj=[(x−1)(x−2)...(x−(i−1))]∗[(x−(i+1))∗(x−(i+2))...(x−n)]
vector<int> pre(k + 2), suf(k + 2);
pre[0] = x;
suf[k + 1] = x - (k + 1);
for (int i = 1; i <= k; i++) pre[i] = pre[i - 1] * (x - i) % mod;
for (int i = k; i >= 1; i--) suf[i] = suf[i + 1] * (x - i) % mod;
For denominator, we can precompute the factorials using their inverse in O(1) also.
n∏j=1,j!=ii−j=[(i−1)(i−2)(i−3)...(i−(i−1))]∗[i−(i+1)(i−(i+2)...(i−n)]=(−1)n−i(n−i)!(i−1)!
int qpow(int base, int exp) {
int res = 1;
while (exp) {
if (exp & 1) res = (res * base) % mod;
base = (base * base) % mod;
exp >>= 1;
}
return res;
}
unordered_map<int, int> rv_m;
int rv(int x) {
if (rv_m.count(x)) {
return rv_m[x];
}
return rv_m[x] = qpow(x, mod - 2);
}
vector<int> inv(k + 2);
inv[0] = 1;
for (int i = 1; i <= k + 1; i++) inv[i] = inv[i - 1] * rv(i) % mod;
Overall, we can calculate f(x) in O(n).
Complete Code:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 1e9 + 7;
int qpow(int base, int exp) {
int res = 1;
while (exp) {
if (exp & 1) res = (res * base) % mod;
base = (base * base) % mod;
exp >>= 1;
}
return res;
}
unordered_map<int, int> rv_m;
int rv(int x) {
if (rv_m.count(x)) {
return rv_m[x];
}
return rv_m[x] = qpow(x, mod - 2);
}
int lagrange_interpolate(int x, int k, bool bf = false) {
if (k == 0) return x;
// find 1 ^ k + 2 ^ k + ... + x ^ k
// (k + 1) degree polynomial -> (k + 2) points
if (x <= k + 1 || bf) {
int s = 0;
for (int i = 1; i <= x; i++) {
s = (s + qpow(i, k)) % mod;
}
return s;
}
vector<int> pre(k + 2), suf(k + 2), inv(k + 2);
inv[0] = 1, pre[0] = x;
suf[k + 1] = x - (k + 1);
for (int i = 1; i <= k; i++) pre[i] = pre[i - 1] * (x - i) % mod;
for (int i = k; i >= 1; i--) suf[i] = suf[i + 1] * (x - i) % mod;
for (int i = 1; i <= k + 1; i++) inv[i] = inv[i - 1] * rv(i) % mod;
int ans = 0;
int yi = 0; // 0 ^ k + ~ i ^ k
int num, denom;
for (int i = 0; i <= k + 1; i++) {
yi = (yi + qpow(i, k)) % mod; // interpolate point: (i, yi)
if (i == 0) num = suf[1];
else if (i == k + 1) num = pre[k];
else num = pre[i - 1] * suf[i + 1] % mod; // numerator
denom = inv[i] * inv[k + 1 - i] % mod; // denominator
if ((i + k) & 1) ans += (yi * num % mod) * denom % mod;
else ans -= (yi * num % mod) * denom % mod;
ans = (ans % mod + mod) % mod;
}
return ans;
}
void solve() {
int n, k;
cin >> n >> k;
cout << lagrange_interpolate(n, k) << endl;
}
int32_t main() {
ios_base::sync_with_stdio(false), cin.tie(nullptr);
// int T; cin >> T;
// while(T--) solve();
solve();
return 0;
}