87 lines
1.8 KiB
Python
87 lines
1.8 KiB
Python
import collections
|
|
import functools
|
|
from typing import Any
|
|
|
|
|
|
def _parse_input(input_data: str) -> tuple[dict[str, int], dict[int, list[int]]]:
|
|
counter = 0
|
|
ids = {}
|
|
|
|
graph = collections.defaultdict(list)
|
|
|
|
for line in input_data.splitlines():
|
|
node, _neighs = line.split(": ")
|
|
neighs = _neighs.split()
|
|
|
|
if node not in ids:
|
|
ids[node] = counter
|
|
counter += 1
|
|
|
|
for n in filter(lambda x: x not in ids, neighs):
|
|
ids[n] = counter
|
|
counter += 1
|
|
|
|
for n in neighs:
|
|
graph[ids[n]].append(ids[node])
|
|
|
|
return ids, graph
|
|
|
|
|
|
def _count_paths(
|
|
graph: dict[int, list[int]],
|
|
start: int,
|
|
end: int,
|
|
skip: set[int] | None = None,
|
|
) -> int:
|
|
if skip is None:
|
|
skip = set()
|
|
|
|
@functools.cache
|
|
def traverse(node: int):
|
|
if node in skip:
|
|
return 0
|
|
if node == start:
|
|
return 1
|
|
return sum(traverse(n) for n in graph[node])
|
|
|
|
return traverse(end)
|
|
|
|
|
|
def part_1(input_data: str) -> Any:
|
|
ids, graph = _parse_input(input_data)
|
|
print(graph)
|
|
|
|
you = ids["you"]
|
|
out = ids["out"]
|
|
|
|
return _count_paths(graph, you, out)
|
|
|
|
|
|
def part_2(input_data: str) -> Any:
|
|
ids, graph = _parse_input(input_data)
|
|
|
|
svr = ids["svr"]
|
|
out = ids["out"]
|
|
|
|
dac = ids["dac"]
|
|
fft = ids["fft"]
|
|
|
|
svr_dac = _count_paths(graph, svr, dac, {out, fft})
|
|
svr_fft = _count_paths(graph, svr, fft, {out, dac})
|
|
|
|
dac_fft = _count_paths(graph, dac, fft, {out})
|
|
fft_dac = _count_paths(graph, fft, dac, {out})
|
|
|
|
dac_out = _count_paths(graph, dac, out, {fft})
|
|
fft_out = _count_paths(graph, fft, out, {dac})
|
|
|
|
return svr_dac * dac_fft * fft_out + svr_fft * fft_dac * dac_out
|
|
|
|
|
|
def test_part_1(example_data):
|
|
assert part_1(example_data) == 5
|
|
|
|
|
|
def test_part_2(example_data):
|
|
assert part_2(example_data) == 1
|