1629번: 곱셈

첫째 줄에 A, B, C가 빈 칸을 사이에 두고 순서대로 주어진다. A, B, C는 모두 2,147,483,647 이하의 자연수이다.

www.acmicpc.net

문제는 직관적이다. A, B, C를 받아 A^B % C를 구하는 문제이다.

문제는 시간 제한이 0.5초, 메모리 제한이 128MB라는 것이다.

 

A, B, C = map(int, input().split())

ans = 1
for _ in range(B):
    ans *= A
    
print(ans % C)

 

당연히 정답은 아니겠지만 그래도 확인차 테스트했다.

A를 B만큼 반복해서 곱한 다음, 마지막에 C로 나눈 나머지를 구하는 것이다.

이 경우에는 A, B, C가 각각 2,147,483,647까지 가능하여 엄청나게 수가 커질 수 있다.

이 경우에는 '메모리 초과'가 뜬다.

 

A, B, C = map(int, input().split())

ans = 1
for _ in range(B):
    ans = ans * A % C
    
print(ans)

메모리를 절약하기 위해서 매 계산마다 C로 나눠주었다.

이는 정수론에서 MOD라고 생각하면 이해가 쉽다.

나머지를 곱해서 또 나머지를 구하나, 마지막에서 한 번에 나머지를 구하나 값은 같다.

 

A = 7, B = 4, C = 3이라고 가정해보자.

7^4는 2401이고 3으로 나누면 나머지가 1이 나온다.

7을 3으로 나눈 나머지는 1, 7을 곱하고 또 나누고, 반복하다보면 1로 같음을 확인할 수 있다.

 

하지만 이 방법은 B번만큼 코드가 실행하기에 시간복잡도는 O(B)가 된다.

// 우리에게 친숙한 형태는 O(N)이다.

메모리는 아꼈지만 이 경우에는 '시간 초과'가 뜬다.

 

수학적 사고가 필요한 순간이다.

정수론 모듈러 연산에서 '고속 지수 연산법'이 있다.

 

위의 사진을 예를 들면 A = 11, B = 7, C = 13이라고 해보자.

> 11^2 = 121 ≡ 4 mod 13

> 11^4 = 11^2 * 11^2 = 4 * 4 = 16 ≡ 3 mod 13

> 11^7 = 11^4 * 11^2 * 11 = 3 * 4 * 11 = 132 ≡ 2 mod 13

즉 2가 된다.

 

다른 방법으로 이분 탐색처럼 a^N의 값을 a^(N/2)로 나누어 2번 곱하는 식으로 잘게 나누는 거다.

7은 홀수이기에 6으로 계산한 뒤, 곱하려는 값을 한 번 더 곱해준다.

> 11^7 = 11^3 * 11^3 * 11

> 11^3 ≡ 11^2 * 11 = 121 mod 13 * 11 = 4 * 11 = 44 = 5 mod 13

> 11^7 = 5 * 5 * 11 = 275 ≡ 2 mod 13

즉 2가 된다.

 

컴퓨터에서는 위의 경우처럼 적절하게 올라가는 것이 힘들다.

고로 위에서부터 이분 탐색하듯이 작은 상황으로 잘게 나누는 방식을 사용한다.

이것이 곧 분할 정복이다.

 

if B odd:	then A^(B//2) * A^(B//2) * A
if B even:	then A^(B//2) * A^(B//2)

식을 이렇게 정의할수 있지만, 같은 식이 자주 나타나기에 아래처럼 바꿀 수 있다.

 

tmp = A^(B//2)

if B odd:	then tmp * tmp * A
if B even:	then tmp * tmp

이를 함수로 바꾸면 아래처럼 된다.

 

def func(A, B):

	tmp = func(A, (B//2))

    if B odd:	then tmp * tmp * A
    if B even:	then tmp * tmp

이를 사용해 코드를 작성하면 아래와 같다.

 

# ---------- Import ----------
import sys
# import time
input = sys.stdin.readline

# ---------- Function ---------
def POW(A, B, C):
    if B == 1:
        return A % C
    
    tmp = POW(A, B//2, C)
    if B % 2 == 0:
        return (tmp * tmp % C)
    else:
        return (tmp * tmp * A % C)

# ---------- Main ----------
A, B, C = map(int, input().split())

# s = time.perf_counter()
result = POW(A, B, C)
# e = time.perf_counter()
print(result)
#print(e-s)

주석으로 바꾼 부분은 시간을 확인하기 위해서 작성한 코드이다.

단순 구현으로는 점점 한계가 다가와 수학이 필요함을 느끼고 있다.

+ Recent posts