バイトの競プロメモ

主に競技プログラミング

AtCoder Grand Contest 039 C - Division by Two with Something

理解するのに時間がかかったので書いてみます

 

問題文

https://atcoder.jp/contests/agc039/tasks/agc039_c

 

まず、文中の操作は下のように、一番右のビットを反転して左へ移動する操作と言い換えられます

f:id:baitop:20191009183634p:plain

N回の操作で全ビットが反転し、2N回の操作で元に戻ります

ここで元に戻るような操作回数をkとすると、kとしてあり得るのは2Nの約数です(k回の操作を複数回繰り返した場合も元に戻り、その中に2Nが含まれるため)

またN回操作をしても元に戻らないことからkはNの約数ではなく、kは偶数だと分かります

 

k回の操作で元に戻るようなビット列の条件を考えます

大まかに考えると、k個ずれても同じ列である事から、下のように長さkの同じ列が並んだ形になっているはずです

f:id:baitop:20191009190114p:plain

 

先頭を具体的に考えると、Nはkの倍数ではない事から、途中で途切れた形になることが分かります

f:id:baitop:20191009190810p:plain

 

また先頭の長さは必ずk/2になります(

(2N % k = 0 より 先頭の長さr = N%kで

(N+N)%k = 0

(N%k + N%k)% k = 0

(r+r)%k = 0

r = k / 2 となるため)

 

上の列にk回操作を施した列も同じであることから

長さkの文字列に必要な条件は010 101 のようにs ~sと書ける事だと分かります

(下の図で a = ~c, b = ~d)

 

f:id:baitop:20191009192042p:plain


次に各kについて、上の条件を満たしX以下である列を数えます

上の数字がXです

f:id:baitop:20191010140026p:plain

 

ここでkで数えた物は kの倍数でも重複して数えられてしまう

(たとえばN=9,k = 2で 0 10101010は k=6で 010 101010 )

ためkをカウントした際に、kの倍数から予めその場合の数を引いておく

下ではエラトステネスの篩みたいな感じでやった

 https://atcoder.jp/contests/agc039/submissions/7925235

#include <bits/stdc++.h>

using namespace std;
using vi = vector<int>;
#define int long long
#define rep(i, n) for (int i = 0; i < n; i++)
#define rer(i, n) for (int i = n; i >= 0; i--)
#define sz(a) (int)a.size()


int MOD = 998244353;
struct mint {
int x;
mint() : x(0) {}
mint(int a) {x = a % MOD;if (x < 0) x += MOD;}
mint &operator+=(mint that) {x = (x + that.x) % MOD;return *this;}
mint &operator-=(mint that) {x = (x + MOD - that.x) % MOD;return *this;}
mint &operator*=(mint that) {x = (int) x * that.x % MOD;return *this;}
mint &operator/=(mint that) { return *this *= that.inverse(); }
mint operator-() { return mint(-this->x); }
friend ostream &operator<<(ostream &out, mint m) { return out << m.x; }
mint inverse() {int a = x, b = MOD, u = 1, v = 0;while (b) {int t = a / b;a -= t * b;u -= t * v;swap(a, b);swap(u, v);}return mint(u);}
operator int() const { return x; }
template<class T> mint &operator+=(T that) { return operator+=((mint) that); }
template<class T> mint &operator-=(T that) { return operator-=((mint) that); }
template<class T> mint &operator*=(T that) { return operator*=((mint) that); }
template<class T> mint &operator/=(T that) { return operator/=(that); }
mint operator+(mint that) { return mint(*this) += that; }
mint operator-(mint that) { return mint(*this) -= that; }
mint operator*(mint that) { return mint(*this) *= that; }
mint operator/(mint that) { return mint(*this) /= that; }
template<class T> mint operator+(T that) { return mint(*this) += (mint) that; }
template<class T> mint operator-(T that) { return mint(*this) -= (mint) that; }
template<class T> mint operator*(T that) { return mint(*this) *= (mint) that; }
template<class T> mint operator/(T that) { return mint(*this) /= (mint) that; }
bool operator==(mint that) const { return x == that.x; }
bool operator!=(mint that) const { return x != that.x; }
bool operator<(mint that) const { return x < that.x; }
bool operator<=(mint that) const { return x <= that.x; }
bool operator>(mint that) const { return x > that.x; }
bool operator>=(mint that) const { return x >= that.x; }
};
typedef vector<mint> vm;
template<typename T, typename U> mint mpow(const T a, const U b) {
assert(b >= 0);
int x = a, res = 1;
U p = b;
while (p > 0) {
if (p & 1) (res *= x) %= MOD;
(x *= x) %= MOD;
p >>= 1;
}
return res;
}
using vm = vector<mint>;



//aをbit列と見なし反転
string operator~(string &a) {
string res = a;
for (auto &&c:res) {
if (c == '0')c = '1';
else if (c == '1')c = '0';
else {
cerr << "cant ~" << a << "must bit" << endl;
exit(0);
}
}
return res;
}

vi divisors(int v) {
vi res;
double lim = std::sqrt(v);
for (int i = 1; i <= lim; ++i) { if (v % i == 0) { res.push_back(i);/* if (i != v / i)res.push_back(v / i);*/ }}
for (int i = sz(res) - 1; i >= 0; i--) { if (res[i] != v / res[i])res.push_back(v / res[i]); }
return res;
}
constexpr int bsetlen = 202020;
bool operator<(bitset<bsetlen> &a, bitset<bsetlen> &b) {rer(i, bsetlen - 1) {if (a[i] < b[i])return true;if (a[i] > b[i])return false;}return false;}
bool operator>(bitset<bsetlen> &a, bitset<bsetlen> &b) {rer(i, bsetlen - 1) {if (a[i] > b[i])return true;if (a[i] < b[i])return false;}return false;}
bool operator<=(bitset<bsetlen> &a, bitset<bsetlen> &b) {rer(i, bsetlen - 1) {if (a[i] < b[i])return true;if (a[i] > b[i])return false;}return true;}
bool operator>=(bitset<bsetlen> &a, bitset<bsetlen> &b) {rer(i, bsetlen - 1) {if (a[i] > b[i])return true;if (a[i] < b[i])return false;}return true;}


signed main() {
int N;
string X;
cin >> N >> X;
//Kを操作回数として扱っている
vi K;
{
auto D = divisors(N * 2);
for (auto &&k: D) {
if ((k % 2) == 0)K.push_back(k);//kは偶数でN*2の約数
}
}
int lim = N * 2 + 1;
vector<mint> cou(lim);//周期が[k]となるような物の数
mint res = 0;

bitset<bsetlen> bx(X);
for (auto &&k:K) {
mint part = 0;
//先頭k/2桁を決めれば後は一意に定まる
//先頭k/2桁の自明に小さい取り方を数える
rep(i, k / 2) {
if (X[i] == '1')part += mpow(2, k / 2 - 1 - i);
}
{
string tail = X.substr(0, k / 2);
string head = ~tail;//mainの上

string tx = tail;
string s = head+tail;
//size
while (sz(tx) < N) tx += s;
if (sz(tx) != N)continue; //予防だが使われない txのサイズは必ずN
bitset<bsetlen> btx(tx);
if (bx >= btx) part += 1;
}
cou[k] += part;
res += cou[k] * k;
for (int s = k * 2; s < lim; s += k) {
cou[s] -= cou[k];//周期kはkの倍数sでも数えられるため引いておく
}
}
cout << res << endl;
return 0;
}