ARC 107 | C - Shuffle Permutation

2 min read読了の目安(約2600字

問題

https://atcoder.jp/contests/arc107/tasks/arc107_c

考えたこと

入力例1について考える。列が交換できるかどうか考えると以下のように赤と緑が交換可能である。
何回でも交換できるので、該当の列は交換可能な列の中ではどの列にも移動できることができる。

交換可能な列がA列あるとするとA!通りの組み合わせが考えられる。

ここで以下のような行列があるとする。ここで列0,1は交換でき列2,3も交換できるが列0,2などは交換できない。

よって交換できるグループを探しその列の数でできる組み合わせを行と列にわたってかけ合わせれば答えとなる。

コード

実装時のTips

  • N \leq 50なのでbicoefは使わなくても間に合うはず
#include <bits/stdc++.h>

#include <atcoder/all>

using namespace std;
using namespace atcoder;
using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
const int MOD = 1e9 + 7;

using mint = modint998244353;

template <class T>
struct bicoef {
  vector<T> fact_, inv_, finv_;
  constexpr bicoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
    int MOD = fact_[0].mod();
    for (int i = 2; i < n; i++) {
      fact_[i] = fact_[i - 1] * i;
      inv_[i] = -inv_[MOD % i] * (MOD / i);
      finv_[i] = finv_[i - 1] * inv_[i];
    }
  }
  constexpr T com(int n, int k) const noexcept {  // nCk
    if (n < k || n < 0 || k < 0) return 0;
    return fact_[n] * finv_[k] * finv_[n - k];
  }
  constexpr T fact(int n) const noexcept {  // n!
    if (n < 0) return 0;
    return fact_[n];
  }
  constexpr T inv(int n) const noexcept {
    if (n < 0) return 0;
    return inv_[n];
  }
  constexpr T finv(int n) const noexcept {  // 1/n!
    if (n < 0) return 0;
    return finv_[n];
  }
};

int main() {
  ll N, K;
  cin >> N >> K;
  vector<vector<int>> a(N, vector<int>(N));
  for (int i = 0; i < N; i++) {
    for (int j = 0; j < N; j++) {
      cin >> a[i][j];
    }
  }
  bicoef<mint> bc(51);

  dsu rows(N);
  for (int i = 0; i < N - 1; i++) {
    for (int j = i + 1; j < N; j++) {
      bool same = true;
      for (int k = 0; k < N; k++) {
        if (a[k][i] + a[k][j] > K) {
          same = false;
          break;
        }
      }
      if (same) {
        rows.merge(i, j);
      }
    }
  }
  dsu columns(N);
  for (int i = 0; i < N - 1; i++) {
    for (int j = i + 1; j < N; j++) {
      bool same = true;
      for (int k = 0; k < N; k++) {
        if (a[i][k] + a[j][k] > K) {
          same = false;
          break;
        }
      }
      if (same) {
        columns.merge(i, j);
      }
    }
  }
  mint ans = 1;
  for (auto group : rows.groups()) {
    ans *= bc.fact(group.size());
  }
  for (auto group : columns.groups()) {
    ans *= bc.fact(group.size());
  }
  cout << ans.val() << endl;
}

参考