2008-12-11
SRM428 Div1 Medium: TheLongPalindrome
- 奇数長のpalindromeは、それより1つ長い(偶数長の)palindromeと同数のパターンを持つ
- 剰余計算
まずはナイーブな実装。
#include <vector>
using namespace std;
const long long H = 1234567891LL;
inline long long sub_(long long a, long long b) {
return (a + H - (b % H)) % H;
}
inline long long add_(long long a, long long b) {
return (a + b) % H;
}
inline long long mul_(long long a, long long b) {
return ((long long)(a % H) * (b % H)) % H;
}
long long c_(int n, int r)
{
if (n == 0 || r == 0 || r == n) return 1;
if (r > n-r) r = n-r;
return (c_(n-1,r-1) * n / r) % H;
}
long long expt_(int n, int k) { // n ^ k
long long p = 1LL;
for (int i=0; i<k; i++) p = mul_(p, n);
return p;
}
long long fac_(int n) { // n !
long long p = 1LL;
for (int k=n; k>1; k--) p = mul_(p, k);
return p;
}
class TheLongPalindrome {
private:
long long f_(int k, int len) {
if (k == 1) return 1;
if (k == len) return fac_(k);
long long t = 0;
for (int j=k,pm=1; j>=1; j--,pm=-pm)
t = (t + pm * expt_(j,len) * c_(k,j)) % H;
return t;
}
public:
int count(int n, int k) {
vector<int> expts(27,1);
int h = (n + 1) / 2, m = n % 2;
if (k > h) k = h;
long long c = 0LL;
for (int len=1; len<=h; len++) {
// i:何文字長?
int k_ = min(len,k); // 字種数
long long o = 0LL;
for (int j=1; j<=k_; j++)
o = add_(o, mul_(c_(26,j), f_(j,len))); // o += 26Cj x f(j,len)
c = (len == h && m == 1)? add_(c,o) : add_(c,o*2);
}
return (int)c;
}
};
- サンプルケース4つは通るが、n=1000000000とか渡したら死ぬのは目に見えている
- 長さlen/2(奇数長の場合は(len+1)/2)の文字列で、1〜k種類の文字が使われる、とは言っても len/2 < k の場合、一度に使える文字数はlen/2種類
- というわけで len/2 = k の上と下で場合分け
- 累乗の計算を速いやつに置き換える
ここで関数 f_() がどのように展開されるかを分析し、パラメータから結果が直接得られるような式が書けないか考える。結論から言うと、len/2 > k の部分については、len/2 = [k+1..n/2] の範囲で
26C1 x { 25C0 - 25C1 + 25C2 - ... ± 25C(k-1) } x 1^(len/2)
- 26C2 x { 24C0 - 24C1 + 24C2 - ... ± 24C(k-2) } x 2^(len/2)
- ...
- 26Ck x k^(len/2)
の総和を求めればよい。(len/2 <= k の部分は最初のやり方で構わない)
ここで係数 26C1 x { 25C0 - 25C1 + 25C2 - ... ± 25C(k-1) } の部分は文字列長に関わらず一定なので計算は一度だけ行い、これに Σk^(len/2) を掛け合わせる。等比数列の和とか・・・20年来使っていない式を活用。計算量は...O(k^2*log(len))かな。
最終的なコードはこんな感じ:
#include <vector>
using namespace std;
const long long H = 1234567891LL;
inline long long sub_(long long a, long long b) {
return (a + H - (b % H)) % H;
}
inline long long add_(long long a, long long b) {
return (a + b) % H;
}
inline long long mul_(long long a, long long b) {
return ((a % H) * (b % H)) % H;
}
long long fast_expt(long long r, long long n) { // r^n % H
if (n == 1) return r;
if (n % 2 == 0)
return fast_expt(mul_(r,r), n/2);
else
return mul_(fast_expt(r,n-1), r);
}
inline long long div_(long long a, long long b) {
return mul_(a, fast_expt(b,H-2));
}
long long geo(int r, int n) {
// 1,r,r^2,...,r^n
if (r == 1) return n % H;
return div_(fast_expt(r,n+1)-1, r-1);
}
long long C(int n, int r)
{
if (n == 0 || r == 0 || r == n) return 1;
if (r > n-r) r = n-r;
return C(n-1,r-1) * n / r;
}
long long fac_(int n) { // n !
long long p = 1LL;
for (int k=n; k>1; k--) p = mul_(p, k);
return p;
}
vector<long long> coeffs_(int k) {
vector<long long> cs(k+1, 1LL); cs[0] = 0;
for (int j=1; j<=k; j++) {
long long t = 0;
for (int i=0,pm=1; i<=k-j; i++,pm=-pm) {
t += C(26-j, i) * pm;
}
cs[j] = C(26, j) * t;
}
return cs;
}
class TheLongPalindrome {
long long f_(int k, int len) {
if (k == 1) return 1;
if (k == len) return fac_(k);
long long t = 0;
for (int j=k,pm=1; j>=1; j--,pm=-pm) {
if (pm >= 0)
t = add_(t, fast_expt(j,len) * C(k,j));
else
t = sub_(t, fast_expt(j,len) * C(k,j));
}
return t;
}
public:
int count(int n, int k) {
int h = (n + 1) / 2, m = n % 2;
if (k > h) k = h;
vector<long long> expts(k+1, 1LL); expts[0] = 0;
vector<long long> coeffs = coeffs_(k);
long long c = 0LL;
for (int len=1; len<=k; len++) {
int k_ = len;
long long o = 0LL;
for (int j=1; j<=k_; j++)
o = add_(o, mul_(C(26,j), f_(j,len))); // o += 26Cj x f(j,len)
c = (len == h && m == 1)? add_(c,o) : add_(c,o*2);
}
if (h > k) {
long long o = 0LL;
for (int r=1; r<=k; r++) {
long long co = coeffs[r];
if (co >= 0)
o = add_(o, mul_(co, sub_(geo(r,h-1), geo(r,k))));
else
o = sub_(o, mul_(-co, sub_(geo(r,h-1), geo(r,k))));
}
c = add_(c, o*2);
o = 0LL;
for (int r=1; r<=k; r++) {
long long co = coeffs[r];
if (co >= 0)
o = add_(o, mul_(co, fast_expt(r,h)));
else
o = sub_(o, mul_(-co, fast_expt(r,h)));
}
c = (m == 1) ? add_(c, o) : add_(c, o*2);
}
return (int)c;
}
};
最後の要素で、全体長が偶数の場合の加算を忘れていたためにCase 3の数値が合わず、剰余演算のミスかと思って小一時間悩んだ。
このコードでは、n=1000000000な最悪ケースでも(コンテストサーバ時間で)3msec以内で演算が終了する。
当然ながら問題は、このコードを75分の間に書いてsubmitできるか、である・・・><
追記
(r^(n+1)-1)/(r-1)%1234567891 の計算で使うdiv_() のコードは、fast_expt()で法にあたる数値を引数の1つとして余分に取るものを使った別の書き方をしていた。が剰余演算のミスかと思って悩んだ間に差し替えた。原因はここではなかったが良い勉強になった。よい機会なので剰余系のライブラリを整備しておこう。
コメント
トラックバック - https://topcoder-g-hatena-ne-jp.jag-icpc.org/n4_t/20081211