hub/venv/lib/python3.7/site-packages/trimesh/scene/transforms.py

565 lines
18 KiB
Python
Raw Normal View History

import copy
import time
import numpy as np
import collections
from .. import util
from .. import caching
from .. import exceptions
from .. import transformations
try:
import networkx as nx
_ForestParent = nx.DiGraph
except BaseException as E:
# create a dummy module which will raise the ImportError
# or other exception only when someone tries to use networkx
nx = exceptions.ExceptionModule(E)
_ForestParent = object
class TransformForest(object):
def __init__(self, base_frame='world'):
# a graph structure, subclass of networkx DiGraph
self.transforms = EnforcedForest()
# hashable, the base or root frame
self.base_frame = base_frame
# save paths, keyed with tuple (from, to)
self._paths = {}
# cache transformation matrices keyed with tuples
self._updated = str(np.random.random())
self._cache = caching.Cache(self.md5)
def update(self, frame_to, frame_from=None, **kwargs):
"""
Update a transform in the tree.
Parameters
------------
frame_from : hashable object
Usually a string (eg 'world').
If left as None it will be set to self.base_frame
frame_to : hashable object
Usually a string (eg 'mesh_0')
matrix : (4,4) float
Homogeneous transformation matrix
quaternion : (4,) float
Quaternion ordered [w, x, y, z]
axis : (3,) float
Axis of rotation
angle : float
Angle of rotation, in radians
translation : (3,) float
Distance to translate
geometry : hashable
Geometry object name, e.g. 'mesh_0'
"""
self._updated = str(np.random.random())
self._cache.clear()
# if no frame specified, use base frame
if frame_from is None:
frame_from = self.base_frame
# convert various kwargs to a single matrix
matrix = kwargs_to_matrix(**kwargs)
# create the edge attributes
attr = {'matrix': matrix, 'time': time.time()}
# pass through geometry to edge attribute
if 'geometry' in kwargs:
attr['geometry'] = kwargs['geometry']
# add the edges
changed = self.transforms.add_edge(frame_from,
frame_to,
**attr)
# set the node attribute with the geometry information
if 'geometry' in kwargs:
nx.set_node_attributes(
self.transforms,
name='geometry',
values={frame_to: kwargs['geometry']})
# if the edge update changed our structure
# dump our cache of shortest paths
if changed:
self._paths = {}
def md5(self):
return self._updated
def copy(self):
"""
Return a copy of the current TransformForest
Returns
------------
copied : TransformForest
"""
copied = TransformForest()
copied.base_frame = copy.deepcopy(self.base_frame)
copied.transforms = copy.deepcopy(self.transforms)
return copied
def to_flattened(self, base_frame=None):
"""
Export the current transform graph as a flattened
"""
if base_frame is None:
base_frame = self.base_frame
flat = {}
for node in self.nodes:
if node == base_frame:
continue
transform, geometry = self.get(
frame_to=node, frame_from=base_frame)
flat[node] = {
'transform': transform.tolist(),
'geometry': geometry
}
return flat
def to_gltf(self, scene):
"""
Export a transforms as the 'nodes' section of a GLTF dict.
Flattens tree.
Returns
--------
gltf : dict
with 'nodes' referencing a list of dicts
"""
# geometry is an OrderedDict
# map mesh name to index: {geometry key : index}
mesh_index = {name: i for i, name
in enumerate(scene.geometry.keys())}
# shortcut to graph
graph = self.transforms
# get the stored node data
node_data = dict(graph.nodes(data=True))
# list of dict, in gltf format
# start with base frame as first node index
result = [{'name': self.base_frame}]
# {node name : node index in gltf}
lookup = {self.base_frame: 0}
# collect the nodes in order
for node in node_data.keys():
if node == self.base_frame:
continue
lookup[node] = len(result)
result.append({'name': node})
# then iterate through to collect data
for info in result:
# name of the scene node
node = info['name']
# store children as indexes
children = [lookup[k] for k in graph[node].keys()]
if len(children) > 0:
info['children'] = children
# if we have a mesh store by index
if 'geometry' in node_data[node]:
info['mesh'] = mesh_index[node_data[node]['geometry']]
# check to see if we have camera node
if node == scene.camera.name:
info['camera'] = 0
try:
# try to ignore KeyError and StopIteration
# parent-child transform is stored in child
parent = next(iter(graph.predecessors(node)))
# get the (4, 4) homogeneous transform
matrix = graph.get_edge_data(parent, node)['matrix']
# only include matrix if it is not an identity matrix
if np.abs(matrix - np.eye(4)).max() > 1e-5:
info['matrix'] = matrix.T.reshape(-1).tolist()
except BaseException:
continue
return {'nodes': result}
def to_edgelist(self):
"""
Export the current transforms as a list of
edge tuples, with each tuple having the format:
(node_a, node_b, {metadata})
Returns
---------
edgelist : (n,) list
Of edge tuples
"""
# save cleaned edges
export = []
# loop through (node, node, edge attributes)
for edge in nx.to_edgelist(self.transforms):
a, b, attr = edge
# geometry is a node property but save it to the
# edge so we don't need two dictionaries
try:
b_attr = self.transforms.nodes[b]
except BaseException:
# networkx 1.X API
b_attr = self.transforms.node[b]
# apply node geometry to edge attributes
if 'geometry' in b_attr:
attr['geometry'] = b_attr['geometry']
# save the matrix as a float list
attr['matrix'] = np.asanyarray(
attr['matrix'], dtype=np.float64).tolist()
export.append((a, b, attr))
return export
def from_edgelist(self, edges, strict=True):
"""
Load transform data from an edge list into the current
scene graph.
Parameters
-------------
edgelist : (n,) tuples
(node_a, node_b, {key: value})
strict : bool
If true, raise a ValueError when a
malformed edge is passed in a tuple.
"""
# loop through each edge
for edge in edges:
# edge contains attributes
if len(edge) == 3:
self.update(edge[1], edge[0], **edge[2])
# edge just contains nodes
elif len(edge) == 2:
self.update(edge[1], edge[0])
# edge is broken
elif strict:
raise ValueError('edge incorrect shape: {}'.format(str(edge)))
def load(self, edgelist):
"""
Load transform data from an edge list into the current
scene graph.
Parameters
-------------
edgelist : (n,) tuples
(node_a, node_b, {key: value})
"""
self.from_edgelist(edgelist, strict=True)
@caching.cache_decorator
def nodes(self):
"""
A list of every node in the graph.
Returns
-------------
nodes : (n,) array
All node names
"""
nodes = list(self.transforms.nodes())
return nodes
@caching.cache_decorator
def nodes_geometry(self):
"""
The nodes in the scene graph with geometry attached.
Returns
------------
nodes_geometry : (m,) array
Node names which have geometry associated
"""
nodes = [n for n, attr in
self.transforms.nodes(data=True)
if 'geometry' in attr]
return nodes
@caching.cache_decorator
def geometry_nodes(self):
"""
Which nodes have this geometry?
Returns
------------
geometry_nodes : dict
Geometry name : (n,) node names
"""
res = collections.defaultdict(list)
for node, attr in self.transforms.nodes(data=True):
if 'geometry' in attr:
res[attr['geometry']].append(node)
return res
def remove_geometries(self, geometries):
"""
Remove the reference for specified geometries from nodes
without deleting the node.
Parameters
------------
geometries : list or str
Name of scene.geometry to dereference.
"""
# make sure we have a set of geometries to remove
if util.is_string(geometries):
geometries = [geometries]
geometries = set(geometries)
# remove the geometry reference from the node without deleting nodes
# this lets us keep our cached paths, and will not screw up children
for node, attrib in self.transforms.nodes(data=True):
if 'geometry' in attrib and attrib['geometry'] in geometries:
attrib.pop('geometry')
# it would be safer to just run _cache.clear
# but the only property using the geometry should be
# nodes_geometry: if this becomes not true change this to clear!
self._cache.cache.pop('nodes_geometry', None)
def get(self, frame_to, frame_from=None):
"""
Get the transform from one frame to another, assuming they are connected
in the transform tree.
If the frames are not connected a NetworkXNoPath error will be raised.
Parameters
------------
frame_to : hashable
Node name, usually a string (eg 'mesh_0')
frame_from : hashable
Node name, usually a string (eg 'world').
If None it will be set to self.base_frame
Returns
----------
transform : (4, 4) float
Homogeneous transformation matrix
"""
if frame_from is None:
frame_from = self.base_frame
# look up transform to see if we have it already
cache_key = (frame_from, frame_to)
cached = self._cache[cache_key]
if cached is not None:
return cached
# get the path in the graph
path = self._get_path(frame_from, frame_to)
# collect transforms along the path
transforms = []
for i in range(len(path) - 1):
# get the matrix and edge direction
data, direction = self.transforms.get_edge_data_direction(
path[i], path[i + 1])
matrix = data['matrix']
if direction < 0:
matrix = np.linalg.inv(matrix)
transforms.append(matrix)
# do all dot products at the end
if len(transforms) == 0:
transform = np.eye(4)
elif len(transforms) == 1:
transform = np.asanyarray(transforms[0], dtype=np.float64)
else:
transform = util.multi_dot(transforms)
try:
attr = self.transforms.nodes[frame_to]
except BaseException:
# networkx 1.X API
attr = self.transforms.node[frame_to]
if 'geometry' in attr:
geometry = attr['geometry']
else:
geometry = None
self._cache[cache_key] = (transform, geometry)
return transform, geometry
def show(self):
"""
Plot the graph layout of the scene.
"""
import matplotlib.pyplot as plt
nx.draw(self.transforms, with_labels=True)
plt.show()
def to_svg(self):
"""
"""
from ..graph import graph_to_svg
return graph_to_svg(self.transforms)
def __contains__(self, key):
try:
return key in self.transforms.nodes
except BaseException:
# networkx 1.X API
util.log.warning('upgrade networkx to version 2.X!')
return key in self.transforms.nodes()
def __getitem__(self, key):
return self.get(key)
def __setitem__(self, key, value):
value = np.asanyarray(value)
if value.shape != (4, 4):
raise ValueError('Matrix must be specified!')
return self.update(key, matrix=value)
def clear(self):
self.transforms = EnforcedForest()
self._paths = {}
self._updated = str(np.random.random())
self._cache.clear()
def _get_path(self, frame_from, frame_to):
"""
Find a path between two frames, either from cached paths or
from the transform graph.
Parameters
------------
frame_from: a frame key, usually a string
eg, 'world'
frame_to: a frame key, usually a string
eg, 'mesh_0'
Returns
----------
path: (n) list of frame keys
eg, ['mesh_finger', 'mesh_hand', 'world']
"""
# store paths keyed as a tuple
key = (frame_from, frame_to)
if key not in self._paths:
# get the actual shortest paths
path = self.transforms.shortest_path_undirected(
frame_from, frame_to)
# store path to avoid recomputing
self._paths[key] = path
return path
return self._paths[key]
class EnforcedForest(_ForestParent):
"""
A subclass of networkx.DiGraph that will raise an error if an
edge is added which would make the DiGraph not a forest or tree.
"""
def __init__(self, *args, **kwargs):
self.flags = {'strict': False, 'assert_forest': False}
for k, v in self.flags.items():
if k in kwargs:
self.flags[k] = bool(kwargs[k])
kwargs.pop(k, None)
super(self.__class__, self).__init__(*args, **kwargs)
# keep a second parallel but undirected copy of the graph
# all of the networkx methods for turning a directed graph
# into an undirected graph are quite slow so we do minor bookkeeping
self._undirected = nx.Graph()
def add_edge(self, u, v, *args, **kwargs):
changed = False
if u == v:
if self.flags['strict']:
raise ValueError('Edge must be between two unique nodes!')
return changed
elif self._undirected.has_edge(u, v):
self.remove_edges_from([[u, v], [v, u]])
elif len(self.nodes()) > 0:
try:
path = nx.shortest_path(self._undirected, u, v)
if self.flags['strict']:
raise ValueError(
'Multiple edge path exists between nodes!')
self.disconnect_path(path)
changed = True
except (nx.NetworkXError, nx.NetworkXNoPath, nx.NetworkXException):
pass
self._undirected.add_edge(u, v)
super(self.__class__, self).add_edge(u, v, *args, **kwargs)
if self.flags['assert_forest']:
# this is quite slow but makes very sure structure is correct
# so is mainly used for testing
assert nx.is_forest(nx.Graph(self))
return changed
def add_edges_from(self, *args, **kwargs):
raise ValueError('EnforcedTree requires add_edge method to be used!')
def add_path(self, *args, **kwargs):
raise ValueError('EnforcedTree requires add_edge method to be used!')
def remove_edge(self, *args, **kwargs):
super(self.__class__, self).remove_edge(*args, **kwargs)
self._undirected.remove_edge(*args, **kwargs)
def remove_edges_from(self, *args, **kwargs):
super(self.__class__, self).remove_edges_from(*args, **kwargs)
self._undirected.remove_edges_from(*args, **kwargs)
def disconnect_path(self, path):
ebunch = np.array([[path[0], path[1]]])
ebunch = np.vstack((ebunch, np.fliplr(ebunch)))
self.remove_edges_from(ebunch)
def shortest_path_undirected(self, u, v):
try:
path = nx.shortest_path(self._undirected, u, v)
except BaseException as E:
raise E
return path
def get_edge_data_direction(self, u, v):
if self.has_edge(u, v):
direction = 1
elif self.has_edge(v, u):
direction = -1
else:
raise ValueError('Edge does not exist!')
data = self.get_edge_data(*[u, v][::direction])
return data, direction
def kwargs_to_matrix(**kwargs):
"""
Turn a set of keyword arguments into a transformation matrix.
"""
if 'matrix' in kwargs:
# a matrix takes precedence over other options
matrix = np.asanyarray(kwargs['matrix'], dtype=np.float64)
elif 'quaternion' in kwargs:
matrix = transformations.quaternion_matrix(kwargs['quaternion'])
elif ('axis' in kwargs) and ('angle' in kwargs):
matrix = transformations.rotation_matrix(kwargs['angle'],
kwargs['axis'])
else:
matrix = np.eye(4)
if 'translation' in kwargs:
# translation can be used in conjunction with any of the methods of
# specifying transforms. In the case a matrix and translation are passed,
# we add the translations together rather than picking one.
matrix[0:3, 3] += kwargs['translation']
return matrix