aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar shahan <shahan@google.com>2018-03-27 13:45:26 -0700
committerGravatar Copybara-Service <copybara-piper@google.com>2018-03-27 13:47:07 -0700
commit7a84f6532428f0b93d323f43e9831c00f854475e (patch)
treed87172f98d3d77b2473235463d765db44e37c3af
parent8bbb6c23d25b85fe13b41edd16010c2b5fafe2ea (diff)
DynamicCodec class.
PiperOrigin-RevId: 190667019
-rw-r--r--src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java251
-rw-r--r--src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java378
2 files changed, 629 insertions, 0 deletions
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
new file mode 100644
index 0000000000..48bc8223c6
--- /dev/null
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
@@ -0,0 +1,251 @@
+// Copyright 2018 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.devtools.build.lib.skyframe.serialization;
+
+import com.google.common.collect.ImmutableSortedMap;
+import com.google.devtools.build.lib.skyframe.serialization.autocodec.UnsafeProvider;
+import com.google.protobuf.CodedInputStream;
+import com.google.protobuf.CodedOutputStream;
+import java.io.IOException;
+import java.lang.reflect.Array;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Field;
+import java.lang.reflect.Modifier;
+import java.nio.ByteBuffer;
+import java.util.Comparator;
+import java.util.Map;
+import sun.reflect.ReflectionFactory;
+
+/**
+ * A codec that serializes arbitrary types.
+ *
+ * <p>TODO(shahan): replace Unsafe with VarHandle once it's available.
+ */
+public class DynamicCodec<T> implements ObjectCodec<T> {
+
+ private final Class<T> type;
+ private final Constructor<T> constructor;
+ private final ImmutableSortedMap<Field, Long> offsets;
+ private final ObjectCodec.MemoizationStrategy strategy;
+
+ public DynamicCodec(Class<T> type) throws ReflectiveOperationException {
+ this(type, ObjectCodec.MemoizationStrategy.MEMOIZE_BEFORE);
+ }
+
+ public DynamicCodec(Class<T> type, ObjectCodec.MemoizationStrategy strategy)
+ throws ReflectiveOperationException {
+ this.type = type;
+ this.constructor = getConstructor(type);
+ this.offsets = getOffsets(type);
+ this.strategy = strategy;
+ }
+
+ @Override
+ public Class<T> getEncodedClass() {
+ return type;
+ }
+
+ @Override
+ public MemoizationStrategy getStrategy() {
+ return strategy;
+ }
+
+ @Override
+ public void serialize(SerializationContext context, T obj, CodedOutputStream codedOut)
+ throws SerializationException, IOException {
+ for (Map.Entry<Field, Long> entry : offsets.entrySet()) {
+ serializeField(context, codedOut, obj, entry.getKey().getType(), entry.getValue());
+ }
+ }
+
+ /**
+ * Serializes a field.
+ *
+ * @param obj the object containing the field to serialize. Can be an array or plain object.
+ * @param type class of the field to serialize
+ * @param offset unsafe offset into obj where the field will be found
+ */
+ private void serializeField(
+ SerializationContext context,
+ CodedOutputStream codedOut,
+ Object obj,
+ Class<?> type,
+ long offset)
+ throws SerializationException, IOException {
+ if (type.isPrimitive()) {
+ if (type.equals(boolean.class)) {
+ codedOut.writeBoolNoTag(UnsafeProvider.getInstance().getBoolean(obj, offset));
+ } else if (type.equals(byte.class)) {
+ codedOut.writeRawByte(UnsafeProvider.getInstance().getByte(obj, offset));
+ } else if (type.equals(short.class)) {
+ ByteBuffer buffer =
+ ByteBuffer.allocate(2).putShort(UnsafeProvider.getInstance().getShort(obj, offset));
+ codedOut.writeRawBytes(buffer);
+ } else if (type.equals(char.class)) {
+ ByteBuffer buffer =
+ ByteBuffer.allocate(2).putChar(UnsafeProvider.getInstance().getChar(obj, offset));
+ codedOut.writeRawBytes(buffer);
+ } else if (type.equals(int.class)) {
+ codedOut.writeInt32NoTag(UnsafeProvider.getInstance().getInt(obj, offset));
+ } else if (type.equals(long.class)) {
+ codedOut.writeInt64NoTag(UnsafeProvider.getInstance().getLong(obj, offset));
+ } else if (type.equals(float.class)) {
+ codedOut.writeFloatNoTag(UnsafeProvider.getInstance().getFloat(obj, offset));
+ } else if (type.equals(double.class)) {
+ codedOut.writeDoubleNoTag(UnsafeProvider.getInstance().getDouble(obj, offset));
+ } else if (type.equals(void.class)) {
+ // Does nothing for void type.
+ } else {
+ throw new UnsupportedOperationException("Unknown primitive type: " + type);
+ }
+ } else if (type.isArray()) {
+ Object arr = UnsafeProvider.getInstance().getObject(obj, offset);
+ if (arr == null) {
+ codedOut.writeInt32NoTag(-1);
+ return;
+ }
+ int length = Array.getLength(arr);
+ codedOut.writeInt32NoTag(length);
+ int base = UnsafeProvider.getInstance().arrayBaseOffset(type);
+ int scale = UnsafeProvider.getInstance().arrayIndexScale(type);
+ if (scale == 0) {
+ throw new SerializationException("Failed to get index scale for type: " + type);
+ }
+ for (int i = 0; i < length; ++i) {
+ // Serializes the ith array field directly from array memory.
+ serializeField(context, codedOut, arr, type.getComponentType(), base + scale * i);
+ }
+ } else {
+ context.serialize(UnsafeProvider.getInstance().getObject(obj, offset), codedOut);
+ }
+ }
+
+ @Override
+ public T deserialize(DeserializationContext context, CodedInputStream codedIn)
+ throws SerializationException, IOException {
+ T instance;
+ try {
+ instance = constructor.newInstance();
+ } catch (ReflectiveOperationException e) {
+ throw new SerializationException("Could not instantiate object of type: " + type, e);
+ }
+ if (strategy.equals(ObjectCodec.MemoizationStrategy.MEMOIZE_BEFORE)) {
+ context.registerInitialValue(instance);
+ }
+ for (Map.Entry<Field, Long> entry : offsets.entrySet()) {
+ deserializeField(context, codedIn, instance, entry.getKey().getType(), entry.getValue());
+ }
+ return instance;
+ }
+
+ /**
+ * Deserializes a field directly into the supplied object.
+ *
+ * @param obj the object containing the field to deserialize. Can be an array or a plain object.
+ * @param type class of the field to deserialize
+ * @param offset unsafe offset into obj where the field should be written
+ */
+ private void deserializeField(
+ DeserializationContext context,
+ CodedInputStream codedIn,
+ Object obj,
+ Class<?> type,
+ long offset)
+ throws SerializationException, IOException {
+ if (type.isPrimitive()) {
+ if (type.equals(boolean.class)) {
+ UnsafeProvider.getInstance().putBoolean(obj, offset, codedIn.readBool());
+ } else if (type.equals(byte.class)) {
+ UnsafeProvider.getInstance().putByte(obj, offset, codedIn.readRawByte());
+ } else if (type.equals(short.class)) {
+ ByteBuffer buffer = ByteBuffer.allocate(2).put(codedIn.readRawBytes(2));
+ UnsafeProvider.getInstance().putShort(obj, offset, buffer.getShort(0));
+ } else if (type.equals(char.class)) {
+ ByteBuffer buffer = ByteBuffer.allocate(2).put(codedIn.readRawBytes(2));
+ UnsafeProvider.getInstance().putChar(obj, offset, buffer.getChar(0));
+ } else if (type.equals(int.class)) {
+ UnsafeProvider.getInstance().putInt(obj, offset, codedIn.readInt32());
+ } else if (type.equals(long.class)) {
+ UnsafeProvider.getInstance().putLong(obj, offset, codedIn.readInt64());
+ } else if (type.equals(float.class)) {
+ UnsafeProvider.getInstance().putFloat(obj, offset, codedIn.readFloat());
+ } else if (type.equals(double.class)) {
+ UnsafeProvider.getInstance().putDouble(obj, offset, codedIn.readDouble());
+ } else if (type.equals(void.class)) {
+ // Does nothing for void type.
+ } else {
+ throw new UnsupportedOperationException("Unknown primitive type: " + type);
+ }
+ } else if (type.isArray()) {
+ int length = codedIn.readInt32();
+ if (length < 0) {
+ UnsafeProvider.getInstance().putObject(obj, offset, null);
+ return;
+ }
+ Object arr = Array.newInstance(type.getComponentType(), length);
+ UnsafeProvider.getInstance().putObject(obj, offset, arr);
+ int base = UnsafeProvider.getInstance().arrayBaseOffset(type);
+ int scale = UnsafeProvider.getInstance().arrayIndexScale(type);
+ if (scale == 0) {
+ throw new SerializationException("Failed to get index scale for type: " + type);
+ }
+ for (int i = 0; i < length; ++i) {
+ // Deserializes type directly into array memory.
+ deserializeField(context, codedIn, arr, type.getComponentType(), base + scale * i);
+ }
+ } else {
+ UnsafeProvider.getInstance().putObject(obj, offset, context.deserialize(codedIn));
+ }
+ }
+
+ private static <T> ImmutableSortedMap<Field, Long> getOffsets(Class<T> type) {
+ ImmutableSortedMap.Builder<Field, Long> offsets =
+ new ImmutableSortedMap.Builder<>(new FieldComparator());
+ for (Class<? super T> next = type; next != null; next = next.getSuperclass()) {
+ for (Field field : next.getDeclaredFields()) {
+ if ((field.getModifiers() & (Modifier.STATIC | Modifier.TRANSIENT)) != 0) {
+ continue; // Skips static or transient fields.
+ }
+ field.setAccessible(true);
+ offsets.put(field, UnsafeProvider.getInstance().objectFieldOffset(field));
+ }
+ }
+ return offsets.build();
+ }
+
+ @SuppressWarnings("unchecked")
+ private static <T> Constructor<T> getConstructor(Class<T> type)
+ throws ReflectiveOperationException {
+ Constructor<T> constructor =
+ (Constructor<T>)
+ ReflectionFactory.getReflectionFactory()
+ .newConstructorForSerialization(type, Object.class.getDeclaredConstructor());
+ constructor.setAccessible(true);
+ return constructor;
+ }
+
+ private static class FieldComparator implements Comparator<Field> {
+
+ @Override
+ public int compare(Field f1, Field f2) {
+ int classCompare =
+ f1.getDeclaringClass().getName().compareTo(f2.getDeclaringClass().getName());
+ if (classCompare != 0) {
+ return classCompare;
+ }
+ return f1.getName().compareTo(f2.getName());
+ }
+ }
+}
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java
new file mode 100644
index 0000000000..3e48781def
--- /dev/null
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java
@@ -0,0 +1,378 @@
+// Copyright 2018 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.devtools.build.lib.skyframe.serialization;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester;
+import java.util.Arrays;
+import java.util.Objects;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link DynamicCodec}. */
+@RunWith(JUnit4.class)
+public final class DynamicCodecTest {
+
+ private static class SimpleExample {
+ private final String elt;
+ private final String elt2;
+ private final int x;
+
+ private SimpleExample(String elt, String elt2, int x) {
+ this.elt = elt;
+ this.elt2 = elt2;
+ this.x = x;
+ }
+
+ @SuppressWarnings("EqualsHashCode") // Testing
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof SimpleExample)) {
+ return false;
+ }
+ SimpleExample that = (SimpleExample) other;
+ return Objects.equals(elt, that.elt) && Objects.equals(elt2, that.elt2) && x == that.x;
+ }
+ }
+
+ @Test
+ public void testExample() throws Exception {
+ new SerializationTester(new SimpleExample("a", "b", -5), new SimpleExample("a", null, 10))
+ .addCodec(new DynamicCodec<>(SimpleExample.class))
+ .makeMemoizing()
+ .runTests();
+ }
+
+ private static class ExampleSubclass extends SimpleExample {
+ private final String elt; // duplicate name with superclass
+
+ private ExampleSubclass(String elt1, String elt2, String elt3, int x) {
+ super(elt1, elt2, x);
+ this.elt = elt3;
+ }
+
+ @SuppressWarnings("EqualsHashCode") // Testing
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof ExampleSubclass)) {
+ return false;
+ }
+ if (!super.equals(other)) {
+ return false;
+ }
+ ExampleSubclass that = (ExampleSubclass) other;
+ return Objects.equals(elt, that.elt);
+ }
+ }
+
+ @Test
+ public void testExampleSubclass() throws Exception {
+ new SerializationTester(
+ new ExampleSubclass("a", "b", "c", 0), new ExampleSubclass("a", null, null, 15))
+ .addCodec(new DynamicCodec<>(ExampleSubclass.class))
+ .makeMemoizing()
+ .runTests();
+ }
+
+ private static class ExampleSmallPrimitives {
+ private final Void v;
+ private final boolean bit;
+ private final byte b;
+ private final short s;
+ private final char c;
+
+ private ExampleSmallPrimitives(boolean bit, byte b, short s, char c) {
+ this.v = null;
+ this.bit = bit;
+ this.b = b;
+ this.s = s;
+ this.c = c;
+ }
+
+ @SuppressWarnings("EqualsHashCode") // Testing
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof ExampleSmallPrimitives)) {
+ return false;
+ }
+ ExampleSmallPrimitives that = (ExampleSmallPrimitives) other;
+ return v == that.v && bit == that.bit && b == that.b && s == that.s && c == that.c;
+ }
+ }
+
+ @Test
+ public void testExampleSmallPrimitives() throws Exception {
+ new SerializationTester(
+ new ExampleSmallPrimitives(false, (byte) 0, (short) 0, 'a'),
+ new ExampleSmallPrimitives(false, (byte) 120, (short) 18000, 'x'),
+ new ExampleSmallPrimitives(true, Byte.MIN_VALUE, Short.MIN_VALUE, Character.MIN_VALUE),
+ new ExampleSmallPrimitives(true, Byte.MAX_VALUE, Short.MAX_VALUE, Character.MAX_VALUE))
+ .addCodec(new DynamicCodec<>(ExampleSmallPrimitives.class))
+ .makeMemoizing()
+ .runTests();
+ }
+
+ private static class ExampleMediumPrimitives {
+ private final int i;
+ private final float f;
+
+ private ExampleMediumPrimitives(int i, float f) {
+ this.i = i;
+ this.f = f;
+ }
+
+ @SuppressWarnings("EqualsHashCode") // Testing
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof ExampleMediumPrimitives)) {
+ return false;
+ }
+ ExampleMediumPrimitives that = (ExampleMediumPrimitives) other;
+ return i == that.i && f == that.f;
+ }
+ }
+
+ @Test
+ public void testExampleMediumPrimitives() throws Exception {
+ new SerializationTester(
+ new ExampleMediumPrimitives(12345, 1e12f),
+ new ExampleMediumPrimitives(67890, -6e9f),
+ new ExampleMediumPrimitives(Integer.MIN_VALUE, Float.MIN_VALUE),
+ new ExampleMediumPrimitives(Integer.MAX_VALUE, Float.MAX_VALUE))
+ .addCodec(new DynamicCodec<>(ExampleMediumPrimitives.class))
+ .makeMemoizing()
+ .runTests();
+ }
+
+ private static class ExampleLargePrimitives {
+ private final long l;
+ private final double d;
+
+ private ExampleLargePrimitives(long l, double d) {
+ this.l = l;
+ this.d = d;
+ }
+
+ @SuppressWarnings("EqualsHashCode") // Testing
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof ExampleLargePrimitives)) {
+ return false;
+ }
+ ExampleLargePrimitives that = (ExampleLargePrimitives) other;
+ return l == that.l && d == that.d;
+ }
+ }
+
+ @Test
+ public void testExampleLargePrimitives() throws Exception {
+ new SerializationTester(
+ new ExampleLargePrimitives(12345346523453L, 1e300),
+ new ExampleLargePrimitives(678900093045L, -9e180),
+ new ExampleLargePrimitives(Long.MIN_VALUE, Double.MIN_VALUE),
+ new ExampleLargePrimitives(Long.MAX_VALUE, Double.MAX_VALUE))
+ .addCodec(new DynamicCodec<>(ExampleLargePrimitives.class))
+ .makeMemoizing()
+ .runTests();
+ }
+
+ private static class ArrayExample {
+ String[] text;
+ byte[] numbers;
+ char[] chars;
+ long[] longs;
+
+ private ArrayExample(String[] text, byte[] numbers, char[] chars, long[] longs) {
+ this.text = text;
+ this.numbers = numbers;
+ this.chars = chars;
+ this.longs = longs;
+ }
+
+ @SuppressWarnings("EqualsHashCode") // Testing
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof ArrayExample)) {
+ return false;
+ }
+ ArrayExample that = (ArrayExample) other;
+ return Arrays.equals(text, that.text)
+ && Arrays.equals(numbers, that.numbers)
+ && Arrays.equals(chars, that.chars)
+ && Arrays.equals(longs, that.longs);
+ }
+ }
+
+ @Test
+ public void testArray() throws Exception {
+ new SerializationTester(
+ new ArrayExample(null, null, null, null),
+ new ArrayExample(new String[] {}, new byte[] {}, new char[] {}, new long[] {}),
+ new ArrayExample(
+ new String[] {"a", "b", "cde"},
+ new byte[] {-1, 0, 1},
+ new char[] {'a', 'b', 'c', 'x', 'y', 'z'},
+ new long[] {Long.MAX_VALUE, Long.MIN_VALUE, 27983741982341L, 52893748523495834L}))
+ .addCodec(new DynamicCodec<>(ArrayExample.class))
+ .makeMemoizing()
+ .runTests();
+ }
+
+ private static class NestedArrayExample {
+ int[][] numbers;
+
+ private NestedArrayExample(int[][] numbers) {
+ this.numbers = numbers;
+ }
+
+ @SuppressWarnings("EqualsHashCode") // Testing
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof NestedArrayExample)) {
+ return false;
+ }
+ NestedArrayExample that = (NestedArrayExample) other;
+ return Arrays.deepEquals(numbers, that.numbers);
+ }
+ }
+
+ @Test
+ public void testNestedArray() throws Exception {
+ new SerializationTester(
+ new NestedArrayExample(null),
+ new NestedArrayExample(
+ new int[][] {
+ {1, 2, 3},
+ {4, 5, 6, 9},
+ {7}
+ }),
+ new NestedArrayExample(new int[][] {{1, 2, 3}, null, {7}}))
+ .addCodec(new DynamicCodec<>(NestedArrayExample.class))
+ .makeMemoizing()
+ .runTests();
+ }
+
+ private static class CycleA {
+ private final int value;
+ private CycleB b;
+
+ private CycleA(int value) {
+ this.value = value;
+ }
+
+ @SuppressWarnings("EqualsHashCode") // Testing
+ @Override
+ public boolean equals(Object other) {
+ // Integrity check. Not really part of equals.
+ assertThat(b.a).isEqualTo(this);
+ if (!(other instanceof CycleA)) {
+ return false;
+ }
+ CycleA that = (CycleA) other;
+ // Consistency check. Not really part of equals.
+ assertThat(that.b.a).isEqualTo(that);
+ return value == that.value && b.value() == that.b.value;
+ }
+ }
+
+ private static class CycleB {
+ private final int value;
+ private CycleA a;
+
+ private CycleB(int value) {
+ this.value = value;
+ }
+
+ public int value() {
+ return value;
+ }
+ }
+
+ private static CycleA createCycle(int valueA, int valueB) {
+ CycleA a = new CycleA(valueA);
+ a.b = new CycleB(valueB);
+ a.b.a = a;
+ return a;
+ }
+
+ @Test
+ public void testCyclic() throws Exception {
+ new SerializationTester(createCycle(1, 2), createCycle(3, 4))
+ .addCodec(new DynamicCodec<>(CycleA.class))
+ .addCodec(new DynamicCodec<>(CycleB.class))
+ .makeMemoizing()
+ .runTests();
+ }
+
+ enum EnumExample {
+ ZERO,
+ ONE,
+ TWO,
+ THREE
+ }
+
+ static class PrimitiveExample {
+
+ private final boolean booleanValue;
+ private final int intValue;
+ private final double doubleValue;
+ private final EnumExample enumValue;
+ private final String stringValue;
+
+ PrimitiveExample(
+ boolean booleanValue,
+ int intValue,
+ double doubleValue,
+ EnumExample enumValue,
+ String stringValue) {
+ this.booleanValue = booleanValue;
+ this.intValue = intValue;
+ this.doubleValue = doubleValue;
+ this.enumValue = enumValue;
+ this.stringValue = stringValue;
+ }
+
+ @SuppressWarnings("EqualsHashCode") // Testing
+ @Override
+ public boolean equals(Object object) {
+ if (object == null) {
+ return false;
+ }
+ PrimitiveExample that = (PrimitiveExample) object;
+ return booleanValue == that.booleanValue
+ && intValue == that.intValue
+ && doubleValue == that.doubleValue
+ && Objects.equals(enumValue, that.enumValue)
+ && Objects.equals(stringValue, that.stringValue);
+ }
+ }
+
+ @Test
+ public void testPrimitiveExample() throws Exception {
+ new SerializationTester(
+ new PrimitiveExample(true, 1, 1.1, EnumExample.ZERO, "foo"),
+ new PrimitiveExample(false, -1, -5.5, EnumExample.ONE, "bar"),
+ new PrimitiveExample(true, 5, 20.0, EnumExample.THREE, null),
+ new PrimitiveExample(true, 100, 100, null, "hello"))
+ .addCodec(
+ new DynamicCodec<>(
+ PrimitiveExample.class, ObjectCodec.MemoizationStrategy.DO_NOT_MEMOIZE))
+ .addCodec(new EnumCodec<>(EnumExample.class))
+ .setRepetitions(100000)
+ .runTests();
+ }
+}