2008-12-11
SRM428 Div1 Medium: TheLongPalindrome
- 奇数長のpalindromeは、それより1つ長い(偶数長の)palindromeと同数のパターンを持つ
- 剰余計算
まずはナイーブな実装。
#include <vector> using namespace std; const long long H = 1234567891LL; inline long long sub_(long long a, long long b) { return (a + H - (b % H)) % H; } inline long long add_(long long a, long long b) { return (a + b) % H; } inline long long mul_(long long a, long long b) { return ((long long)(a % H) * (b % H)) % H; } long long c_(int n, int r) { if (n == 0 || r == 0 || r == n) return 1; if (r > n-r) r = n-r; return (c_(n-1,r-1) * n / r) % H; } long long expt_(int n, int k) { // n ^ k long long p = 1LL; for (int i=0; i<k; i++) p = mul_(p, n); return p; } long long fac_(int n) { // n ! long long p = 1LL; for (int k=n; k>1; k--) p = mul_(p, k); return p; } class TheLongPalindrome { private: long long f_(int k, int len) { if (k == 1) return 1; if (k == len) return fac_(k); long long t = 0; for (int j=k,pm=1; j>=1; j--,pm=-pm) t = (t + pm * expt_(j,len) * c_(k,j)) % H; return t; } public: int count(int n, int k) { vector<int> expts(27,1); int h = (n + 1) / 2, m = n % 2; if (k > h) k = h; long long c = 0LL; for (int len=1; len<=h; len++) { // i:何文字長? int k_ = min(len,k); // 字種数 long long o = 0LL; for (int j=1; j<=k_; j++) o = add_(o, mul_(c_(26,j), f_(j,len))); // o += 26Cj x f(j,len) c = (len == h && m == 1)? add_(c,o) : add_(c,o*2); } return (int)c; } };
- サンプルケース4つは通るが、n=1000000000とか渡したら死ぬのは目に見えている
- 長さlen/2(奇数長の場合は(len+1)/2)の文字列で、1〜k種類の文字が使われる、とは言っても len/2 < k の場合、一度に使える文字数はlen/2種類
- というわけで len/2 = k の上と下で場合分け
- 累乗の計算を速いやつに置き換える
ここで関数 f_() がどのように展開されるかを分析し、パラメータから結果が直接得られるような式が書けないか考える。結論から言うと、len/2 > k の部分については、len/2 = [k+1..n/2] の範囲で
26C1 x { 25C0 - 25C1 + 25C2 - ... ± 25C(k-1) } x 1^(len/2)
- 26C2 x { 24C0 - 24C1 + 24C2 - ... ± 24C(k-2) } x 2^(len/2)
- ...
- 26Ck x k^(len/2)
の総和を求めればよい。(len/2 <= k の部分は最初のやり方で構わない)
ここで係数 26C1 x { 25C0 - 25C1 + 25C2 - ... ± 25C(k-1) } の部分は文字列長に関わらず一定なので計算は一度だけ行い、これに Σk^(len/2) を掛け合わせる。等比数列の和とか・・・20年来使っていない式を活用。計算量は...O(k^2*log(len))かな。
最終的なコードはこんな感じ:
#include <vector> using namespace std; const long long H = 1234567891LL; inline long long sub_(long long a, long long b) { return (a + H - (b % H)) % H; } inline long long add_(long long a, long long b) { return (a + b) % H; } inline long long mul_(long long a, long long b) { return ((a % H) * (b % H)) % H; } long long fast_expt(long long r, long long n) { // r^n % H if (n == 1) return r; if (n % 2 == 0) return fast_expt(mul_(r,r), n/2); else return mul_(fast_expt(r,n-1), r); } inline long long div_(long long a, long long b) { return mul_(a, fast_expt(b,H-2)); } long long geo(int r, int n) { // 1,r,r^2,...,r^n if (r == 1) return n % H; return div_(fast_expt(r,n+1)-1, r-1); } long long C(int n, int r) { if (n == 0 || r == 0 || r == n) return 1; if (r > n-r) r = n-r; return C(n-1,r-1) * n / r; } long long fac_(int n) { // n ! long long p = 1LL; for (int k=n; k>1; k--) p = mul_(p, k); return p; } vector<long long> coeffs_(int k) { vector<long long> cs(k+1, 1LL); cs[0] = 0; for (int j=1; j<=k; j++) { long long t = 0; for (int i=0,pm=1; i<=k-j; i++,pm=-pm) { t += C(26-j, i) * pm; } cs[j] = C(26, j) * t; } return cs; } class TheLongPalindrome { long long f_(int k, int len) { if (k == 1) return 1; if (k == len) return fac_(k); long long t = 0; for (int j=k,pm=1; j>=1; j--,pm=-pm) { if (pm >= 0) t = add_(t, fast_expt(j,len) * C(k,j)); else t = sub_(t, fast_expt(j,len) * C(k,j)); } return t; } public: int count(int n, int k) { int h = (n + 1) / 2, m = n % 2; if (k > h) k = h; vector<long long> expts(k+1, 1LL); expts[0] = 0; vector<long long> coeffs = coeffs_(k); long long c = 0LL; for (int len=1; len<=k; len++) { int k_ = len; long long o = 0LL; for (int j=1; j<=k_; j++) o = add_(o, mul_(C(26,j), f_(j,len))); // o += 26Cj x f(j,len) c = (len == h && m == 1)? add_(c,o) : add_(c,o*2); } if (h > k) { long long o = 0LL; for (int r=1; r<=k; r++) { long long co = coeffs[r]; if (co >= 0) o = add_(o, mul_(co, sub_(geo(r,h-1), geo(r,k)))); else o = sub_(o, mul_(-co, sub_(geo(r,h-1), geo(r,k)))); } c = add_(c, o*2); o = 0LL; for (int r=1; r<=k; r++) { long long co = coeffs[r]; if (co >= 0) o = add_(o, mul_(co, fast_expt(r,h))); else o = sub_(o, mul_(-co, fast_expt(r,h))); } c = (m == 1) ? add_(c, o) : add_(c, o*2); } return (int)c; } };
最後の要素で、全体長が偶数の場合の加算を忘れていたためにCase 3の数値が合わず、剰余演算のミスかと思って小一時間悩んだ。
このコードでは、n=1000000000な最悪ケースでも(コンテストサーバ時間で)3msec以内で演算が終了する。
当然ながら問題は、このコードを75分の間に書いてsubmitできるか、である・・・><
追記
(r^(n+1)-1)/(r-1)%1234567891 の計算で使うdiv_() のコードは、fast_expt()で法にあたる数値を引数の1つとして余分に取るものを使った別の書き方をしていた。が剰余演算のミスかと思って悩んだ間に差し替えた。原因はここではなかったが良い勉強になった。よい機会なので剰余系のライブラリを整備しておこう。