Visualizer cleanup
This commit is contained in:
parent
26f867093f
commit
268ceb5d0f
|
@ -8,28 +8,34 @@ from collections import defaultdict
|
|||
import matplotlib.cm as cm
|
||||
import matplotlib.colors as colors
|
||||
|
||||
|
||||
class MatsimVisualizer():
|
||||
def __init__(self, network_file_path, events_file_path):
|
||||
self.network_file_path = network_file_path
|
||||
self.events_file_path = events_file_path
|
||||
self.G = nx.Graph()
|
||||
self.pos = None
|
||||
self.traffic_per_tick = defaultdict(lambda: defaultdict(int))
|
||||
self.cumulative_traffic = defaultdict(lambda: defaultdict(int))
|
||||
self.cmap = cm.viridis
|
||||
def __init__(self, network_file_path, events_file_path, output_file_path):
|
||||
self._nodes = None
|
||||
self._links = None
|
||||
self._pos = None
|
||||
self.norm = None
|
||||
|
||||
self._output_file_path = output_file_path
|
||||
self._network_file_path = network_file_path
|
||||
self._events_file_path = events_file_path
|
||||
|
||||
self._G = nx.Graph()
|
||||
self._traffic_per_tick = defaultdict(lambda: defaultdict(int))
|
||||
self._cumulative_traffic = defaultdict(lambda: defaultdict(int))
|
||||
self._cmap = cm.viridis
|
||||
|
||||
def load_data(self):
|
||||
# Load network data
|
||||
with gzip.open(self.network_file_path, 'rb') as file:
|
||||
with gzip.open(self._network_file_path, 'rb') as file:
|
||||
network_doc = xmltodict.parse(file.read().decode('utf-8'))
|
||||
|
||||
# Parse nodes
|
||||
self.nodes = {node['@id']: (float(node['@x']), float(node['@y'])) for node in
|
||||
network_doc['network']['nodes']['node']}
|
||||
self._nodes = {node['@id']: (float(node['@x']), float(node['@y'])) for node in
|
||||
network_doc['network']['nodes']['node']}
|
||||
|
||||
# Parse links
|
||||
self.links = [{
|
||||
self._links = [{
|
||||
'id': link['@id'],
|
||||
'from': link['@from'],
|
||||
'to': link['@to']
|
||||
|
@ -38,7 +44,7 @@ class MatsimVisualizer():
|
|||
link_state = defaultdict(list)
|
||||
|
||||
# Load and parse the events file
|
||||
with gzip.open(self.events_file_path, 'rb') as file:
|
||||
with gzip.open(self._events_file_path, 'rb') as file:
|
||||
events_doc = xmltodict.parse(file.read().decode('utf-8'))
|
||||
|
||||
for event in events_doc['events']['event']:
|
||||
|
@ -49,51 +55,52 @@ class MatsimVisualizer():
|
|||
|
||||
if link_id is not None and event_type is not None and tick is not None:
|
||||
if event_type == 'entered link' or event_type == 'vehicle enters traffic':
|
||||
self.traffic_per_tick[tick][link_id] += 1
|
||||
self._traffic_per_tick[tick][link_id] += 1
|
||||
link_state[link_id].append(vehicle_id)
|
||||
elif event_type == 'left link' or event_type == 'vehicle leaves traffic':
|
||||
self.traffic_per_tick[tick][link_id] -= 1
|
||||
self._traffic_per_tick[tick][link_id] -= 1
|
||||
link_state[link_id].remove(vehicle_id)
|
||||
|
||||
for link in self.links:
|
||||
self.cumulative_traffic[0][link['id']] = 0
|
||||
for link in self._links:
|
||||
self._cumulative_traffic[0][link['id']] = 0
|
||||
|
||||
# Accumulate the counts to get the total number of vehicles on each link up to each tick
|
||||
actual_tick = 0
|
||||
sorted_ticks = sorted(self.traffic_per_tick.keys())
|
||||
sorted_ticks = sorted(self._traffic_per_tick.keys())
|
||||
for tick in sorted_ticks:
|
||||
if actual_tick not in self.cumulative_traffic:
|
||||
if actual_tick not in self._cumulative_traffic:
|
||||
# Start with the vehicle counts of the previous tick
|
||||
self.cumulative_traffic[actual_tick] = defaultdict(int, self.cumulative_traffic.get(actual_tick - 1, {}))
|
||||
self._cumulative_traffic[actual_tick] = defaultdict(int, self._cumulative_traffic.get(actual_tick - 1, {}))
|
||||
|
||||
# Apply the changes recorded for the current tick
|
||||
for link_id, change in self.traffic_per_tick[tick].items():
|
||||
self.cumulative_traffic[actual_tick][link_id] += change
|
||||
for link_id, change in self._traffic_per_tick[tick].items():
|
||||
self._cumulative_traffic[actual_tick][link_id] += change
|
||||
|
||||
actual_tick += 1 # Move to the next tick
|
||||
|
||||
def create_graph(self):
|
||||
for node_id, coords in self.nodes.items():
|
||||
self.G.add_node(node_id, pos=coords)
|
||||
for link in self.links:
|
||||
self.G.add_edge(link['from'], link['to'])
|
||||
self.pos = nx.get_node_attributes(self.G, 'pos')
|
||||
for node_id, coords in self._nodes.items():
|
||||
self._G.add_node(node_id, pos=coords)
|
||||
for link in self._links:
|
||||
self._G.add_edge(link['from'], link['to'])
|
||||
self._pos = nx.get_node_attributes(self._G, 'pos')
|
||||
|
||||
def setup_color_mapping(self):
|
||||
# Find max traffic to setup the normalization instance
|
||||
max_traffic = max(max(self.cumulative_traffic[tick].values()) for tick in self.cumulative_traffic)
|
||||
max_traffic = max(max(self._cumulative_traffic[tick].values()) for tick in self._cumulative_traffic)
|
||||
self.norm = colors.Normalize(vmin=0, vmax=max_traffic)
|
||||
|
||||
def update(self, frame_number):
|
||||
tick = sorted(self.cumulative_traffic.keys())[frame_number]
|
||||
traffic_data = self.cumulative_traffic[tick]
|
||||
tick = sorted(self._cumulative_traffic.keys())[frame_number]
|
||||
traffic_data = self._cumulative_traffic[tick]
|
||||
|
||||
edge_colors = [self.cmap(self.norm(traffic_data.get(link['id'], 0))) for link in self.links]
|
||||
edge_widths = [2 + self.norm(traffic_data.get(link['id'], 0)) * 3 for link in self.links]
|
||||
edge_colors = [self._cmap(self.norm(traffic_data.get(link['id'], 0))) for link in self._links]
|
||||
edge_widths = [10 + self.norm(traffic_data.get(link['id'], 0)) * 10 for link in self._links]
|
||||
|
||||
plt.cla()
|
||||
nx.draw(self.G, self.pos, node_size=0, node_color='blue', width=edge_widths, edge_color=edge_colors, with_labels=False,
|
||||
edge_cmap=self.cmap)
|
||||
nx.draw(self._G, self._pos, node_size=0, node_color='blue', width=edge_widths, edge_color=edge_colors,
|
||||
with_labels=False,
|
||||
edge_cmap=self._cmap)
|
||||
|
||||
plt.title(f"Time: {tick}")
|
||||
|
||||
|
@ -104,10 +111,10 @@ class MatsimVisualizer():
|
|||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
sm = plt.cm.ScalarMappable(cmap=self.cmap, norm=self.norm)
|
||||
sm = plt.cm.ScalarMappable(cmap=self._cmap, norm=self.norm)
|
||||
sm.set_array([])
|
||||
plt.colorbar(sm, ax=ax, label='Traffic Density')
|
||||
|
||||
ani = FuncAnimation(fig, self.update, frames=len(self.cumulative_traffic), repeat=False)
|
||||
ani.save('traffic_animation.gif', writer='ffmpeg', fps=5)
|
||||
ani = FuncAnimation(fig, self.update, frames=len(self._cumulative_traffic), repeat=False)
|
||||
ani.save(f"{self._output_file_path}/traffic_animation.gif", writer='ffmpeg', fps=5)
|
||||
plt.show()
|
||||
|
|
Loading…
Reference in New Issue
Block a user