본문 바로가기

Algorithm/Algorithm Study

트라이(Trie)

문자열 관련 알고리즘은 정말 다루기 힘든편에 속한다. 하지만 그중에도 잘만 이해하면 유용하게 사용할 수 있는 알고리즘이 몇있다. 그중에서 트라이(Trie)에 대해서 얘기해보고자 한다.


트라이 - 나무위키

https://namu.wiki/w/%ED%8A%B8%EB%9D%BC%EC%9D%B4


트라이 알고리즘은 여러 단어들이 주어진 경우, 특정 단어가 예시로 주어진 단어들 중에 포함되는가를 빠르게 파악할 수 있는 알고리즘이다. 예를들어 n개의 단어가 주어져있고, 각 문자열의 길이가 m인 경우 단순하게 이 단어들중에 특정 단어가 포함되어있는지 확인하려면 최대 O(nm)의 시간 복잡도를 갖게된다. 하지만 트라이를 사용하면 시간복잡도가 O(m)으로 압도적으로 단축되게 된다. 어떻게 이렇게 시간이 줄어들 수 있는지 생각해보자.


트라이 알고리즘에서는 단어들이 주어지면 단어들을 문자단위로 분해하여 자료구조를 형성하게 된다. 문자열의 각 문자들이 개별노드가 되는데 개별노드를 형성해서 추가될 때는 기존에 존재하는 노드들까지 순차적으로 타고가다가 더이상 타고갈 수 있는 노드가 없게되면 새로운 노드를 생성해서 기존의 노드에 추가해주게 된다. 일종의 가지치기인 셈이다. 예를 들어 print, python, pytorch라는 세 단어가 주어졌다고 하자. print가 처음 트라이에 추가될 때는 어떠한 노드들도 없었기 때문에 루트를 시작점으로하여 각 글자들을 노드로 하여 트라이 구조를 만들어주게된다. 그다음 단어인 python을 보면 print의 경우와 조금 다르다. p가 겹치기 때문이다. 이경우 p까지는 print의 p를 사용해도 된다는 것을 알 수 있다. 그렇기 때문에 신규로 추가될 노드들은 p를 제외한 ython이라는 노드들이다. 마지막 pytorch를 보면 좀더 확실히 이해할 수 있다. pytorch는 python와 pyt 부분이 겹친다는 것을 알 수 있다. 그렇기 때문에 pytorch는 python의 pyt부분까지 같은 노드를 공유할 수 있다. 그렇기 때문에 orch 노드만 추가해주면 된다. 최종적으로 이 세단어의 트라이는 아래와 같이 구성된다.




붉은색의 루트노드에서 p가 파생되고 print의 p를 제외한 부분이 하나의 가지를 이루며, python의 ython이 또 하나의 가지를 이루고 pytorch의 pyt를 제외한 orch 부분이 하나의 가지를 이루게 된다.


트라이가 완성되면 특정 단어가 이 트라이 구조 안에 포함되어있는지 아주 쉽게 확인할 수 있다. 에를 들어 python이 포함되어있는지 확인한다고 하면 루트에서 p를 가르키는 주소에 노드가 존재하는지 확인한다. 만일 p 노드가 존재하면 그 다음 글자인 y를 가르키는 주소에 노드가 존재하는지 확인하고 이런식으로 마지막 글자인 n까지 확인하게 되면 python이라는 단어가 트라이에 포함되어 있는지 알 수 있게된다. 특정 글자에 노드가 준비되어있는지 바로바로 확인하기 위해서는 나올수 있는 글자 숫자만큼 배열을 준비해서 관리하면 된다. 나올 수 있는 모든 글자가 소문자이면 26개 크기의 배열을 준비하면 되고, 0 ~ 9까지의 숫자라면 10개 크기의 배열을 사용하면 된다.


실제 문제에서는 트라이가 어떻게 이용되는지 확인해보자.


5052번: 전화번호 목록 - Baekjoon Online Judge

https://www.acmicpc.net/problem/5052


import java.io.BufferedReader;
import java.io.InputStreamReader;

public class Main {
	private static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	private static String str;
	private static int i;
	private static int res;

	static int nextInt() {
		try {
			i = Integer.parseInt(br.readLine().trim());
		} catch (Exception e) {
		}
		return i;
	}

	static String nextLine() {
		try {
			str = br.readLine().trim();
		} catch (Exception e) {
		}
		return str;
	}

	public static void main(String[] args) {
		int T = nextInt();

		for (int t = 1; t <= T; t++) {
			int N = nextInt();

			Node root = new Node('r', 0);
			for (int n = 1; n <= N; n++) {
				Node node = root;
				String str = nextLine();

				for (int i = 0; i < str.length(); i++) {
					char c = str.charAt(i);
					if (node.nxt[c - '0'] == null) {
						node.cnt++;
						node.nxt[c - '0'] = new Node(c, 0);
					}
					node = node.nxt[c - '0'];
				}

			}
			res = 0;
			searchLeaf(root);
			if (res != N) {
				System.out.println("NO");
			} else {
				System.out.println("YES");
			}
		}
	}

	private static void searchLeaf(Node node) {
		if (node.cnt == 0) {
			res++;
			return;
		}
		for (int i = 0; i < 10; i++) {
			if (node.nxt[i] != null) {
				searchLeaf(node.nxt[i]);
			}
		}
	}
}

class Node {
	char data;
	int cnt;
	Node[] nxt = new Node[10];

	public Node(char data, int cnt) {
		this.data = data;
		this.cnt = cnt;
	}
}