changeset 4168:9465df6b2a0b

Modified SkinningMesh implementation for efficiency improvements and fixed some bugs based on wrong assumptions Changed to simpler data structures for faster update. Added an event handler for joints to only update points if the joints' transforms changed. Fixed bug that assumed that the mesh and the parent joint have the same transform. Fixed bug that assumed that the joints transforms are defined w.r.t. the scene.
author Alex X. Lee
date Wed, 03 Jul 2013 12:01:02 -0700
parents b43eb1f37da3
children 15f45099ae7a
files apps/experiments/3DViewer/src/main/java/com/javafx/experiments/importers/maya/Loader.java apps/experiments/3DViewer/src/main/java/com/javafx/experiments/shape3d/SkinningMesh.java
diffstat 2 files changed, 135 insertions(+), 71 deletions(-) [+]
line wrap: on
line diff
--- a/apps/experiments/3DViewer/src/main/java/com/javafx/experiments/importers/maya/Loader.java	Wed Jul 03 11:32:29 2013 -0700
+++ b/apps/experiments/3DViewer/src/main/java/com/javafx/experiments/importers/maya/Loader.java	Wed Jul 03 12:01:02 2013 -0700
@@ -47,11 +47,13 @@
 import com.javafx.experiments.shape3d.PolygonMeshView;
 import com.javafx.experiments.shape3d.SkinningMesh;
 import com.sun.javafx.geom.Vec3f;
+import java.util.HashSet;
+import java.util.Set;
 import javafx.animation.AnimationTimer;
 import javafx.beans.value.ChangeListener;
 import javafx.beans.value.ObservableValue;
+import javafx.scene.Parent;
 import javafx.scene.Scene;
-import javafx.scene.transform.Transform;
 
 /** Loader */
 class Loader {
@@ -216,11 +218,18 @@
         MArray ma = (MArray) n.getAttr("ma");
 
         List<Joint> jointNodes = new ArrayList<Joint>();
+        Set<Parent> jointForest = new HashSet<Parent>(); // root's children that have joints in their trees
         for (int i = 0; i < ma.getSize(); i++) {
             // hack... ?
             MNode c = n.getIncomingConnectionToType("ma[" + i + "]", "joint");
             Joint jn = (Joint) resolveNode(c);
             jointNodes.add(jn);
+            
+            Parent rootChild = jn; // root's child, which is an ancestor of joint jn
+            while (rootChild.getParent() != null) {
+                rootChild = rootChild.getParent();
+            }
+            jointForest.add(rootChild);
         }
         
         MNode outputMeshMNode = resolveOutputMesh(n);
@@ -231,11 +240,12 @@
         // We must be able to find the original converter in the meshConverters map
         MNode origOrigMesh = resolveOrigInputMesh(n);
         //               println("ORIG ORIG={origOrigMesh}");
-
+        
         // TODO: What is with this? origMesh
         resolveNode(origOrigMesh).setVisible(false);
 
         MArray bindPreMatrixArray = (MArray) n.getAttr("pm");
+        Affine bindGlobalMatrix = convertMatrix((MFloatArray) n.getAttr("gm"));
 
         Affine[] bindPreMatrix = new Affine[bindPreMatrixArray.getSize()];
         for (int i = 0; i < bindPreMatrixArray.getSize(); i++) {
@@ -243,11 +253,11 @@
         }
 
         MArray mayaWeights = (MArray) n.getAttr("wl");
-        float[][] weights = new float [mayaWeights.getSize()][jointNodes.size()];
+        float[][] weights = new float [jointNodes.size()][mayaWeights.getSize()];
         for (int i=0; i<mayaWeights.getSize(); i++) {
             MFloatArray curWeights = (MFloatArray) mayaWeights.getData(i).getData("w");
             for (int j = 0; j < jointNodes.size(); j++) {
-                weights[i][j] = j < curWeights.getSize() ? curWeights.get(j) : 0;
+                weights[j][i] = j < curWeights.getSize() ? curWeights.get(j) : 0;
             }
         }
         
@@ -259,8 +269,7 @@
             PolygonMeshView targetMayaMeshView = (PolygonMeshView) targetMayaMeshNode;
             
             PolygonMesh sourceMesh = (PolygonMesh) sourceMayaMeshView.getMesh();
-            Transform meshTransform = targetMayaMeshView.getLocalToSceneTransform();
-            SkinningMesh targetMesh = new SkinningMesh(sourceMesh, meshTransform, weights, bindPreMatrix, jointNodes);
+            SkinningMesh targetMesh = new SkinningMesh(sourceMesh, weights, bindPreMatrix, bindGlobalMatrix, jointNodes, new ArrayList(jointForest));
             targetMayaMeshView.setMesh(targetMesh);
 
             final SkinningMeshTimer skinningMeshTimer = new SkinningMeshTimer(targetMesh);
@@ -462,10 +471,8 @@
             mv.setMesh((PolygonMesh) mesh);
 //            mv.setCullFace(CullFace.NONE); //TODO
             loaded.put(n, mv);
-            if (((PolygonMesh)mesh).getPoints().size() > 0) {
-                if (node != null) {
-                    ((Group) node).getChildren().add(mv);
-                }
+            if (node != null) {
+                ((Group) node).getChildren().add(mv);
             }
         } else {
             MeshView mv = new MeshView();
@@ -478,10 +485,8 @@
             mv.setMesh((TriangleMesh) mesh);
 
             loaded.put(n, mv);
-            if (((TriangleMesh)mesh).getPoints().size() > 0) {
-                if (node != null) {
-                    ((Group) node).getChildren().add(mv);
-                }
+            if (node != null) {
+                ((Group) node).getChildren().add(mv);
             }
         }
     }
--- a/apps/experiments/3DViewer/src/main/java/com/javafx/experiments/shape3d/SkinningMesh.java	Wed Jul 03 11:32:29 2013 -0700
+++ b/apps/experiments/3DViewer/src/main/java/com/javafx/experiments/shape3d/SkinningMesh.java	Wed Jul 03 12:01:02 2013 -0700
@@ -4,105 +4,164 @@
 import com.javafx.experiments.importers.maya.Joint;
 import java.util.ArrayList;
 import java.util.List;
+import javafx.beans.InvalidationListener;
+import javafx.beans.Observable;
 import javafx.collections.ObservableFloatArray;
 import javafx.geometry.Point3D;
+import javafx.scene.Node;
+import javafx.scene.Parent;
 import javafx.scene.shape.TriangleMesh;
 import javafx.scene.transform.Affine;
+import javafx.scene.transform.MatrixType;
 import javafx.scene.transform.NonInvertibleTransformException;
 import javafx.scene.transform.Transform;
 
 /**
- * PolygonMesh that updates itself when the joint transforms are updated.
- * Assumes that the dimensions of weights is nJoints x nPoints
+ * PolygonMesh that knows how to update itself given changes in joint transforms.
+ * The mesh can be updated with an AnimationTimer.
  */
 public class SkinningMesh extends PolygonMesh {
-    private final Point3D[][] relativePoints; // nPoints x nJoints
-    private final float[][] weights; // nPoints x nJoints
+    private final float[][] relativePoints; // nJoints x nPoints*3
+    private final float[][] weights; // nJoints x nPoints
     private final List<Integer>[] weightIndices;
-    private final List<Joint> joints;
+    private final List<JointIndex> jointIndexForest;
+    private boolean jointsTransformDirty = true;
+    private Transform bindGlobalInverseTransform;
+    private final Transform[] jointToRootTransforms; // the root refers to the group containing all the mesh skinning nodes (i.e. the parent of jointForest)
     private final int nPoints;
     private final int nJoints;
-    private Transform meshInverseTransform;
-
-    public SkinningMesh(PolygonMesh mesh, Transform meshTransform, float[][] weights, Affine[] bindTransforms, List<Joint> joints) {
+    
+    
+    /**
+     * SkinningMesh constructor
+     * 
+     * @param mesh The binding mesh
+     * @param weights A two-dimensional array (nJoints x nPoints) of the influence weights used for skinning
+     * @param bindTransforms The binding transforms for every joint
+     * @param bindGlobalTransform The global binding transform; all binding transforms are defined with respect to this frame
+     * @param joints A list of joints used for skinning; the order of these are associated with the respective attributes of @weights and @bindTransforms
+     * @param jointForest A list of the top level trees that contain the joints; all the @joints should be contained in this forest
+     */
+    public SkinningMesh(PolygonMesh mesh, float[][] weights, Affine[] bindTransforms, Affine bindGlobalTransform, List<Joint> joints, List<Parent> jointForest) {
         this.getPoints().addAll(mesh.getPoints());
         this.getTexCoords().addAll(mesh.getTexCoords());
         this.faces = mesh.faces;
         
         this.weights = weights;
-        this.joints = joints;
-        
+
         nJoints = joints.size();
         nPoints = getPoints().size()/ TriangleMesh.NUM_COMPONENTS_PER_POINT;
         
+        // Create the jointIndexForest forest. Its structure is the same as 
+        // jointForest, except that this forest have indices information and 
+        // some branches are pruned if they don't contain joints.
+        jointIndexForest = new ArrayList<JointIndex>(jointForest.size());
+        for (Parent jointRoot : jointForest) {
+            jointIndexForest.add(new JointIndex(jointRoot, joints.indexOf(jointRoot), joints));
+        }
+        
         try {
-            meshInverseTransform = meshTransform.createInverse();
+            bindGlobalInverseTransform = bindGlobalTransform.createInverse();
         } catch (NonInvertibleTransformException ex) {
             System.err.println("Caught NonInvertibleTransformException: " + ex.getMessage());
         }
         
-        weightIndices = new List[nPoints];
-        for (int i = 0; i < nPoints; i++) {
-            weightIndices[i] = new ArrayList<Integer>();
-            for (int j = 0; j < nJoints; j++) {
-                if (weights[i][j] != 0.0f) {
-                    weightIndices[i].add(new Integer(j));
+        jointToRootTransforms = new Transform[nJoints];
+        
+        // For optimization purposes, store the indices of the non-zero weights
+        weightIndices = new List[nJoints];
+        for (int j = 0; j < nJoints; j++) {
+            weightIndices[j] = new ArrayList<Integer>();
+            for (int i = 0; i < nPoints; i++) {
+                if (weights[j][i] != 0.0f) {
+                    weightIndices[j].add(new Integer(i));
                 }
             }
         }
         
+        // Compute the points of the binding mesh relative to the binding transforms
         ObservableFloatArray points = getPoints();
-        relativePoints = new Point3D[nPoints][nJoints];
+        relativePoints = new float[nJoints][nPoints*3];
         for (int j = 0; j < nJoints; j++) {
-            Transform postBindTransform = bindTransforms[j].createConcatenation(meshTransform);
+            Transform postBindTransform = bindTransforms[j].createConcatenation(bindGlobalTransform);
             for (int i = 0; i < nPoints; i++) {
-                relativePoints[i][j] = postBindTransform.transform(points.get(3*i), points.get(3*i+1), points.get(3*i+2));
+                Point3D relativePoint = postBindTransform.transform(points.get(3*i), points.get(3*i+1), points.get(3*i+2));
+                relativePoints[j][3*i]   = (float) relativePoint.getX();
+                relativePoints[j][3*i+1] = (float) relativePoint.getY();
+                relativePoints[j][3*i+2] = (float) relativePoint.getZ();
+            }
+        }
+        
+        // Add a listener to all the joints so that we can track when any of their transforms have changed
+        for (Joint joint : joints) {
+            joint.localToParentTransformProperty().addListener(new InvalidationListener() {
+                @Override
+                public void invalidated(Observable observable) {
+                    jointsTransformDirty = true;
+                }
+            });
+        }
+    }
+    
+    private class JointIndex {
+        public Node node;
+        public int index;
+        public List<JointIndex> children = new ArrayList<JointIndex>();
+        public JointIndex parent = null;
+        public Transform localToGlobalTransform;
+        public JointIndex(Node n, int ind, List<Joint> orderedJoints) {
+            node = n;
+            index = ind;
+            if (node instanceof Parent) {
+                for (Node childJoint : ((Parent)node).getChildrenUnmodifiable()) {
+                    if (childJoint instanceof Parent) { // is childJoint a joint or a node with children?
+                        int childInd = orderedJoints.indexOf(childJoint);
+                        JointIndex childJointIndex = new JointIndex(childJoint, childInd, orderedJoints);
+                        childJointIndex.parent = this;
+                        children.add(childJointIndex);
+                    }
+                }
             }
         }
     }
     
+    // Updates the jointToRootTransforms by doing a a depth-first search of the jointIndexForest
+    private void updateLocalToGlobalTransforms(List<JointIndex> jointIndexForest) {
+        for (JointIndex jointIndex : jointIndexForest) {
+            if (jointIndex.parent == null) {
+                jointIndex.localToGlobalTransform = bindGlobalInverseTransform.createConcatenation(jointIndex.node.getLocalToParentTransform());
+            } else {
+                jointIndex.localToGlobalTransform = jointIndex.parent.localToGlobalTransform.createConcatenation(jointIndex.node.getLocalToParentTransform());
+            }
+            if (jointIndex.index != -1) {
+                jointToRootTransforms[jointIndex.index] = jointIndex.localToGlobalTransform;
+            }
+            updateLocalToGlobalTransforms(jointIndex.children);
+        }
+    }
+    
+    // Updates its points only if any of the joints' transforms have changed
     public void update() {
-        Transform[] preJointTransforms = new Transform[nJoints];
+        if (!jointsTransformDirty) {
+            return;
+        }
+        
+        updateLocalToGlobalTransforms(jointIndexForest);
+        
+        float[] points = new float[nPoints*3];
+        double[] t = new double[12];
+        float[] relativePoint;
         for (int j = 0; j < nJoints; j++) {
-            preJointTransforms[j] = meshInverseTransform.createConcatenation(joints.get(j).getLocalToSceneTransform());
-        }
-
-        float[] points = new float [getPoints().size()];
-        
-        for (int i = 0; i < nPoints; i++) {
-            if (!weightIndices[i].isEmpty()) {
-                Point3D weightedPoint = new Point3D(0,0,0);
-                for (Integer j : weightIndices[i]) {
-                    Point3D absolutePoint = preJointTransforms[j].transform(relativePoints[i][j]);
-                    weightedPoint = weightedPoint.add(absolutePoint.multiply(weights[i][j]));
-                }
-                points[3*i] = (float) weightedPoint.getX();
-                points[3*i+1] = (float) weightedPoint.getY();
-                points[3*i+2] = (float) weightedPoint.getZ();
+            jointToRootTransforms[j].toArray(MatrixType.MT_3D_3x4, t);
+            relativePoint = relativePoints[j];
+            for (Integer i : weightIndices[j]) {
+                points[3*i]   += weights[j][i] * (t[0] * relativePoint[3*i] + t[1] * relativePoint[3*i+1] + t[2] * relativePoint[3*i+2] + t[3]);
+                points[3*i+1] += weights[j][i] * (t[4] * relativePoint[3*i] + t[5] * relativePoint[3*i+1] + t[6] * relativePoint[3*i+2] + t[7]);
+                points[3*i+2] += weights[j][i] * (t[8] * relativePoint[3*i] + t[9] * relativePoint[3*i+1] + t[10] * relativePoint[3*i+2] + t[11]);
             }
         }
         getPoints().set(0, points, 0, points.length);
         
-//        // The following loop is equivalent to the one above, the difference
-//        // being that this one is more straight-forward (it checks and skips
-//        // the zero weights).
-//        for (int i = 0; i < nPoints; i++) {
-//            Point3D weightedPoint = new Point3D(0,0,0);
-//            boolean isVertexInfluenced = false;
-//            for (int j = 0; j < nJoints; j++) {
-//                if (weights[i][j] != 0.0f) {
-//                    isVertexInfluenced = true;
-//                    Point3D absolutePoint = preJointTransforms[j].transform(relativePoints[i][j]);
-//                    weightedPoint = weightedPoint.add(absolutePoint.multiply(weights[i][j]));
-//                }
-//            }
-//            if (isVertexInfluenced) {
-//                points[3*i] = (float) weightedPoint.getX();
-//                points[3*i+1] = (float) weightedPoint.getY();
-//                points[3*i+2] = (float) weightedPoint.getZ();
-//            }
-//        }
-//        getPoints().set(0, points, 0, points.length);
-        
+        jointsTransformDirty = false;
     }
 }