2024-02-01 14:35:55 -05:00
import gzip
2024-02-15 14:32:49 -05:00
from lxml import etree
2024-02-01 14:35:55 -05:00
import networkx as nx
import matplotlib . pyplot as plt
import matplotlib . cm as cm
import matplotlib . colors as colors
2024-03-05 15:14:36 -05:00
from matplotlib . animation import FuncAnimation
from collections import defaultdict
2024-02-01 14:35:55 -05:00
2024-02-14 16:54:41 -05:00
2024-03-05 15:14:36 -05:00
class MatsimVisualizer :
def __init__ ( self , input_file_path , output_file_path , viewport , show_facilities = False , show_nearby_nodes = False ) :
2024-02-16 11:47:55 -05:00
self . _nodes = { }
self . _links = [ ]
2024-02-14 16:54:41 -05:00
self . _pos = None
2024-02-01 14:35:55 -05:00
self . norm = None
2024-02-21 12:44:08 -05:00
self . _max_x = None
self . _min_x = None
self . _max_y = None
self . _min_y = None
2024-02-16 11:47:55 -05:00
self . _viewport = viewport
2024-02-01 14:35:55 -05:00
2024-02-14 16:54:41 -05:00
self . _output_file_path = output_file_path
2024-03-05 15:14:36 -05:00
self . _network_file_path = f " { input_file_path } /output_network.xml.gz "
self . _events_file_path = f " { input_file_path } /output_events.xml.gz "
self . _facilities_file_path = f " { input_file_path } /output_facilities.xml.gz "
2024-02-14 16:54:41 -05:00
self . _G = nx . Graph ( )
self . _traffic_per_tick = defaultdict ( lambda : defaultdict ( int ) )
2024-03-05 15:14:36 -05:00
self . _close_to_facility_nodes = set ( )
2024-02-14 16:54:41 -05:00
2024-02-15 14:32:49 -05:00
self . _tick = 0
self . _max_traffic = 0
2024-02-21 12:44:08 -05:00
self . _max_traffic_link = None
self . _max_traffic_tick = None
2024-02-15 14:32:49 -05:00
2024-03-05 15:14:36 -05:00
self . _show_facilities = show_facilities
self . _show_nearby_nodes = show_nearby_nodes
def _load_data ( self ) :
# ====== LOAD FACILITIES ====== #
if self . _show_facilities :
with gzip . open ( self . _facilities_file_path , ' rb ' ) as file :
facilities_doc = etree . parse ( file )
for node in facilities_doc . xpath ( ' /facilities/facility ' ) :
x = float ( node . get ( ' x ' ) )
y = float ( node . get ( ' y ' ) )
if self . _viewport [ 0 ] < x < self . _viewport [ 2 ] and self . _viewport [ 1 ] > y > self . _viewport [ 3 ] :
self . _nodes [ f " facility_ { node . get ( ' id ' ) } " ] = ( x , y )
2024-02-15 14:32:49 -05:00
# ====== LOAD NETWORK ====== #
2024-02-14 16:54:41 -05:00
with gzip . open ( self . _network_file_path , ' rb ' ) as file :
2024-02-15 14:32:49 -05:00
network_doc = etree . parse ( file )
2024-02-01 14:35:55 -05:00
2024-02-16 11:47:55 -05:00
for node in network_doc . xpath ( ' /network/nodes/node ' ) :
x = float ( node . get ( ' x ' ) )
y = float ( node . get ( ' y ' ) )
2024-02-21 12:44:08 -05:00
if self . _max_x is None or x > self . _max_x :
self . _max_x = x
if self . _min_x is None or x < self . _min_x :
self . _min_x = x
if self . _max_y is None or y > self . _max_y :
self . _max_y = y
if self . _min_y is None or y < self . _min_y :
self . _min_y = y
if self . _viewport [ 0 ] < x < self . _viewport [ 2 ] and self . _viewport [ 1 ] > y > self . _viewport [ 3 ] :
2024-02-16 11:47:55 -05:00
self . _nodes [ node . get ( ' id ' ) ] = ( x , y )
for link in network_doc . xpath ( ' /network/links/link ' ) :
start_node = link . get ( ' from ' )
end_node = link . get ( ' to ' )
id = link . get ( ' id ' )
if start_node in self . _nodes and end_node in self . _nodes :
self . _links . append ( {
' id ' : id ,
' from ' : start_node ,
' to ' : end_node ,
' vehicles ' : 0 ,
} )
2024-02-01 14:35:55 -05:00
2024-02-15 14:32:49 -05:00
cumulative_traffic = defaultdict ( int )
# ====== LOAD EVENTS ====== #
2024-02-14 16:54:41 -05:00
with gzip . open ( self . _events_file_path , ' rb ' ) as file :
2024-02-15 14:32:49 -05:00
events_doc = etree . parse ( file )
2024-03-12 13:32:47 -04:00
current_time = None
2024-02-15 14:32:49 -05:00
self . _tick = 0
for event in events_doc . xpath ( ' /events/event ' ) :
link_id = event . get ( ' link ' )
event_type = event . get ( ' type ' )
time = float ( event . get ( ' time ' ) )
2024-03-12 13:32:47 -04:00
if current_time is None :
current_time = float ( event . get ( ' time ' ) )
2024-02-15 14:32:49 -05:00
ticked = False
2024-03-12 13:33:23 -04:00
if link_id is not None and event_type is not None and time is not None and any ( link [ ' id ' ] == link_id for link in self . _links ) :
2024-02-15 14:32:49 -05:00
if event_type in [ ' entered link ' , ' vehicle enters traffic ' ] :
self . _traffic_per_tick [ self . _tick ] [ link_id ] + = 1
2024-02-16 11:47:55 -05:00
ticked = True
2024-02-15 14:32:49 -05:00
cumulative_traffic [ link_id ] + = 1
2024-02-16 11:47:55 -05:00
# We need to find the maximum value for traffic at any given point for the heatmap
2024-02-15 14:32:49 -05:00
if self . _max_traffic < cumulative_traffic [ link_id ] :
self . _max_traffic = cumulative_traffic [ link_id ]
2024-02-21 12:44:08 -05:00
self . _max_traffic_link = link_id
self . _max_traffic_tick = self . _tick
2024-02-15 14:32:49 -05:00
elif event_type in [ ' left link ' , ' vehicle leaves traffic ' ] :
self . _traffic_per_tick [ self . _tick ] [ link_id ] - = 1
cumulative_traffic [ link_id ] - = 1
ticked = True
2024-03-12 13:32:47 -04:00
if ticked and current_time != time :
2024-02-15 14:32:49 -05:00
self . _tick + = 1
2024-03-12 13:32:47 -04:00
current_time = time
2024-02-15 14:32:49 -05:00
2024-03-05 15:14:36 -05:00
def _create_graph ( self ) :
2024-02-14 16:54:41 -05:00
for node_id , coords in self . _nodes . items ( ) :
2024-03-05 15:14:36 -05:00
self . _G . add_node ( node_id , pos = coords , is_facility = node_id . startswith ( ' facility_ ' ) )
2024-02-14 16:54:41 -05:00
for link in self . _links :
2024-03-05 15:14:36 -05:00
self . _G . add_edge ( link [ ' from ' ] , link [ ' to ' ] , id = link [ ' id ' ] )
2024-02-14 16:54:41 -05:00
self . _pos = nx . get_node_attributes ( self . _G , ' pos ' )
2024-02-01 14:35:55 -05:00
2024-03-05 15:14:36 -05:00
if self . _show_nearby_nodes :
self . _identify_close_nodes ( radius = 100 ) # TODO: move to load data
def _identify_close_nodes ( self , radius ) :
facility_coords = { node : self . _pos [ node ] for node in self . _G . nodes if self . _G . nodes [ node ] . get ( ' is_facility ' ) }
2024-02-01 14:35:55 -05:00
2024-03-05 15:14:36 -05:00
for node , coords in self . _pos . items ( ) :
if self . _G . nodes [ node ] . get ( ' is_facility ' ) :
continue
for facility_node , facility_coord in facility_coords . items ( ) :
if self . _euclidean_distance ( coords , facility_coord ) < = radius :
self . _close_to_facility_nodes . add ( node )
break
def _setup_color_mapping ( self ) :
self . norm = colors . Normalize ( vmin = 0 , vmax = self . _max_traffic / 2 )
def _update ( self , frame ) :
2024-02-15 14:32:49 -05:00
traffic_change = self . _traffic_per_tick [ self . _tick ]
2024-03-05 15:14:36 -05:00
edge_colors_dict = { }
edge_widths_dict = { }
2024-02-15 14:32:49 -05:00
for link in self . _links :
for link_id , change in traffic_change . items ( ) :
if link_id == link [ ' id ' ] :
link [ ' vehicles ' ] + = change
2024-02-01 14:35:55 -05:00
2024-03-05 15:14:36 -05:00
edge_colors_dict [ link [ ' id ' ] ] = self . _cmap ( self . norm ( link [ ' vehicles ' ] ) )
edge_widths_dict [ link [ ' id ' ] ] = 1 + self . norm ( link [ ' vehicles ' ] )
edgelist = [ ( link [ ' from ' ] , link [ ' to ' ] ) for link in self . _links ]
edge_colors = [ edge_colors_dict [ link [ " id " ] ] for link in self . _links ]
edge_widths = [ edge_widths_dict [ link [ " id " ] ] for link in self . _links ]
2024-02-01 14:35:55 -05:00
plt . cla ( )
2024-03-05 15:14:36 -05:00
# Draw Facilities
if self . _show_facilities :
facility_nodes = { node : pos for node , pos in self . _pos . items ( ) if self . _G . nodes [ node ] [ ' is_facility ' ] }
nx . draw_networkx_nodes ( self . _G , facility_nodes , facility_nodes . keys ( ) , node_size = 1 , node_color = ' red ' )
# Draw Network
if self . _show_nearby_nodes :
close_nodes = { node : self . _pos [ node ] for node in self . _close_to_facility_nodes if node in self . _pos }
nx . draw_networkx_nodes ( self . _G , self . _pos , close_nodes . keys ( ) , node_size = 1 , node_color = ' blue ' )
nx . draw_networkx_edges ( self . _G , self . _pos , edgelist = edgelist , width = edge_widths , edge_color = edge_colors , edge_cmap = self . _cmap )
2024-02-15 14:32:49 -05:00
plt . title ( f " Time: { self . _tick } " )
self . _tick + = 1
2024-02-01 14:35:55 -05:00
2024-03-05 15:14:36 -05:00
def visualize ( self , fps = 15 , colormap = ' inferno ' ) :
self . _load_data ( )
self . _create_graph ( )
self . _setup_color_mapping ( )
2024-02-01 14:35:55 -05:00
fig , ax = plt . subplots ( )
2024-02-21 12:44:08 -05:00
ax . set_aspect ( ( self . _max_x - self . _min_x ) / ( self . _max_y - self . _min_y ) )
2024-03-05 15:14:36 -05:00
if hasattr ( cm , colormap ) :
self . _cmap = getattr ( cm , colormap )
else :
print ( f " Colormap ' { colormap } ' not recognized. Falling back to ' inferno ' . Please select from: ' inferno ' , ' plasma ' , ' magma ' , ' spring ' , ' summer ' , ' autumn ' , or ' winter ' . " )
self . _cmap = cm . inferno
2024-02-14 16:54:41 -05:00
sm = plt . cm . ScalarMappable ( cmap = self . _cmap , norm = self . norm )
2024-02-01 14:35:55 -05:00
sm . set_array ( [ ] )
plt . colorbar ( sm , ax = ax , label = ' Traffic Density ' )
2024-02-15 14:32:49 -05:00
self . _tick = 0
2024-03-05 15:14:36 -05:00
ani = FuncAnimation ( fig , self . _update , frames = len ( self . _traffic_per_tick ) , repeat = False )
ani . save ( f " { self . _output_file_path } /traffic_animation.gif " , writer = ' ffmpeg ' , fps = fps )
2024-02-01 14:35:55 -05:00
plt . show ( )
2024-02-21 12:44:08 -05:00
2024-03-05 15:14:36 -05:00
def _euclidean_distance ( self , coord1 , coord2 ) :
return ( ( coord1 [ 0 ] - coord2 [ 0 ] ) * * 2 + ( coord1 [ 1 ] - coord2 [ 1 ] ) * * 2 ) * * 0.5