import gzip from lxml import etree import networkx as nx import matplotlib.pyplot as plt import matplotlib.cm as cm import matplotlib.colors as colors from matplotlib.animation import FuncAnimation from collections import defaultdict class MatsimVisualizer: def __init__(self, input_file_path, output_file_path, viewport, show_facilities=False, show_nearby_nodes=False): self._nodes = {} self._links = [] self._pos = None self.norm = None self._max_x = None self._min_x = None self._max_y = None self._min_y = None self._viewport = viewport self._output_file_path = output_file_path self._network_file_path = f"{input_file_path}/output_network.xml.gz" self._events_file_path = f"{input_file_path}/output_events.xml.gz" self._facilities_file_path = f"{input_file_path}/output_facilities.xml.gz" self._G = nx.Graph() self._traffic_per_tick = defaultdict(lambda: defaultdict(int)) self._close_to_facility_nodes = set() self._tick = 0 self._max_traffic = 0 self._max_traffic_link = None self._max_traffic_tick = None self._show_facilities = show_facilities self._show_nearby_nodes = show_nearby_nodes def _load_data(self): # ====== LOAD FACILITIES ====== # if self._show_facilities: with gzip.open(self._facilities_file_path, 'rb') as file: facilities_doc = etree.parse(file) for node in facilities_doc.xpath('/facilities/facility'): x = float(node.get('x')) y = float(node.get('y')) if self._viewport[0] < x < self._viewport[2] and self._viewport[1] > y > self._viewport[3]: self._nodes[f"facility_{node.get('id')}"] = (x, y) # ====== LOAD NETWORK ====== # with gzip.open(self._network_file_path, 'rb') as file: network_doc = etree.parse(file) for node in network_doc.xpath('/network/nodes/node'): x = float(node.get('x')) y = float(node.get('y')) if self._max_x is None or x > self._max_x: self._max_x = x if self._min_x is None or x < self._min_x: self._min_x = x if self._max_y is None or y > self._max_y: self._max_y = y if self._min_y is None or y < self._min_y: self._min_y = y if self._viewport[0] < x < self._viewport[2] and self._viewport[1] > y > self._viewport[3]: self._nodes[node.get('id')] = (x, y) for link in network_doc.xpath('/network/links/link'): start_node = link.get('from') end_node = link.get('to') id = link.get('id') if start_node in self._nodes and end_node in self._nodes: self._links.append({ 'id': id, 'from': start_node, 'to': end_node, 'vehicles': 0, }) cumulative_traffic = defaultdict(int) # ====== LOAD EVENTS ====== # with gzip.open(self._events_file_path, 'rb') as file: events_doc = etree.parse(file) current_time = None self._tick = 0 for event in events_doc.xpath('/events/event'): link_id = event.get('link') event_type = event.get('type') time = float(event.get('time')) if current_time is None: current_time = float(event.get('time')) ticked = False if link_id is not None and event_type is not None and time is not None and any(link['id'] == link_id for link in self._links): if event_type in ['entered link', 'vehicle enters traffic']: self._traffic_per_tick[self._tick][link_id] += 1 ticked = True cumulative_traffic[link_id] += 1 # We need to find the maximum value for traffic at any given point for the heatmap if self._max_traffic < cumulative_traffic[link_id]: self._max_traffic = cumulative_traffic[link_id] self._max_traffic_link = link_id self._max_traffic_tick = self._tick elif event_type in ['left link', 'vehicle leaves traffic']: self._traffic_per_tick[self._tick][link_id] -= 1 cumulative_traffic[link_id] -= 1 ticked = True if ticked and current_time != time: self._tick += 1 current_time = time def _create_graph(self): for node_id, coords in self._nodes.items(): self._G.add_node(node_id, pos=coords, is_facility=node_id.startswith('facility_')) for link in self._links: self._G.add_edge(link['from'], link['to'], id=link['id']) self._pos = nx.get_node_attributes(self._G, 'pos') if self._show_nearby_nodes: self._identify_close_nodes(radius=100) # TODO: move to load data def _identify_close_nodes(self, radius): facility_coords = {node: self._pos[node] for node in self._G.nodes if self._G.nodes[node].get('is_facility')} for node, coords in self._pos.items(): if self._G.nodes[node].get('is_facility'): continue for facility_node, facility_coord in facility_coords.items(): if self._euclidean_distance(coords, facility_coord) <= radius: self._close_to_facility_nodes.add(node) break def _setup_color_mapping(self): self.norm = colors.Normalize(vmin=0, vmax=self._max_traffic/2) def _update(self, frame): traffic_change = self._traffic_per_tick[self._tick] edge_colors_dict = {} edge_widths_dict = {} for link in self._links: for link_id, change in traffic_change.items(): if link_id == link['id']: link['vehicles'] += change edge_colors_dict[link['id']] = self._cmap(self.norm(link['vehicles'])) edge_widths_dict[link['id']] = 1 + self.norm(link['vehicles']) edgelist = [(link['from'], link['to']) for link in self._links] edge_colors = [edge_colors_dict[link["id"]] for link in self._links] edge_widths = [edge_widths_dict[link["id"]] for link in self._links] plt.cla() # Draw Facilities if self._show_facilities: facility_nodes = {node: pos for node, pos in self._pos.items() if self._G.nodes[node]['is_facility']} nx.draw_networkx_nodes(self._G, facility_nodes, facility_nodes.keys(), node_size=1, node_color='red') # Draw Network if self._show_nearby_nodes: close_nodes = {node: self._pos[node] for node in self._close_to_facility_nodes if node in self._pos} nx.draw_networkx_nodes(self._G, self._pos, close_nodes.keys(), node_size=1, node_color='blue') nx.draw_networkx_edges(self._G, self._pos, edgelist=edgelist, width=edge_widths, edge_color=edge_colors, edge_cmap=self._cmap) plt.title(f"Time: {self._tick}") self._tick += 1 def visualize(self, fps=15, colormap='inferno'): self._load_data() self._create_graph() self._setup_color_mapping() fig, ax = plt.subplots() ax.set_aspect((self._max_x - self._min_x) / (self._max_y - self._min_y)) if hasattr(cm, colormap): self._cmap = getattr(cm, colormap) else: print(f"Colormap '{colormap}' not recognized. Falling back to 'inferno'. Please select from: 'inferno', 'plasma', 'magma', 'spring', 'summer', 'autumn', or 'winter'.") self._cmap = cm.inferno sm = plt.cm.ScalarMappable(cmap=self._cmap, norm=self.norm) sm.set_array([]) plt.colorbar(sm, ax=ax, label='Traffic Density') self._tick = 0 ani = FuncAnimation(fig, self._update, frames=len(self._traffic_per_tick), repeat=False) ani.save(f"{self._output_file_path}/traffic_animation.gif", writer='ffmpeg', fps=fps) plt.show() def _euclidean_distance(self, coord1, coord2): return ((coord1[0] - coord2[0]) ** 2 + (coord1[1] - coord2[1]) ** 2) ** 0.5