Loading [MathJax]/jax/output/HTML-CSS/jax.js

Friday, 31 December 2021

Lagrange Interpolation

There are some well-known formulas

ni=1i=1+2++n=n(n+1)2

ni=1i2=12+22++n2=n(n+1)(2n+1)6

ni=1i3=13+23++n3=(n(n+1)2)2

Then what is the value of the following sum of the k-th power?

ni=1ik=1k+2k++nkmod109+7

given 1<=n<=109 and 0<=k<=106

The target sum will be a degree (k+1) polynomial and we can interpolate the answer with (k+2) data points, i.e. degree(f)+1 points. In order to find the data points, we need to calculate f(0)=0, f(x)=f(x1)+xk.

If there is less than k+2 data points, we can calculate the answer directly.

if (x <= k + 1) {
  int s = 0;
  for (int i = 1; i <= x; i++) {
    s = (s + qpow(i, k)) % mod;
  }
  return s;
}

Otherwise, let's say f(x1)=y1,f(x2)=y2,...,f(xn)=yn and f is the unique (n1) degree polynomial and we are interested in f(x)=ni=1fi(x) where fi(x)=yinj=1,j!=ixxjxixj. Therefore, we have our Lagrange interpolation as f(x)=ni=1yinj=1,j!=ixxjxixj.

However, we need O(n2) to calculate f(x). Let's substitute xi=i and yi=f(i) and we got f(x)=ni=1f(i)nj=1,j!=ixjnj=1,j!=iij. What can we do for numerator and denonminator here? For numerator, we can per-calculate the prefix and suffix product of x, for each i, we can calculate the nubmerator in O(1).

nj=1,j!=ixxj=[(x1)(x2)...(x(i1))][(x(i+1))(x(i+2))...(xn)]

vector<int> pre(k + 2), suf(k + 2);
pre[0] = x;
suf[k + 1] = x - (k + 1);
for (int i = 1; i <= k; i++) pre[i] = pre[i - 1] * (x - i) % mod;
for (int i = k; i >= 1; i--) suf[i] = suf[i + 1] * (x - i) % mod;

For denominator, we can precompute the factorials using their inverse in O(1) also.

nj=1,j!=iij=[(i1)(i2)(i3)...(i(i1))][i(i+1)(i(i+2)...(in)]=(1)ni(ni)!(i1)!

int qpow(int base, int exp) {
  int res = 1;
  while (exp) {
    if (exp & 1) res = (res * base) % mod;
    base = (base * base) % mod;
    exp >>= 1;
  }
  return res;
}

unordered_map<int, int> rv_m;
int rv(int x) {
  if (rv_m.count(x)) {
    return rv_m[x];
  }
  return rv_m[x] = qpow(x, mod - 2);
}

vector<int> inv(k + 2);
inv[0] = 1;
for (int i = 1; i <= k + 1; i++) inv[i] = inv[i - 1] * rv(i) % mod;

Overall, we can calculate f(x) in O(n).

Complete Code:

#include <bits/stdc++.h>
using namespace std;

#define int long long
const int mod = 1e9 + 7;

int qpow(int base, int exp) {
  int res = 1;
  while (exp) {
    if (exp & 1) res = (res * base) % mod;
    base = (base * base) % mod;
    exp >>= 1;
  }
  return res;
}

unordered_map<int, int> rv_m;
int rv(int x) {
  if (rv_m.count(x)) {
    return rv_m[x];
  }
  return rv_m[x] = qpow(x, mod - 2);
}

int lagrange_interpolate(int x, int k, bool bf = false) {
  if (k == 0) return x;
  // find 1 ^ k + 2 ^ k + ... + x ^ k
  // (k + 1) degree polynomial -> (k + 2) points
  if (x <= k + 1 || bf) {
    int s = 0;
    for (int i = 1; i <= x; i++) {
      s = (s + qpow(i, k)) % mod;
    }
    return s;
  }
  vector<int> pre(k + 2), suf(k + 2), inv(k + 2);
  inv[0] = 1, pre[0] = x;
  suf[k + 1] = x - (k + 1);
  for (int i = 1; i <= k; i++) pre[i] = pre[i - 1] * (x - i) % mod;
  for (int i = k; i >= 1; i--) suf[i] = suf[i + 1] * (x - i) % mod;
  for (int i = 1; i <= k + 1; i++) inv[i] = inv[i - 1] * rv(i) % mod;
  int ans = 0;
  int yi = 0;  // 0 ^ k + ~ i ^ k
  int num, denom;
  for (int i = 0; i <= k + 1; i++) {
    yi = (yi + qpow(i, k)) % mod; // interpolate point: (i, yi)
    if (i == 0) num = suf[1];
    else if (i == k + 1) num = pre[k];
    else num = pre[i - 1] * suf[i + 1] % mod;  // numerator
    denom = inv[i] * inv[k + 1 - i] % mod;  // denominator
    if ((i + k) & 1) ans += (yi * num % mod) * denom % mod;
    else ans -= (yi * num % mod) * denom % mod;
    ans = (ans % mod + mod) % mod;
  }
  return ans;
}

void solve() {
  int n, k; 
  cin >> n >> k;
  cout << lagrange_interpolate(n, k) << endl;
}

int32_t main() {
  ios_base::sync_with_stdio(false), cin.tie(nullptr);

  // int T; cin >> T;
  // while(T--) solve();
  solve();
  return 0;
}

No comments:

Post a Comment

A Fun Problem - Math

# Problem Statement JATC's math teacher always gives the class some interesting math problems so that they don't get bored. Today t...