208 lines
7.4 KiB
Python
208 lines
7.4 KiB
Python
import gzip
|
|
|
|
from lxml import etree
|
|
import networkx as nx
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.cm as cm
|
|
import matplotlib.colors as colors
|
|
from matplotlib.animation import FuncAnimation
|
|
from collections import defaultdict
|
|
|
|
|
|
class MatsimVisualizer:
|
|
def __init__(self, input_file_path, output_file_path, viewport, show_facilities=False, show_nearby_nodes=False):
|
|
self._nodes = {}
|
|
self._links = []
|
|
self._pos = None
|
|
self.norm = None
|
|
|
|
self._max_x = None
|
|
self._min_x = None
|
|
self._max_y = None
|
|
self._min_y = None
|
|
self._viewport = viewport
|
|
|
|
self._output_file_path = output_file_path
|
|
self._network_file_path = f"{input_file_path}/output_network.xml.gz"
|
|
self._events_file_path = f"{input_file_path}/output_events.xml.gz"
|
|
self._facilities_file_path = f"{input_file_path}/output_facilities.xml.gz"
|
|
|
|
self._G = nx.Graph()
|
|
self._traffic_per_tick = defaultdict(lambda: defaultdict(int))
|
|
self._close_to_facility_nodes = set()
|
|
|
|
self._tick = 0
|
|
self._max_traffic = 0
|
|
self._max_traffic_link = None
|
|
self._max_traffic_tick = None
|
|
|
|
self._show_facilities = show_facilities
|
|
self._show_nearby_nodes = show_nearby_nodes
|
|
|
|
def _load_data(self):
|
|
# ====== LOAD FACILITIES ====== #
|
|
if self._show_facilities:
|
|
with gzip.open(self._facilities_file_path, 'rb') as file:
|
|
facilities_doc = etree.parse(file)
|
|
|
|
for node in facilities_doc.xpath('/facilities/facility'):
|
|
x = float(node.get('x'))
|
|
y = float(node.get('y'))
|
|
if self._viewport[0] < x < self._viewport[2] and self._viewport[1] > y > self._viewport[3]:
|
|
self._nodes[f"facility_{node.get('id')}"] = (x, y)
|
|
|
|
# ====== LOAD NETWORK ====== #
|
|
with gzip.open(self._network_file_path, 'rb') as file:
|
|
network_doc = etree.parse(file)
|
|
|
|
for node in network_doc.xpath('/network/nodes/node'):
|
|
x = float(node.get('x'))
|
|
y = float(node.get('y'))
|
|
|
|
if self._max_x is None or x > self._max_x:
|
|
self._max_x = x
|
|
|
|
if self._min_x is None or x < self._min_x:
|
|
self._min_x = x
|
|
|
|
if self._max_y is None or y > self._max_y:
|
|
self._max_y = y
|
|
|
|
if self._min_y is None or y < self._min_y:
|
|
self._min_y = y
|
|
|
|
if self._viewport[0] < x < self._viewport[2] and self._viewport[1] > y > self._viewport[3]:
|
|
self._nodes[node.get('id')] = (x, y)
|
|
|
|
for link in network_doc.xpath('/network/links/link'):
|
|
start_node = link.get('from')
|
|
end_node = link.get('to')
|
|
id = link.get('id')
|
|
if start_node in self._nodes and end_node in self._nodes:
|
|
self._links.append({
|
|
'id': id,
|
|
'from': start_node,
|
|
'to': end_node,
|
|
'vehicles': 0,
|
|
})
|
|
|
|
cumulative_traffic = defaultdict(int)
|
|
|
|
# ====== LOAD EVENTS ====== #
|
|
with gzip.open(self._events_file_path, 'rb') as file:
|
|
events_doc = etree.parse(file)
|
|
|
|
current_time = None
|
|
self._tick = 0
|
|
for event in events_doc.xpath('/events/event'):
|
|
link_id = event.get('link')
|
|
event_type = event.get('type')
|
|
time = float(event.get('time'))
|
|
if current_time is None:
|
|
current_time = float(event.get('time'))
|
|
ticked = False
|
|
|
|
if link_id is not None and event_type is not None and time is not None and any(link['id'] == link_id for link in self._links):
|
|
if event_type in ['entered link', 'vehicle enters traffic']:
|
|
self._traffic_per_tick[self._tick][link_id] += 1
|
|
ticked = True
|
|
|
|
cumulative_traffic[link_id] += 1
|
|
# We need to find the maximum value for traffic at any given point for the heatmap
|
|
if self._max_traffic < cumulative_traffic[link_id]:
|
|
self._max_traffic = cumulative_traffic[link_id]
|
|
self._max_traffic_link = link_id
|
|
self._max_traffic_tick = self._tick
|
|
elif event_type in ['left link', 'vehicle leaves traffic']:
|
|
self._traffic_per_tick[self._tick][link_id] -= 1
|
|
cumulative_traffic[link_id] -= 1
|
|
ticked = True
|
|
|
|
if ticked and current_time != time:
|
|
self._tick += 1
|
|
current_time = time
|
|
|
|
def _create_graph(self):
|
|
for node_id, coords in self._nodes.items():
|
|
self._G.add_node(node_id, pos=coords, is_facility=node_id.startswith('facility_'))
|
|
for link in self._links:
|
|
self._G.add_edge(link['from'], link['to'], id=link['id'])
|
|
self._pos = nx.get_node_attributes(self._G, 'pos')
|
|
|
|
if self._show_nearby_nodes:
|
|
self._identify_close_nodes(radius=100) # TODO: move to load data
|
|
|
|
def _identify_close_nodes(self, radius):
|
|
facility_coords = {node: self._pos[node] for node in self._G.nodes if self._G.nodes[node].get('is_facility')}
|
|
|
|
for node, coords in self._pos.items():
|
|
if self._G.nodes[node].get('is_facility'):
|
|
continue
|
|
|
|
for facility_node, facility_coord in facility_coords.items():
|
|
if self._euclidean_distance(coords, facility_coord) <= radius:
|
|
self._close_to_facility_nodes.add(node)
|
|
break
|
|
|
|
def _setup_color_mapping(self):
|
|
self.norm = colors.Normalize(vmin=0, vmax=self._max_traffic/2)
|
|
|
|
def _update(self, frame):
|
|
traffic_change = self._traffic_per_tick[self._tick]
|
|
|
|
edge_colors_dict = {}
|
|
edge_widths_dict = {}
|
|
|
|
for link in self._links:
|
|
for link_id, change in traffic_change.items():
|
|
if link_id == link['id']:
|
|
link['vehicles'] += change
|
|
|
|
edge_colors_dict[link['id']] = self._cmap(self.norm(link['vehicles']))
|
|
edge_widths_dict[link['id']] = 1 + self.norm(link['vehicles'])
|
|
|
|
edgelist = [(link['from'], link['to']) for link in self._links]
|
|
edge_colors = [edge_colors_dict[link["id"]] for link in self._links]
|
|
edge_widths = [edge_widths_dict[link["id"]] for link in self._links]
|
|
|
|
plt.cla()
|
|
|
|
# Draw Facilities
|
|
if self._show_facilities:
|
|
facility_nodes = {node: pos for node, pos in self._pos.items() if self._G.nodes[node]['is_facility']}
|
|
nx.draw_networkx_nodes(self._G, facility_nodes, facility_nodes.keys(), node_size=1, node_color='red')
|
|
|
|
# Draw Network
|
|
if self._show_nearby_nodes:
|
|
close_nodes = {node: self._pos[node] for node in self._close_to_facility_nodes if node in self._pos}
|
|
nx.draw_networkx_nodes(self._G, self._pos, close_nodes.keys(), node_size=1, node_color='blue')
|
|
nx.draw_networkx_edges(self._G, self._pos, edgelist=edgelist, width=edge_widths, edge_color=edge_colors, edge_cmap=self._cmap)
|
|
plt.title(f"Time: {self._tick}")
|
|
self._tick += 1
|
|
|
|
def visualize(self, fps=15, colormap='inferno'):
|
|
self._load_data()
|
|
self._create_graph()
|
|
self._setup_color_mapping()
|
|
|
|
fig, ax = plt.subplots()
|
|
ax.set_aspect((self._max_x - self._min_x) / (self._max_y - self._min_y))
|
|
|
|
if hasattr(cm, colormap):
|
|
self._cmap = getattr(cm, colormap)
|
|
else:
|
|
print(f"Colormap '{colormap}' not recognized. Falling back to 'inferno'. Please select from: 'inferno', 'plasma', 'magma', 'spring', 'summer', 'autumn', or 'winter'.")
|
|
self._cmap = cm.inferno
|
|
|
|
sm = plt.cm.ScalarMappable(cmap=self._cmap, norm=self.norm)
|
|
sm.set_array([])
|
|
plt.colorbar(sm, ax=ax, label='Traffic Density')
|
|
|
|
self._tick = 0
|
|
ani = FuncAnimation(fig, self._update, frames=len(self._traffic_per_tick), repeat=False)
|
|
ani.save(f"{self._output_file_path}/traffic_animation.gif", writer='ffmpeg', fps=fps)
|
|
plt.show()
|
|
|
|
def _euclidean_distance(self, coord1, coord2):
|
|
return ((coord1[0] - coord2[0]) ** 2 + (coord1[1] - coord2[1]) ** 2) ** 0.5
|