티스토리 뷰

PS/BOJ

[Python] [백준] 11438번: LCA 2

개발을해보자 2021. 3. 5. 15:59

www.acmicpc.net/problem/11438

 

11438번: LCA 2

첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정

www.acmicpc.net

 

문제

N(2 ≤ N ≤ 100,000)개의 정점으로 이루어진 트리가 주어진다. 트리의 각 정점은 1번부터 N번까지 번호가 매겨져 있으며, 루트는 1번이다.

두 노드의 쌍 M(1 ≤ M ≤ 100,000)개가 주어졌을 때, 두 노드의 가장 가까운 공통 조상이 몇 번인지 출력한다.

입력

첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정점 쌍이 주어진다.

출력

M개의 줄에 차례대로 입력받은 두 정점의 가장 가까운 공통 조상을 출력한다.

 

 

 

 

풀이

동빈나 님의 최소 공통 조상 알고리즘 영상을 참고하였습니다.

youtu.be/O895NbxirM8

노드마다 2**i번째 부모에 대한 정보까지 저장할 수 있는 parent 리스트를 만든다.

이후 각 노드의 2**i 부모 정보까지 모두 갱신하고 lca 알고리즘을 수행한다.

11437번 문제처럼 알고리즘 자체는 크게 다르지 않지만, 조상을 찾기 위해 거슬러 올라가는 깊이가 2**i라는 것에 차이가 있다.

트리를 그리면서 푸니까 겨우 이해가 됐던 어려운 문제이다..

 

[파이썬으로 구현한 코드]

import sys
sys.setrecursionlimit(100000)
input = sys.stdin.readline
LENGTH = 21

n = int(input())
parent = [[0] * LENGTH for _ in range(n + 1)]
visited = [False] * (n + 1)
d = [0] * (n + 1)
graph = [[] for _ in range(n + 1)]

for _ in range(n - 1):
    a, b = map(int, input().split())
    graph[a].append(b)
    graph[b].append(a)


def dfs(x, depth):
    visited[x] = True
    d[x] = depth

    for node in graph[x]:
        if visited[node]:
            continue
        # 우선 바로 위에 있는 부모 정보만 갱신
        parent[node][0] = x
        dfs(node, depth + 1)


# 모든 노드의 전체 부모 관계 갱신하기
def set_parent():
    dfs(1, 0)
    for i in range(1, LENGTH):
        for j in range(1, n + 1):
            # 각 노드에 대해 2**i번째 부모 정보 갱신
            parent[j][i] = parent[parent[j][i - 1]][i - 1]


def lca(a, b):
    # 무조건 b의 깊이가 더 깊도록 설정
    if d[a] > d[b]:
        a, b = b, a

    # a와 b의 깊이가 동일해주도록 설정
    for i in range(LENGTH - 1, -1, -1):
        if d[b] - d[a] >= 2**i:
            b = parent[b][i]

    if a == b:
        return a

    # 올라가면서 공통 조상 찾기
    for i in range(LENGTH - 1, -1, -1):
        if parent[a][i] != parent[b][i]:
            a = parent[a][i]
            b = parent[b][i]

    return parent[a][0]


set_parent()

m = int(input())

for _ in range(m):
    a, b = map(int, input().split())
    print(lca(a, b))

 

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
글 보관함