バイトの競プロメモ

主に競技プログラミング

D - Two Sequences AtCoder Regular Contest 092 

D - Two Sequences

 

長さNの数列 a,bが与えられる。

ai bjのN^2の全ての選び方について、ai + bjを計算しする。

それらのxorを求めよ

制約

  • 入力は全て整数
  • 1N200,000

解法

愚直に計算すると間に合わない。

xorは繰り上がりしないので、ある桁で1が奇数回足されればその桁は1になる。

0から数えてK桁目について考える。

ai bjを足した時、K桁目が1になるのはどんな条件の時だろうか?2 ^ KをTとすると、合計がTの時はK桁目だけが1になることが分かる。K桁目が繰り上がるまでは条件を満たす。具体的には[T 2T)の時だ。

同じように[3T 4T)の時にも条件を満たし、[5T 6T),[7T 8T)と続いていく。

 

ここで、K桁目に影響を与えるのはK行目以下であるので、a,bの下位Kビットだけ取り出す。そうすればa,bは 2T未満のため、a + b < 4Tとなるため、ai bj の合計が[T 2T) [3T 4T)となる個数だけ調べれば良くなる。

その個数回、K桁目でxorが行われ、それが奇数なら桁Kは1だ。

 

それぞれの合計の区間は、bをソートしておけば一つのaについてO(log n)で求められる

のでO(n log n)となる

 

問題の芯

xor演算は繰り上がりがないので、桁ごとに独立して考えられる

調べる範囲が複数ある場合、工夫をして調べる範囲を減らしたい。

今回の場合、下位Kビットだけ見れば良いってかんじに

こうなるのはどういう場合、というのを考察しよう(K桁目が1になるのは~)

何と何が足される時という風に,2つのものについて考えると難しくなるので、合計が〇〇なら~という風に考えたい

 

public static void main(String[] args)
{
        N = sc.nextInt();
        a = new long[N];
        b = new long[N];
        c = sc.nextLongArray(N);
        d = sc.nextLongArray(N);
        long res = 0;
        for (int i = 0; i < 29; i++)
        {
            long cou = 0;
            long T = 1 << i;
            for (int j = 0; j < N; j++)
            {
                a[j] = c[j] % (2 * T);
                b[j] = d[j] % (2 * T);
            }
            Arrays.sort(b);
            //a + b は 4T 未満
            for (int ai = 0; ai < N; ai++)
            {
                //かっこいい                      //2つの合計がT以上2T未満になる個数
                cou += lowerBound(b, 2 * T - a[ai]) - lowerBound(b, 1 * T - a[ai]); 
                cou += lowerBound(b, 4 * T - a[ai]) - lowerBound(b, 3 * T - a[ai]);
            }
            res |= (cou & 1) << i;
        }
        System.out.println(res);
    }
    /**
     * <h1>指定した値以上の先頭のインデクスを返す</h1>
     * <p>配列要素が0のときは、0が返る。</p>
     *
     * @return<b>int</b> : 探索した値以上で、先頭になるインデクス
     * 値が無ければ、挿入できる最小のインデックス
     */
    public static int lowerBound(final int[] arr, final int value)
    {
        int low = 0;
        int high = arr.length;
        int mid;

        while (low < high)
        {
            mid = ((high - low) >>> 1) + low;    //(low + high) / 2 (オーバーフロー対策)
            if (arr[mid] < value)
            {
                low = mid + 1;
            }
            else
            {
                high = mid;
            }
        }
        return low;
    }