136 lines
3.4 KiB
Python
136 lines
3.4 KiB
Python
import math
|
|
from collections.abc import Iterable
|
|
from typing import Any, Self
|
|
|
|
Point = tuple[int, int, int]
|
|
|
|
|
|
def _parse_input(input_data: str) -> list[Point]:
|
|
points = []
|
|
for line in input_data.splitlines():
|
|
a, b, c, *_ = list(map(int, line.split(",")))
|
|
points.append((a, b, c))
|
|
|
|
return points
|
|
|
|
|
|
def _get_sorted_pairs(points: list[Point]) -> list[tuple[int, Point, Point]]:
|
|
pairs: list[tuple[int, Point, Point]] = []
|
|
|
|
for idx, p1 in enumerate(points):
|
|
x1, y1, z1 = p1
|
|
for p2 in points[idx + 1 :]:
|
|
x2, y2, z2 = p2
|
|
dist = (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2
|
|
pairs.append((dist, p1, p2))
|
|
|
|
pairs.sort()
|
|
return pairs
|
|
|
|
|
|
class UnionFind:
|
|
def __init__(self):
|
|
# internal counter for ids
|
|
self._component_id = 0
|
|
# component id to size mapper
|
|
self.component_size: dict[int, int] = {}
|
|
# point to id mapper
|
|
self.points: dict[Point, int] = {}
|
|
# point to parent point mapper
|
|
self.parent: dict[Point, Point] = {}
|
|
|
|
@classmethod
|
|
def from_points(cls, points: Iterable[Point]) -> Self:
|
|
c = cls()
|
|
for p in points:
|
|
c.insert_point(p)
|
|
|
|
return c
|
|
|
|
def _issue_component_id(self) -> int:
|
|
self._component_id += 1
|
|
return self._component_id
|
|
|
|
def _get_parent(self, p: Point) -> Point:
|
|
if self.parent[p] == p:
|
|
return p
|
|
|
|
parent = self._get_parent(self.parent[p])
|
|
|
|
self.parent[p] = parent
|
|
self.points[p] = self.points[parent]
|
|
|
|
return parent
|
|
|
|
def _update_parent(self, point: Point, parent: Point):
|
|
ppoint = self._get_parent(point)
|
|
pparent = self._get_parent(parent)
|
|
|
|
self.parent[ppoint] = pparent
|
|
self.points[pparent] = self.points[pparent]
|
|
|
|
def insert_point(self, p: Point) -> None:
|
|
pid = self._issue_component_id()
|
|
|
|
assert p not in self.points
|
|
self.points[p] = pid
|
|
self.parent[p] = p
|
|
|
|
self.component_size[pid] = 1
|
|
|
|
def connect(self, p1: Point, p2: Point):
|
|
if p1 not in self.points:
|
|
self.insert_point(p1)
|
|
|
|
if p2 not in self.points:
|
|
self.insert_point(p2)
|
|
|
|
pid1 = self.points[self._get_parent(p1)]
|
|
pid2 = self.points[self._get_parent(p2)]
|
|
|
|
if pid1 == pid2:
|
|
return
|
|
|
|
# swap components so we can assume component pid1 is larger
|
|
if self.component_size[pid1] < self.component_size[pid2]:
|
|
p1, p2 = p2, p1
|
|
pid1, pid2 = pid2, pid1
|
|
|
|
self.component_size[pid1] += self.component_size[pid2]
|
|
self.component_size.pop(pid2)
|
|
|
|
self._update_parent(p2, p1)
|
|
|
|
|
|
def part_1(input_data: str) -> Any:
|
|
points = _parse_input(input_data)
|
|
pairs = _get_sorted_pairs(points)
|
|
|
|
g = UnionFind.from_points(points)
|
|
for _, p1, p2 in pairs[:1000]:
|
|
g.connect(p1, p2)
|
|
|
|
return math.prod(sorted(g.component_size.values())[-3:])
|
|
|
|
|
|
def part_2(input_data: str) -> Any:
|
|
points = _parse_input(input_data)
|
|
pairs = _get_sorted_pairs(points)
|
|
|
|
g = UnionFind.from_points(points)
|
|
for _, p1, p2 in pairs:
|
|
g.connect(p1, p2)
|
|
|
|
if len(g.component_size) == 1:
|
|
return p1[0] * p2[0]
|
|
|
|
return None
|
|
|
|
|
|
def test_part_1(example_data):
|
|
assert part_1(example_data) == 20
|
|
|
|
|
|
def test_part_2(example_data):
|
|
assert part_2(example_data) == 25272
|