advent-of-code-2019/day20/py/main.py

192 lines
7.8 KiB
Python

import networkx
import matplotlib.pyplot as plt
import itertools
import sys
from typing import List, Tuple, Optional
START_NODE_LABEL = 'AA'
END_NODE_LABEL = 'ZZ'
# Connect any adjacent nodes in the maze
def connect_adjacent_nodes(graph: networkx.Graph):
for node in graph.nodes:
row, col = node
for d_row, d_col in ((0, 1), (1, 0), (0, -1), (-1, 0)):
neighbor_candidate = (row + d_row, col + d_col)
if neighbor_candidate in graph.nodes:
graph.add_edge(node, neighbor_candidate)
# Check if the node is on the outer edge
def is_outer_node(graph: networkx.Graph, node: Tuple[int, int]) -> bool:
min_row = min(graph.nodes, key=lambda x: x[0])[0]
max_row = max(graph.nodes, key=lambda x: x[0])[0]
min_col = min(graph.nodes, key=lambda x: x[1])[1]
max_col = max(graph.nodes, key=lambda x: x[1])[1]
return node[0] in (min_row, max_row) or node[1] in (min_col, max_col)
def label_portal_nodes(graph: networkx.Graph, input_lines: List[str]):
for node in graph.nodes:
row, col = node
for d_row, d_col in ((0, 1), (1, 0), (0, -1), (-1, 0)):
portal_id = input_lines[row + d_row][col + d_col] + input_lines[row + 2 * d_row][col + 2 * d_col]
if not portal_id.isalpha():
continue
# We will read nodes backwards if we are going in a negative direction (i.e. up or to the left)
if d_row < 0 or d_col < 0:
portal_id = ''.join(reversed(portal_id))
networkx.set_node_attributes(graph, {node: portal_id}, 'label')
networkx.set_node_attributes(graph, {node: is_outer_node(graph, node)}, 'outer')
# Add an edge between each portal of the same label
def connect_portal_nodes(graph: networkx.Graph):
known_portals = {}
for node, label in graph.nodes.data('label'):
other_end = known_portals.get(label)
if other_end is None:
known_portals[label] = node
else:
graph.add_edge(node, other_end, distance=1)
# Make the maze from the given input
def make_graph_from_input_lines(input_lines: List[str]) -> networkx.Graph:
OPEN_CHAR = '.'
graph = networkx.Graph()
for row, line in enumerate(input_lines):
for col, char in enumerate(line):
if char != OPEN_CHAR:
continue
graph.add_node((row, col))
connect_adjacent_nodes(graph)
label_portal_nodes(graph, input_lines)
reduced_graph = reduce_graph_to_labelled_nodes(graph)
connect_portal_nodes(reduced_graph)
return reduced_graph
def reduce_graph_to_labelled_nodes(graph: networkx.Graph) -> networkx.Graph:
reduced_graph = networkx.subgraph(graph, (node for node, label in graph.nodes.data('label') if label is not None))
reduced_graph = reduced_graph.copy()
for node1, node2 in itertools.combinations(reduced_graph.nodes, 2):
try:
path = networkx.shortest_path(graph, node1, node2)
except networkx.NetworkXNoPath:
continue
distance = len(path) - 1
reduced_graph.add_edge(node1, node2, distance=distance)
return reduced_graph
# Given a portal node, find a node with the same label (and is thus the other end of the portal)
def get_opposite_end_of_portal_node(graph: networkx.Graph, node: Tuple[int, int]) -> Tuple[int, int]:
for candidate_node in graph.nodes:
if graph.nodes[candidate_node]['label'] == graph.nodes[node]['label'] and candidate_node != node:
return candidate_node
else:
raise ValueError(f"Given node {graph.nodes[node]['label']} has no opposite")
def part1(graph: networkx.Graph) -> int:
start_node = next(node for node, label in graph.nodes.data('label') if label == START_NODE_LABEL)
end_node = next(node for node, label in graph.nodes.data('label') if label == END_NODE_LABEL)
path = networkx.shortest_path(graph, start_node, end_node)
return sum(graph.edges[(node1, node2)]['distance'] for node1, node2 in zip(path, path[1:]))
def part2(graph: networkx.Graph) -> int:
start_node = next(node for node, label in graph.nodes.data('label') if label == START_NODE_LABEL)
end_node = next(node for node, label in graph.nodes.data('label') if label == END_NODE_LABEL)
def get_neighbors_on_outer_edge(node: Tuple[int, int], depth: int) -> List[Tuple[int, int]]:
neighbors = []
for neighbor in graph.neighbors(node):
if not graph.nodes[neighbor]['outer']:
continue
# If the depth is zero, we only want to return AA and ZZ (if they are neighbors)
# If the depth is greater than zero, we want a node if it is any neighbor except for AA and ZZ
if ((depth == 0 and neighbor in (start_node, end_node))
or (depth > 0 and neighbor not in (start_node, end_node))):
neighbors.append(neighbor)
return neighbors
# to_visit, visited, and distances hold tuples of (node, level)
to_visit = [(start_node, 0)]
visited = set()
distances = {(start_node, 0): 0}
while len(to_visit) > 0:
node, depth = to_visit.pop()
if node == end_node and depth <= 0:
# Need to subtract 1 to account for the fake "portal" we go through from ZZ to ZZ
return distances[(node, depth)] - 1
visited.add((node, depth))
outer_neighbors = get_neighbors_on_outer_edge(node, depth)
inner_neighbors = [candidate_node for candidate_node in graph.neighbors(node)
if not graph.nodes[candidate_node]['outer']]
# Go to the outer neighbors first to prioritize getting out
for neighbor in outer_neighbors + inner_neighbors:
# We don't want to consider nodes that have already been visited, or other portal ends
if graph.nodes[neighbor]['label'] == graph.nodes[node]['label']:
continue
elif (neighbor, depth) in visited:
continue
# Find the opposite end of the portal, which is any node with the same label.
# If it's the end node, we consider it to be a portal to itself (articialially giving it a depth of 1)
if neighbor == end_node:
opposite_end = end_node
else:
opposite_end = get_opposite_end_of_portal_node(graph, neighbor)
next_depth = depth - 1 if graph.nodes[neighbor]['outer'] else depth + 1
to_explore = (opposite_end, next_depth)
distance_to_neighbor = distances[(node, depth)] + graph.edges[(node, neighbor)]['distance']
old_distance = distances.get((neighbor, depth))
# If we have found a distance to this node that's better than something we've already explored (or we don't
# have one, store it as the distance.
if old_distance is None or old_distance > distance_to_neighbor:
distances[(neighbor, depth)] = distance_to_neighbor
distances[to_explore] = distances[(neighbor, depth)] + 1
if to_explore not in visited:
to_visit.insert(0, to_explore)
# A debug function used to visualize the drawn graph
def draw_graph(graph: networkx.Graph) -> None:
positions = networkx.spring_layout(graph)
networkx.draw_networkx_nodes(graph, positions)
networkx.draw_networkx_edges(graph, positions)
networkx.draw_networkx_labels(graph, positions, {node: label for node, label in graph.nodes.data('label')})
networkx.draw_networkx_edge_labels(graph, positions, {(node1, node2): distance for node1, node2, distance in graph.edges.data('distance')})
plt.show()
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: ./main.py in_file")
sys.exit(1)
with open(sys.argv[1]) as f:
input_lines = [line.rstrip('\n') for line in f.readlines()]
graph = make_graph_from_input_lines(input_lines)
print(part1(graph))
print(part2(graph))