- N, M が与えられる。整数 1〜N の部分集合 A と 1〜M の部分集合 B を互いに重複なしで選んだとき、A の全要素の XOR < B の全要素の XOR となる選び方は何通りあるか。mod 1,000,000,007 で求める。
- 1≦N,M≦2000
- 1〜2000 の部分集合の全要素の XOR が 0 以上 2048 未満なのがポイントっぽい
int dp[2][2048][2048];
としてみると
初期値は dp[0][0][0] = 1
更新は
i in [1, max(N, M)], xa,xb in [0, 2048) について
数字 i を使わない場合 → dp[i][xa][xb] = dp[i-1][xa][xb]
i≦N かつ数字 i を A に入れる場合 → dp[i][xa^i][xb] += dp[i-1][xa][xb]
i≦M かつ数字 i を B に入れる場合 → dp[i][xa][xb^i] += dp[i-1][xa][xb]
答え = Σ dp[max(N,M)][xa][xb] ただし xa<xb
- 書いてみてサンプル通ったけどそういえば最内ループの処理回数が 2000*2048*2048 になってだめだ! (終了数分前)
- おわり
- 8 分で解いた Petr の解答と kinaba さんのツイートを見る
- 基本は上のを使うんだけど、あと2ステップ必要っぽい
int dp[2][2048][2048];
class WinterAndSnowmenSameBits {
public:
int getNumber(int N, int M) {
CLEAR(dp, 0);
int men=0;
dp[men][0][0]=1;
int xors = 1;
while(xors <= max(N, M)) xors*=2;
RANGE(i, 1, max(N, M)+1) {
CLEAR(dp[men^1], 0);
REP(xa, xors) REP(xb, xors) {
if(dp[men][xa][xb]==0) continue;
PLUS(dp[men^1][xa][xb], dp[men][xa][xb]);
if(i<N+1) PLUS(dp[men^1][xa^(i)][xb], dp[men][xa][xb]);
if(i<M+1) PLUS(dp[men^1][xa][xb^(i)], dp[men][xa][xb]);
}
men^=1;
}
int ans=0;
REP(xa, xors) REP(xb, xors) if(xa<xb) PLUS(ans, dp[men][xa][xb]);
return ans;
}
};
- (2) dp定義は変えず、全 XA<XB のペアを「上から見て same_bits ビットは同じでその次のビットで XA<XB が決まる」ような仲間にグループわけする。(same_bits in [0, 最大ビット幅) )
- グループ分けしただけなのでまだ遅い
int dp[2][2048][2048];
class WinterAndSnowmen {
public:
int getNumber(int N, int M) {
int bits = 1;
while(1<<bits <= max(N, M)) bits++;
int xors = 1<<bits;
int ans=0;
REP(same_bits, bits) {
CLEAR(dp, 0);
int men=0;
dp[men][0][0]=1;
RANGE(i, 1, max(N, M)+1) {
CLEAR(dp[men^1], 0);
REP(xa, xors) REP(xb, xors) {
if(dp[men][xa][xb]==0) continue;
PLUS(dp[men^1][xa][xb], dp[men][xa][xb]);
if(i<N+1) PLUS(dp[men^1][xa^(i)][xb], dp[men][xa][xb]);
if(i<M+1) PLUS(dp[men^1][xa][xb^(i)], dp[men][xa][xb]);
}
men^=1;
}
REP(xa, xors) REP(xb, xors) if(((xa^xb)>>(bits-same_bits))==0 && ((xa^xb)>>(bits-same_bits-1))&1 && xa<xb) PLUS(ans, dp[men][xa][xb]);
}
return ans;
}
};
- (3) さらに dp の定義を以下のように変えると状態数が 2*2048*2*2 に減る(うおおぉー)
- 処理回数も 11*2000*11*2*2 = 968000 以下となって間に合う
int dp[2][2048][2][2];
#define PLUS(a, b) (a) = ((ll)(a)+(b))%mod
class WinterAndSnowmen {
public:
int getNumber(int N, int M) {
int bits = 1;
while(1<<bits <= max(N, M)) bits++;
int ans=0;
REP(same_bits, bits) {
CLEAR(dp, 0);
int men=0;
dp[men][0][0][0]=1;
RANGE(i, 1, max(N, M)+1) {
CLEAR(dp[men^1], 0);
REP(prefix, 1<<same_bits) REP(diffA, 2) REP(diffB, 2) {
if(dp[men][prefix][diffA][diffB]==0) continue;
int iPrefix = i>>(bits-same_bits);
int iDiff = (i>>(bits-same_bits-1))&1;
PLUS(dp[men^1][prefix][diffA][diffB], dp[men][prefix][diffA][diffB]);
if(i<N+1) PLUS(dp[men^1][prefix^iPrefix][diffA^iDiff][diffB], dp[men][prefix][diffA][diffB]);
if(i<M+1) PLUS(dp[men^1][prefix^iPrefix][diffA][diffB^iDiff], dp[men][prefix][diffA][diffB]);
}
men^=1;
}
PLUS(ans, dp[men][0][0][1]);
}
return ans;
}
};