565 lines
18 KiB
Python
565 lines
18 KiB
Python
|
import copy
|
||
|
import time
|
||
|
|
||
|
import numpy as np
|
||
|
import collections
|
||
|
|
||
|
from .. import util
|
||
|
from .. import caching
|
||
|
from .. import exceptions
|
||
|
from .. import transformations
|
||
|
|
||
|
try:
|
||
|
import networkx as nx
|
||
|
_ForestParent = nx.DiGraph
|
||
|
except BaseException as E:
|
||
|
# create a dummy module which will raise the ImportError
|
||
|
# or other exception only when someone tries to use networkx
|
||
|
nx = exceptions.ExceptionModule(E)
|
||
|
_ForestParent = object
|
||
|
|
||
|
|
||
|
class TransformForest(object):
|
||
|
def __init__(self, base_frame='world'):
|
||
|
# a graph structure, subclass of networkx DiGraph
|
||
|
self.transforms = EnforcedForest()
|
||
|
# hashable, the base or root frame
|
||
|
self.base_frame = base_frame
|
||
|
|
||
|
# save paths, keyed with tuple (from, to)
|
||
|
self._paths = {}
|
||
|
# cache transformation matrices keyed with tuples
|
||
|
self._updated = str(np.random.random())
|
||
|
self._cache = caching.Cache(self.md5)
|
||
|
|
||
|
def update(self, frame_to, frame_from=None, **kwargs):
|
||
|
"""
|
||
|
Update a transform in the tree.
|
||
|
|
||
|
Parameters
|
||
|
------------
|
||
|
frame_from : hashable object
|
||
|
Usually a string (eg 'world').
|
||
|
If left as None it will be set to self.base_frame
|
||
|
frame_to : hashable object
|
||
|
Usually a string (eg 'mesh_0')
|
||
|
matrix : (4,4) float
|
||
|
Homogeneous transformation matrix
|
||
|
quaternion : (4,) float
|
||
|
Quaternion ordered [w, x, y, z]
|
||
|
axis : (3,) float
|
||
|
Axis of rotation
|
||
|
angle : float
|
||
|
Angle of rotation, in radians
|
||
|
translation : (3,) float
|
||
|
Distance to translate
|
||
|
geometry : hashable
|
||
|
Geometry object name, e.g. 'mesh_0'
|
||
|
"""
|
||
|
|
||
|
self._updated = str(np.random.random())
|
||
|
self._cache.clear()
|
||
|
|
||
|
# if no frame specified, use base frame
|
||
|
if frame_from is None:
|
||
|
frame_from = self.base_frame
|
||
|
# convert various kwargs to a single matrix
|
||
|
matrix = kwargs_to_matrix(**kwargs)
|
||
|
|
||
|
# create the edge attributes
|
||
|
attr = {'matrix': matrix, 'time': time.time()}
|
||
|
# pass through geometry to edge attribute
|
||
|
if 'geometry' in kwargs:
|
||
|
attr['geometry'] = kwargs['geometry']
|
||
|
|
||
|
# add the edges
|
||
|
changed = self.transforms.add_edge(frame_from,
|
||
|
frame_to,
|
||
|
**attr)
|
||
|
# set the node attribute with the geometry information
|
||
|
if 'geometry' in kwargs:
|
||
|
nx.set_node_attributes(
|
||
|
self.transforms,
|
||
|
name='geometry',
|
||
|
values={frame_to: kwargs['geometry']})
|
||
|
# if the edge update changed our structure
|
||
|
# dump our cache of shortest paths
|
||
|
if changed:
|
||
|
self._paths = {}
|
||
|
|
||
|
def md5(self):
|
||
|
return self._updated
|
||
|
|
||
|
def copy(self):
|
||
|
"""
|
||
|
Return a copy of the current TransformForest
|
||
|
|
||
|
Returns
|
||
|
------------
|
||
|
copied : TransformForest
|
||
|
"""
|
||
|
copied = TransformForest()
|
||
|
copied.base_frame = copy.deepcopy(self.base_frame)
|
||
|
copied.transforms = copy.deepcopy(self.transforms)
|
||
|
|
||
|
return copied
|
||
|
|
||
|
def to_flattened(self, base_frame=None):
|
||
|
"""
|
||
|
Export the current transform graph as a flattened
|
||
|
"""
|
||
|
if base_frame is None:
|
||
|
base_frame = self.base_frame
|
||
|
|
||
|
flat = {}
|
||
|
for node in self.nodes:
|
||
|
if node == base_frame:
|
||
|
continue
|
||
|
transform, geometry = self.get(
|
||
|
frame_to=node, frame_from=base_frame)
|
||
|
flat[node] = {
|
||
|
'transform': transform.tolist(),
|
||
|
'geometry': geometry
|
||
|
}
|
||
|
return flat
|
||
|
|
||
|
def to_gltf(self, scene):
|
||
|
"""
|
||
|
Export a transforms as the 'nodes' section of a GLTF dict.
|
||
|
Flattens tree.
|
||
|
|
||
|
Returns
|
||
|
--------
|
||
|
gltf : dict
|
||
|
with 'nodes' referencing a list of dicts
|
||
|
"""
|
||
|
|
||
|
# geometry is an OrderedDict
|
||
|
# map mesh name to index: {geometry key : index}
|
||
|
mesh_index = {name: i for i, name
|
||
|
in enumerate(scene.geometry.keys())}
|
||
|
|
||
|
# shortcut to graph
|
||
|
graph = self.transforms
|
||
|
# get the stored node data
|
||
|
node_data = dict(graph.nodes(data=True))
|
||
|
|
||
|
# list of dict, in gltf format
|
||
|
# start with base frame as first node index
|
||
|
result = [{'name': self.base_frame}]
|
||
|
# {node name : node index in gltf}
|
||
|
lookup = {self.base_frame: 0}
|
||
|
|
||
|
# collect the nodes in order
|
||
|
for node in node_data.keys():
|
||
|
if node == self.base_frame:
|
||
|
continue
|
||
|
lookup[node] = len(result)
|
||
|
result.append({'name': node})
|
||
|
|
||
|
# then iterate through to collect data
|
||
|
for info in result:
|
||
|
# name of the scene node
|
||
|
node = info['name']
|
||
|
# store children as indexes
|
||
|
children = [lookup[k] for k in graph[node].keys()]
|
||
|
if len(children) > 0:
|
||
|
info['children'] = children
|
||
|
# if we have a mesh store by index
|
||
|
if 'geometry' in node_data[node]:
|
||
|
info['mesh'] = mesh_index[node_data[node]['geometry']]
|
||
|
# check to see if we have camera node
|
||
|
if node == scene.camera.name:
|
||
|
info['camera'] = 0
|
||
|
try:
|
||
|
# try to ignore KeyError and StopIteration
|
||
|
# parent-child transform is stored in child
|
||
|
parent = next(iter(graph.predecessors(node)))
|
||
|
# get the (4, 4) homogeneous transform
|
||
|
matrix = graph.get_edge_data(parent, node)['matrix']
|
||
|
# only include matrix if it is not an identity matrix
|
||
|
if np.abs(matrix - np.eye(4)).max() > 1e-5:
|
||
|
info['matrix'] = matrix.T.reshape(-1).tolist()
|
||
|
except BaseException:
|
||
|
continue
|
||
|
return {'nodes': result}
|
||
|
|
||
|
def to_edgelist(self):
|
||
|
"""
|
||
|
Export the current transforms as a list of
|
||
|
edge tuples, with each tuple having the format:
|
||
|
(node_a, node_b, {metadata})
|
||
|
|
||
|
Returns
|
||
|
---------
|
||
|
edgelist : (n,) list
|
||
|
Of edge tuples
|
||
|
"""
|
||
|
# save cleaned edges
|
||
|
export = []
|
||
|
# loop through (node, node, edge attributes)
|
||
|
for edge in nx.to_edgelist(self.transforms):
|
||
|
a, b, attr = edge
|
||
|
# geometry is a node property but save it to the
|
||
|
# edge so we don't need two dictionaries
|
||
|
try:
|
||
|
b_attr = self.transforms.nodes[b]
|
||
|
except BaseException:
|
||
|
# networkx 1.X API
|
||
|
b_attr = self.transforms.node[b]
|
||
|
# apply node geometry to edge attributes
|
||
|
if 'geometry' in b_attr:
|
||
|
attr['geometry'] = b_attr['geometry']
|
||
|
# save the matrix as a float list
|
||
|
attr['matrix'] = np.asanyarray(
|
||
|
attr['matrix'], dtype=np.float64).tolist()
|
||
|
export.append((a, b, attr))
|
||
|
return export
|
||
|
|
||
|
def from_edgelist(self, edges, strict=True):
|
||
|
"""
|
||
|
Load transform data from an edge list into the current
|
||
|
scene graph.
|
||
|
|
||
|
Parameters
|
||
|
-------------
|
||
|
edgelist : (n,) tuples
|
||
|
(node_a, node_b, {key: value})
|
||
|
strict : bool
|
||
|
If true, raise a ValueError when a
|
||
|
malformed edge is passed in a tuple.
|
||
|
"""
|
||
|
# loop through each edge
|
||
|
for edge in edges:
|
||
|
# edge contains attributes
|
||
|
if len(edge) == 3:
|
||
|
self.update(edge[1], edge[0], **edge[2])
|
||
|
# edge just contains nodes
|
||
|
elif len(edge) == 2:
|
||
|
self.update(edge[1], edge[0])
|
||
|
# edge is broken
|
||
|
elif strict:
|
||
|
raise ValueError('edge incorrect shape: {}'.format(str(edge)))
|
||
|
|
||
|
def load(self, edgelist):
|
||
|
"""
|
||
|
Load transform data from an edge list into the current
|
||
|
scene graph.
|
||
|
|
||
|
Parameters
|
||
|
-------------
|
||
|
edgelist : (n,) tuples
|
||
|
(node_a, node_b, {key: value})
|
||
|
"""
|
||
|
self.from_edgelist(edgelist, strict=True)
|
||
|
|
||
|
@caching.cache_decorator
|
||
|
def nodes(self):
|
||
|
"""
|
||
|
A list of every node in the graph.
|
||
|
|
||
|
Returns
|
||
|
-------------
|
||
|
nodes : (n,) array
|
||
|
All node names
|
||
|
"""
|
||
|
nodes = list(self.transforms.nodes())
|
||
|
return nodes
|
||
|
|
||
|
@caching.cache_decorator
|
||
|
def nodes_geometry(self):
|
||
|
"""
|
||
|
The nodes in the scene graph with geometry attached.
|
||
|
|
||
|
Returns
|
||
|
------------
|
||
|
nodes_geometry : (m,) array
|
||
|
Node names which have geometry associated
|
||
|
"""
|
||
|
nodes = [n for n, attr in
|
||
|
self.transforms.nodes(data=True)
|
||
|
if 'geometry' in attr]
|
||
|
return nodes
|
||
|
|
||
|
@caching.cache_decorator
|
||
|
def geometry_nodes(self):
|
||
|
"""
|
||
|
Which nodes have this geometry?
|
||
|
|
||
|
Returns
|
||
|
------------
|
||
|
geometry_nodes : dict
|
||
|
Geometry name : (n,) node names
|
||
|
"""
|
||
|
res = collections.defaultdict(list)
|
||
|
for node, attr in self.transforms.nodes(data=True):
|
||
|
if 'geometry' in attr:
|
||
|
res[attr['geometry']].append(node)
|
||
|
return res
|
||
|
|
||
|
def remove_geometries(self, geometries):
|
||
|
"""
|
||
|
Remove the reference for specified geometries from nodes
|
||
|
without deleting the node.
|
||
|
|
||
|
Parameters
|
||
|
------------
|
||
|
geometries : list or str
|
||
|
Name of scene.geometry to dereference.
|
||
|
"""
|
||
|
# make sure we have a set of geometries to remove
|
||
|
if util.is_string(geometries):
|
||
|
geometries = [geometries]
|
||
|
geometries = set(geometries)
|
||
|
|
||
|
# remove the geometry reference from the node without deleting nodes
|
||
|
# this lets us keep our cached paths, and will not screw up children
|
||
|
for node, attrib in self.transforms.nodes(data=True):
|
||
|
if 'geometry' in attrib and attrib['geometry'] in geometries:
|
||
|
attrib.pop('geometry')
|
||
|
|
||
|
# it would be safer to just run _cache.clear
|
||
|
# but the only property using the geometry should be
|
||
|
# nodes_geometry: if this becomes not true change this to clear!
|
||
|
self._cache.cache.pop('nodes_geometry', None)
|
||
|
|
||
|
def get(self, frame_to, frame_from=None):
|
||
|
"""
|
||
|
Get the transform from one frame to another, assuming they are connected
|
||
|
in the transform tree.
|
||
|
|
||
|
If the frames are not connected a NetworkXNoPath error will be raised.
|
||
|
|
||
|
Parameters
|
||
|
------------
|
||
|
frame_to : hashable
|
||
|
Node name, usually a string (eg 'mesh_0')
|
||
|
frame_from : hashable
|
||
|
Node name, usually a string (eg 'world').
|
||
|
If None it will be set to self.base_frame
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
transform : (4, 4) float
|
||
|
Homogeneous transformation matrix
|
||
|
"""
|
||
|
|
||
|
if frame_from is None:
|
||
|
frame_from = self.base_frame
|
||
|
|
||
|
# look up transform to see if we have it already
|
||
|
cache_key = (frame_from, frame_to)
|
||
|
cached = self._cache[cache_key]
|
||
|
if cached is not None:
|
||
|
return cached
|
||
|
|
||
|
# get the path in the graph
|
||
|
path = self._get_path(frame_from, frame_to)
|
||
|
|
||
|
# collect transforms along the path
|
||
|
transforms = []
|
||
|
|
||
|
for i in range(len(path) - 1):
|
||
|
# get the matrix and edge direction
|
||
|
data, direction = self.transforms.get_edge_data_direction(
|
||
|
path[i], path[i + 1])
|
||
|
matrix = data['matrix']
|
||
|
if direction < 0:
|
||
|
matrix = np.linalg.inv(matrix)
|
||
|
transforms.append(matrix)
|
||
|
# do all dot products at the end
|
||
|
if len(transforms) == 0:
|
||
|
transform = np.eye(4)
|
||
|
elif len(transforms) == 1:
|
||
|
transform = np.asanyarray(transforms[0], dtype=np.float64)
|
||
|
else:
|
||
|
transform = util.multi_dot(transforms)
|
||
|
|
||
|
try:
|
||
|
attr = self.transforms.nodes[frame_to]
|
||
|
except BaseException:
|
||
|
# networkx 1.X API
|
||
|
attr = self.transforms.node[frame_to]
|
||
|
|
||
|
if 'geometry' in attr:
|
||
|
geometry = attr['geometry']
|
||
|
else:
|
||
|
geometry = None
|
||
|
|
||
|
self._cache[cache_key] = (transform, geometry)
|
||
|
|
||
|
return transform, geometry
|
||
|
|
||
|
def show(self):
|
||
|
"""
|
||
|
Plot the graph layout of the scene.
|
||
|
"""
|
||
|
import matplotlib.pyplot as plt
|
||
|
nx.draw(self.transforms, with_labels=True)
|
||
|
plt.show()
|
||
|
|
||
|
def to_svg(self):
|
||
|
"""
|
||
|
"""
|
||
|
from ..graph import graph_to_svg
|
||
|
return graph_to_svg(self.transforms)
|
||
|
|
||
|
def __contains__(self, key):
|
||
|
try:
|
||
|
return key in self.transforms.nodes
|
||
|
except BaseException:
|
||
|
# networkx 1.X API
|
||
|
util.log.warning('upgrade networkx to version 2.X!')
|
||
|
return key in self.transforms.nodes()
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return self.get(key)
|
||
|
|
||
|
def __setitem__(self, key, value):
|
||
|
value = np.asanyarray(value)
|
||
|
if value.shape != (4, 4):
|
||
|
raise ValueError('Matrix must be specified!')
|
||
|
return self.update(key, matrix=value)
|
||
|
|
||
|
def clear(self):
|
||
|
self.transforms = EnforcedForest()
|
||
|
self._paths = {}
|
||
|
self._updated = str(np.random.random())
|
||
|
self._cache.clear()
|
||
|
|
||
|
def _get_path(self, frame_from, frame_to):
|
||
|
"""
|
||
|
Find a path between two frames, either from cached paths or
|
||
|
from the transform graph.
|
||
|
|
||
|
Parameters
|
||
|
------------
|
||
|
frame_from: a frame key, usually a string
|
||
|
eg, 'world'
|
||
|
frame_to: a frame key, usually a string
|
||
|
eg, 'mesh_0'
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
path: (n) list of frame keys
|
||
|
eg, ['mesh_finger', 'mesh_hand', 'world']
|
||
|
"""
|
||
|
# store paths keyed as a tuple
|
||
|
key = (frame_from, frame_to)
|
||
|
if key not in self._paths:
|
||
|
# get the actual shortest paths
|
||
|
path = self.transforms.shortest_path_undirected(
|
||
|
frame_from, frame_to)
|
||
|
# store path to avoid recomputing
|
||
|
self._paths[key] = path
|
||
|
return path
|
||
|
return self._paths[key]
|
||
|
|
||
|
|
||
|
class EnforcedForest(_ForestParent):
|
||
|
"""
|
||
|
A subclass of networkx.DiGraph that will raise an error if an
|
||
|
edge is added which would make the DiGraph not a forest or tree.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
self.flags = {'strict': False, 'assert_forest': False}
|
||
|
|
||
|
for k, v in self.flags.items():
|
||
|
if k in kwargs:
|
||
|
self.flags[k] = bool(kwargs[k])
|
||
|
kwargs.pop(k, None)
|
||
|
|
||
|
super(self.__class__, self).__init__(*args, **kwargs)
|
||
|
# keep a second parallel but undirected copy of the graph
|
||
|
# all of the networkx methods for turning a directed graph
|
||
|
# into an undirected graph are quite slow so we do minor bookkeeping
|
||
|
self._undirected = nx.Graph()
|
||
|
|
||
|
def add_edge(self, u, v, *args, **kwargs):
|
||
|
changed = False
|
||
|
if u == v:
|
||
|
if self.flags['strict']:
|
||
|
raise ValueError('Edge must be between two unique nodes!')
|
||
|
return changed
|
||
|
elif self._undirected.has_edge(u, v):
|
||
|
self.remove_edges_from([[u, v], [v, u]])
|
||
|
elif len(self.nodes()) > 0:
|
||
|
try:
|
||
|
path = nx.shortest_path(self._undirected, u, v)
|
||
|
if self.flags['strict']:
|
||
|
raise ValueError(
|
||
|
'Multiple edge path exists between nodes!')
|
||
|
self.disconnect_path(path)
|
||
|
changed = True
|
||
|
except (nx.NetworkXError, nx.NetworkXNoPath, nx.NetworkXException):
|
||
|
pass
|
||
|
self._undirected.add_edge(u, v)
|
||
|
super(self.__class__, self).add_edge(u, v, *args, **kwargs)
|
||
|
|
||
|
if self.flags['assert_forest']:
|
||
|
# this is quite slow but makes very sure structure is correct
|
||
|
# so is mainly used for testing
|
||
|
assert nx.is_forest(nx.Graph(self))
|
||
|
|
||
|
return changed
|
||
|
|
||
|
def add_edges_from(self, *args, **kwargs):
|
||
|
raise ValueError('EnforcedTree requires add_edge method to be used!')
|
||
|
|
||
|
def add_path(self, *args, **kwargs):
|
||
|
raise ValueError('EnforcedTree requires add_edge method to be used!')
|
||
|
|
||
|
def remove_edge(self, *args, **kwargs):
|
||
|
super(self.__class__, self).remove_edge(*args, **kwargs)
|
||
|
self._undirected.remove_edge(*args, **kwargs)
|
||
|
|
||
|
def remove_edges_from(self, *args, **kwargs):
|
||
|
super(self.__class__, self).remove_edges_from(*args, **kwargs)
|
||
|
self._undirected.remove_edges_from(*args, **kwargs)
|
||
|
|
||
|
def disconnect_path(self, path):
|
||
|
ebunch = np.array([[path[0], path[1]]])
|
||
|
ebunch = np.vstack((ebunch, np.fliplr(ebunch)))
|
||
|
self.remove_edges_from(ebunch)
|
||
|
|
||
|
def shortest_path_undirected(self, u, v):
|
||
|
try:
|
||
|
path = nx.shortest_path(self._undirected, u, v)
|
||
|
except BaseException as E:
|
||
|
raise E
|
||
|
return path
|
||
|
|
||
|
def get_edge_data_direction(self, u, v):
|
||
|
if self.has_edge(u, v):
|
||
|
direction = 1
|
||
|
elif self.has_edge(v, u):
|
||
|
direction = -1
|
||
|
else:
|
||
|
raise ValueError('Edge does not exist!')
|
||
|
data = self.get_edge_data(*[u, v][::direction])
|
||
|
return data, direction
|
||
|
|
||
|
|
||
|
def kwargs_to_matrix(**kwargs):
|
||
|
"""
|
||
|
Turn a set of keyword arguments into a transformation matrix.
|
||
|
"""
|
||
|
if 'matrix' in kwargs:
|
||
|
# a matrix takes precedence over other options
|
||
|
matrix = np.asanyarray(kwargs['matrix'], dtype=np.float64)
|
||
|
elif 'quaternion' in kwargs:
|
||
|
matrix = transformations.quaternion_matrix(kwargs['quaternion'])
|
||
|
elif ('axis' in kwargs) and ('angle' in kwargs):
|
||
|
matrix = transformations.rotation_matrix(kwargs['angle'],
|
||
|
kwargs['axis'])
|
||
|
else:
|
||
|
matrix = np.eye(4)
|
||
|
|
||
|
if 'translation' in kwargs:
|
||
|
# translation can be used in conjunction with any of the methods of
|
||
|
# specifying transforms. In the case a matrix and translation are passed,
|
||
|
# we add the translations together rather than picking one.
|
||
|
matrix[0:3, 3] += kwargs['translation']
|
||
|
return matrix
|