// Copyright 2017 The Bazel Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.devtools.build.lib.skyframe.serialization.autocodec; import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMultimap; import com.google.devtools.build.lib.skyframe.serialization.autocodec.SerializationCodeGenerator.Context; import com.google.devtools.build.lib.skyframe.serialization.autocodec.SerializationCodeGenerator.Marshaller; import com.google.devtools.build.lib.skyframe.serialization.autocodec.SerializationCodeGenerator.PrimitiveValueSerializationCodeGenerator; import com.google.devtools.build.lib.skyframe.serialization.strings.StringCodecs; import com.squareup.javapoet.TypeName; import java.util.Map; import javax.annotation.processing.ProcessingEnvironment; import javax.lang.model.element.TypeElement; import javax.lang.model.type.ArrayType; import javax.lang.model.type.DeclaredType; import javax.lang.model.type.PrimitiveType; import javax.lang.model.type.TypeKind; import javax.lang.model.type.TypeMirror; import javax.lang.model.type.TypeVariable; import javax.lang.model.type.WildcardType; /** Class containing all {@link Marshaller} instances. */ class Marshallers { private final ProcessingEnvironment env; Marshallers(ProcessingEnvironment env) { this.env = env; } void writeSerializationCode(Context context) { SerializationCodeGenerator generator = getMatchingCodeGenerator(context.type); boolean needsNullHandling = context.canBeNull() && generator != contextMarshaller; if (needsNullHandling) { context.builder.beginControlFlow("if ($L != null)", context.name); context.builder.addStatement("codedOut.writeBoolNoTag(true)"); } generator.addSerializationCode(context); if (needsNullHandling) { context.builder.nextControlFlow("else"); context.builder.addStatement("codedOut.writeBoolNoTag(false)"); context.builder.endControlFlow(); } } void writeDeserializationCode(Context context) { SerializationCodeGenerator generator = getMatchingCodeGenerator(context.type); boolean needsNullHandling = context.canBeNull() && generator != contextMarshaller; // If we have a generic or a wildcard parameter we need to erase it when we write the code out. TypeName contextTypeName = context.getTypeName(); if (context.isDeclaredType() && !context.getDeclaredType().getTypeArguments().isEmpty()) { for (TypeMirror paramTypeMirror : context.getDeclaredType().getTypeArguments()) { if (isVariableOrWildcardType(paramTypeMirror)) { contextTypeName = TypeName.get(env.getTypeUtils().erasure(context.getDeclaredType())); } } // If we're just a generic or wildcard, get the erasure and use that. } else if (isVariableOrWildcardType(context.getTypeMirror())) { contextTypeName = TypeName.get(env.getTypeUtils().erasure(context.getTypeMirror())); } if (needsNullHandling) { context.builder.addStatement("$T $L = null", contextTypeName, context.name); context.builder.beginControlFlow("if (codedIn.readBool())"); } else { context.builder.addStatement("$T $L", contextTypeName, context.name); } generator.addDeserializationCode(context); if (needsNullHandling) { context.builder.endControlFlow(); } } private SerializationCodeGenerator getMatchingCodeGenerator(TypeMirror type) { if (type.getKind() == TypeKind.ARRAY) { return arrayCodeGenerator; } if (type instanceof PrimitiveType) { PrimitiveType primitiveType = (PrimitiveType) type; return primitiveGenerators .stream() .filter(generator -> generator.matches((PrimitiveType) type)) .findFirst() .orElseThrow(() -> new IllegalArgumentException("No generator for: " + primitiveType)); } // We're dealing with a generic. if (isVariableOrWildcardType(type)) { return contextMarshaller; } // TODO(cpeyser): Refactor primitive handling from AutoCodecProcessor.java if (!(type instanceof DeclaredType)) { throw new IllegalArgumentException( "Can only serialize primitive, array or declared fields, found " + type); } DeclaredType declaredType = (DeclaredType) type; return marshallers .stream() .filter(marshaller -> marshaller.matches(declaredType)) .findFirst() .orElseThrow( () -> new IllegalArgumentException( "No marshaller for: " + ((TypeElement) declaredType.asElement()).getQualifiedName())); } private final SerializationCodeGenerator arrayCodeGenerator = new SerializationCodeGenerator() { @Override public void addSerializationCode(Context context) { String length = context.makeName("length"); context.builder.addStatement("int $L = $L.length", length, context.name); context.builder.addStatement("codedOut.writeInt32NoTag($L)", length); Context repeated = context.with( ((ArrayType) context.type).getComponentType(), context.makeName("repeated")); String indexName = context.makeName("i"); context.builder.beginControlFlow( "for(int $L = 0; $L < $L; ++$L)", indexName, indexName, length, indexName); context.builder.addStatement( "$T $L = $L[$L]", repeated.getTypeName(), repeated.name, context.name, indexName); writeSerializationCode(repeated); context.builder.endControlFlow(); } @Override public void addDeserializationCode(Context context) { Context repeated = context.with( ((ArrayType) context.type).getComponentType(), context.makeName("repeated")); String lengthName = context.makeName("length"); context.builder.addStatement("int $L = codedIn.readInt32()", lengthName); String resultName = context.makeName("result"); context.builder.addStatement( "$T[] $L = new $T[$L]", repeated.getTypeName(), resultName, repeated.getTypeName(), lengthName); String indexName = context.makeName("i"); context.builder.beginControlFlow( "for (int $L = 0; $L < $L; ++$L)", indexName, indexName, lengthName, indexName); writeDeserializationCode(repeated); context.builder.addStatement("$L[$L] = $L", resultName, indexName, repeated.name); context.builder.endControlFlow(); context.builder.addStatement("$L = $L", context.name, resultName); } }; private static final PrimitiveValueSerializationCodeGenerator INT_CODE_GENERATOR = new PrimitiveValueSerializationCodeGenerator() { @Override public boolean matches(PrimitiveType type) { return type.getKind() == TypeKind.INT; } @Override public void addSerializationCode(Context context) { context.builder.addStatement("codedOut.writeInt32NoTag($L)", context.name); } @Override public void addDeserializationCode(Context context) { context.builder.addStatement("$L = codedIn.readInt32()", context.name); } }; private static final PrimitiveValueSerializationCodeGenerator LONG_CODE_GENERATOR = new PrimitiveValueSerializationCodeGenerator() { @Override public boolean matches(PrimitiveType type) { return type.getKind() == TypeKind.LONG; } @Override public void addSerializationCode(Context context) { context.builder.addStatement("codedOut.writeInt64NoTag($L)", context.name); } @Override public void addDeserializationCode(Context context) { context.builder.addStatement("$L = codedIn.readInt64()", context.name); } }; private static final PrimitiveValueSerializationCodeGenerator BYTE_CODE_GENERATOR = new PrimitiveValueSerializationCodeGenerator() { @Override public boolean matches(PrimitiveType type) { return type.getKind() == TypeKind.BYTE; } @Override public void addSerializationCode(Context context) { context.builder.addStatement("codedOut.write($L)", context.name); } @Override public void addDeserializationCode(Context context) { context.builder.addStatement("$L = codedIn.readRawByte()", context.name); } }; private static final PrimitiveValueSerializationCodeGenerator BOOLEAN_CODE_GENERATOR = new PrimitiveValueSerializationCodeGenerator() { @Override public boolean matches(PrimitiveType type) { return type.getKind() == TypeKind.BOOLEAN; } @Override public void addSerializationCode(Context context) { context.builder.addStatement("codedOut.writeBoolNoTag($L)", context.name); } @Override public void addDeserializationCode(Context context) { context.builder.addStatement("$L = codedIn.readBool()", context.name); } }; private static final PrimitiveValueSerializationCodeGenerator DOUBLE_CODE_GENERATOR = new PrimitiveValueSerializationCodeGenerator() { @Override public boolean matches(PrimitiveType type) { return type.getKind() == TypeKind.DOUBLE; } @Override public void addSerializationCode(Context context) { context.builder.addStatement("codedOut.writeDoubleNoTag($L)", context.name); } @Override public void addDeserializationCode(Context context) { context.builder.addStatement("$L = codedIn.readDouble()", context.name); } }; private final Marshaller charSequenceMarshaller = new Marshaller() { @Override public boolean matches(DeclaredType type) { return matchesType(type, CharSequence.class); } @Override public void addSerializationCode(Context context) { context.builder.addStatement( "$T.asciiOptimized().serialize(context, $L.toString(), codedOut)", StringCodecs.class, context.name); } @Override public void addDeserializationCode(Context context) { context.builder.addStatement( "$L = $T.asciiOptimized().deserialize(context, codedIn)", context.name, StringCodecs.class); } }; private final Marshaller supplierMarshaller = new Marshaller() { @Override public boolean matches(DeclaredType type) { return matchesErased(type, Supplier.class); } @Override public void addSerializationCode(Context context) { DeclaredType suppliedType = (DeclaredType) context.getDeclaredType().getTypeArguments().get(0); writeSerializationCode(context.with(suppliedType, context.name + ".get()")); } @Override public void addDeserializationCode(Context context) { DeclaredType suppliedType = (DeclaredType) context.getDeclaredType().getTypeArguments().get(0); String suppliedName = context.makeName("supplied"); writeDeserializationCode(context.with(suppliedType, suppliedName)); String suppliedFinalName = context.makeName("suppliedFinal"); context.builder.addStatement( "final $T $L = $L", suppliedType, suppliedFinalName, suppliedName); context.builder.addStatement("$L = () -> $L", context.name, suppliedFinalName); } }; private final Marshaller multimapMarshaller = new Marshaller() { @Override public boolean matches(DeclaredType type) { return matchesErased(type, ImmutableMultimap.class) || matchesErased(type, ImmutableListMultimap.class); } @Override public void addSerializationCode(Context context) { context.builder.addStatement("codedOut.writeInt32NoTag($L.size())", context.name); String entryName = context.makeName("entry"); Context key = context.with( context.getDeclaredType().getTypeArguments().get(0), entryName + ".getKey()"); Context value = context.with( context.getDeclaredType().getTypeArguments().get(1), entryName + ".getValue()"); context.builder.beginControlFlow( "for ($T<$T, $T> $L : $L.entries())", Map.Entry.class, key.getTypeName(), value.getTypeName(), entryName, context.name); writeSerializationCode(key); writeSerializationCode(value); context.builder.endControlFlow(); } @Override public void addDeserializationCode(Context context) { Context key = context.with( context.getDeclaredType().getTypeArguments().get(0), context.makeName("key")); Context value = context.with( context.getDeclaredType().getTypeArguments().get(1), context.makeName("value")); String builderName = context.makeName("builder"); context.builder.addStatement( "$T<$T, $T> $L = new $T<>()", ImmutableListMultimap.Builder.class, key.getTypeName(), value.getTypeName(), builderName, ImmutableListMultimap.Builder.class); String lengthName = context.makeName("length"); context.builder.addStatement("int $L = codedIn.readInt32()", lengthName); String indexName = context.makeName("i"); context.builder.beginControlFlow( "for (int $L = 0; $L < $L; ++$L)", indexName, indexName, lengthName, indexName); writeDeserializationCode(key); writeDeserializationCode(value); context.builder.addStatement("$L.put($L, $L)", builderName, key.name, value.name); context.builder.endControlFlow(); context.builder.addStatement("$L = $L.build()", context.name, builderName); } }; /** Delegates marshalling back to the context. */ private final Marshaller contextMarshaller = new Marshaller() { @Override public boolean matches(DeclaredType unusedType) { return true; } @Override public void addSerializationCode(Context context) { context.builder.addStatement("context.serialize($L, codedOut)", context.name); } @Override public void addDeserializationCode(Context context) { context.builder.addStatement("$L = context.deserialize(codedIn)", context.name); } }; private final ImmutableList primitiveGenerators = ImmutableList.of( INT_CODE_GENERATOR, LONG_CODE_GENERATOR, BYTE_CODE_GENERATOR, BOOLEAN_CODE_GENERATOR, DOUBLE_CODE_GENERATOR); private final ImmutableList marshallers = ImmutableList.of( charSequenceMarshaller, supplierMarshaller, multimapMarshaller, contextMarshaller); /** True when {@code type} has the same type as {@code clazz}. */ private boolean matchesType(TypeMirror type, Class clazz) { return env.getTypeUtils().isSameType(type, getType(clazz)); } /** True when erasure of {@code type} matches erasure of {@code clazz}. */ private boolean matchesErased(TypeMirror type, Class clazz) { return env.getTypeUtils() .isSameType(env.getTypeUtils().erasure(type), env.getTypeUtils().erasure(getType(clazz))); } /** Returns the TypeMirror corresponding to {@code clazz}. */ private TypeMirror getType(Class clazz) { return AutoCodecUtil.getType(clazz, env); } static boolean isVariableOrWildcardType(TypeMirror type) { return type instanceof TypeVariable || type instanceof WildcardType; } }