matsim-proto/matsim_visualizer.py

132 lines
4.1 KiB
Python
Raw Normal View History

import gzip
from lxml import etree
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from collections import defaultdict
import matplotlib.cm as cm
import matplotlib.colors as colors
2024-02-14 16:54:41 -05:00
class MatsimVisualizer():
2024-02-16 11:47:55 -05:00
def __init__(self, network_file_path, events_file_path, output_file_path, viewport):
self._nodes = {}
self._links = []
2024-02-14 16:54:41 -05:00
self._pos = None
self.norm = None
2024-02-16 11:47:55 -05:00
self._viewport = viewport
2024-02-14 16:54:41 -05:00
self._output_file_path = output_file_path
self._network_file_path = network_file_path
self._events_file_path = events_file_path
self._G = nx.Graph()
self._traffic_per_tick = defaultdict(lambda: defaultdict(int))
self._cmap = cm.viridis
self._tick = 0
self._max_traffic = 0
def load_data(self):
# ====== LOAD NETWORK ====== #
2024-02-14 16:54:41 -05:00
with gzip.open(self._network_file_path, 'rb') as file:
network_doc = etree.parse(file)
2024-02-16 11:47:55 -05:00
for node in network_doc.xpath('/network/nodes/node'):
x = float(node.get('x'))
y = float(node.get('y'))
if self._viewport[0] < x < self._viewport[2] and self._viewport[1] < y < self._viewport[3]:
self._nodes[node.get('id')] = (x, y)
for link in network_doc.xpath('/network/links/link'):
start_node = link.get('from')
end_node = link.get('to')
id = link.get('id')
if start_node in self._nodes and end_node in self._nodes:
self._links.append({
'id': id,
'from': start_node,
'to': end_node,
'vehicles': 0,
})
cumulative_traffic = defaultdict(int)
# ====== LOAD EVENTS ====== #
2024-02-14 16:54:41 -05:00
with gzip.open(self._events_file_path, 'rb') as file:
events_doc = etree.parse(file)
self._tick = 0
for event in events_doc.xpath('/events/event'):
link_id = event.get('link')
event_type = event.get('type')
time = float(event.get('time'))
ticked = False
if link_id is not None and event_type is not None and time is not None:
if event_type in ['entered link', 'vehicle enters traffic']:
self._traffic_per_tick[self._tick][link_id] += 1
2024-02-16 11:47:55 -05:00
ticked = True
cumulative_traffic[link_id] += 1
2024-02-16 11:47:55 -05:00
# We need to find the maximum value for traffic at any given point for the heatmap
if self._max_traffic < cumulative_traffic[link_id]:
self._max_traffic = cumulative_traffic[link_id]
elif event_type in ['left link', 'vehicle leaves traffic']:
self._traffic_per_tick[self._tick][link_id] -= 1
cumulative_traffic[link_id] -= 1
ticked = True
2024-02-16 11:47:55 -05:00
if ticked:
self._tick += 1
def create_graph(self):
2024-02-14 16:54:41 -05:00
for node_id, coords in self._nodes.items():
self._G.add_node(node_id, pos=coords)
for link in self._links:
self._G.add_edge(link['from'], link['to'])
self._pos = nx.get_node_attributes(self._G, 'pos')
def setup_color_mapping(self):
self.norm = colors.Normalize(vmin=0, vmax=self._max_traffic)
2024-02-16 11:47:55 -05:00
def update(self, frame):
traffic_change = self._traffic_per_tick[self._tick]
edge_colors = []
edge_widths = []
for link in self._links:
for link_id, change in traffic_change.items():
if link_id == link['id']:
link['vehicles'] += change
break
edge_colors.append(self._cmap(link['vehicles']))
2024-02-16 11:47:55 -05:00
edge_widths.append(1 + self.norm(link['vehicles']) * 10)
plt.cla()
2024-02-14 16:54:41 -05:00
nx.draw(self._G, self._pos, node_size=0, node_color='blue', width=edge_widths, edge_color=edge_colors,
with_labels=False,
edge_cmap=self._cmap)
plt.title(f"Time: {self._tick}")
self._tick += 1
def visualize(self):
self.load_data()
self.create_graph()
self.setup_color_mapping()
fig, ax = plt.subplots()
2024-02-14 16:54:41 -05:00
sm = plt.cm.ScalarMappable(cmap=self._cmap, norm=self.norm)
sm.set_array([])
plt.colorbar(sm, ax=ax, label='Traffic Density')
self._tick = 0
ani = FuncAnimation(fig, self.update, frames=len(self._traffic_per_tick), repeat=False)
2024-02-14 16:54:41 -05:00
ani.save(f"{self._output_file_path}/traffic_animation.gif", writer='ffmpeg', fps=5)
plt.show()