diff options
author | 2018-03-21 17:10:27 -0700 | |
---|---|---|
committer | 2018-03-21 17:11:39 -0700 | |
commit | 2410e1ab3e035382abe519003c618271a69a7b8e (patch) | |
tree | d39cc4379b69b2371efc2f093ed0482b3a5254fa /src/test | |
parent | 0f5679ef95611e457a6e39313cf88feac8b2278f (diff) |
Clean up unnecessary "additional data" from memoizing deserialization. Since memoization is now a simple on-off switch, change semantics to have at most one memoizing frame: starting memoization is now an idempotent operation.
PiperOrigin-RevId: 189993914
Diffstat (limited to 'src/test')
2 files changed, 81 insertions, 16 deletions
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java index b80f005467..31efc088b0 100644 --- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java +++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java @@ -83,11 +83,7 @@ public class DeserializationContextTest { DeserializationContext deserializationContext = new DeserializationContext(registry, ImmutableMap.of()); when(codedInputStream.readSInt32()).thenReturn(0); - assertThat( - (Object) - deserializationContext - .newMemoizingContext(new Object()) - .deserialize(codedInputStream)) + assertThat((Object) deserializationContext.getMemoizingContext().deserialize(codedInputStream)) .isEqualTo(null); Mockito.verify(codedInputStream).readSInt32(); Mockito.verifyZeroInteractions(registry); @@ -102,11 +98,7 @@ public class DeserializationContextTest { DeserializationContext deserializationContext = new DeserializationContext(registry, ImmutableMap.of()); when(codedInputStream.readSInt32()).thenReturn(1); - assertThat( - (Object) - deserializationContext - .newMemoizingContext(new Object()) - .deserialize(codedInputStream)) + assertThat((Object) deserializationContext.getMemoizingContext().deserialize(codedInputStream)) .isEqualTo(constant); Mockito.verify(codedInputStream).readSInt32(); Mockito.verify(registry).maybeGetConstantByTag(1); @@ -127,7 +119,7 @@ public class DeserializationContextTest { when(registry.getCodecDescriptorByTag(1)).thenReturn(codecDescriptor); CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class); DeserializationContext deserializationContext = - new DeserializationContext(registry, ImmutableMap.of()).newMemoizingContext(new Object()); + new DeserializationContext(registry, ImmutableMap.of()).getMemoizingContext(); when(codec.deserialize(deserializationContext, codedInputStream)).thenReturn(returned); when(codedInputStream.readSInt32()).thenReturn(1); assertThat((Object) deserializationContext.deserialize(codedInputStream)).isEqualTo(returned); diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContextTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContextTest.java index bd8ed19a1d..0508bb5e7d 100644 --- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContextTest.java +++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContextTest.java @@ -14,10 +14,13 @@ package com.google.devtools.build.lib.skyframe.serialization; +import static com.google.common.truth.Truth.assertThat; import static com.google.devtools.build.lib.testutil.MoreAsserts.assertThrows; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.when; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.devtools.build.lib.skyframe.serialization.ObjectCodec.MemoizationStrategy; import com.google.devtools.build.lib.skyframe.serialization.testutils.TestUtils; @@ -25,6 +28,7 @@ import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import java.io.IOException; import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -81,7 +85,7 @@ public class SerializationContextTest { CodedOutputStream codedOutputStream = Mockito.mock(CodedOutputStream.class); SerializationContext serializationContext = new SerializationContext(registry, ImmutableMap.of()); - serializationContext.newMemoizingContext().serialize(null, codedOutputStream); + serializationContext.getMemoizingContext().serialize(null, codedOutputStream); Mockito.verify(codedOutputStream).writeSInt32NoTag(0); Mockito.verifyZeroInteractions(registry); } @@ -94,7 +98,7 @@ public class SerializationContextTest { SerializationContext serializationContext = new SerializationContext(registry, ImmutableMap.of()); Object constant = new Object(); - serializationContext.newMemoizingContext().serialize(constant, codedOutputStream); + serializationContext.getMemoizingContext().serialize(constant, codedOutputStream); Mockito.verify(codedOutputStream).writeSInt32NoTag(1); Mockito.verify(registry).maybeGetTagForConstant(constant); } @@ -113,7 +117,7 @@ public class SerializationContextTest { when(registry.getCodecDescriptor(String.class)).thenReturn(codecDescriptor); CodedOutputStream codedOutputStream = Mockito.mock(CodedOutputStream.class); SerializationContext underTest = - new SerializationContext(registry, ImmutableMap.of()).newMemoizingContext(); + new SerializationContext(registry, ImmutableMap.of()).getMemoizingContext(); underTest.serialize("string", codedOutputStream); Mockito.verify(codedOutputStream).writeSInt32NoTag(1); Mockito.verify(registry).maybeGetTagForConstant("string"); @@ -124,6 +128,75 @@ public class SerializationContextTest { } @Test + public void startMemoizingIsIdempotent() throws IOException, SerializationException { + ObjectCodecRegistry registry = + ObjectCodecRegistry.newBuilder() + .add(new CodecMemoizing()) + .add(new CalledOnlyOnce()) + .build(); + + String repeated = "repeated string"; + ImmutableList<Object> obj = ImmutableList.of(ImmutableList.of(repeated, repeated), repeated); + assertThat(TestUtils.roundTrip(obj, registry)).isEqualTo(obj); + } + + private static class CodecMemoizing implements ObjectCodec<ImmutableList<Object>> { + @SuppressWarnings("unchecked") + @Override + public Class<ImmutableList<Object>> getEncodedClass() { + return (Class<ImmutableList<Object>>) (Class<?>) ImmutableList.class; + } + + @Override + public void serialize( + SerializationContext context, ImmutableList<Object> obj, CodedOutputStream codedOut) + throws SerializationException, IOException { + context = context.getMemoizingContext(); + codedOut.writeInt32NoTag(obj.size()); + for (Object item : obj) { + context.serialize(item, codedOut); + } + } + + @Override + public ImmutableList<Object> deserialize( + DeserializationContext context, CodedInputStream codedIn) + throws SerializationException, IOException { + context = context.getMemoizingContext(); + int size = codedIn.readInt32(); + ImmutableList.Builder<Object> builder = ImmutableList.builder(); + for (int i = 0; i < size; i++) { + builder.add(context.<Object>deserialize(codedIn)); + } + return builder.build(); + } + } + + private static class CalledOnlyOnce implements ObjectCodec<String> { + private final AtomicBoolean serializationCalled = new AtomicBoolean(false); + private final AtomicBoolean deserializationCalled = new AtomicBoolean(false); + + @Override + public Class<String> getEncodedClass() { + return String.class; + } + + @Override + public void serialize(SerializationContext context, String obj, CodedOutputStream codedOut) + throws IOException { + Preconditions.checkState(!serializationCalled.getAndSet(true)); + codedOut.writeStringNoTag(obj); + } + + @Override + public String deserialize(DeserializationContext context, CodedInputStream codedIn) + throws IOException { + Preconditions.checkState(!deserializationCalled.getAndSet(true)); + return codedIn.readString(); + } + } + + @Test public void mismatchMemoizingRoundtrip() { ArrayList<Object> repeatedObject = new ArrayList<>(); repeatedObject.add(null); @@ -136,7 +209,7 @@ public class SerializationContextTest { assertThrows( SerializationException.class, () -> - TestUtils.roundTripMemoized( + TestUtils.roundTrip( toSerialize, ObjectCodecRegistry.newBuilder() .add(new BadCodecOnlyMemoizesWhenDeserializing()) @@ -163,7 +236,7 @@ public class SerializationContextTest { @Override public ArrayList<?> deserialize(DeserializationContext context, CodedInputStream codedIn) throws SerializationException, IOException { - context = context.newMemoizingContext(new Object()); + context = context.getMemoizingContext(); int size = codedIn.readInt32(); ArrayList<?> result = new ArrayList<>(); for (int i = 0; i < size; i++) { |