Algorithm

행렬 제곱 (정방행렬) <백준10830번 문제>

whyWhale 2020. 11. 11. 07:55

 

 

 

안녕하세요. 이번에는 행렬의 제곱에 대하여 설명드리겠습니다.

 

먼저 예를 하나 들어보겠습니다.

 

행렬 의 (18,19)승을 구하고 싶다고 하면, 어떻게 계산하는 것이 빠르겠습니까?

 

가장 먼저 생각하시는것이 순차적으로 A * A 을 구하고 또 A^2 * A  ... 이런식으로 하나하나 곱해가는 방식이 가장 먼저 생각이 나실 겁니다!!!!

 

혹시 여기에서 단서가 있지 않을 까 싶어 노트에 적고 생각을 해보았습니다. 

 

기본적으로 A*A -> A^2 을 구하게 되고  A^2을 가지고 A^2 * A^2 곱하게 되면  지수가 점점 커지는 것을 알 수 있습니다.  행렬읮 제곱을 구할 때 지수가 짝수일 때는 A^2 을가지고 더 시간을 단축시킬 수 있다는 것을 알게되며 홀수 일때는 단순히 홀수-1 => 짝수가 될 수 있는 원리를 생각하고 거기에 A를 더해주면 되겠다고 생각을 하게 되었습니다.

 

분할정복이라는 문제유형을 알고 분할정복의 방식으로 일단 가장 작은 단위로 나뉘어서 점점 작은 단위의 문제를 해결하게 되는 방식으로 코드를 만들어 보았습니다.   

 

감사합니다.

 

 

import java.util.Scanner;

 class MatrixSquare {
    static int n;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int n = sc.nextInt(); // 행열의 값
        long b = sc.nextLong(); // 지수
        sc.nextLine();
        long arr[][] = new long[n][n];
        for (int i = 0; i < n; i++) {
            String str[] = sc.nextLine().split(" ");
            for (int j = 0; j < str.length; j++) {
                arr[i][j] = Integer.parseInt(str[j]);
            }
        }
        long[][] res = new long[n][n];
        res = divide(arr, b);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print(res[i][j] % 1000 + " ");
            }
            System.out.println();
        }
    }

    public static long[][] divide(long[][] arr, long b) {
        if (b == 0) {
            long[][] tmp = new long[n][n];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    tmp[i][j] = 1;
                }
            }
            return tmp;
        }
        if (b == 1)
            return arr;

// 홀일떄.. divide 하고 b-1 상태로 짝수로 만듦.
        if (b % 2 == 1) {
            long[][] tmp = divide(arr, b - 1);
            return square(arr, tmp);
        }
// 짝수일때.. divide 도 계속 짝수로.
        else {
            long[][] tmp = divide(arr, b / 2);
            return square(tmp, tmp);
        }

    }

    public static long[][] square(long[][] arr, long[][] arr2) {
        long ans[][] = new long[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                long tmp = 0;
                for (int k = 0; k < n; k++) {
                    tmp += arr[i][k] * arr2[k][j] % 1000; // 일반적인 행렬 곱셈 원리.
                }
                ans[i][j] = tmp % 1000;
            }
        }
        return ans;
    }

    public static void print(long arr[][]) {
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print(arr[i][j] % 1000 + " ");
            }
            System.out.println();
        }
    }
}