import collections import functools from typing import Any def _parse_input(input_data: str) -> tuple[dict[str, int], dict[int, list[int]]]: counter = 0 ids = {} graph = collections.defaultdict(list) for line in input_data.splitlines(): node, _neighs = line.split(": ") neighs = _neighs.split() if node not in ids: ids[node] = counter counter += 1 for n in filter(lambda x: x not in ids, neighs): ids[n] = counter counter += 1 for n in neighs: graph[ids[n]].append(ids[node]) return ids, graph def _count_paths( graph: dict[int, list[int]], start: int, end: int, skip: set[int] | None = None, ) -> int: if skip is None: skip = set() @functools.cache def traverse(node: int): if node in skip: return 0 if node == start: return 1 return sum(traverse(n) for n in graph[node]) return traverse(end) def part_1(input_data: str) -> Any: ids, graph = _parse_input(input_data) print(graph) you = ids["you"] out = ids["out"] return _count_paths(graph, you, out) def part_2(input_data: str) -> Any: ids, graph = _parse_input(input_data) svr = ids["svr"] out = ids["out"] dac = ids["dac"] fft = ids["fft"] svr_dac = _count_paths(graph, svr, dac, {out, fft}) svr_fft = _count_paths(graph, svr, fft, {out, dac}) dac_fft = _count_paths(graph, dac, fft, {out}) fft_dac = _count_paths(graph, fft, dac, {out}) dac_out = _count_paths(graph, dac, out, {fft}) fft_out = _count_paths(graph, fft, out, {dac}) return svr_dac * dac_fft * fft_out + svr_fft * fft_dac * dac_out def test_part_1(example_data): assert part_1(example_data) == 5 def test_part_2(example_data): assert part_2(example_data) == 1