import math from collections.abc import Iterable from typing import Any, Self Point = tuple[int, int, int] def _parse_input(input_data: str) -> list[Point]: points = [] for line in input_data.splitlines(): a, b, c, *_ = list(map(int, line.split(","))) points.append((a, b, c)) return points def _get_sorted_pairs(points: list[Point]) -> list[tuple[int, Point, Point]]: pairs: list[tuple[int, Point, Point]] = [] for idx, p1 in enumerate(points): x1, y1, z1 = p1 for p2 in points[idx + 1 :]: x2, y2, z2 = p2 dist = (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2 pairs.append((dist, p1, p2)) pairs.sort() return pairs class UnionFind: def __init__(self): # internal counter for ids self._component_id = 0 # component id to size mapper self.component_size: dict[int, int] = {} # point to id mapper self.points: dict[Point, int] = {} # point to parent point mapper self.parent: dict[Point, Point] = {} @classmethod def from_points(cls, points: Iterable[Point]) -> Self: c = cls() for p in points: c.insert_point(p) return c def _issue_component_id(self) -> int: self._component_id += 1 return self._component_id def _get_parent(self, p: Point) -> Point: if self.parent[p] == p: return p parent = self._get_parent(self.parent[p]) self.parent[p] = parent self.points[p] = self.points[parent] return parent def _update_parent(self, point: Point, parent: Point): ppoint = self._get_parent(point) pparent = self._get_parent(parent) self.parent[ppoint] = pparent self.points[pparent] = self.points[pparent] def insert_point(self, p: Point) -> None: pid = self._issue_component_id() assert p not in self.points self.points[p] = pid self.parent[p] = p self.component_size[pid] = 1 def connect(self, p1: Point, p2: Point): if p1 not in self.points: self.insert_point(p1) if p2 not in self.points: self.insert_point(p2) pid1 = self.points[self._get_parent(p1)] pid2 = self.points[self._get_parent(p2)] if pid1 == pid2: return # swap components so we can assume component pid1 is larger if self.component_size[pid1] < self.component_size[pid2]: p1, p2 = p2, p1 pid1, pid2 = pid2, pid1 self.component_size[pid1] += self.component_size[pid2] self.component_size.pop(pid2) self._update_parent(p2, p1) def part_1(input_data: str) -> Any: points = _parse_input(input_data) pairs = _get_sorted_pairs(points) g = UnionFind.from_points(points) for _, p1, p2 in pairs[:1000]: g.connect(p1, p2) return math.prod(sorted(g.component_size.values())[-3:]) def part_2(input_data: str) -> Any: points = _parse_input(input_data) pairs = _get_sorted_pairs(points) g = UnionFind.from_points(points) for _, p1, p2 in pairs: g.connect(p1, p2) if len(g.component_size) == 1: return p1[0] * p2[0] return None def test_part_1(example_data): assert part_1(example_data) == 20 def test_part_2(example_data): assert part_2(example_data) == 25272