diff options
Diffstat (limited to 'java/core/src/main/java/com/google/protobuf/CodedInputStream.java')
-rw-r--r-- | java/core/src/main/java/com/google/protobuf/CodedInputStream.java | 299 |
1 files changed, 214 insertions, 85 deletions
diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index 3dfbcb0a..d6a941b1 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -34,8 +34,8 @@ import static com.google.protobuf.Internal.EMPTY_BYTE_ARRAY; import static com.google.protobuf.Internal.EMPTY_BYTE_BUFFER; import static com.google.protobuf.Internal.UTF_8; import static com.google.protobuf.Internal.checkNotNull; -import static com.google.protobuf.WireFormat.FIXED_32_SIZE; -import static com.google.protobuf.WireFormat.FIXED_64_SIZE; +import static com.google.protobuf.WireFormat.FIXED32_SIZE; +import static com.google.protobuf.WireFormat.FIXED64_SIZE; import static com.google.protobuf.WireFormat.MAX_VARINT_SIZE; import java.io.ByteArrayOutputStream; @@ -372,6 +372,64 @@ public abstract class CodedInputStream { return oldLimit; } + + private boolean explicitDiscardUnknownFields = false; + + /** TODO(liujisi): flip the default.*/ + private static volatile boolean proto3DiscardUnknownFieldsDefault = true; + + static void setProto3DiscardUnknownsByDefaultForTest() { + proto3DiscardUnknownFieldsDefault = true; + } + + static void setProto3KeepUnknownsByDefaultForTest() { + proto3DiscardUnknownFieldsDefault = false; + } + + static boolean getProto3DiscardUnknownFieldsDefault() { + return proto3DiscardUnknownFieldsDefault; + } + + /** + * Sets this {@code CodedInputStream} to discard unknown fields. Only applies to full runtime + * messages; lite messages will always preserve unknowns. + * + * <p>Note calling this function alone will have NO immediate effect on the underlying input data. + * The unknown fields will be discarded during parsing. This affects both Proto2 and Proto3 full + * runtime. + */ + final void discardUnknownFields() { + explicitDiscardUnknownFields = true; + } + + /** + * Reverts the unknown fields preservation behavior for Proto2 and Proto3 full runtime to their + * default. + */ + final void unsetDiscardUnknownFields() { + explicitDiscardUnknownFields = false; + } + + /** + * Whether unknown fields in this input stream should be discarded during parsing into full + * runtime messages. + */ + final boolean shouldDiscardUnknownFields() { + return explicitDiscardUnknownFields; + } + + /** + * Whether unknown fields in this input stream should be discarded during parsing for proto3 full + * runtime messages. + * + * <p>This function was temporarily introduced before proto3 unknown fields behavior is changed. + * TODO(liujisi): remove this and related code in GeneratedMessage after proto3 unknown + * fields migration is done. + */ + final boolean shouldDiscardUnknownFieldsProto3() { + return explicitDiscardUnknownFields ? true : proto3DiscardUnknownFieldsDefault; + } + /** * Resets the current size counter to zero (see {@link #setSizeLimit(int)}). Only valid for {@link * InputStream}-backed streams. @@ -572,7 +630,7 @@ public abstract class CodedInputStream { skipRawVarint(); return true; case WireFormat.WIRETYPE_FIXED64: - skipRawBytes(FIXED_64_SIZE); + skipRawBytes(FIXED64_SIZE); return true; case WireFormat.WIRETYPE_LENGTH_DELIMITED: skipRawBytes(readRawVarint32()); @@ -585,7 +643,7 @@ public abstract class CodedInputStream { case WireFormat.WIRETYPE_END_GROUP: return false; case WireFormat.WIRETYPE_FIXED32: - skipRawBytes(FIXED_32_SIZE); + skipRawBytes(FIXED32_SIZE); return true; default: throw InvalidProtocolBufferException.invalidWireType(); @@ -1064,12 +1122,12 @@ public abstract class CodedInputStream { public int readRawLittleEndian32() throws IOException { int tempPos = pos; - if (limit - tempPos < FIXED_32_SIZE) { + if (limit - tempPos < FIXED32_SIZE) { throw InvalidProtocolBufferException.truncatedMessage(); } final byte[] buffer = this.buffer; - pos = tempPos + FIXED_32_SIZE; + pos = tempPos + FIXED32_SIZE; return (((buffer[tempPos] & 0xff)) | ((buffer[tempPos + 1] & 0xff) << 8) | ((buffer[tempPos + 2] & 0xff) << 16) @@ -1080,12 +1138,12 @@ public abstract class CodedInputStream { public long readRawLittleEndian64() throws IOException { int tempPos = pos; - if (limit - tempPos < FIXED_64_SIZE) { + if (limit - tempPos < FIXED64_SIZE) { throw InvalidProtocolBufferException.truncatedMessage(); } final byte[] buffer = this.buffer; - pos = tempPos + FIXED_64_SIZE; + pos = tempPos + FIXED64_SIZE; return (((buffer[tempPos] & 0xffL)) | ((buffer[tempPos + 1] & 0xffL) << 8) | ((buffer[tempPos + 2] & 0xffL) << 16) @@ -1290,7 +1348,7 @@ public abstract class CodedInputStream { skipRawVarint(); return true; case WireFormat.WIRETYPE_FIXED64: - skipRawBytes(FIXED_64_SIZE); + skipRawBytes(FIXED64_SIZE); return true; case WireFormat.WIRETYPE_LENGTH_DELIMITED: skipRawBytes(readRawVarint32()); @@ -1303,7 +1361,7 @@ public abstract class CodedInputStream { case WireFormat.WIRETYPE_END_GROUP: return false; case WireFormat.WIRETYPE_FIXED32: - skipRawBytes(FIXED_32_SIZE); + skipRawBytes(FIXED32_SIZE); return true; default: throw InvalidProtocolBufferException.invalidWireType(); @@ -1429,7 +1487,9 @@ public abstract class CodedInputStream { final int size = readRawVarint32(); if (size > 0 && size <= remaining()) { // TODO(nathanmittler): Is there a way to avoid this copy? - byte[] bytes = copyToArray(pos, pos + size); + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); String result = new String(bytes, UTF_8); pos += size; return result; @@ -1449,7 +1509,9 @@ public abstract class CodedInputStream { final int size = readRawVarint32(); if (size >= 0 && size <= remaining()) { // TODO(nathanmittler): Is there a way to avoid this copy? - byte[] bytes = copyToArray(pos, pos + size); + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); // TODO(martinrb): We could save a pass by validating while decoding. if (!Utf8.isValidUtf8(bytes)) { throw InvalidProtocolBufferException.invalidUtf8(); @@ -1545,14 +1607,17 @@ public abstract class CodedInputStream { public ByteString readBytes() throws IOException { final int size = readRawVarint32(); if (size > 0 && size <= remaining()) { - ByteBuffer result; if (immutable && enableAliasing) { - result = slice(pos, pos + size); + final ByteBuffer result = slice(pos, pos + size); + pos += size; + return ByteString.wrap(result); } else { - result = copy(pos, pos + size); + // Use UnsafeUtil to copy the memory to bytes instead of using ByteBuffer ways. + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + pos += size; + return ByteString.wrap(bytes); } - pos += size; - return ByteString.wrap(result); } if (size == 0) { @@ -1573,18 +1638,21 @@ public abstract class CodedInputStream { public ByteBuffer readByteBuffer() throws IOException { final int size = readRawVarint32(); if (size > 0 && size <= remaining()) { - ByteBuffer result; // "Immutable" implies that buffer is backing a ByteString. // Disallow slicing in this case to prevent the caller from modifying the contents // of the ByteString. if (!immutable && enableAliasing) { - result = slice(pos, pos + size); + final ByteBuffer result = slice(pos, pos + size); + pos += size; + return result; } else { - result = copy(pos, pos + size); + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + pos += size; + return ByteBuffer.wrap(bytes); } - pos += size; // TODO(nathanmittler): Investigate making the ByteBuffer be made read-only - return result; } if (size == 0) { @@ -1785,11 +1853,11 @@ public abstract class CodedInputStream { public int readRawLittleEndian32() throws IOException { long tempPos = pos; - if (limit - tempPos < FIXED_32_SIZE) { + if (limit - tempPos < FIXED32_SIZE) { throw InvalidProtocolBufferException.truncatedMessage(); } - pos = tempPos + FIXED_32_SIZE; + pos = tempPos + FIXED32_SIZE; return (((UnsafeUtil.getByte(tempPos) & 0xff)) | ((UnsafeUtil.getByte(tempPos + 1) & 0xff) << 8) | ((UnsafeUtil.getByte(tempPos + 2) & 0xff) << 16) @@ -1800,11 +1868,11 @@ public abstract class CodedInputStream { public long readRawLittleEndian64() throws IOException { long tempPos = pos; - if (limit - tempPos < FIXED_64_SIZE) { + if (limit - tempPos < FIXED64_SIZE) { throw InvalidProtocolBufferException.truncatedMessage(); } - pos = tempPos + FIXED_64_SIZE; + pos = tempPos + FIXED64_SIZE; return (((UnsafeUtil.getByte(tempPos) & 0xffL)) | ((UnsafeUtil.getByte(tempPos + 1) & 0xffL) << 8) | ((UnsafeUtil.getByte(tempPos + 2) & 0xffL) << 16) @@ -1943,27 +2011,6 @@ public abstract class CodedInputStream { buffer.limit(prevLimit); } } - - private ByteBuffer copy(long begin, long end) throws IOException { - return ByteBuffer.wrap(copyToArray(begin, end)); - } - - private byte[] copyToArray(long begin, long end) throws IOException { - int prevPos = buffer.position(); - int prevLimit = buffer.limit(); - try { - buffer.position(bufferPos(begin)); - buffer.limit(bufferPos(end)); - byte[] bytes = new byte[(int) (end - begin)]; - buffer.get(bytes); - return bytes; - } catch (IllegalArgumentException e) { - throw InvalidProtocolBufferException.truncatedMessage(); - } finally { - buffer.position(prevPos); - buffer.limit(prevLimit); - } - } } /** @@ -2034,7 +2081,7 @@ public abstract class CodedInputStream { skipRawVarint(); return true; case WireFormat.WIRETYPE_FIXED64: - skipRawBytes(FIXED_64_SIZE); + skipRawBytes(FIXED64_SIZE); return true; case WireFormat.WIRETYPE_LENGTH_DELIMITED: skipRawBytes(readRawVarint32()); @@ -2047,7 +2094,7 @@ public abstract class CodedInputStream { case WireFormat.WIRETYPE_END_GROUP: return false; case WireFormat.WIRETYPE_FIXED32: - skipRawBytes(FIXED_32_SIZE); + skipRawBytes(FIXED32_SIZE); return true; default: throw InvalidProtocolBufferException.invalidWireType(); @@ -2332,8 +2379,7 @@ public abstract class CodedInputStream { if (size == 0) { return ByteString.EMPTY; } - // Slow path: Build a byte array first then copy it. - return ByteString.wrap(readRawBytesSlowPath(size)); + return readBytesSlowPath(size); } @Override @@ -2558,13 +2604,13 @@ public abstract class CodedInputStream { public int readRawLittleEndian32() throws IOException { int tempPos = pos; - if (bufferSize - tempPos < FIXED_32_SIZE) { - refillBuffer(FIXED_32_SIZE); + if (bufferSize - tempPos < FIXED32_SIZE) { + refillBuffer(FIXED32_SIZE); tempPos = pos; } final byte[] buffer = this.buffer; - pos = tempPos + FIXED_32_SIZE; + pos = tempPos + FIXED32_SIZE; return (((buffer[tempPos] & 0xff)) | ((buffer[tempPos + 1] & 0xff) << 8) | ((buffer[tempPos + 2] & 0xff) << 16) @@ -2575,13 +2621,13 @@ public abstract class CodedInputStream { public long readRawLittleEndian64() throws IOException { int tempPos = pos; - if (bufferSize - tempPos < FIXED_64_SIZE) { - refillBuffer(FIXED_64_SIZE); + if (bufferSize - tempPos < FIXED64_SIZE) { + refillBuffer(FIXED64_SIZE); tempPos = pos; } final byte[] buffer = this.buffer; - pos = tempPos + FIXED_64_SIZE; + pos = tempPos + FIXED64_SIZE; return (((buffer[tempPos] & 0xffL)) | ((buffer[tempPos + 1] & 0xffL) << 8) | ((buffer[tempPos + 2] & 0xffL) << 16) @@ -2675,7 +2721,13 @@ public abstract class CodedInputStream { */ private void refillBuffer(int n) throws IOException { if (!tryRefillBuffer(n)) { - throw InvalidProtocolBufferException.truncatedMessage(); + // We have to distinguish the exception between sizeLimitExceeded and truncatedMessage. So + // we just throw an sizeLimitExceeded exception here if it exceeds the sizeLimit + if (n > sizeLimit - totalBytesRetired - pos) { + throw InvalidProtocolBufferException.sizeLimitExceeded(); + } else { + throw InvalidProtocolBufferException.truncatedMessage(); + } } } @@ -2684,8 +2736,8 @@ public abstract class CodedInputStream { * buffer. Caller must ensure that the requested space is not yet available, and that the * requested space is less than BUFFER_SIZE. * - * @return {@code true} if the bytes could be made available; {@code false} if the end of the - * stream or the current limit was reached. + * @return {@code true} If the bytes could be made available; {@code false} 1. Current at the + * end of the stream 2. The current limit was reached 3. The total size limit was reached */ private boolean tryRefillBuffer(int n) throws IOException { if (pos + n <= bufferSize) { @@ -2693,6 +2745,14 @@ public abstract class CodedInputStream { "refillBuffer() called when " + n + " bytes were already available in buffer"); } + // Check whether the size of total message needs to read is bigger than the size limit. + // We shouldn't throw an exception here as isAtEnd() function needs to get this function's + // return as the result. + if (n > sizeLimit - totalBytesRetired - pos) { + return false; + } + + // Shouldn't throw the exception here either. if (totalBytesRetired + pos + n > currentLimit) { // Oops, we hit a limit. return false; @@ -2712,7 +2772,16 @@ public abstract class CodedInputStream { pos = 0; } - int bytesRead = input.read(buffer, bufferSize, buffer.length - bufferSize); + // Here we should refill the buffer as many bytes as possible. + int bytesRead = + input.read( + buffer, + bufferSize, + Math.min( + // the size of allocated but unused bytes in the buffer + buffer.length - bufferSize, + // do not exceed the total bytes limit + sizeLimit - totalBytesRetired - bufferSize)); if (bytesRead == 0 || bytesRead < -1 || bytesRead > buffer.length) { throw new IllegalStateException( "InputStream#read(byte[]) returned invalid result: " @@ -2721,10 +2790,6 @@ public abstract class CodedInputStream { } if (bytesRead > 0) { bufferSize += bytesRead; - // Integer-overflow-conscious check against sizeLimit - if (totalBytesRetired + n - sizeLimit > 0) { - throw InvalidProtocolBufferException.sizeLimitExceeded(); - } recomputeBufferSizeAfterLimit(); return (bufferSize >= n) ? true : tryRefillBuffer(n); } @@ -2756,6 +2821,49 @@ public abstract class CodedInputStream { * (bufferSize - pos) && size > 0) */ private byte[] readRawBytesSlowPath(final int size) throws IOException { + // Attempt to read the data in one byte array when it's safe to do. + byte[] result = readRawBytesSlowPathOneChunk(size); + if (result != null) { + return result; + } + + final int originalBufferPos = pos; + final int bufferedBytes = bufferSize - pos; + + // Mark the current buffer consumed. + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; + + // Determine the number of bytes we need to read from the input stream. + int sizeLeft = size - bufferedBytes; + + // The size is very large. For security reasons we read them in small + // chunks. + List<byte[]> chunks = readRawBytesSlowPathRemainingChunks(sizeLeft); + + // OK, got everything. Now concatenate it all into one buffer. + final byte[] bytes = new byte[size]; + + // Start by copying the leftover bytes from this.buffer. + System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + + // And now all the chunks. + int tempPos = bufferedBytes; + for (final byte[] chunk : chunks) { + System.arraycopy(chunk, 0, bytes, tempPos, chunk.length); + tempPos += chunk.length; + } + + // Done. + return bytes; + } + + /** + * Attempts to read the data in one byte array when it's safe to do. Returns null if the size to + * read is too large and needs to be allocated in smaller chunks for security reasons. + */ + private byte[] readRawBytesSlowPathOneChunk(final int size) throws IOException { if (size == 0) { return Internal.EMPTY_BYTE_ARRAY; } @@ -2776,14 +2884,7 @@ public abstract class CodedInputStream { throw InvalidProtocolBufferException.truncatedMessage(); } - final int originalBufferPos = pos; final int bufferedBytes = bufferSize - pos; - - // Mark the current buffer consumed. - totalBytesRetired += bufferSize; - pos = 0; - bufferSize = 0; - // Determine the number of bytes we need to read from the input stream. int sizeLeft = size - bufferedBytes; // TODO(nathanmittler): Consider using a value larger than DEFAULT_BUFFER_SIZE. @@ -2793,7 +2894,10 @@ public abstract class CodedInputStream { final byte[] bytes = new byte[size]; // Copy all of the buffered bytes to the result buffer. - System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + System.arraycopy(buffer, pos, bytes, 0, bufferedBytes); + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; // Fill the remaining bytes from the input stream. int tempPos = bufferedBytes; @@ -2809,6 +2913,11 @@ public abstract class CodedInputStream { return bytes; } + return null; + } + + /** Reads the remaining data in small chunks from the input stream. */ + private List<byte[]> readRawBytesSlowPathRemainingChunks(int sizeLeft) throws IOException { // The size is very large. For security reasons, we can't allocate the // entire byte array yet. The size comes directly from the input, so a // maliciously-crafted message could provide a bogus very large size in @@ -2834,21 +2943,41 @@ public abstract class CodedInputStream { chunks.add(chunk); } - // OK, got everything. Now concatenate it all into one buffer. - final byte[] bytes = new byte[size]; - - // Start by copying the leftover bytes from this.buffer. - System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + return chunks; + } - // And now all the chunks. - int tempPos = bufferedBytes; - for (final byte[] chunk : chunks) { - System.arraycopy(chunk, 0, bytes, tempPos, chunk.length); - tempPos += chunk.length; + /** + * Like readBytes, but caller must have already checked the fast path: (size <= (bufferSize - + * pos) && size > 0 || size == 0) + */ + private ByteString readBytesSlowPath(final int size) throws IOException { + final byte[] result = readRawBytesSlowPathOneChunk(size); + if (result != null) { + return ByteString.wrap(result); } - // Done. - return bytes; + final int originalBufferPos = pos; + final int bufferedBytes = bufferSize - pos; + + // Mark the current buffer consumed. + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; + + // Determine the number of bytes we need to read from the input stream. + int sizeLeft = size - bufferedBytes; + + // The size is very large. For security reasons we read them in small + // chunks. + List<byte[]> chunks = readRawBytesSlowPathRemainingChunks(sizeLeft); + + // Wrap the byte arrays into a single ByteString. + List<ByteString> byteStrings = new ArrayList<ByteString>(1 + chunks.size()); + byteStrings.add(ByteString.copyFrom(buffer, originalBufferPos, bufferedBytes)); + for (byte[] chunk : chunks) { + byteStrings.add(ByteString.wrap(chunk)); + } + return ByteString.copyFrom(byteStrings); } @Override |