バイトの競プロメモ

主に競技プログラミング

D - 漸化式 AtCoder Beginner Contest 009

問題概略
長さKの数列A,Cが与えられる。
n >= 1で
An+k = (c1 & An+k-1) xor (c2 & An+k-2) xor .....(ck & An)である時、
Amを求めよ

制約
M <= 100
N <= 10^9
A <= 2^32-1

&とxorが*,+なら行列のべき乗の形にして高速に解ける。
実は今回も同様に出来る。
32ビットの非負の数は&,xor上で半環であり、単位元や交換法則が成り立つので。
単位元には~0Lを使えば良い

public static void solve() throws Exception {
        //longを忘れるなオーバーフローするぞ
        K = ni();
        M = ni();
        A = nla(K);
        C = nla(K);
        long[][] l = new long[K][K];
        for (int i = 0; i < K; i++) {
            l[0][i] = C[i];
        }
        for (int i = 0; i < K - 1; i++) {
            l[i + 1][i] = ~0L;
        }
        long[][] r = new long[K][1];
        for (int i = 0; i < K; i++) {
            r[i][0] = A[K - i - 1];
        }
        if (M <= K) {
            System.out.println(A[M - 1]);
            return;
        }

        long[][] res = matMul(matPow(l, M - K), r);
        System.out.println(res[0][0]);
    }

    public static void matPrint(long[][] a) {
        for (int hi = 0; hi < a.length; hi++) {
            for (int wi = 0; wi < a[0].length; wi++) {
                System.out.print(a[hi][wi] + " ");
            }
            System.out.println("");
        }
    }

    //rにlを掛ける l * r
    public static long[][] matMul(long[][] l, long[][] r) throws IOException {
        int lh = l.length;
        int lw = l[0].length;
        int rh = r.length;
        int rw = r[0].length;
        //lwとrhが,同じである必要がある
        if (lw != rh) throw new IOException();
        long[][] res = new long[lh][rw];
        for (int i = 0; i < lh; i++) {
            for (int j = 0; j < rw; j++) {
                for (int k = 0; k < lw; k++) {
                    res[i][j] ^= l[i][k] & r[k][j];
                }
            }
        }
        return res;
    }

    public static long[][] matPow(long[][] a, int n) throws IOException {
        int h = a.length;
        int w = a[0].length;
        if (h != w) throw new IOException();
        long[][] res = new long[h][h];
        for (int i = 0; i < h; i++) {
            res[i][i] = ~0L;
        }
        long[][] pow = a.clone();
        while (n > 0) {
            if (bitGet(n, 0)) res = matMul(pow, res);
            pow = matMul(pow, pow);
            n >>= 1;
        }
        return res;
    }