杜教筛习题:hdu 5608 function

    #数学 #数论 #莫比乌斯反演 #杜教筛

    题目描述

    设数论函数$f$满足$n^2−3n+2=\sum_{d \mid n}f(d) $

    求$\sum_{i=1}^{n}f(i)$

    答案对$10^9+7$取模

    题解

    回忆杜教筛的一个表达形式

    设$g(n)=1,h(n)=n^2-3n+2$,则有$f \times g = h$,即$f \times 1 = h$

    所以

    因为$f \times 1 = h$,所以$f$是积性函数,且$f=h \times \mu$,即$f(n)=\sum_{d \mid n}h(d)\mu(\frac{n}{d})$

    不妨对于每一个$\mu(i)$,枚举$j$,将$f(i \times j)$加上$\mu(i)h(j)$

    之后就是递归计算了

    代码

    #include "bits/stdc++.h"
    using namespace std;
    #define DEBUG printf("Passing [%s] in LINE %d\n",__FUNCTION__,__LINE__)
    typedef long long ll;
    typedef pair<int, int> pii;
    const int N = 1e6, mod = 1e9 + 7;
    
    ll ans, inv6, inv2, n;
    
    int mu[N + 10], pri[N + 10], tot, S[N + 10], dn[N + 10], vis[N + 10], f[N + 10];
    
    ll pw(ll a, ll b) {
        ll r = 1;
        for( ; b ; b >>= 1, a = a * a % mod) if(b & 1) r = r * a % mod;
        return r;
    }
    
    int h(ll n) {
        return (n * n % mod - 3 * n % mod + 2) % mod;
    }
    
    int H(ll n) {
        return (((
            n * (n + 1) % mod * (2 * n + 1) % mod * inv6 % mod
        ) - (
            3 * n % mod * (1 + n) % mod * inv2 % mod
        )) % mod + (
            2 * n % mod
        )) % mod;
    }
    
    map<ll, ll> val;
    ll F(ll n) {
        if(n <= N) return f[n];
        else if(val.find(n) != val.end()) return val[n];
        else {
            ll res = H(n);
            for(ll i = 2, j ; i <= n ; i = j + 1) {
                j = n / (n / i);
                res = (res - F(n / i) * (j - i + 1) % mod) % mod;
            }
            return val[n] = res;
        }
    }
    
    void sol() {
        scanf("%lld", &n);
        printf("%d\n", (F(n) % mod + mod) % mod);
    }
    
    int main() {
        inv6 = pw(6, mod - 2), inv2 = pw(2, mod - 2);
        mu[1] = 1;
        for(int i = 2 ; i <= N ; ++ i) {
            if(!vis[i]) pri[++ tot] = i, mu[i] = -1;
            for(int j = 1 ; j <= tot && i * pri[j] <= N ; ++ j) {
                vis[i * pri[j]] = 1;
                if(i % pri[j] == 0) break;
                mu[i * pri[j]] = -mu[i];
            }
        }
        for(int i = 1 ; i <= N ; ++ i) {
            for(int j = 1 ; i * j <= N ; ++ j) {
                f[i * j] = ((ll) f[i * j] + mu[j] * h(i) % mod) % mod;
            }
            f[i] = ((ll) f[i - 1] + f[i]) % mod;
        }
        int T; scanf("%d", &T); while(T --) sol();
    }