diff options
3 files changed, 151 insertions, 7 deletions
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/Memoizer.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/Memoizer.java index c60b62fb5f..234c74f7d7 100644 --- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/Memoizer.java +++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/Memoizer.java @@ -135,6 +135,8 @@ class Memoizer { /** A context for serializing; wraps a memo table. Not thread-safe. */ static class Serializer { private final SerializingMemoTable memo = new SerializingMemoTable(); + @Nullable private String lastString = null; + /** * Serializes an object using the given codec and current memo table state. * @@ -149,18 +151,59 @@ class Memoizer { throws SerializationException, IOException { MemoizationStrategy strategy = codec.getStrategy(); if (strategy == MemoizationStrategy.DO_NOT_MEMOIZE) { - codec.serialize(context, obj, codedOut); + // TODO(janakr): there is no reason this is limited to the DO_NOT_MEMOIZE case, but we don't + // memoize Strings, so putting the code here saves a tiny bit of work in the other cases. If + // the StringCodec#getStrategy changes, this block of code will have to move. + if (!maybeEmitString(context, obj, codec, codedOut)) { + codec.serialize(context, obj, codedOut); + } } else { // The caller already checked the table, so this is definitely a new value. serializeMemoContent(context, obj, codec, codedOut, strategy); } } + private <T> boolean maybeEmitString( + SerializationContext context, + T obj, + ObjectCodec<? super T> codec, + CodedOutputStream codedOut) + throws SerializationException, IOException { + if (!(obj instanceof String)) { + return false; + } + int commonPrefixLen = -1; + String str = (String) obj; + if (lastString != null) { + commonPrefixLen = commonPrefixLen(str, lastString); + if (commonPrefixLen != 0) { + @SuppressWarnings("unchecked") + T checkObj = (T) codec.getEncodedClass().cast(str.substring(commonPrefixLen)); + obj = checkObj; + } + } + lastString = str; + codec.serialize(context, obj, codedOut); + if (commonPrefixLen > -1) { + codedOut.writeInt32NoTag(commonPrefixLen); + } + return true; + } + @Nullable Integer getMemoizedIndex(Object obj) { return memo.lookupNullable(obj); } + private static int commonPrefixLen(String first, String second) { + int shared = 0; + int max = Math.min(first.length(), second.length()); + while (shared < max && first.charAt(shared) == second.charAt(shared)) { + ++shared; + } + return shared; + } + // Corresponds to MemoContent in the abstract grammar. private <T> void serializeMemoContent( SerializationContext context, @@ -224,6 +267,7 @@ class Memoizer { */ static class Deserializer { private final DeserializingMemoTable memo = new DeserializingMemoTable(); + @Nullable private String lastString = null; @Nullable private Integer tagForMemoizedBefore = null; private final Deque<Object> memoizedBeforeStackForSanityChecking = new ArrayDeque<>(); @@ -243,7 +287,7 @@ class Memoizer { codec); MemoizationStrategy strategy = codec.getStrategy(); if (strategy == MemoizationStrategy.DO_NOT_MEMOIZE) { - return codec.deserialize(context, codedIn); + return maybeTransformString(codec.deserialize(context, codedIn), codec, codedIn); } else { switch (strategy) { case MEMOIZE_BEFORE: @@ -256,6 +300,42 @@ class Memoizer { } } + private <T> T maybeTransformString( + T value, ObjectCodec<? extends T> codec, CodedInputStream codedIn) throws IOException { + if (!(value instanceof String)) { + return value; + } + String str = (String) value; + if (lastString != null) { + int commonPrefixLen = codedIn.readInt32(); + Preconditions.checkState( + commonPrefixLen > -1, "Bad data for %s and %s (%s)", str, lastString, commonPrefixLen); + if (commonPrefixLen > 0) { + int lastLen = lastString.length(); + Preconditions.checkState( + lastLen >= commonPrefixLen, + "Bad data for %s (%s and %s)", + str, + lastString, + commonPrefixLen); + if (str.isEmpty()) { + // This is a substring or the same string. Save some garbage by re-using if possible. + if (commonPrefixLen < lastLen) { + str = lastString.substring(0, commonPrefixLen); + } else { + // commonPrefixLen == lastLen. + str = lastString; + } + } else { + str = lastString.substring(0, commonPrefixLen) + str; + } + value = codec.getEncodedClass().cast(str); + } + } + lastString = str; + return value; + } + Object getMemoized(int memoIndex) { return Preconditions.checkNotNull(memo.lookup(memoIndex), memoIndex); } diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContextTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContextTest.java index f151c5aa54..776cf2da6c 100644 --- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContextTest.java +++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContextTest.java @@ -15,6 +15,7 @@ package com.google.devtools.build.lib.skyframe.serialization; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static com.google.devtools.build.lib.testutil.MoreAsserts.assertThrows; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.when; @@ -23,9 +24,11 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.devtools.build.lib.skyframe.serialization.ObjectCodec.MemoizationStrategy; +import com.google.devtools.build.lib.skyframe.serialization.strings.StringCodecs; import com.google.devtools.build.lib.skyframe.serialization.testutils.TestUtils; import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.concurrent.atomic.AtomicBoolean; @@ -128,6 +131,69 @@ public class SerializationContextTest { } @Test + public void memoizingStringPrefixes_EndtoEnd() throws Exception { + ObjectCodec<String> codec = StringCodecs.simple(); + Memoizer.Serializer serializer = new Memoizer.Serializer(); + SerializationContext serializationContext = Mockito.mock(SerializationContext.class); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + CodedOutputStream codedOutputStream = CodedOutputStream.newInstance(byteArrayOutputStream); + + // Serialize all the things. + serializer.serialize(serializationContext, "string", codec, codedOutputStream); + serializer.serialize(serializationContext, "string2", codec, codedOutputStream); + serializer.serialize(serializationContext, "string2", codec, codedOutputStream); + serializer.serialize(serializationContext, "strip", codec, codedOutputStream); + serializer.serialize(serializationContext, "banana", codec, codedOutputStream); + serializer.serialize(serializationContext, "", codec, codedOutputStream); + serializer.serialize(serializationContext, "peach", codec, codedOutputStream); + + // SerializationContext not used for simple string serialization. + Mockito.verifyZeroInteractions(serializationContext); + + // Flush outputs and assert not too much data was written. + codedOutputStream.flush(); + byteArrayOutputStream.flush(); + byte[] bytes = byteArrayOutputStream.toByteArray(); + int stringOverhead = 1; + int prefixOverhead = 1; + // Every string but the first has prefixOverhead, every string has stringOverhead. + assertThat(bytes.length) + .isEqualTo( + "string".length() + + "2".length() + + "p".length() + + "banana".length() + + "peach".length() + + 7 * stringOverhead + + 6 * prefixOverhead); + + // Prepare inputs. + CodedInputStream codedInputStream = CodedInputStream.newInstance(bytes); + Memoizer.Deserializer deserializer = new Memoizer.Deserializer(); + DeserializationContext deserializationContext = Mockito.mock(DeserializationContext.class); + + // Deserialize and assert fidelity. + assertThat(deserializer.deserialize(deserializationContext, codec, codedInputStream)) + .isEqualTo("string"); + String returnedString = + deserializer.deserialize(deserializationContext, codec, codedInputStream); + assertThat(returnedString).isEqualTo("string2"); + String newReturnedString = + deserializer.deserialize(deserializationContext, codec, codedInputStream); + assertThat(newReturnedString).isEqualTo("string2"); + assertWithMessage("Same string twice in a row should be the same object") + .that(newReturnedString) + .isSameAs(returnedString); + assertThat(deserializer.deserialize(deserializationContext, codec, codedInputStream)) + .isEqualTo("strip"); + assertThat(deserializer.deserialize(deserializationContext, codec, codedInputStream)) + .isEqualTo("banana"); + assertThat(deserializer.deserialize(deserializationContext, codec, codedInputStream)).isEmpty(); + assertThat(deserializer.deserialize(deserializationContext, codec, codedInputStream)) + .isEqualTo("peach"); + } + + @Test public void startMemoizingIsIdempotent() throws IOException, SerializationException { ObjectCodecRegistry registry = ObjectCodecRegistry.newBuilder() diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/strings/StringCodecTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/strings/StringCodecTest.java index 8e8b80ea0b..178071c345 100644 --- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/strings/StringCodecTest.java +++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/strings/StringCodecTest.java @@ -14,7 +14,7 @@ package com.google.devtools.build.lib.skyframe.serialization.strings; -import com.google.devtools.build.lib.skyframe.serialization.testutils.ObjectCodecTester; +import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -22,11 +22,9 @@ import org.junit.runners.JUnit4; /** Basic tests for {@link StringCodec}. */ @RunWith(JUnit4.class) public class StringCodecTest { - @Test public void testCodec() throws Exception { - ObjectCodecTester.newBuilder(new StringCodec()) - .addSubjects("usually precomputed and supports weird unicodes: (╯°□°)╯︵┻━┻ ") - .buildAndRunTests(); + new SerializationTester("usually precomputed and supports weird unicodes: (╯°□°)╯︵┻━┻ ") + .runTests(); } } |