matsim-proto/matsim_visualizer.py

208 lines
7.4 KiB
Python

import gzip
from lxml import etree
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
from matplotlib.animation import FuncAnimation
from collections import defaultdict
class MatsimVisualizer:
def __init__(self, input_file_path, output_file_path, viewport, show_facilities=False, show_nearby_nodes=False):
self._nodes = {}
self._links = []
self._pos = None
self.norm = None
self._max_x = None
self._min_x = None
self._max_y = None
self._min_y = None
self._viewport = viewport
self._output_file_path = output_file_path
self._network_file_path = f"{input_file_path}/output_network.xml.gz"
self._events_file_path = f"{input_file_path}/output_events.xml.gz"
self._facilities_file_path = f"{input_file_path}/output_facilities.xml.gz"
self._G = nx.Graph()
self._traffic_per_tick = defaultdict(lambda: defaultdict(int))
self._close_to_facility_nodes = set()
self._tick = 0
self._max_traffic = 0
self._max_traffic_link = None
self._max_traffic_tick = None
self._show_facilities = show_facilities
self._show_nearby_nodes = show_nearby_nodes
def _load_data(self):
# ====== LOAD FACILITIES ====== #
if self._show_facilities:
with gzip.open(self._facilities_file_path, 'rb') as file:
facilities_doc = etree.parse(file)
for node in facilities_doc.xpath('/facilities/facility'):
x = float(node.get('x'))
y = float(node.get('y'))
if self._viewport[0] < x < self._viewport[2] and self._viewport[1] > y > self._viewport[3]:
self._nodes[f"facility_{node.get('id')}"] = (x, y)
# ====== LOAD NETWORK ====== #
with gzip.open(self._network_file_path, 'rb') as file:
network_doc = etree.parse(file)
for node in network_doc.xpath('/network/nodes/node'):
x = float(node.get('x'))
y = float(node.get('y'))
if self._max_x is None or x > self._max_x:
self._max_x = x
if self._min_x is None or x < self._min_x:
self._min_x = x
if self._max_y is None or y > self._max_y:
self._max_y = y
if self._min_y is None or y < self._min_y:
self._min_y = y
if self._viewport[0] < x < self._viewport[2] and self._viewport[1] > y > self._viewport[3]:
self._nodes[node.get('id')] = (x, y)
for link in network_doc.xpath('/network/links/link'):
start_node = link.get('from')
end_node = link.get('to')
id = link.get('id')
if start_node in self._nodes and end_node in self._nodes:
self._links.append({
'id': id,
'from': start_node,
'to': end_node,
'vehicles': 0,
})
cumulative_traffic = defaultdict(int)
# ====== LOAD EVENTS ====== #
with gzip.open(self._events_file_path, 'rb') as file:
events_doc = etree.parse(file)
current_time = None
self._tick = 0
for event in events_doc.xpath('/events/event'):
link_id = event.get('link')
event_type = event.get('type')
time = float(event.get('time'))
if current_time is None:
current_time = float(event.get('time'))
ticked = False
if link_id is not None and event_type is not None and time is not None and any(link['id'] == link_id for link in self._links):
if event_type in ['entered link', 'vehicle enters traffic']:
self._traffic_per_tick[self._tick][link_id] += 1
ticked = True
cumulative_traffic[link_id] += 1
# We need to find the maximum value for traffic at any given point for the heatmap
if self._max_traffic < cumulative_traffic[link_id]:
self._max_traffic = cumulative_traffic[link_id]
self._max_traffic_link = link_id
self._max_traffic_tick = self._tick
elif event_type in ['left link', 'vehicle leaves traffic']:
self._traffic_per_tick[self._tick][link_id] -= 1
cumulative_traffic[link_id] -= 1
ticked = True
if ticked and current_time != time:
self._tick += 1
current_time = time
def _create_graph(self):
for node_id, coords in self._nodes.items():
self._G.add_node(node_id, pos=coords, is_facility=node_id.startswith('facility_'))
for link in self._links:
self._G.add_edge(link['from'], link['to'], id=link['id'])
self._pos = nx.get_node_attributes(self._G, 'pos')
if self._show_nearby_nodes:
self._identify_close_nodes(radius=100) # TODO: move to load data
def _identify_close_nodes(self, radius):
facility_coords = {node: self._pos[node] for node in self._G.nodes if self._G.nodes[node].get('is_facility')}
for node, coords in self._pos.items():
if self._G.nodes[node].get('is_facility'):
continue
for facility_node, facility_coord in facility_coords.items():
if self._euclidean_distance(coords, facility_coord) <= radius:
self._close_to_facility_nodes.add(node)
break
def _setup_color_mapping(self):
self.norm = colors.Normalize(vmin=0, vmax=self._max_traffic/2)
def _update(self, frame):
traffic_change = self._traffic_per_tick[self._tick]
edge_colors_dict = {}
edge_widths_dict = {}
for link in self._links:
for link_id, change in traffic_change.items():
if link_id == link['id']:
link['vehicles'] += change
edge_colors_dict[link['id']] = self._cmap(self.norm(link['vehicles']))
edge_widths_dict[link['id']] = 1 + self.norm(link['vehicles'])
edgelist = [(link['from'], link['to']) for link in self._links]
edge_colors = [edge_colors_dict[link["id"]] for link in self._links]
edge_widths = [edge_widths_dict[link["id"]] for link in self._links]
plt.cla()
# Draw Facilities
if self._show_facilities:
facility_nodes = {node: pos for node, pos in self._pos.items() if self._G.nodes[node]['is_facility']}
nx.draw_networkx_nodes(self._G, facility_nodes, facility_nodes.keys(), node_size=1, node_color='red')
# Draw Network
if self._show_nearby_nodes:
close_nodes = {node: self._pos[node] for node in self._close_to_facility_nodes if node in self._pos}
nx.draw_networkx_nodes(self._G, self._pos, close_nodes.keys(), node_size=1, node_color='blue')
nx.draw_networkx_edges(self._G, self._pos, edgelist=edgelist, width=edge_widths, edge_color=edge_colors, edge_cmap=self._cmap)
plt.title(f"Time: {self._tick}")
self._tick += 1
def visualize(self, fps=15, colormap='inferno'):
self._load_data()
self._create_graph()
self._setup_color_mapping()
fig, ax = plt.subplots()
ax.set_aspect((self._max_x - self._min_x) / (self._max_y - self._min_y))
if hasattr(cm, colormap):
self._cmap = getattr(cm, colormap)
else:
print(f"Colormap '{colormap}' not recognized. Falling back to 'inferno'. Please select from: 'inferno', 'plasma', 'magma', 'spring', 'summer', 'autumn', or 'winter'.")
self._cmap = cm.inferno
sm = plt.cm.ScalarMappable(cmap=self._cmap, norm=self.norm)
sm.set_array([])
plt.colorbar(sm, ax=ax, label='Traffic Density')
self._tick = 0
ani = FuncAnimation(fig, self._update, frames=len(self._traffic_per_tick), repeat=False)
ani.save(f"{self._output_file_path}/traffic_animation.gif", writer='ffmpeg', fps=fps)
plt.show()
def _euclidean_distance(self, coord1, coord2):
return ((coord1[0] - coord2[0]) ** 2 + (coord1[1] - coord2[1]) ** 2) ** 0.5