😊
ARC 106 | D - Powers
問題
考えたこと
mintを使って以下のように貪欲にやるとTLEになる。オーダを下げる必要がある。
using mint = modint998244353;
int main() {
int n, k;
cin >> n >> k;
vector<int> a(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}
for (int x = 1; x <= k; x++) {
mint res;
for (int l = 0; l < n - 1; l++) {
for (int r = l + 1; r < n; r++) {
mint tmp = mint(a[l] + a[r]).pow(x);
res += tmp;
}
}
cout << res.val() << endl;
}
}
の式は、以下の黒塗りの部分を計算していることになる。
(これ以降の式はmodを省略)
このままだと計算しにくいので以下のように赤と黒の部分に分けて考える。
問題の式は
ここでSを以下のようにおいてSを効率的にもとめる方法を考える。
なお二項係数は以下のように展開できる。
よってSは次のようになる。
定数はシグマの前に移項できるので以下のようになる。
シグマの順番は入れ替えれるので以下のようにできる。
jに対して
以下のようにVを定義する。
するとSは以下に変形できる。
ここですべての
コード
#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using namespace atcoder;
using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
const int MOD = 1e9 + 7;
using mint = modint998244353;
// 参考: https://qiita.com/drken/items/f2ea4b58b0d21621bd51
template <class T>
struct bicoef {
vector<T> fact_, inv_, finv_;
constexpr bicoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
int MOD = fact_[0].mod();
for (int i = 2; i < n; i++) {
fact_[i] = fact_[i - 1] * i;
inv_[i] = -inv_[MOD % i] * (MOD / i);
finv_[i] = finv_[i - 1] * inv_[i];
}
}
constexpr T com(int n, int k) const noexcept {
if (n < k || n < 0 || k < 0) return 0;
return fact_[n] * finv_[k] * finv_[n - k];
}
constexpr T fact(int n) const noexcept {
if (n < 0) return 0;
return fact_[n];
}
constexpr T inv(int n) const noexcept {
if (n < 0) return 0;
return inv_[n];
}
constexpr T finv(int n) const noexcept {
if (n < 0) return 0;
return finv_[n];
}
};
int main() {
int N, K;
cin >> N >> K;
const int maxK = K + 1;
bicoef<mint> bc(K + 1);
vector<int> a(N);
for (int i = 0; i < N; i++) {
cin >> a[i];
}
//ak[k][i] = ai^k
vector<vector<mint>> ak(maxK, vector<mint>(N));
for (int k = 0; k < maxK; k++) {
for (int i = 0; i < N; i++) {
if (k == 0) {
ak[k][i] = 1;
continue;
}
ak[k][i] = ak[k - 1][i] * a[i];
}
}
// v[k] = kに対しての \sum_{i}^{N}\frac{A_i^k}{k!}
vector<mint> v(maxK);
for (int k = 0; k < maxK; k++) {
for (int i = 0; i < N; i++) {
v[k] += ak[k][i] * bc.finv(k);
}
}
for (int X = 1; X <= K; X++) {
mint ans = 0;
for (int k = 0; k <= X; k++) {
int i = X - k;
ans += v[i] * v[k];
}
ans *= bc.fact(X);
mint d = 0;
for (int i = 0; i < N; i++) {
d += ak[X][i];
}
d *= mint(2).pow(X);
ans -= d;
ans /= 2;
cout << ans.val() << endl;
}
}
感想
解説見ればわかるけど次同じような問題でもできる気がしない。
あと、実装がACするまでに結構時間かかった。。。大変。
Discussion