diff options
author | 2018-03-27 13:45:26 -0700 | |
---|---|---|
committer | 2018-03-27 13:47:07 -0700 | |
commit | 7a84f6532428f0b93d323f43e9831c00f854475e (patch) | |
tree | d87172f98d3d77b2473235463d765db44e37c3af | |
parent | 8bbb6c23d25b85fe13b41edd16010c2b5fafe2ea (diff) |
DynamicCodec class.
PiperOrigin-RevId: 190667019
-rw-r--r-- | src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java | 251 | ||||
-rw-r--r-- | src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java | 378 |
2 files changed, 629 insertions, 0 deletions
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java new file mode 100644 index 0000000000..48bc8223c6 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java @@ -0,0 +1,251 @@ +// Copyright 2018 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.devtools.build.lib.skyframe.serialization; + +import com.google.common.collect.ImmutableSortedMap; +import com.google.devtools.build.lib.skyframe.serialization.autocodec.UnsafeProvider; +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.CodedOutputStream; +import java.io.IOException; +import java.lang.reflect.Array; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.nio.ByteBuffer; +import java.util.Comparator; +import java.util.Map; +import sun.reflect.ReflectionFactory; + +/** + * A codec that serializes arbitrary types. + * + * <p>TODO(shahan): replace Unsafe with VarHandle once it's available. + */ +public class DynamicCodec<T> implements ObjectCodec<T> { + + private final Class<T> type; + private final Constructor<T> constructor; + private final ImmutableSortedMap<Field, Long> offsets; + private final ObjectCodec.MemoizationStrategy strategy; + + public DynamicCodec(Class<T> type) throws ReflectiveOperationException { + this(type, ObjectCodec.MemoizationStrategy.MEMOIZE_BEFORE); + } + + public DynamicCodec(Class<T> type, ObjectCodec.MemoizationStrategy strategy) + throws ReflectiveOperationException { + this.type = type; + this.constructor = getConstructor(type); + this.offsets = getOffsets(type); + this.strategy = strategy; + } + + @Override + public Class<T> getEncodedClass() { + return type; + } + + @Override + public MemoizationStrategy getStrategy() { + return strategy; + } + + @Override + public void serialize(SerializationContext context, T obj, CodedOutputStream codedOut) + throws SerializationException, IOException { + for (Map.Entry<Field, Long> entry : offsets.entrySet()) { + serializeField(context, codedOut, obj, entry.getKey().getType(), entry.getValue()); + } + } + + /** + * Serializes a field. + * + * @param obj the object containing the field to serialize. Can be an array or plain object. + * @param type class of the field to serialize + * @param offset unsafe offset into obj where the field will be found + */ + private void serializeField( + SerializationContext context, + CodedOutputStream codedOut, + Object obj, + Class<?> type, + long offset) + throws SerializationException, IOException { + if (type.isPrimitive()) { + if (type.equals(boolean.class)) { + codedOut.writeBoolNoTag(UnsafeProvider.getInstance().getBoolean(obj, offset)); + } else if (type.equals(byte.class)) { + codedOut.writeRawByte(UnsafeProvider.getInstance().getByte(obj, offset)); + } else if (type.equals(short.class)) { + ByteBuffer buffer = + ByteBuffer.allocate(2).putShort(UnsafeProvider.getInstance().getShort(obj, offset)); + codedOut.writeRawBytes(buffer); + } else if (type.equals(char.class)) { + ByteBuffer buffer = + ByteBuffer.allocate(2).putChar(UnsafeProvider.getInstance().getChar(obj, offset)); + codedOut.writeRawBytes(buffer); + } else if (type.equals(int.class)) { + codedOut.writeInt32NoTag(UnsafeProvider.getInstance().getInt(obj, offset)); + } else if (type.equals(long.class)) { + codedOut.writeInt64NoTag(UnsafeProvider.getInstance().getLong(obj, offset)); + } else if (type.equals(float.class)) { + codedOut.writeFloatNoTag(UnsafeProvider.getInstance().getFloat(obj, offset)); + } else if (type.equals(double.class)) { + codedOut.writeDoubleNoTag(UnsafeProvider.getInstance().getDouble(obj, offset)); + } else if (type.equals(void.class)) { + // Does nothing for void type. + } else { + throw new UnsupportedOperationException("Unknown primitive type: " + type); + } + } else if (type.isArray()) { + Object arr = UnsafeProvider.getInstance().getObject(obj, offset); + if (arr == null) { + codedOut.writeInt32NoTag(-1); + return; + } + int length = Array.getLength(arr); + codedOut.writeInt32NoTag(length); + int base = UnsafeProvider.getInstance().arrayBaseOffset(type); + int scale = UnsafeProvider.getInstance().arrayIndexScale(type); + if (scale == 0) { + throw new SerializationException("Failed to get index scale for type: " + type); + } + for (int i = 0; i < length; ++i) { + // Serializes the ith array field directly from array memory. + serializeField(context, codedOut, arr, type.getComponentType(), base + scale * i); + } + } else { + context.serialize(UnsafeProvider.getInstance().getObject(obj, offset), codedOut); + } + } + + @Override + public T deserialize(DeserializationContext context, CodedInputStream codedIn) + throws SerializationException, IOException { + T instance; + try { + instance = constructor.newInstance(); + } catch (ReflectiveOperationException e) { + throw new SerializationException("Could not instantiate object of type: " + type, e); + } + if (strategy.equals(ObjectCodec.MemoizationStrategy.MEMOIZE_BEFORE)) { + context.registerInitialValue(instance); + } + for (Map.Entry<Field, Long> entry : offsets.entrySet()) { + deserializeField(context, codedIn, instance, entry.getKey().getType(), entry.getValue()); + } + return instance; + } + + /** + * Deserializes a field directly into the supplied object. + * + * @param obj the object containing the field to deserialize. Can be an array or a plain object. + * @param type class of the field to deserialize + * @param offset unsafe offset into obj where the field should be written + */ + private void deserializeField( + DeserializationContext context, + CodedInputStream codedIn, + Object obj, + Class<?> type, + long offset) + throws SerializationException, IOException { + if (type.isPrimitive()) { + if (type.equals(boolean.class)) { + UnsafeProvider.getInstance().putBoolean(obj, offset, codedIn.readBool()); + } else if (type.equals(byte.class)) { + UnsafeProvider.getInstance().putByte(obj, offset, codedIn.readRawByte()); + } else if (type.equals(short.class)) { + ByteBuffer buffer = ByteBuffer.allocate(2).put(codedIn.readRawBytes(2)); + UnsafeProvider.getInstance().putShort(obj, offset, buffer.getShort(0)); + } else if (type.equals(char.class)) { + ByteBuffer buffer = ByteBuffer.allocate(2).put(codedIn.readRawBytes(2)); + UnsafeProvider.getInstance().putChar(obj, offset, buffer.getChar(0)); + } else if (type.equals(int.class)) { + UnsafeProvider.getInstance().putInt(obj, offset, codedIn.readInt32()); + } else if (type.equals(long.class)) { + UnsafeProvider.getInstance().putLong(obj, offset, codedIn.readInt64()); + } else if (type.equals(float.class)) { + UnsafeProvider.getInstance().putFloat(obj, offset, codedIn.readFloat()); + } else if (type.equals(double.class)) { + UnsafeProvider.getInstance().putDouble(obj, offset, codedIn.readDouble()); + } else if (type.equals(void.class)) { + // Does nothing for void type. + } else { + throw new UnsupportedOperationException("Unknown primitive type: " + type); + } + } else if (type.isArray()) { + int length = codedIn.readInt32(); + if (length < 0) { + UnsafeProvider.getInstance().putObject(obj, offset, null); + return; + } + Object arr = Array.newInstance(type.getComponentType(), length); + UnsafeProvider.getInstance().putObject(obj, offset, arr); + int base = UnsafeProvider.getInstance().arrayBaseOffset(type); + int scale = UnsafeProvider.getInstance().arrayIndexScale(type); + if (scale == 0) { + throw new SerializationException("Failed to get index scale for type: " + type); + } + for (int i = 0; i < length; ++i) { + // Deserializes type directly into array memory. + deserializeField(context, codedIn, arr, type.getComponentType(), base + scale * i); + } + } else { + UnsafeProvider.getInstance().putObject(obj, offset, context.deserialize(codedIn)); + } + } + + private static <T> ImmutableSortedMap<Field, Long> getOffsets(Class<T> type) { + ImmutableSortedMap.Builder<Field, Long> offsets = + new ImmutableSortedMap.Builder<>(new FieldComparator()); + for (Class<? super T> next = type; next != null; next = next.getSuperclass()) { + for (Field field : next.getDeclaredFields()) { + if ((field.getModifiers() & (Modifier.STATIC | Modifier.TRANSIENT)) != 0) { + continue; // Skips static or transient fields. + } + field.setAccessible(true); + offsets.put(field, UnsafeProvider.getInstance().objectFieldOffset(field)); + } + } + return offsets.build(); + } + + @SuppressWarnings("unchecked") + private static <T> Constructor<T> getConstructor(Class<T> type) + throws ReflectiveOperationException { + Constructor<T> constructor = + (Constructor<T>) + ReflectionFactory.getReflectionFactory() + .newConstructorForSerialization(type, Object.class.getDeclaredConstructor()); + constructor.setAccessible(true); + return constructor; + } + + private static class FieldComparator implements Comparator<Field> { + + @Override + public int compare(Field f1, Field f2) { + int classCompare = + f1.getDeclaringClass().getName().compareTo(f2.getDeclaringClass().getName()); + if (classCompare != 0) { + return classCompare; + } + return f1.getName().compareTo(f2.getName()); + } + } +} diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java new file mode 100644 index 0000000000..3e48781def --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java @@ -0,0 +1,378 @@ +// Copyright 2018 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.devtools.build.lib.skyframe.serialization; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester; +import java.util.Arrays; +import java.util.Objects; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link DynamicCodec}. */ +@RunWith(JUnit4.class) +public final class DynamicCodecTest { + + private static class SimpleExample { + private final String elt; + private final String elt2; + private final int x; + + private SimpleExample(String elt, String elt2, int x) { + this.elt = elt; + this.elt2 = elt2; + this.x = x; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object other) { + if (!(other instanceof SimpleExample)) { + return false; + } + SimpleExample that = (SimpleExample) other; + return Objects.equals(elt, that.elt) && Objects.equals(elt2, that.elt2) && x == that.x; + } + } + + @Test + public void testExample() throws Exception { + new SerializationTester(new SimpleExample("a", "b", -5), new SimpleExample("a", null, 10)) + .addCodec(new DynamicCodec<>(SimpleExample.class)) + .makeMemoizing() + .runTests(); + } + + private static class ExampleSubclass extends SimpleExample { + private final String elt; // duplicate name with superclass + + private ExampleSubclass(String elt1, String elt2, String elt3, int x) { + super(elt1, elt2, x); + this.elt = elt3; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object other) { + if (!(other instanceof ExampleSubclass)) { + return false; + } + if (!super.equals(other)) { + return false; + } + ExampleSubclass that = (ExampleSubclass) other; + return Objects.equals(elt, that.elt); + } + } + + @Test + public void testExampleSubclass() throws Exception { + new SerializationTester( + new ExampleSubclass("a", "b", "c", 0), new ExampleSubclass("a", null, null, 15)) + .addCodec(new DynamicCodec<>(ExampleSubclass.class)) + .makeMemoizing() + .runTests(); + } + + private static class ExampleSmallPrimitives { + private final Void v; + private final boolean bit; + private final byte b; + private final short s; + private final char c; + + private ExampleSmallPrimitives(boolean bit, byte b, short s, char c) { + this.v = null; + this.bit = bit; + this.b = b; + this.s = s; + this.c = c; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object other) { + if (!(other instanceof ExampleSmallPrimitives)) { + return false; + } + ExampleSmallPrimitives that = (ExampleSmallPrimitives) other; + return v == that.v && bit == that.bit && b == that.b && s == that.s && c == that.c; + } + } + + @Test + public void testExampleSmallPrimitives() throws Exception { + new SerializationTester( + new ExampleSmallPrimitives(false, (byte) 0, (short) 0, 'a'), + new ExampleSmallPrimitives(false, (byte) 120, (short) 18000, 'x'), + new ExampleSmallPrimitives(true, Byte.MIN_VALUE, Short.MIN_VALUE, Character.MIN_VALUE), + new ExampleSmallPrimitives(true, Byte.MAX_VALUE, Short.MAX_VALUE, Character.MAX_VALUE)) + .addCodec(new DynamicCodec<>(ExampleSmallPrimitives.class)) + .makeMemoizing() + .runTests(); + } + + private static class ExampleMediumPrimitives { + private final int i; + private final float f; + + private ExampleMediumPrimitives(int i, float f) { + this.i = i; + this.f = f; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object other) { + if (!(other instanceof ExampleMediumPrimitives)) { + return false; + } + ExampleMediumPrimitives that = (ExampleMediumPrimitives) other; + return i == that.i && f == that.f; + } + } + + @Test + public void testExampleMediumPrimitives() throws Exception { + new SerializationTester( + new ExampleMediumPrimitives(12345, 1e12f), + new ExampleMediumPrimitives(67890, -6e9f), + new ExampleMediumPrimitives(Integer.MIN_VALUE, Float.MIN_VALUE), + new ExampleMediumPrimitives(Integer.MAX_VALUE, Float.MAX_VALUE)) + .addCodec(new DynamicCodec<>(ExampleMediumPrimitives.class)) + .makeMemoizing() + .runTests(); + } + + private static class ExampleLargePrimitives { + private final long l; + private final double d; + + private ExampleLargePrimitives(long l, double d) { + this.l = l; + this.d = d; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object other) { + if (!(other instanceof ExampleLargePrimitives)) { + return false; + } + ExampleLargePrimitives that = (ExampleLargePrimitives) other; + return l == that.l && d == that.d; + } + } + + @Test + public void testExampleLargePrimitives() throws Exception { + new SerializationTester( + new ExampleLargePrimitives(12345346523453L, 1e300), + new ExampleLargePrimitives(678900093045L, -9e180), + new ExampleLargePrimitives(Long.MIN_VALUE, Double.MIN_VALUE), + new ExampleLargePrimitives(Long.MAX_VALUE, Double.MAX_VALUE)) + .addCodec(new DynamicCodec<>(ExampleLargePrimitives.class)) + .makeMemoizing() + .runTests(); + } + + private static class ArrayExample { + String[] text; + byte[] numbers; + char[] chars; + long[] longs; + + private ArrayExample(String[] text, byte[] numbers, char[] chars, long[] longs) { + this.text = text; + this.numbers = numbers; + this.chars = chars; + this.longs = longs; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object other) { + if (!(other instanceof ArrayExample)) { + return false; + } + ArrayExample that = (ArrayExample) other; + return Arrays.equals(text, that.text) + && Arrays.equals(numbers, that.numbers) + && Arrays.equals(chars, that.chars) + && Arrays.equals(longs, that.longs); + } + } + + @Test + public void testArray() throws Exception { + new SerializationTester( + new ArrayExample(null, null, null, null), + new ArrayExample(new String[] {}, new byte[] {}, new char[] {}, new long[] {}), + new ArrayExample( + new String[] {"a", "b", "cde"}, + new byte[] {-1, 0, 1}, + new char[] {'a', 'b', 'c', 'x', 'y', 'z'}, + new long[] {Long.MAX_VALUE, Long.MIN_VALUE, 27983741982341L, 52893748523495834L})) + .addCodec(new DynamicCodec<>(ArrayExample.class)) + .makeMemoizing() + .runTests(); + } + + private static class NestedArrayExample { + int[][] numbers; + + private NestedArrayExample(int[][] numbers) { + this.numbers = numbers; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object other) { + if (!(other instanceof NestedArrayExample)) { + return false; + } + NestedArrayExample that = (NestedArrayExample) other; + return Arrays.deepEquals(numbers, that.numbers); + } + } + + @Test + public void testNestedArray() throws Exception { + new SerializationTester( + new NestedArrayExample(null), + new NestedArrayExample( + new int[][] { + {1, 2, 3}, + {4, 5, 6, 9}, + {7} + }), + new NestedArrayExample(new int[][] {{1, 2, 3}, null, {7}})) + .addCodec(new DynamicCodec<>(NestedArrayExample.class)) + .makeMemoizing() + .runTests(); + } + + private static class CycleA { + private final int value; + private CycleB b; + + private CycleA(int value) { + this.value = value; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object other) { + // Integrity check. Not really part of equals. + assertThat(b.a).isEqualTo(this); + if (!(other instanceof CycleA)) { + return false; + } + CycleA that = (CycleA) other; + // Consistency check. Not really part of equals. + assertThat(that.b.a).isEqualTo(that); + return value == that.value && b.value() == that.b.value; + } + } + + private static class CycleB { + private final int value; + private CycleA a; + + private CycleB(int value) { + this.value = value; + } + + public int value() { + return value; + } + } + + private static CycleA createCycle(int valueA, int valueB) { + CycleA a = new CycleA(valueA); + a.b = new CycleB(valueB); + a.b.a = a; + return a; + } + + @Test + public void testCyclic() throws Exception { + new SerializationTester(createCycle(1, 2), createCycle(3, 4)) + .addCodec(new DynamicCodec<>(CycleA.class)) + .addCodec(new DynamicCodec<>(CycleB.class)) + .makeMemoizing() + .runTests(); + } + + enum EnumExample { + ZERO, + ONE, + TWO, + THREE + } + + static class PrimitiveExample { + + private final boolean booleanValue; + private final int intValue; + private final double doubleValue; + private final EnumExample enumValue; + private final String stringValue; + + PrimitiveExample( + boolean booleanValue, + int intValue, + double doubleValue, + EnumExample enumValue, + String stringValue) { + this.booleanValue = booleanValue; + this.intValue = intValue; + this.doubleValue = doubleValue; + this.enumValue = enumValue; + this.stringValue = stringValue; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object object) { + if (object == null) { + return false; + } + PrimitiveExample that = (PrimitiveExample) object; + return booleanValue == that.booleanValue + && intValue == that.intValue + && doubleValue == that.doubleValue + && Objects.equals(enumValue, that.enumValue) + && Objects.equals(stringValue, that.stringValue); + } + } + + @Test + public void testPrimitiveExample() throws Exception { + new SerializationTester( + new PrimitiveExample(true, 1, 1.1, EnumExample.ZERO, "foo"), + new PrimitiveExample(false, -1, -5.5, EnumExample.ONE, "bar"), + new PrimitiveExample(true, 5, 20.0, EnumExample.THREE, null), + new PrimitiveExample(true, 100, 100, null, "hello")) + .addCodec( + new DynamicCodec<>( + PrimitiveExample.class, ObjectCodec.MemoizationStrategy.DO_NOT_MEMOIZE)) + .addCodec(new EnumCodec<>(EnumExample.class)) + .setRepetitions(100000) + .runTests(); + } +} |