changeset 1701:0a0392566a68

Partial (all security checks not yet implemented) implementation of $ method. No regressions in existing tests, but otherwise untested -- readResolve is needed.
author rfield
date Mon, 17 Dec 2012 01:59:19 -0800
parents 6acec3010e26
children 1f2fbcd0de7e
files src/share/classes/com/sun/tools/javac/code/Symtab.java src/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java src/share/classes/com/sun/tools/javac/util/Names.java
diffstat 3 files changed, 173 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/src/share/classes/com/sun/tools/javac/code/Symtab.java	Fri Dec 14 22:11:05 2012 +0000
+++ b/src/share/classes/com/sun/tools/javac/code/Symtab.java	Mon Dec 17 01:59:19 2012 -0800
@@ -126,6 +126,7 @@
     public final Type stringBuilderType;
     public final Type cloneableType;
     public final Type serializableType;
+    public final Type serializedLambdaType;
     public final Type methodHandleType;
     public final Type methodHandleLookupType;
     public final Type methodTypeType;
@@ -458,6 +459,7 @@
         cloneableType = enterClass("java.lang.Cloneable");
         throwableType = enterClass("java.lang.Throwable");
         serializableType = enterClass("java.io.Serializable");
+        serializedLambdaType = enterClass("java.lang.invoke.SerializedLambda");
         methodHandleType = enterClass("java.lang.invoke.MethodHandle");
         methodHandleLookupType = enterClass("java.lang.invoke.MethodHandles$Lookup");
         methodTypeType = enterClass("java.lang.invoke.MethodType");
--- a/src/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java	Fri Dec 14 22:11:05 2012 +0000
+++ b/src/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java	Mon Dec 17 01:59:19 2012 -0800
@@ -89,13 +89,55 @@
 
     /** current translation context (visitor argument) */
     private TranslationContext<?> context;
+    
+    /** info about the current class being processed */
+    private KlassInfo kInfo;
 
-    /** list of translated methods
-     **/
-    private ListBuffer<JCTree> translatedMethodList;
-    
     /** Flag for alternate metafactories indicating the lambda object is intended to be serializable */
     public static final int FLAG_SERIALIZABLE = 1 << 0;
+    
+    private class KlassInfo {
+
+        /**
+         * list of methods to append
+         */
+        private ListBuffer<JCTree> appendedMethodList;
+        
+        /**
+         * list of deserialization cases
+         */
+        private final Map<String, ListBuffer<JCStatement>> deserializeCases;
+        
+       /**
+         * deserialize method symbol
+         */
+        private final MethodSymbol deserMethodSym;
+        
+        /**
+         * deserialize method parameter symbol
+         */
+        private final VarSymbol deserParamSym;
+        
+        /**
+         * deserialize method captured arguments local variable symbol
+         */
+        private final VarSymbol deserCaptArgsSym;
+        
+        private KlassInfo(Symbol kSym) {
+            appendedMethodList = ListBuffer.lb();
+            deserializeCases = new HashMap<>();
+            long flags = PRIVATE | STATIC | SYNTHETIC;
+            MethodType type = new MethodType(List.of(syms.serializedLambdaType), syms.objectType,
+                    List.<Type>nil(), syms.methodClass);
+            deserMethodSym = makeSyntheticMethod(flags, names.deserialize, type, kSym);
+            deserParamSym = new VarSymbol(FINAL, names.fromString("lambda"), syms.serializedLambdaType, deserMethodSym);
+            deserCaptArgsSym = new VarSymbol(FINAL, names.fromString("captargs"), new Type.ArrayType(syms.objectType, syms.arrayClass), deserMethodSym);
+        }
+        
+        private void addMethod(JCTree decl) {
+            appendedMethodList = appendedMethodList.prepend(decl);
+        }
+    }
 
     // <editor-fold defaultstate="collapsed" desc="Instantiating">
     private static final Context.Key<LambdaToMethod> unlambdaKey =
@@ -172,18 +214,22 @@
             //analyze class
             analyzer.analyzeClass(tree);
         }
-        ListBuffer<JCTree> prevTranslated = translatedMethodList;
+        KlassInfo prevKlassInfo = kInfo;
         try {
-            translatedMethodList = ListBuffer.lb();
+            kInfo = new KlassInfo(tree.sym);
             super.visitClassDef(tree);
+            if (!kInfo.deserializeCases.isEmpty()) {
+                kInfo.addMethod(makeDeserializeMethod(tree.sym));
+            }
             //add all translated instance methods here
-            tree.defs = tree.defs.appendList(translatedMethodList.toList());
-            for (JCTree lambda : translatedMethodList) {
+            List<JCTree> newMethods = kInfo.appendedMethodList.toList();
+            tree.defs = tree.defs.appendList(newMethods);
+            for (JCTree lambda : newMethods) {
                 tree.sym.members().enter(((JCMethodDecl)lambda).sym);
             }
             result = tree;
         } finally {
-            translatedMethodList = prevTranslated;
+            kInfo = prevKlassInfo;
         }
     }
 
@@ -221,7 +267,7 @@
         lambdaDecl.body = translate(makeLambdaBody(tree, lambdaDecl));
 
         //Add the method to the list of methods to be added to this class.
-        translatedMethodList = translatedMethodList.prepend(lambdaDecl);
+        kInfo.addMethod(lambdaDecl);
 
         //now that we have generated a method for the lambda expression,
         //we can translate the lambda into a method reference pointing to the newly
@@ -453,6 +499,105 @@
         }
         return trans_block;
     }
+    
+    private JCMethodDecl makeDeserializeMethod(Symbol kSym) {
+        ListBuffer<JCCase> cases = ListBuffer.lb();
+        ListBuffer<JCBreak> breaks = ListBuffer.lb();
+        for (Map.Entry<String, ListBuffer<JCStatement>> entry : kInfo.deserializeCases.entrySet()) {
+            JCBreak br = make.Break(null);
+            breaks.add(br);
+            List<JCStatement> stmts = entry.getValue().append(br).toList();
+            cases.add(make.Case(make.Literal(entry.getKey()), stmts));
+        }
+        JCSwitch sw = make.Switch(deserGetter("getImplMethodName", syms.stringType), cases.toList());
+        for (JCBreak br : breaks) {
+            br.target = sw;
+        }
+        JCBlock body = make.Block(0L, List.<JCStatement>of(
+                make.VarDef(kInfo.deserCaptArgsSym, deserGetter("getCapturedArgs", kInfo.deserCaptArgsSym.type)),
+                sw,
+                make.Throw(makeNewClass(
+                    syms.illegalArgumentExceptionType, 
+                    List.<JCExpression>of(make.Literal("Invalid lambda deserialization"))))));
+        JCMethodDecl deser = make.MethodDef(make.Modifiers(kInfo.deserMethodSym.flags()),
+                        names.deserialize,
+                        make.QualIdent(kInfo.deserMethodSym.getReturnType().tsym),
+                        List.<JCTypeParameter>nil(),
+                        List.of(make.VarDef(kInfo.deserParamSym, null)),
+                        List.<JCExpression>nil(),
+                        body,
+                        null);
+        deser.sym = kInfo.deserMethodSym;
+        deser.type = kInfo.deserMethodSym.type;
+        return deser;
+    }
+
+    /** Make an attributed class instance creation expression.
+     *  @param ctype    The class type.
+     *  @param args     The constructor arguments.
+     */
+    JCNewClass makeNewClass(Type ctype, List<JCExpression> args) {
+        JCNewClass tree = make.NewClass(null,
+            null, make.QualIdent(ctype.tsym), args, null);
+        tree.constructor = rs.resolveConstructor(
+            null, attrEnv, ctype, TreeInfo.types(args), List.<Type>nil());
+        tree.type = ctype;
+        return tree;
+    }
+
+    private void addDeserializationCase(int refKind, Symbol refSym, Type targetType, MethodSymbol samSym, 
+            DiagnosticPosition pos, List<Object> staticArgs, MethodType indyType) {
+        JCBinary kindTest = make.Binary(JCTree.Tag.EQ, deserGetter("getImplMethodKind", syms.intType), make.Literal(refKind));
+        kindTest.operator = rs.resolveBinaryOperator(null, JCTree.Tag.EQ, attrEnv, syms.intType, syms.intType);
+        kindTest.setType(syms.booleanType);
+        String key = refSym.getQualifiedName().toString();
+        ListBuffer<JCExpression> serArgs = ListBuffer.lb();
+        int i = 0;
+        for (Type t : indyType.getParameterTypes()) {
+            serArgs.add(make.TypeCast(t, make.Indexed(kInfo.deserCaptArgsSym, make.Literal(i))));
+            ++i;
+        }
+        JCStatement stmt = make.If(
+                deserTest(deserTest( //deserTest(deserTest(deserTest(
+                    kindTest, 
+                    "getFunctionalInterfaceClass", types.erasure(targetType).toString()),
+                    "getFunctionalInterfaceMethodName", samSym.getSimpleName().toString()),
+                make.Exec(makeIndyCall(
+                    pos, 
+                    syms.lambdaMetafactory, 
+                    names.altMetaFactory, 
+                    staticArgs, indyType, serArgs.toList())),
+                null);
+        ListBuffer<JCStatement> stmts = kInfo.deserializeCases.get(key);
+        if (stmts == null) {
+            stmts = ListBuffer.lb();
+            kInfo.deserializeCases.put(key, stmts);
+        }
+        stmts.append(stmt);
+    }
+    
+    private JCExpression deserTest(JCExpression prev, String func, String lit) {
+        MethodType eqmt = new MethodType(List.of(syms.objectType), syms.booleanType, List.<Type>nil(), syms.methodClass);
+        Symbol eqsym = rs.resolveQualifiedMethod(null, attrEnv, syms.objectType, names.equals, List.of(syms.objectType), List.<Type>nil());
+        JCMethodInvocation eqtest = make.Apply(
+                List.<JCExpression>nil(), 
+                make.Select(deserGetter(func, syms.stringType), eqsym).setType(eqmt),
+                List.<JCExpression>of(make.Literal(lit)));
+        eqtest.setType(syms.booleanType);
+        JCBinary compound = make.Binary(JCTree.Tag.AND, prev, eqtest);
+        compound.operator = rs.resolveBinaryOperator(null, JCTree.Tag.AND, attrEnv, syms.booleanType, syms.booleanType);
+        compound.setType(syms.booleanType);
+        return compound;
+    }
+    
+    private JCExpression deserGetter(String func, Type type) {
+        MethodType getmt = new MethodType(List.<Type>nil(), type, List.<Type>nil(), syms.methodClass);
+        Symbol getsym = rs.resolveQualifiedMethod(null, attrEnv, syms.serializedLambdaType, names.fromString(func), List.<Type>nil(), List.<Type>nil());
+        return make.Apply(
+                    List.<JCExpression>nil(), 
+                    make.Select(make.Ident(kInfo.deserParamSym).setType(syms.serializedLambdaType), getsym).setType(getmt), 
+                    List.<JCExpression>nil()).setType(type);
+    }
 
     /**
      * Create new synthetic method with given flags, name, type, owner
@@ -685,8 +830,7 @@
      * * super is used
      */
     private void bridgeMemberReference(JCMemberReference tree, ReferenceTranslationContext localContext) {
-        JCMethodDecl bridgeDecl = (new MemberReferenceBridger(tree, localContext).bridge());
-        translatedMethodList = translatedMethodList.prepend(bridgeDecl);
+        kInfo.addMethod(new MemberReferenceBridger(tree, localContext).bridge());
     }
 
     /**
@@ -695,8 +839,9 @@
     private JCExpression makeMetaFactoryIndyCall(JCExpression tree, FunctionalInfo fInfo, int refKind, Symbol refSym, List<JCExpression> indy_args) {
         //determine the static bsm args
         Type mtype = makeFunctionalDescriptorType(fInfo.targetType, true);
+        MethodSymbol samSym = (MethodSymbol) types.findDescriptorSymbol(fInfo.targetType.tsym);
         List<Object> staticArgs = List.<Object>of(
-                new Pool.MethodHandle(ClassFile.REF_invokeInterface, types.findDescriptorSymbol(fInfo.targetType.tsym)),
+                new Pool.MethodHandle(ClassFile.REF_invokeInterface, samSym),
                 new Pool.MethodHandle(refKind, refSym),
                 new MethodType(mtype.getParameterTypes(),
                         mtype.getReturnType(),
@@ -706,14 +851,6 @@
         boolean altMetafactory =
                 fInfo.isSerializable || fInfo.targets.tail.nonEmpty();
         
-        if (altMetafactory) {
-            int flags = fInfo.isSerializable ? FLAG_SERIALIZABLE : 0;
-            staticArgs = staticArgs.append(flags);
-            for (Symbol t : fInfo.targets.tail) {
-                staticArgs = staticArgs.append(t);
-            }
-        }
-
         //computed indy arg types
         ListBuffer<Type> indy_args_types = ListBuffer.lb();
         for (JCExpression arg : indy_args) {
@@ -729,9 +866,19 @@
         Name metafactoryName = altMetafactory ?
                 names.altMetaFactory : names.metaFactory;
 
+        if (altMetafactory) {
+            int flags = fInfo.isSerializable ? FLAG_SERIALIZABLE : 0;
+            staticArgs = staticArgs.append(flags);
+            for (Symbol t : fInfo.targets.tail) {
+                staticArgs = staticArgs.append(t);
+            }
+            addDeserializationCase(refKind, refSym, fInfo.targetType, samSym, 
+                    tree, staticArgs, indyType);
+        }
+
         return makeIndyCall(tree, syms.lambdaMetafactory, metafactoryName, staticArgs, indyType, indy_args);
     }
-
+    
     /**
      * Generate an indy method call with given name, type and static bootstrap
      * arguments types
--- a/src/share/classes/com/sun/tools/javac/util/Names.java	Fri Dec 14 22:11:05 2012 +0000
+++ b/src/share/classes/com/sun/tools/javac/util/Names.java	Mon Dec 17 01:59:19 2012 -0800
@@ -73,6 +73,7 @@
     public final Name clone;
     public final Name close;
     public final Name compareTo;
+    public final Name deserialize;
     public final Name desiredAssertionStatus;
     public final Name equals;
     public final Name error;
@@ -207,6 +208,7 @@
         clone = fromString("clone");
         close = fromString("close");
         compareTo = fromString("compareTo");
+        deserialize = fromString("$deserialize$");
         desiredAssertionStatus = fromString("desiredAssertionStatus");
         equals = fromString("equals");
         error = fromString("<error>");