Hatena::Grouptopcoder

naoya_t@topcoder RSSフィード

2008-12-11

SRM428 Div1 Medium: TheLongPalindrome

| 18:09 | SRM428 Div1 Medium: TheLongPalindrome - naoya_t@topcoder を含むブックマーク はてなブックマーク - SRM428 Div1 Medium: TheLongPalindrome - naoya_t@topcoder SRM428 Div1 Medium: TheLongPalindrome - naoya_t@topcoder のブックマークコメント

  • 奇数長のpalindromeは、それより1つ長い(偶数長の)palindromeと同数のパターンを持つ
  • 剰余計算

まずはナイーブな実装。

#include <vector>
using namespace std;

const long long H = 1234567891LL;
inline long long sub_(long long a, long long b) {
  return (a + H - (b % H)) % H;
}
inline long long add_(long long a, long long b) {
  return (a + b) % H;
}
inline long long mul_(long long a, long long b) {
  return ((long long)(a % H) * (b % H)) % H;
}

long long c_(int n, int r)
{
  if (n == 0 || r == 0 || r == n) return 1;
  if (r > n-r) r = n-r;
  return (c_(n-1,r-1) * n / r) % H;
}
long long expt_(int n, int k) { // n ^ k
  long long p = 1LL;
  for (int i=0; i<k; i++) p = mul_(p, n);
  return p;
}
long long fac_(int n) { // n !
  long long p = 1LL;
  for (int k=n; k>1; k--) p = mul_(p, k);
  return p;
}

class TheLongPalindrome {
private:
  long long f_(int k, int len) {
    if (k == 1) return 1;
    if (k == len) return fac_(k);
    long long t = 0;
    for (int j=k,pm=1; j>=1; j--,pm=-pm)
      t = (t + pm * expt_(j,len) * c_(k,j)) % H;
    return t;
  }

public:
  int count(int n, int k) {
    vector<int> expts(27,1);

    int h = (n + 1) / 2, m = n % 2;
    if (k > h) k = h;

    long long c = 0LL;
    for (int len=1; len<=h; len++) {
      // i:何文字長?
      int k_ = min(len,k); // 字種数
      long long o = 0LL;
      for (int j=1; j<=k_; j++)
        o = add_(o, mul_(c_(26,j), f_(j,len))); // o += 26Cj x f(j,len)
      c = (len == h && m == 1)? add_(c,o) : add_(c,o*2);
    }
    return (int)c;
  }
};
  • サンプルケース4つは通るが、n=1000000000とか渡したら死ぬのは目に見えている
  • 長さlen/2(奇数長の場合は(len+1)/2)の文字列で、1〜k種類の文字が使われる、とは言っても len/2 < k の場合、一度に使える文字数はlen/2種類
  • というわけで len/2 = k の上と下で場合分け
  • 累乗の計算を速いやつに置き換える

ここで関数 f_() がどのように展開されるかを分析し、パラメータから結果が直接得られるような式が書けないか考える。結論から言うと、len/2 > k の部分については、len/2 = [k+1..n/2] の範囲で

26C1 x { 25C0 - 25C1 + 25C2 - ... ± 25C(k-1) } x 1^(len/2)

  1. 26C2 x { 24C0 - 24C1 + 24C2 - ... ± 24C(k-2) } x 2^(len/2)
  2. ...
  3. 26Ck x k^(len/2)

の総和を求めればよい。(len/2 <= k の部分は最初のやり方で構わない)

ここで係数 26C1 x { 25C0 - 25C1 + 25C2 - ... ± 25C(k-1) } の部分は文字列長に関わらず一定なので計算は一度だけ行い、これに Σk^(len/2) を掛け合わせる。等比数列の和とか・・・20年来使っていない式を活用。計算量は...O(k^2*log(len))かな。


最終的なコードはこんな感じ:

#include <vector>
using namespace std;

const long long H = 1234567891LL;

inline long long sub_(long long a, long long b) {
  return (a + H - (b % H)) % H;
}
inline long long add_(long long a, long long b) {
  return (a + b) % H;
}
inline long long mul_(long long a, long long b) {
  return ((a % H) * (b % H)) % H;
}

long long fast_expt(long long r, long long n) { // r^n % H
  if (n == 1) return r;
  if (n % 2 == 0)
    return fast_expt(mul_(r,r), n/2);
  else
    return mul_(fast_expt(r,n-1), r);
}

inline long long div_(long long a, long long b) {
  return mul_(a, fast_expt(b,H-2));
}

long long geo(int r, int n) {
  // 1,r,r^2,...,r^n
  if (r == 1) return n % H;

  return div_(fast_expt(r,n+1)-1, r-1);
}

long long C(int n, int r)
{
  if (n == 0 || r == 0 || r == n) return 1;
  if (r > n-r) r = n-r;
  return C(n-1,r-1) * n / r;
}

long long fac_(int n) { // n !
  long long p = 1LL;
  for (int k=n; k>1; k--) p = mul_(p, k);
  return p;
}

vector<long long> coeffs_(int k) {
  vector<long long> cs(k+1, 1LL); cs[0] = 0;
  
  for (int j=1; j<=k; j++) {
    long long t = 0;
    for (int i=0,pm=1; i<=k-j; i++,pm=-pm) {
      t += C(26-j, i) * pm;
    }
    cs[j] = C(26, j) * t;
  }
  return cs;
}

class TheLongPalindrome {
  long long f_(int k, int len) {
    if (k == 1) return 1;
    if (k == len) return fac_(k);
    
    long long t = 0;
    for (int j=k,pm=1; j>=1; j--,pm=-pm) {
      if (pm >= 0)
        t = add_(t, fast_expt(j,len) * C(k,j));
      else
        t = sub_(t, fast_expt(j,len) * C(k,j));
    }
    return t;
  }

public:
  int count(int n, int k) {
    int h = (n + 1) / 2, m = n % 2;
    if (k > h) k = h;

    vector<long long> expts(k+1, 1LL); expts[0] = 0;
    vector<long long> coeffs = coeffs_(k);

    long long c = 0LL;
    for (int len=1; len<=k; len++) {
      int k_ = len;
      long long o = 0LL;
      for (int j=1; j<=k_; j++)
        o = add_(o, mul_(C(26,j), f_(j,len))); // o += 26Cj x f(j,len)
      c = (len == h && m == 1)? add_(c,o) : add_(c,o*2);
    }

    if (h > k) {
      long long o = 0LL;
      for (int r=1; r<=k; r++) {
        long long co = coeffs[r];

        if (co >= 0)
          o = add_(o, mul_(co, sub_(geo(r,h-1), geo(r,k))));
        else
          o = sub_(o, mul_(-co, sub_(geo(r,h-1), geo(r,k))));
      }
      c = add_(c, o*2);

      o = 0LL;
      for (int r=1; r<=k; r++) {
        long long co = coeffs[r];
        if (co >= 0)
          o = add_(o, mul_(co, fast_expt(r,h)));
        else
          o = sub_(o, mul_(-co, fast_expt(r,h)));
      }
      c = (m == 1) ? add_(c, o) : add_(c, o*2);
    }

    return (int)c;
  }
};

最後の要素で、全体長が偶数の場合の加算を忘れていたためにCase 3の数値が合わず、剰余演算のミスかと思って小一時間悩んだ。

このコードでは、n=1000000000な最悪ケースでも(コンテストサーバ時間で)3msec以内で演算が終了する。

当然ながら問題は、このコードを75分の間に書いてsubmitできるか、である・・・><

追記

(r^(n+1)-1)/(r-1)%1234567891 の計算で使うdiv_() のコードは、fast_expt()で法にあたる数値を引数の1つとして余分に取るものを使った別の書き方をしていた。が剰余演算のミスかと思って悩んだ間に差し替えた。原因はここではなかったが良い勉強になった。よい機会なので剰余系のライブラリを整備しておこう。

トラックバック - https://topcoder-g-hatena-ne-jp.jag-icpc.org/n4_t/20081211