changeset 3077:169287a814ae

insert loop memory merging
author Thomas Wuerthinger <thomas@wuerthinger.net>
date Tue, 28 Jun 2011 12:20:31 +0200
parents 36b6bb73a5cf
children b15101f82e2d
files graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/phases/MemoryPhase.java graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/schedule/Block.java
diffstat 2 files changed, 146 insertions(+), 98 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/phases/MemoryPhase.java	Mon Jun 27 17:38:43 2011 +0200
+++ b/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/phases/MemoryPhase.java	Tue Jun 28 12:20:31 2011 +0200
@@ -37,75 +37,107 @@
     public static class MemoryMap {
 
         private final Block block;
-        private HashMap<Object, Node> locationToWrite;
-        private HashMap<Object, List<Node>> locationToReads;
-        private Node lastReadWriteMerge;
-        private Node lastWriteMerge;
-        private int mergeOperations;
+        private HashMap<Object, Node> locationForWrite;
+        private HashMap<Object, Node> locationForRead;
+        private Node mergeForWrite;
+        private Node mergeForRead;
+        private int mergeOperationCount;
+        private Node loopCheckPoint;
+        private MemoryMap loopEntryMap;
 
         public MemoryMap(Block b, MemoryMap memoryMap) {
             this(b);
-            for (Entry<Object, Node> e : memoryMap.locationToWrite.entrySet()) {
-                locationToWrite.put(e.getKey(), e.getValue());
+            if (b.firstNode() instanceof LoopBegin) {
+                loopCheckPoint = new WriteMemoryCheckpointNode(b.firstNode().graph());
+                mergeForWrite = loopCheckPoint;
+                mergeForRead = loopCheckPoint;
+                this.loopEntryMap = memoryMap;
+            } else {
+                memoryMap.locationForWrite.putAll(memoryMap.locationForWrite);
+                memoryMap.locationForRead.putAll(memoryMap.locationForRead);
+                mergeForWrite = memoryMap.mergeForWrite;
+                mergeForRead = memoryMap.mergeForRead;
             }
-//            for (Entry<Object, List<Node>> e : memoryMap.locationToReads.entrySet()) {
-//                locationToReads.put(e.getKey(), new ArrayList<Node>(e.getValue()));
-//            }
-            lastReadWriteMerge = memoryMap.lastReadWriteMerge;
-            lastWriteMerge = memoryMap.lastWriteMerge;
         }
 
         public MemoryMap(Block b) {
-            block = b;
-            locationToWrite = new HashMap<Object, Node>();
-            locationToReads = new HashMap<Object, List<Node>>();
             if (GraalOptions.TraceMemoryMaps) {
                 TTY.println("Creating new memory map for block B" + b.blockID());
             }
+
+            block = b;
+            locationForWrite = new HashMap<Object, Node>();
+            locationForRead = new HashMap<Object, Node>();
             StartNode startNode = b.firstNode().graph().start();
             if (b.firstNode() == startNode) {
                 WriteMemoryCheckpointNode checkpoint = new WriteMemoryCheckpointNode(startNode.graph());
                 checkpoint.setNext((FixedNode) startNode.start());
                 startNode.setStart(checkpoint);
-                lastReadWriteMerge = checkpoint;
-                lastWriteMerge = checkpoint;
+                mergeForWrite = checkpoint;
+                mergeForRead = checkpoint;
             }
         }
 
-        public void mergeWith(MemoryMap memoryMap) {
+        public Node getLoopCheckPoint() {
+            return loopCheckPoint;
+        }
+
+        public MemoryMap getLoopEntryMap() {
+            return loopEntryMap;
+        }
+
+        public void resetMergeOperationCount() {
+            mergeOperationCount = 0;
+        }
+
+        public void mergeWith(MemoryMap memoryMap, Block block) {
             if (GraalOptions.TraceMemoryMaps) {
                 TTY.println("Merging with memory map of block B" + memoryMap.block.blockID());
             }
+            mergeForWrite = mergeNodes(mergeForWrite, memoryMap.mergeForWrite, locationForWrite, memoryMap.locationForWrite, block);
+            mergeForRead = mergeNodes(mergeForRead, memoryMap.mergeForRead, locationForRead, memoryMap.locationForRead, block);
+            mergeOperationCount++;
+        }
 
-            lastReadWriteMerge = mergeNodes(lastReadWriteMerge, memoryMap.lastReadWriteMerge);
-            lastWriteMerge = mergeNodes(lastWriteMerge, memoryMap.lastWriteMerge);
+        private Node mergeNodes(Node mergeLeft, Node mergeRight, HashMap<Object, Node> locationLeft, HashMap<Object, Node> locationRight, Block block) {
+            if (GraalOptions.TraceMemoryMaps) {
+                TTY.println("Merging main merge nodes: " + mergeLeft.id() + " and " + mergeRight.id());
+            }
 
-            List<Object> toRemove = new ArrayList<Object>();
-            for (Entry<Object, Node> e : locationToWrite.entrySet()) {
-                if (memoryMap.locationToWrite.containsKey(e.getKey())) {
+            for (Entry<Object, Node> e : locationRight.entrySet()) {
+                if (!locationLeft.containsKey(e.getKey())) {
+                    // Only available in right map => create correct node for left map.
+                    if (GraalOptions.TraceMemoryMaps) {
+                        TTY.println("Only right map " + e.getKey());
+                    }
+                    Node leftNode = mergeLeft;
+                    if (leftNode instanceof Phi && ((Phi) leftNode).merge() == block.firstNode()) {
+                        leftNode = leftNode.copyWithEdges();
+                    }
+                    locationLeft.put(e.getKey(), leftNode);
+                }
+            }
+
+            for (Entry<Object, Node> e : locationLeft.entrySet()) {
+                if (locationRight.containsKey(e.getKey())) {
+                    // Available in both maps.
                     if (GraalOptions.TraceMemoryMaps) {
                         TTY.println("Merging entries for location " + e.getKey());
                     }
-                    locationToWrite.put(e.getKey(), mergeNodes(e.getValue(), memoryMap.locationToWrite.get(e.getKey())));
+                    locationLeft.put(e.getKey(), mergeNodes(e.getValue(), locationRight.get(e.getKey()), block));
                 } else {
-                    toRemove.add(e.getKey());
+                    // Only available in left map.
+                    if (GraalOptions.TraceMemoryMaps) {
+                        TTY.println("Only available in left map " + e.getKey());
+                    }
+                    locationLeft.put(e.getKey(), mergeNodes(e.getValue(), mergeRight, block));
                 }
             }
 
-            for (Object o : toRemove) {
-                locationToWrite.remove(o);
-            }
-
-//            for (Entry<Object, List<Node>> e : memoryMap.locationToReads.entrySet()) {
-//                for (Node n : e.getValue()) {
-//                    addRead(n, e.getKey());
-//                }
-//            }
-
-            mergeOperations++;
+            return mergeNodes(mergeLeft, mergeRight, block);
         }
 
-        private Node mergeNodes(Node original, Node newValue) {
+        private Node mergeNodes(Node original, Node newValue, Block block) {
             if (original == newValue) {
                 // Nothing to merge.
                 if (GraalOptions.TraceMemoryMaps) {
@@ -115,15 +147,24 @@
             }
             Merge m = (Merge) block.firstNode();
             if (original instanceof Phi && ((Phi) original).merge() == m) {
-                ((Phi) original).addInput(newValue);
+                Phi phi = (Phi) original;
+                phi.addInput(newValue);
+                if (GraalOptions.TraceMemoryMaps) {
+                    TTY.println("Add new input to phi " + original.id());
+                }
+                assert phi.valueCount() <= phi.merge().endCount();
                 return original;
             } else {
                 Phi phi = new Phi(CiKind.Illegal, m, m.graph());
                 phi.makeDead(); // Phi does not produce a value, it is only a memory phi.
-                for (int i = 0; i < mergeOperations + 1; ++i) {
+                for (int i = 0; i < mergeOperationCount + 1; ++i) {
                     phi.addInput(original);
                 }
                 phi.addInput(newValue);
+                if (GraalOptions.TraceMemoryMaps) {
+                    TTY.println("Creating new phi " + phi.id());
+                }
+                assert phi.valueCount() <= phi.merge().endCount() + ((phi.merge() instanceof LoopBegin) ? 1 : 0) : phi.merge() + "/" + phi.valueCount() + "/" + phi.merge().endCount() + "/" + mergeOperationCount;
                 return phi;
             }
         }
@@ -134,33 +175,20 @@
             }
 
             // Merge in all writes.
-            for (Entry<Object, Node> writeEntry : locationToWrite.entrySet()) {
+            for (Entry<Object, Node> writeEntry : locationForWrite.entrySet()) {
                 memMerge.mergedNodes().add(writeEntry.getValue());
-
-                // Register the merge point as a read such that subsequent writes to this location will depend on it (but subsequent reads do not).
-//                addRead(memMerge, writeEntry.getKey());
             }
-            lastWriteMerge = memMerge;
+            locationForWrite.clear();
+            mergeForWrite = memMerge;
         }
 
         public void createReadWriteMemoryCheckpoint(AbstractMemoryCheckpointNode memMerge) {
             if (GraalOptions.TraceMemoryMaps) {
                 TTY.println("Creating readwrite memory checkpoint at node " + memMerge.id());
             }
-
-            // Merge in all writes.
-            for (Entry<Object, Node> writeEntry : locationToWrite.entrySet()) {
-                memMerge.mergedNodes().add(writeEntry.getValue());
-            }
-            locationToWrite.clear();
-
-            // Merge in all reads.
-//            for (Entry<Object, List<Node>> readEntry : locationToReads.entrySet()) {
-//                memMerge.mergedNodes().addAll(readEntry.getValue());
-//            }
-            locationToReads.clear();
-            lastWriteMerge = memMerge;
-            lastReadWriteMerge = memMerge;
+            createWriteMemoryMerge(memMerge);
+            locationForRead.clear();
+            mergeForRead = memMerge;
         }
 
         public void registerWrite(WriteNode node) {
@@ -169,53 +197,57 @@
                 TTY.println("Register write to " + location + " at node " + node.id());
             }
 
-            boolean connectionAdded = false;
-//            if (locationToReads.containsKey(location)) {
-//                for (Node prevRead : locationToReads.get(location)) {
-//                    node.inputs().variablePart().add(prevRead);
-//                    connectionAdded = true;
-//                }
-//            }
+            // Create dependency on previous write to same location.
+            node.inputs().variablePart().add(getLocationForWrite(node));
 
-            if (!connectionAdded) {
-                if (locationToWrite.containsKey(location)) {
-                    Node prevWrite = locationToWrite.get(location);
-                    node.inputs().variablePart().add(prevWrite);
-                    connectionAdded = true;
-                }
+            locationForWrite.put(location, node);
+            locationForRead.put(location, node);
+        }
+
+        public Node getLocationForWrite(WriteNode node) {
+            Object location = node.location().locationIdentity();
+            if (locationForWrite.containsKey(location)) {
+                Node prevWrite = locationForWrite.get(location);
+                return prevWrite;
+            } else {
+                return mergeForWrite;
             }
-
-            node.inputs().variablePart().add(lastWriteMerge);
-
-            locationToWrite.put(location, node);
-            locationToReads.remove(location);
         }
 
         public void registerRead(ReadNode node) {
-            Object location = node.location().locationIdentity();
             if (GraalOptions.TraceMemoryMaps) {
-                TTY.println("Register read to " + location + " at node " + node.id());
+                TTY.println("Register read to node " + node.id());
             }
 
-            boolean connectionAdded = false;
-            if (locationToWrite.containsKey(location)) {
-                Node prevWrite = locationToWrite.get(location);
-                node.inputs().variablePart().add(prevWrite);
-                connectionAdded = true;
-            }
-
-            if (!connectionAdded) {
-                node.inputs().variablePart().add(lastReadWriteMerge);
-            }
-
-            //addRead(node, location);
+            // Create dependency on previous node that creates the memory state for this location.
+            node.inputs().variablePart().add(getLocationForRead(node));
         }
 
-        private void addRead(Node node, Object location) {
-            if (!locationToReads.containsKey(location)) {
-                locationToReads.put(location, new ArrayList<Node>());
+        public Node getLocationForRead(ReadNode node) {
+            Object location = node.location().locationIdentity();
+            if (locationForRead.containsKey(location)) {
+                return locationForRead.get(location);
             }
-            locationToReads.get(location).add(node);
+            return mergeForRead;
+        }
+
+        public void replaceCheckPoint(Node loopCheckPoint) {
+            List<Node> usages = new ArrayList<Node>(loopCheckPoint.usages());
+            for (Node n : usages) {
+                replaceCheckPoint(loopCheckPoint, n);
+            }
+        }
+
+        private void replaceCheckPoint(Node loopCheckPoint, Node n) {
+            if (n instanceof ReadNode) {
+                n.inputs().replace(loopCheckPoint, getLocationForRead((ReadNode) n));
+            } else if (n instanceof WriteNode) {
+                n.inputs().replace(loopCheckPoint, getLocationForWrite((WriteNode) n));
+            } else if (n instanceof WriteMemoryCheckpointNode) {
+                n.inputs().replace(loopCheckPoint, mergeForWrite);
+            } else {
+                n.inputs().replace(loopCheckPoint, mergeForRead);
+            }
         }
     }
 
@@ -227,11 +259,11 @@
         List<Block> blocks = s.getBlocks();
         MemoryMap[] memoryMaps = new MemoryMap[blocks.size()];
         for (final Block b : blocks) {
-            process(b, memoryMaps);
+            process(b, memoryMaps, s.getNodeToBlock());
         }
     }
 
-    private void process(final Block b, MemoryMap[] memoryMaps) {
+    private void process(final Block b, MemoryMap[] memoryMaps, NodeMap<Block> nodeMap) {
         // Visit every block at most once.
         if (memoryMaps[b.blockID()] != null) {
             return;
@@ -239,7 +271,7 @@
 
         // Process predecessors before this block.
         for (Block pred : b.getPredecessors()) {
-            process(pred, memoryMaps);
+            process(pred, memoryMaps, nodeMap);
         }
 
         // Create initial memory map for the block.
@@ -249,7 +281,9 @@
         } else {
             map = new MemoryMap(b, memoryMaps[b.getPredecessors().get(0).blockID()]);
             for (int i = 1; i < b.getPredecessors().size(); ++i) {
-                map.mergeWith(memoryMaps[b.getPredecessors().get(i).blockID()]);
+                assert b.firstNode() instanceof Merge : b.firstNode();
+                Block block = b.getPredecessors().get(i);
+                map.mergeWith(memoryMaps[block.blockID()], b);
             }
         }
 
@@ -276,5 +310,15 @@
         }
 
         memoryMaps[b.blockID()] = map;
+        if (b.lastNode() instanceof LoopEnd) {
+            LoopEnd end = (LoopEnd) b.lastNode();
+            LoopBegin begin = end.loopBegin();
+            Block beginBlock = nodeMap.get(begin);
+            MemoryMap memoryMap = memoryMaps[beginBlock.blockID()];
+            memoryMap.getLoopEntryMap().resetMergeOperationCount();
+            memoryMap.getLoopEntryMap().mergeWith(map, beginBlock);
+            Node loopCheckPoint = memoryMap.getLoopCheckPoint();
+            memoryMap.getLoopEntryMap().replaceCheckPoint(loopCheckPoint);
+        }
     }
 }
--- a/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/schedule/Block.java	Mon Jun 27 17:38:43 2011 +0200
+++ b/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/schedule/Block.java	Tue Jun 28 12:20:31 2011 +0200
@@ -171,6 +171,10 @@
         return "B" + blockID;
     }
 
+    public boolean isLoopHeader() {
+        return firstNode instanceof LoopBegin;
+    }
+
     public Block dominator() {
         return dominator;
     }