aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib')
-rw-r--r--tensorflow/core/lib/core/arena.cc246
-rw-r--r--tensorflow/core/lib/core/arena.h90
-rw-r--r--tensorflow/core/lib/core/arena_test.cc92
-rw-r--r--tensorflow/core/lib/core/bit_cast_test.cc95
-rw-r--r--tensorflow/core/lib/core/bits.h84
-rw-r--r--tensorflow/core/lib/core/blocking_counter.h41
-rw-r--r--tensorflow/core/lib/core/blocking_counter_test.cc36
-rw-r--r--tensorflow/core/lib/core/casts.h85
-rw-r--r--tensorflow/core/lib/core/coding.cc164
-rw-r--r--tensorflow/core/lib/core/coding.h55
-rw-r--r--tensorflow/core/lib/core/coding_test.cc168
-rw-r--r--tensorflow/core/lib/core/command_line_flags.cc94
-rw-r--r--tensorflow/core/lib/core/command_line_flags.h60
-rw-r--r--tensorflow/core/lib/core/error_codes.proto145
-rw-r--r--tensorflow/core/lib/core/errors.h131
-rw-r--r--tensorflow/core/lib/core/notification.h42
-rw-r--r--tensorflow/core/lib/core/notification_test.cc64
-rw-r--r--tensorflow/core/lib/core/raw_coding.h43
-rw-r--r--tensorflow/core/lib/core/refcount.cc35
-rw-r--r--tensorflow/core/lib/core/refcount.h63
-rw-r--r--tensorflow/core/lib/core/refcount_test.cc92
-rw-r--r--tensorflow/core/lib/core/status.cc107
-rw-r--r--tensorflow/core/lib/core/status_test.cc84
-rw-r--r--tensorflow/core/lib/core/status_test_util.h20
-rw-r--r--tensorflow/core/lib/core/stringpiece.cc57
-rw-r--r--tensorflow/core/lib/core/stringpiece.h159
-rw-r--r--tensorflow/core/lib/core/threadpool.cc108
-rw-r--r--tensorflow/core/lib/core/threadpool.h59
-rw-r--r--tensorflow/core/lib/core/threadpool_test.cc93
-rw-r--r--tensorflow/core/lib/gtl/array_slice.h299
-rw-r--r--tensorflow/core/lib/gtl/array_slice_internal.h253
-rw-r--r--tensorflow/core/lib/gtl/array_slice_test.cc646
-rw-r--r--tensorflow/core/lib/gtl/edit_distance.h82
-rw-r--r--tensorflow/core/lib/gtl/edit_distance_test.cc125
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector.h839
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector_test.cc905
-rw-r--r--tensorflow/core/lib/gtl/int_type.h343
-rw-r--r--tensorflow/core/lib/gtl/int_type_test.cc282
-rw-r--r--tensorflow/core/lib/gtl/iterator_range.h49
-rw-r--r--tensorflow/core/lib/gtl/iterator_range_test.cc60
-rw-r--r--tensorflow/core/lib/gtl/manual_constructor.h230
-rw-r--r--tensorflow/core/lib/gtl/manual_constructor_test.cc113
-rw-r--r--tensorflow/core/lib/gtl/map_util.h123
-rw-r--r--tensorflow/core/lib/gtl/map_util_test.cc47
-rw-r--r--tensorflow/core/lib/gtl/stl_util.h130
-rw-r--r--tensorflow/core/lib/gtl/top_n.h324
-rw-r--r--tensorflow/core/lib/gtl/top_n_test.cc249
-rw-r--r--tensorflow/core/lib/hash/crc32c.cc244
-rw-r--r--tensorflow/core/lib/hash/crc32c.h39
-rw-r--r--tensorflow/core/lib/hash/crc32c_test.cc51
-rw-r--r--tensorflow/core/lib/hash/hash.cc113
-rw-r--r--tensorflow/core/lib/hash/hash.h28
-rw-r--r--tensorflow/core/lib/hash/hash_test.cc64
-rw-r--r--tensorflow/core/lib/histogram/histogram.cc247
-rw-r--r--tensorflow/core/lib/histogram/histogram.h119
-rw-r--r--tensorflow/core/lib/histogram/histogram_test.cc112
-rw-r--r--tensorflow/core/lib/io/block.cc236
-rw-r--r--tensorflow/core/lib/io/block.h45
-rw-r--r--tensorflow/core/lib/io/block_builder.cc107
-rw-r--r--tensorflow/core/lib/io/block_builder.h57
-rw-r--r--tensorflow/core/lib/io/format.cc148
-rw-r--r--tensorflow/core/lib/io/format.h99
-rw-r--r--tensorflow/core/lib/io/inputbuffer.cc112
-rw-r--r--tensorflow/core/lib/io/inputbuffer.h62
-rw-r--r--tensorflow/core/lib/io/inputbuffer_test.cc174
-rw-r--r--tensorflow/core/lib/io/iterator.cc72
-rw-r--r--tensorflow/core/lib/io/iterator.h93
-rw-r--r--tensorflow/core/lib/io/match.cc31
-rw-r--r--tensorflow/core/lib/io/match.h24
-rw-r--r--tensorflow/core/lib/io/match_test.cc51
-rw-r--r--tensorflow/core/lib/io/path.cc92
-rw-r--r--tensorflow/core/lib/io/path.h47
-rw-r--r--tensorflow/core/lib/io/path_test.cc65
-rw-r--r--tensorflow/core/lib/io/record_reader.cc80
-rw-r--r--tensorflow/core/lib/io/record_reader.h36
-rw-r--r--tensorflow/core/lib/io/record_writer.cc42
-rw-r--r--tensorflow/core/lib/io/record_writer.h34
-rw-r--r--tensorflow/core/lib/io/recordio_test.cc245
-rw-r--r--tensorflow/core/lib/io/table.cc169
-rw-r--r--tensorflow/core/lib/io/table.h76
-rw-r--r--tensorflow/core/lib/io/table_builder.cc263
-rw-r--r--tensorflow/core/lib/io/table_builder.h87
-rw-r--r--tensorflow/core/lib/io/table_format.txt8
-rw-r--r--tensorflow/core/lib/io/table_options.h53
-rw-r--r--tensorflow/core/lib/io/table_test.cc601
-rw-r--r--tensorflow/core/lib/io/two_level_iterator.cc148
-rw-r--r--tensorflow/core/lib/io/two_level_iterator.h30
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_handle.cc162
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_handle.h51
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.cc557
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.h130
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc304
-rw-r--r--tensorflow/core/lib/jpeg/testdata/bad_huffman.jpgbin0 -> 15416 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/corrupt.jpgbin0 -> 1552 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpgbin0 -> 755 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpgbin0 -> 5505 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpgbin0 -> 5092 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpgbin0 -> 3771 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpgbin0 -> 5324 bytes
-rw-r--r--tensorflow/core/lib/png/png_io.cc385
-rw-r--r--tensorflow/core/lib/png/png_io.h88
-rw-r--r--tensorflow/core/lib/png/testdata/lena_gray.pngbin0 -> 1491 bytes
-rw-r--r--tensorflow/core/lib/png/testdata/lena_rgba.pngbin0 -> 4032 bytes
-rw-r--r--tensorflow/core/lib/random/distribution_sampler.cc80
-rw-r--r--tensorflow/core/lib/random/distribution_sampler.h79
-rw-r--r--tensorflow/core/lib/random/distribution_sampler_test.cc90
-rw-r--r--tensorflow/core/lib/random/exact_uniform_int.h68
-rw-r--r--tensorflow/core/lib/random/philox_random.h232
-rw-r--r--tensorflow/core/lib/random/philox_random_test.cc58
-rw-r--r--tensorflow/core/lib/random/philox_random_test_utils.h36
-rw-r--r--tensorflow/core/lib/random/random.cc22
-rw-r--r--tensorflow/core/lib/random/random.h16
-rw-r--r--tensorflow/core/lib/random/random_distributions.h361
-rw-r--r--tensorflow/core/lib/random/random_distributions_test.cc270
-rw-r--r--tensorflow/core/lib/random/random_test.cc21
-rw-r--r--tensorflow/core/lib/random/simple_philox.cc24
-rw-r--r--tensorflow/core/lib/random/simple_philox.h61
-rw-r--r--tensorflow/core/lib/random/simple_philox_test.cc120
-rw-r--r--tensorflow/core/lib/random/weighted_picker.cc203
-rw-r--r--tensorflow/core/lib/random/weighted_picker.h118
-rw-r--r--tensorflow/core/lib/random/weighted_picker_test.cc254
-rw-r--r--tensorflow/core/lib/strings/numbers.cc260
-rw-r--r--tensorflow/core/lib/strings/numbers.h92
-rw-r--r--tensorflow/core/lib/strings/numbers_test.cc113
-rw-r--r--tensorflow/core/lib/strings/ordered_code.cc515
-rw-r--r--tensorflow/core/lib/strings/ordered_code.h77
-rw-r--r--tensorflow/core/lib/strings/ordered_code_test.cc1183
-rw-r--r--tensorflow/core/lib/strings/str_util.cc312
-rw-r--r--tensorflow/core/lib/strings/str_util.h149
-rw-r--r--tensorflow/core/lib/strings/str_util_test.cc258
-rw-r--r--tensorflow/core/lib/strings/strcat.cc194
-rw-r--r--tensorflow/core/lib/strings/strcat.h229
-rw-r--r--tensorflow/core/lib/strings/strcat_test.cc324
-rw-r--r--tensorflow/core/lib/strings/stringprintf.cc85
-rw-r--r--tensorflow/core/lib/strings/stringprintf.h37
-rw-r--r--tensorflow/core/lib/strings/stringprintf_test.cc113
136 files changed, 19846 insertions, 0 deletions
diff --git a/tensorflow/core/lib/core/arena.cc b/tensorflow/core/lib/core/arena.cc
new file mode 100644
index 0000000000..ceb1001af0
--- /dev/null
+++ b/tensorflow/core/lib/core/arena.cc
@@ -0,0 +1,246 @@
+// This approach to arenas overcomes many of the limitations described
+// in the "Specialized allocators" section of
+// http://www.pdos.lcs.mit.edu/~dm/c++-new.html
+//
+// A somewhat similar approach to Gladiator, but for heap-detection, was
+// suggested by Ron van der Wal and Scott Meyers at
+// http://www.aristeia.com/BookErrata/M27Comments_frames.html
+
+#include "tensorflow/core/lib/core/arena.h"
+
+#include <assert.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+namespace core {
+
+static const int kPageSize = getpagesize();
+
+// ----------------------------------------------------------------------
+// Arena::Arena()
+// Arena::~Arena()
+// Destroying the arena automatically calls Reset()
+// ----------------------------------------------------------------------
+
+Arena::Arena(const size_t block_size)
+ : remaining_(0),
+ block_size_(block_size),
+ freestart_(NULL), // set for real in Reset()
+ blocks_alloced_(1),
+ overflow_blocks_(NULL) {
+ assert(block_size > kDefaultAlignment);
+
+ first_blocks_[0].mem = reinterpret_cast<char*>(malloc(block_size_));
+
+ first_blocks_[0].size = block_size_;
+
+ Reset();
+}
+
+Arena::~Arena() {
+ FreeBlocks();
+ assert(overflow_blocks_ == NULL); // FreeBlocks() should do that
+ // The first X blocks stay allocated always by default. Delete them now.
+ for (size_t i = 0; i < blocks_alloced_; ++i) free(first_blocks_[i].mem);
+}
+
+// Returns true iff it advances freestart_ to the first position
+// satisfying alignment without exhausting the current block.
+bool Arena::SatisfyAlignment(size_t alignment) {
+ const size_t overage = reinterpret_cast<size_t>(freestart_) & (alignment - 1);
+ if (overage > 0) {
+ const size_t waste = alignment - overage;
+ if (waste >= remaining_) {
+ return false;
+ }
+ freestart_ += waste;
+ remaining_ -= waste;
+ }
+ DCHECK_EQ(0, reinterpret_cast<size_t>(freestart_) & (alignment - 1));
+ return true;
+}
+
+// ----------------------------------------------------------------------
+// Arena::Reset()
+// Clears all the memory an arena is using.
+// ----------------------------------------------------------------------
+
+void Arena::Reset() {
+ FreeBlocks();
+ freestart_ = first_blocks_[0].mem;
+ remaining_ = first_blocks_[0].size;
+
+ // There is no guarantee the first block is properly aligned, so
+ // enforce that now.
+ CHECK(SatisfyAlignment(kDefaultAlignment));
+
+ freestart_when_empty_ = freestart_;
+}
+
+// ----------------------------------------------------------------------
+// Arena::MakeNewBlock()
+// Our sbrk() equivalent. We always make blocks of the same size
+// (though GetMemory() can also make a new block for really big
+// data.
+// ----------------------------------------------------------------------
+
+void Arena::MakeNewBlock(const uint32 alignment) {
+ AllocatedBlock* block = AllocNewBlock(block_size_, alignment);
+ freestart_ = block->mem;
+ remaining_ = block->size;
+ CHECK(SatisfyAlignment(alignment));
+}
+
+// The following simple numeric routines also exist in util/math/mathutil.h
+// but we don't want to depend on that library.
+
+// Euclid's algorithm for Greatest Common Denominator.
+static uint32 GCD(uint32 x, uint32 y) {
+ while (y != 0) {
+ uint32 r = x % y;
+ x = y;
+ y = r;
+ }
+ return x;
+}
+
+static uint32 LeastCommonMultiple(uint32 a, uint32 b) {
+ if (a > b) {
+ return (a / GCD(a, b)) * b;
+ } else if (a < b) {
+ return (b / GCD(b, a)) * a;
+ } else {
+ return a;
+ }
+}
+
+// -------------------------------------------------------------
+// Arena::AllocNewBlock()
+// Adds and returns an AllocatedBlock.
+// The returned AllocatedBlock* is valid until the next call
+// to AllocNewBlock or Reset. (i.e. anything that might
+// affect overflow_blocks_).
+// -------------------------------------------------------------
+
+Arena::AllocatedBlock* Arena::AllocNewBlock(const size_t block_size,
+ const uint32 alignment) {
+ AllocatedBlock* block;
+ // Find the next block.
+ if (blocks_alloced_ < TF_ARRAYSIZE(first_blocks_)) {
+ // Use one of the pre-allocated blocks
+ block = &first_blocks_[blocks_alloced_++];
+ } else { // oops, out of space, move to the vector
+ if (overflow_blocks_ == NULL)
+ overflow_blocks_ = new std::vector<AllocatedBlock>;
+ // Adds another block to the vector.
+ overflow_blocks_->resize(overflow_blocks_->size() + 1);
+ // block points to the last block of the vector.
+ block = &overflow_blocks_->back();
+ }
+
+ // NOTE(tucker): this utility is made slightly more complex by
+ // not disallowing the case where alignment > block_size.
+ // Can we, without breaking existing code?
+
+ // Must be a multiple of kDefaultAlignment, unless requested
+ // alignment is 1, in which case we don't care at all.
+ const uint32 adjusted_alignment =
+ (alignment > 1 ? LeastCommonMultiple(alignment, kDefaultAlignment) : 1);
+
+ CHECK_LE(adjusted_alignment, 1 << 20)
+ << "Alignment on boundaries greater than 1MB not supported.";
+
+ // If block_size > alignment we force block_size to be a multiple
+ // of alignment; if block_size < alignment we make no adjustment.
+ size_t adjusted_block_size = block_size;
+ if (adjusted_alignment > 1) {
+ if (adjusted_block_size > adjusted_alignment) {
+ const uint32 excess = adjusted_block_size % adjusted_alignment;
+ adjusted_block_size += (excess > 0 ? adjusted_alignment - excess : 0);
+ }
+ block->mem = reinterpret_cast<char*>(
+ port::aligned_malloc(adjusted_block_size, adjusted_alignment));
+ } else {
+ block->mem = reinterpret_cast<char*>(malloc(adjusted_block_size));
+ }
+ block->size = adjusted_block_size;
+ CHECK(NULL != block->mem) << "block_size=" << block_size
+ << " adjusted_block_size=" << adjusted_block_size
+ << " alignment=" << alignment
+ << " adjusted_alignment=" << adjusted_alignment;
+
+ return block;
+}
+
+// ----------------------------------------------------------------------
+// Arena::GetMemoryFallback()
+// We take memory out of our pool, aligned on the byte boundary
+// requested. If we don't have space in our current pool, we
+// allocate a new block (wasting the remaining space in the
+// current block) and give you that. If your memory needs are
+// too big for a single block, we make a special your-memory-only
+// allocation -- this is equivalent to not using the arena at all.
+// ----------------------------------------------------------------------
+
+void* Arena::GetMemoryFallback(const size_t size, const int alignment) {
+ if (0 == size) {
+ return NULL; // stl/stl_alloc.h says this is okay
+ }
+
+ // alignment must be a positive power of 2.
+ CHECK(alignment > 0 && 0 == (alignment & (alignment - 1)));
+
+ // If the object is more than a quarter of the block size, allocate
+ // it separately to avoid wasting too much space in leftover bytes.
+ if (block_size_ == 0 || size > block_size_ / 4) {
+ return AllocNewBlock(size, alignment)->mem;
+ }
+
+ // Enforce alignment on freestart_ then check for adequate space,
+ // which may require starting a new block.
+ if (!SatisfyAlignment(alignment) || size > remaining_) {
+ MakeNewBlock(alignment);
+ }
+ CHECK_LE(size, remaining_);
+
+ remaining_ -= size;
+ void* result = freestart_;
+ freestart_ += size;
+
+ return result;
+}
+
+// ----------------------------------------------------------------------
+// Arena::ReturnMemoryFallback()
+// Arena::FreeBlocks()
+// Unlike GetMemory(), which does actual work, ReturnMemory() is a
+// no-op: we don't "free" memory until Reset() is called. We do
+// update some stats, though. Note we do no checking that the
+// pointer you pass in was actually allocated by us, or that it
+// was allocated for the size you say, so be careful here!
+// FreeBlocks() does the work for Reset(), actually freeing all
+// memory allocated in one fell swoop.
+// ----------------------------------------------------------------------
+
+void Arena::FreeBlocks() {
+ for (size_t i = 1; i < blocks_alloced_; ++i) { // keep first block alloced
+ free(first_blocks_[i].mem);
+ first_blocks_[i].mem = NULL;
+ first_blocks_[i].size = 0;
+ }
+ blocks_alloced_ = 1;
+ if (overflow_blocks_ != NULL) {
+ std::vector<AllocatedBlock>::iterator it;
+ for (it = overflow_blocks_->begin(); it != overflow_blocks_->end(); ++it) {
+ free(it->mem);
+ }
+ delete overflow_blocks_; // These should be used very rarely
+ overflow_blocks_ = NULL;
+ }
+}
+
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/arena.h b/tensorflow/core/lib/core/arena.h
new file mode 100644
index 0000000000..59896803bb
--- /dev/null
+++ b/tensorflow/core/lib/core/arena.h
@@ -0,0 +1,90 @@
+// TODO(vrv): Switch this to an open-sourced version of Arena.
+
+#ifndef TENSORFLOW_LIB_CORE_ARENA_H_
+#define TENSORFLOW_LIB_CORE_ARENA_H_
+
+#include <assert.h>
+
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace core {
+
+// This class is "thread-compatible": different threads can access the
+// arena at the same time without locking, as long as they use only
+// const methods.
+class Arena {
+ public:
+ // Allocates a thread-compatible arena with the specified block size.
+ explicit Arena(const size_t block_size);
+ ~Arena();
+
+ char* Alloc(const size_t size) {
+ return reinterpret_cast<char*>(GetMemory(size, 1));
+ }
+
+ void Reset();
+
+// This should be the worst-case alignment for any type. This is
+// good for IA-32, SPARC version 7 (the last one I know), and
+// supposedly Alpha. i386 would be more time-efficient with a
+// default alignment of 8, but ::operator new() uses alignment of 4,
+// and an assertion will fail below after the call to MakeNewBlock()
+// if you try to use a larger alignment.
+#ifdef __i386__
+ static const int kDefaultAlignment = 4;
+#else
+ static const int kDefaultAlignment = 8;
+#endif
+
+ protected:
+ bool SatisfyAlignment(const size_t alignment);
+ void MakeNewBlock(const uint32 alignment);
+ void* GetMemoryFallback(const size_t size, const int align);
+ void* GetMemory(const size_t size, const int align) {
+ assert(remaining_ <= block_size_); // an invariant
+ if (size > 0 && size < remaining_ && align == 1) { // common case
+ void* result = freestart_;
+ freestart_ += size;
+ remaining_ -= size;
+ return result;
+ }
+ return GetMemoryFallback(size, align);
+ }
+
+ size_t remaining_;
+
+ private:
+ struct AllocatedBlock {
+ char* mem;
+ size_t size;
+ };
+
+ // Allocate new new block of at least block_size, with the specified
+ // alignment.
+ // The returned AllocatedBlock* is valid until the next call to AllocNewBlock
+ // or Reset (i.e. anything that might affect overflow_blocks_).
+ AllocatedBlock* AllocNewBlock(const size_t block_size,
+ const uint32 alignment);
+
+ const size_t block_size_;
+ char* freestart_; // beginning of the free space in most recent block
+ char* freestart_when_empty_; // beginning of the free space when we're empty
+ // STL vector isn't as efficient as it could be, so we use an array at first
+ size_t blocks_alloced_; // how many of the first_blocks_ have been alloced
+ AllocatedBlock first_blocks_[16]; // the length of this array is arbitrary
+ // if the first_blocks_ aren't enough, expand into overflow_blocks_.
+ std::vector<AllocatedBlock>* overflow_blocks_;
+
+ void FreeBlocks(); // Frees all except first block
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Arena);
+};
+
+} // namespace core
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_ARENA_H_
diff --git a/tensorflow/core/lib/core/arena_test.cc b/tensorflow/core/lib/core/arena_test.cc
new file mode 100644
index 0000000000..fa147c3014
--- /dev/null
+++ b/tensorflow/core/lib/core/arena_test.cc
@@ -0,0 +1,92 @@
+#include "tensorflow/core/lib/core/arena.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace core {
+namespace {
+
+// Write random data to allocated memory
+static void TestMemory(void* mem, int size) {
+ // Check that we can memset the entire memory
+ memset(mem, 0xaa, size);
+
+ // Do some memory allocation to check that the arena doesn't mess up
+ // the internal memory allocator
+ char* tmp[100];
+ for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) {
+ tmp[i] = new char[i * i + 1];
+ }
+
+ memset(mem, 0xcc, size);
+
+ // Free up the allocated memory;
+ for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) {
+ delete[] tmp[i];
+ }
+
+ // Check that we can memset the entire memory
+ memset(mem, 0xee, size);
+}
+
+TEST(ArenaTest, TestBasicArena) {
+ Arena a(1024);
+ char* memory = a.Alloc(100);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 100);
+
+ // Allocate again
+ memory = a.Alloc(100);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 100);
+}
+
+TEST(ArenaTest, TestVariousArenaSizes) {
+ {
+ Arena a(1024);
+
+ // Allocate blocksize
+ char* memory = a.Alloc(1024);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 1024);
+
+ // Allocate another blocksize
+ char* memory2 = a.Alloc(1024);
+ ASSERT_NE(memory2, nullptr);
+ TestMemory(memory2, 1024);
+ }
+
+ // Allocate an arena and allocate two blocks
+ // that together exceed a block size
+ {
+ Arena a(1024);
+
+ //
+ char* memory = a.Alloc(768);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 768);
+
+ // Allocate another blocksize
+ char* memory2 = a.Alloc(768);
+ ASSERT_NE(memory2, nullptr);
+ TestMemory(memory2, 768);
+ }
+
+ // Allocate larger than a blocksize
+ {
+ Arena a(1024);
+
+ char* memory = a.Alloc(10240);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 10240);
+
+ // Allocate another blocksize
+ char* memory2 = a.Alloc(1234);
+ ASSERT_NE(memory2, nullptr);
+ TestMemory(memory2, 1234);
+ }
+}
+
+} // namespace
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/bit_cast_test.cc b/tensorflow/core/lib/core/bit_cast_test.cc
new file mode 100644
index 0000000000..0ea583e96f
--- /dev/null
+++ b/tensorflow/core/lib/core/bit_cast_test.cc
@@ -0,0 +1,95 @@
+// Unit test for bit_cast template.
+
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+// Marshall and unmarshall.
+// ISO spec C++ section 3.9 promises this will work.
+
+template <int N>
+struct marshall {
+ char buf[N];
+};
+
+template <class T>
+void TestMarshall(const T values[], int num_values) {
+ for (int i = 0; i < num_values; ++i) {
+ T t0 = values[i];
+ marshall<sizeof(T)> m0 = bit_cast<marshall<sizeof(T)> >(t0);
+ T t1 = bit_cast<T>(m0);
+ marshall<sizeof(T)> m1 = bit_cast<marshall<sizeof(T)> >(t1);
+ ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T)));
+ ASSERT_EQ(0, memcmp(&m0, &m1, sizeof(T)));
+ }
+}
+
+// Convert back and forth to an integral type. The C++ standard does
+// not guarantee this will work.
+//
+// There are implicit assumptions about sizeof(float) and
+// sizeof(double). These assumptions are quite extant everywhere.
+
+template <class T, class I>
+void TestIntegral(const T values[], int num_values) {
+ for (int i = 0; i < num_values; ++i) {
+ T t0 = values[i];
+ I i0 = bit_cast<I>(t0);
+ T t1 = bit_cast<T>(i0);
+ I i1 = bit_cast<I>(t1);
+ ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T)));
+ ASSERT_EQ(i0, i1);
+ }
+}
+
+TEST(BitCast, Bool) {
+ LOG(INFO) << "Test bool";
+ static const bool bool_list[] = {false, true};
+ TestMarshall<bool>(bool_list, TF_ARRAYSIZE(bool_list));
+}
+
+TEST(BitCast, Int32) {
+ static const int32 int_list[] = {0, 1, 100, 2147483647,
+ -1, -100, -2147483647, -2147483647 - 1};
+ TestMarshall<int32>(int_list, TF_ARRAYSIZE(int_list));
+}
+
+TEST(BitCast, Int64) {
+ static const int64 int64_list[] = {0, 1, 1LL << 40, -1, -(1LL << 40)};
+ TestMarshall<int64>(int64_list, TF_ARRAYSIZE(int64_list));
+}
+
+TEST(BitCast, Uint64) {
+ static const uint64 uint64_list[] = {0, 1, 1LLU << 40, 1LLU << 63};
+ TestMarshall<uint64>(uint64_list, TF_ARRAYSIZE(uint64_list));
+}
+
+TEST(BitCast, Float) {
+ static const float float_list[] = {0.0, 1.0, -1.0, 10.0, -10.0, 1e10,
+ 1e20, 1e-10, 1e-20, 2.71828, 3.14159};
+ TestMarshall<float>(float_list, TF_ARRAYSIZE(float_list));
+ TestIntegral<float, int32>(float_list, TF_ARRAYSIZE(float_list));
+ TestIntegral<float, uint32>(float_list, TF_ARRAYSIZE(float_list));
+}
+
+TEST(BitCast, Double) {
+ static const double double_list[] = {
+ 0.0,
+ 1.0,
+ -1.0,
+ 10.0,
+ -10.0,
+ 1e10,
+ 1e100,
+ 1e-10,
+ 1e-100,
+ 2.718281828459045,
+ 3.141592653589793238462643383279502884197169399375105820974944};
+ TestMarshall<double>(double_list, TF_ARRAYSIZE(double_list));
+ TestIntegral<double, int64>(double_list, TF_ARRAYSIZE(double_list));
+ TestIntegral<double, uint64>(double_list, TF_ARRAYSIZE(double_list));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/bits.h b/tensorflow/core/lib/core/bits.h
new file mode 100644
index 0000000000..5456a63168
--- /dev/null
+++ b/tensorflow/core/lib/core/bits.h
@@ -0,0 +1,84 @@
+#ifndef TENSORFLOW_LIB_CORE_BITS_H_
+#define TENSORFLOW_LIB_CORE_BITS_H_
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+int Log2Floor(uint32 n);
+int Log2Floor64(uint64 n);
+
+// Return ceiling(log2(n)) for positive integer n. Returns -1 iff n == 0.
+int Log2Ceiling(uint32 n);
+int Log2Ceiling64(uint64 n);
+
+// ------------------------------------------------------------------------
+// Implementation details follow
+// ------------------------------------------------------------------------
+
+#if defined(__GNUC__)
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+inline int Log2Floor(uint32 n) {
+ return n == 0 ? -1 : 31 ^ __builtin_clz(n);
+}
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+inline int Log2Floor64(uint64 n) {
+ return n == 0 ? -1 : 63 ^ __builtin_clzll(n);
+}
+
+#else
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+inline int Log2Floor(uint32 n) {
+ if (n == 0)
+ return -1;
+ int log = 0;
+ uint32 value = n;
+ for (int i = 4; i >= 0; --i) {
+ int shift = (1 << i);
+ uint32 x = value >> shift;
+ if (x != 0) {
+ value = x;
+ log += shift;
+ }
+ }
+ assert(value == 1);
+ return log;
+}
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+// Log2Floor64() is defined in terms of Log2Floor32()
+inline int Log2Floor64(uint64 n) {
+ const uint32 topbits = static_cast<uint32>(n >> 32);
+ if (topbits == 0) {
+ // Top bits are zero, so scan in bottom bits
+ return Log2Floor(static_cast<uint32>(n));
+ } else {
+ return 32 + Log2Floor(topbits);
+ }
+}
+
+#endif
+
+inline int Log2Ceiling(uint32 n) {
+ int floor = Log2Floor(n);
+ if (n == (n & ~(n - 1))) // zero or a power of two
+ return floor;
+ else
+ return floor + 1;
+}
+
+inline int Log2Ceiling64(uint64 n) {
+ int floor = Log2Floor64(n);
+ if (n == (n & ~(n - 1))) // zero or a power of two
+ return floor;
+ else
+ return floor + 1;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_BITS_H_
diff --git a/tensorflow/core/lib/core/blocking_counter.h b/tensorflow/core/lib/core/blocking_counter.h
new file mode 100644
index 0000000000..f141be2c76
--- /dev/null
+++ b/tensorflow/core/lib/core/blocking_counter.h
@@ -0,0 +1,41 @@
+#ifndef TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_
+#define TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class BlockingCounter {
+ public:
+ BlockingCounter(int initial_count) : count_(initial_count) {
+ CHECK_GE(count_, 0);
+ }
+
+ ~BlockingCounter() {}
+
+ inline void DecrementCount() {
+ mutex_lock l(mu_);
+ --count_;
+ CHECK(count_ >= 0);
+ if (count_ == 0) {
+ cond_var_.notify_all();
+ }
+ }
+
+ inline void Wait() {
+ mutex_lock l(mu_);
+ while (count_ > 0) {
+ cond_var_.wait(l);
+ }
+ }
+
+ private:
+ int count_;
+ mutex mu_;
+ condition_variable cond_var_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_
diff --git a/tensorflow/core/lib/core/blocking_counter_test.cc b/tensorflow/core/lib/core/blocking_counter_test.cc
new file mode 100644
index 0000000000..feb0342086
--- /dev/null
+++ b/tensorflow/core/lib/core/blocking_counter_test.cc
@@ -0,0 +1,36 @@
+#include <gtest/gtest.h>
+
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(BlockingCounterTest, TestZero) {
+ BlockingCounter bc(0);
+ bc.Wait();
+}
+
+TEST(BlockingCounterTest, TestSingleThread) {
+ BlockingCounter bc(2);
+ bc.DecrementCount();
+ bc.DecrementCount();
+ bc.Wait();
+}
+
+TEST(BlockingCounterTest, TestMultipleThread) {
+ int N = 3;
+ thread::ThreadPool* thread_pool =
+ new thread::ThreadPool(Env::Default(), "test", N);
+
+ BlockingCounter bc(N);
+ for (int i = 0; i < N; ++i) {
+ thread_pool->Schedule([&bc] { bc.DecrementCount(); });
+ }
+
+ bc.Wait();
+ delete thread_pool;
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/casts.h b/tensorflow/core/lib/core/casts.h
new file mode 100644
index 0000000000..5b72048ac5
--- /dev/null
+++ b/tensorflow/core/lib/core/casts.h
@@ -0,0 +1,85 @@
+// Various Google-specific casting templates.
+//
+// This code is compiled directly on many platforms, including client
+// platforms like Windows, Mac, and embedded systems. Before making
+// any changes here, make sure that you're not breaking any platforms.
+//
+
+#ifndef TENSORFLOW_LIB_CORE_CASTS_H_
+#define TENSORFLOW_LIB_CORE_CASTS_H_
+
+#include <string.h> // for memcpy
+
+namespace tensorflow {
+
+// bit_cast<Dest,Source> is a template function that implements the
+// equivalent of "*reinterpret_cast<Dest*>(&source)". We need this in
+// very low-level functions like the protobuf library and fast math
+// support.
+//
+// float f = 3.14159265358979;
+// int i = bit_cast<int32>(f);
+// // i = 0x40490fdb
+//
+// The classical address-casting method is:
+//
+// // WRONG
+// float f = 3.14159265358979; // WRONG
+// int i = * reinterpret_cast<int*>(&f); // WRONG
+//
+// The address-casting method actually produces undefined behavior
+// according to ISO C++ specification section 3.10 -15 -. Roughly, this
+// section says: if an object in memory has one type, and a program
+// accesses it with a different type, then the result is undefined
+// behavior for most values of "different type".
+//
+// This is true for any cast syntax, either *(int*)&f or
+// *reinterpret_cast<int*>(&f). And it is particularly true for
+// conversions between integral lvalues and floating-point lvalues.
+//
+// The purpose of 3.10 -15- is to allow optimizing compilers to assume
+// that expressions with different types refer to different memory. gcc
+// 4.0.1 has an optimizer that takes advantage of this. So a
+// non-conforming program quietly produces wildly incorrect output.
+//
+// The problem is not the use of reinterpret_cast. The problem is type
+// punning: holding an object in memory of one type and reading its bits
+// back using a different type.
+//
+// The C++ standard is more subtle and complex than this, but that
+// is the basic idea.
+//
+// Anyways ...
+//
+// bit_cast<> calls memcpy() which is blessed by the standard,
+// especially by the example in section 3.9 . Also, of course,
+// bit_cast<> wraps up the nasty logic in one place.
+//
+// Fortunately memcpy() is very fast. In optimized mode, with a
+// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
+// code with the minimal amount of data movement. On a 32-bit system,
+// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
+// compiles to two loads and two stores.
+//
+// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
+//
+// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
+// is likely to surprise you.
+//
+// Props to Bill Gibbons for the compile time assertion technique and
+// Art Komninos and Igor Tandetnik for the msvc experiments.
+//
+// -- mec 2005-10-17
+
+template <class Dest, class Source>
+inline Dest bit_cast(const Source& source) {
+ static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match");
+
+ Dest dest;
+ memcpy(&dest, &source, sizeof(dest));
+ return dest;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_CASTS_H_
diff --git a/tensorflow/core/lib/core/coding.cc b/tensorflow/core/lib/core/coding.cc
new file mode 100644
index 0000000000..efff554742
--- /dev/null
+++ b/tensorflow/core/lib/core/coding.cc
@@ -0,0 +1,164 @@
+#include "tensorflow/core/lib/core/coding.h"
+
+namespace tensorflow {
+namespace core {
+
+void EncodeFixed32(char* buf, uint32 value) {
+ if (port::kLittleEndian) {
+ memcpy(buf, &value, sizeof(value));
+ } else {
+ buf[0] = value & 0xff;
+ buf[1] = (value >> 8) & 0xff;
+ buf[2] = (value >> 16) & 0xff;
+ buf[3] = (value >> 24) & 0xff;
+ }
+}
+
+void EncodeFixed64(char* buf, uint64 value) {
+ if (port::kLittleEndian) {
+ memcpy(buf, &value, sizeof(value));
+ } else {
+ buf[0] = value & 0xff;
+ buf[1] = (value >> 8) & 0xff;
+ buf[2] = (value >> 16) & 0xff;
+ buf[3] = (value >> 24) & 0xff;
+ buf[4] = (value >> 32) & 0xff;
+ buf[5] = (value >> 40) & 0xff;
+ buf[6] = (value >> 48) & 0xff;
+ buf[7] = (value >> 56) & 0xff;
+ }
+}
+
+void PutFixed32(string* dst, uint32 value) {
+ char buf[sizeof(value)];
+ EncodeFixed32(buf, value);
+ dst->append(buf, sizeof(buf));
+}
+
+void PutFixed64(string* dst, uint64 value) {
+ char buf[sizeof(value)];
+ EncodeFixed64(buf, value);
+ dst->append(buf, sizeof(buf));
+}
+
+char* EncodeVarint32(char* dst, uint32 v) {
+ // Operate on characters as unsigneds
+ unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
+ static const int B = 128;
+ if (v < (1 << 7)) {
+ *(ptr++) = v;
+ } else if (v < (1 << 14)) {
+ *(ptr++) = v | B;
+ *(ptr++) = v >> 7;
+ } else if (v < (1 << 21)) {
+ *(ptr++) = v | B;
+ *(ptr++) = (v >> 7) | B;
+ *(ptr++) = v >> 14;
+ } else if (v < (1 << 28)) {
+ *(ptr++) = v | B;
+ *(ptr++) = (v >> 7) | B;
+ *(ptr++) = (v >> 14) | B;
+ *(ptr++) = v >> 21;
+ } else {
+ *(ptr++) = v | B;
+ *(ptr++) = (v >> 7) | B;
+ *(ptr++) = (v >> 14) | B;
+ *(ptr++) = (v >> 21) | B;
+ *(ptr++) = v >> 28;
+ }
+ return reinterpret_cast<char*>(ptr);
+}
+
+void PutVarint32(string* dst, uint32 v) {
+ char buf[5];
+ char* ptr = EncodeVarint32(buf, v);
+ dst->append(buf, ptr - buf);
+}
+
+char* EncodeVarint64(char* dst, uint64 v) {
+ static const int B = 128;
+ unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
+ while (v >= B) {
+ *(ptr++) = (v & (B - 1)) | B;
+ v >>= 7;
+ }
+ *(ptr++) = static_cast<unsigned char>(v);
+ return reinterpret_cast<char*>(ptr);
+}
+
+void PutVarint64(string* dst, uint64 v) {
+ char buf[10];
+ char* ptr = EncodeVarint64(buf, v);
+ dst->append(buf, ptr - buf);
+}
+
+int VarintLength(uint64_t v) {
+ int len = 1;
+ while (v >= 128) {
+ v >>= 7;
+ len++;
+ }
+ return len;
+}
+
+const char* GetVarint32PtrFallback(const char* p, const char* limit,
+ uint32* value) {
+ uint32 result = 0;
+ for (uint32 shift = 0; shift <= 28 && p < limit; shift += 7) {
+ uint32 byte = *(reinterpret_cast<const unsigned char*>(p));
+ p++;
+ if (byte & 128) {
+ // More bytes are present
+ result |= ((byte & 127) << shift);
+ } else {
+ result |= (byte << shift);
+ *value = result;
+ return reinterpret_cast<const char*>(p);
+ }
+ }
+ return NULL;
+}
+
+bool GetVarint32(StringPiece* input, uint32* value) {
+ const char* p = input->data();
+ const char* limit = p + input->size();
+ const char* q = GetVarint32Ptr(p, limit, value);
+ if (q == NULL) {
+ return false;
+ } else {
+ *input = StringPiece(q, limit - q);
+ return true;
+ }
+}
+
+const char* GetVarint64Ptr(const char* p, const char* limit, uint64* value) {
+ uint64 result = 0;
+ for (uint32 shift = 0; shift <= 63 && p < limit; shift += 7) {
+ uint64 byte = *(reinterpret_cast<const unsigned char*>(p));
+ p++;
+ if (byte & 128) {
+ // More bytes are present
+ result |= ((byte & 127) << shift);
+ } else {
+ result |= (byte << shift);
+ *value = result;
+ return reinterpret_cast<const char*>(p);
+ }
+ }
+ return NULL;
+}
+
+bool GetVarint64(StringPiece* input, uint64* value) {
+ const char* p = input->data();
+ const char* limit = p + input->size();
+ const char* q = GetVarint64Ptr(p, limit, value);
+ if (q == NULL) {
+ return false;
+ } else {
+ *input = StringPiece(q, limit - q);
+ return true;
+ }
+}
+
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/coding.h b/tensorflow/core/lib/core/coding.h
new file mode 100644
index 0000000000..0c14bf1bbf
--- /dev/null
+++ b/tensorflow/core/lib/core/coding.h
@@ -0,0 +1,55 @@
+// Endian-neutral encoding:
+// * Fixed-length numbers are encoded with least-significant byte first
+// * In addition we support variable length "varint" encoding
+// * Strings are encoded prefixed by their length in varint format
+
+#ifndef TENSORFLOW_LIB_CORE_CODING_H_
+#define TENSORFLOW_LIB_CORE_CODING_H_
+
+#include "tensorflow/core/lib/core/raw_coding.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace core {
+
+// Lower-level versions of Put... that write directly into a character buffer
+// REQUIRES: dst has enough space for the value being written
+extern void EncodeFixed32(char* dst, uint32 value);
+extern void EncodeFixed64(char* dst, uint64 value);
+extern void PutFixed32(string* dst, uint32 value);
+extern void PutFixed64(string* dst, uint64 value);
+
+extern void PutVarint32(string* dst, uint32 value);
+extern void PutVarint64(string* dst, uint64 value);
+
+extern bool GetVarint32(StringPiece* input, uint32* value);
+extern bool GetVarint64(StringPiece* input, uint64* value);
+
+extern const char* GetVarint32Ptr(const char* p, const char* limit, uint32* v);
+extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v);
+
+// Internal routine for use by fallback path of GetVarint32Ptr
+extern const char* GetVarint32PtrFallback(const char* p, const char* limit,
+ uint32* value);
+inline const char* GetVarint32Ptr(const char* p, const char* limit,
+ uint32* value) {
+ if (p < limit) {
+ uint32 result = *(reinterpret_cast<const unsigned char*>(p));
+ if ((result & 128) == 0) {
+ *value = result;
+ return p + 1;
+ }
+ }
+ return GetVarint32PtrFallback(p, limit, value);
+}
+
+extern char* EncodeVarint64(char* dst, uint64 v);
+
+// Returns the length of the varint32 or varint64 encoding of "v"
+extern int VarintLength(uint64_t v);
+
+} // namespace core
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_CODING_H_
diff --git a/tensorflow/core/lib/core/coding_test.cc b/tensorflow/core/lib/core/coding_test.cc
new file mode 100644
index 0000000000..5e9e2c5e96
--- /dev/null
+++ b/tensorflow/core/lib/core/coding_test.cc
@@ -0,0 +1,168 @@
+#include "tensorflow/core/lib/core/coding.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace core {
+
+TEST(Coding, Fixed32) {
+ static const int N = 100000;
+
+ string s;
+ for (uint32 v = 0; v < N; v++) {
+ char buf[sizeof(uint32)];
+ EncodeFixed32(buf, v);
+ s.append(buf, sizeof(buf));
+ }
+
+ const char* p = s.data();
+ for (uint32 v = 0; v < N; v++) {
+ uint32 actual = DecodeFixed32(p);
+ ASSERT_EQ(v, actual);
+ p += sizeof(uint32);
+ }
+}
+
+TEST(Coding, Fixed64) {
+ string s;
+ for (int power = 0; power <= 63; power++) {
+ uint64 v = static_cast<uint64>(1) << power;
+ char buf[sizeof(uint64)];
+ EncodeFixed64(buf, v - 1);
+ s.append(buf, sizeof(buf));
+ EncodeFixed64(buf, v + 0);
+ s.append(buf, sizeof(buf));
+ EncodeFixed64(buf, v + 1);
+ s.append(buf, sizeof(buf));
+ }
+
+ const char* p = s.data();
+ for (int power = 0; power <= 63; power++) {
+ uint64 v = static_cast<uint64>(1) << power;
+ uint64 actual;
+ actual = DecodeFixed64(p);
+ ASSERT_EQ(v - 1, actual);
+ p += sizeof(uint64);
+
+ actual = DecodeFixed64(p);
+ ASSERT_EQ(v + 0, actual);
+ p += sizeof(uint64);
+
+ actual = DecodeFixed64(p);
+ ASSERT_EQ(v + 1, actual);
+ p += sizeof(uint64);
+ }
+}
+
+// Test that encoding routines generate little-endian encodings
+TEST(Coding, EncodingOutput) {
+ char dst[8];
+ EncodeFixed32(dst, 0x04030201);
+ ASSERT_EQ(0x01, static_cast<int>(dst[0]));
+ ASSERT_EQ(0x02, static_cast<int>(dst[1]));
+ ASSERT_EQ(0x03, static_cast<int>(dst[2]));
+ ASSERT_EQ(0x04, static_cast<int>(dst[3]));
+
+ EncodeFixed64(dst, 0x0807060504030201ull);
+ ASSERT_EQ(0x01, static_cast<int>(dst[0]));
+ ASSERT_EQ(0x02, static_cast<int>(dst[1]));
+ ASSERT_EQ(0x03, static_cast<int>(dst[2]));
+ ASSERT_EQ(0x04, static_cast<int>(dst[3]));
+ ASSERT_EQ(0x05, static_cast<int>(dst[4]));
+ ASSERT_EQ(0x06, static_cast<int>(dst[5]));
+ ASSERT_EQ(0x07, static_cast<int>(dst[6]));
+ ASSERT_EQ(0x08, static_cast<int>(dst[7]));
+}
+
+TEST(Coding, Varint32) {
+ string s;
+ for (uint32 i = 0; i < (32 * 32); i++) {
+ uint32 v = (i / 32) << (i % 32);
+ PutVarint32(&s, v);
+ }
+
+ const char* p = s.data();
+ const char* limit = p + s.size();
+ for (uint32 i = 0; i < (32 * 32); i++) {
+ uint32 expected = (i / 32) << (i % 32);
+ uint32 actual;
+ p = GetVarint32Ptr(p, limit, &actual);
+ ASSERT_TRUE(p != NULL);
+ ASSERT_EQ(expected, actual);
+ }
+ ASSERT_EQ(p, s.data() + s.size());
+}
+
+TEST(Coding, Varint64) {
+ // Construct the list of values to check
+ std::vector<uint64> values;
+ // Some special values
+ values.push_back(0);
+ values.push_back(100);
+ values.push_back(~static_cast<uint64>(0));
+ values.push_back(~static_cast<uint64>(0) - 1);
+ for (uint32 k = 0; k < 64; k++) {
+ // Test values near powers of two
+ const uint64 power = 1ull << k;
+ values.push_back(power);
+ values.push_back(power - 1);
+ values.push_back(power + 1);
+ }
+
+ string s;
+ for (size_t i = 0; i < values.size(); i++) {
+ PutVarint64(&s, values[i]);
+ }
+
+ const char* p = s.data();
+ const char* limit = p + s.size();
+ for (size_t i = 0; i < values.size(); i++) {
+ ASSERT_TRUE(p < limit);
+ uint64 actual;
+ p = GetVarint64Ptr(p, limit, &actual);
+ ASSERT_TRUE(p != NULL);
+ ASSERT_EQ(values[i], actual);
+ }
+ ASSERT_EQ(p, limit);
+}
+
+TEST(Coding, Varint32Overflow) {
+ uint32 result;
+ string input("\x81\x82\x83\x84\x85\x11");
+ ASSERT_TRUE(GetVarint32Ptr(input.data(), input.data() + input.size(),
+ &result) == NULL);
+}
+
+TEST(Coding, Varint32Truncation) {
+ uint32 large_value = (1u << 31) + 100;
+ string s;
+ PutVarint32(&s, large_value);
+ uint32 result;
+ for (size_t len = 0; len < s.size() - 1; len++) {
+ ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + len, &result) == NULL);
+ }
+ ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + s.size(), &result) != NULL);
+ ASSERT_EQ(large_value, result);
+}
+
+TEST(Coding, Varint64Overflow) {
+ uint64 result;
+ string input("\x81\x82\x83\x84\x85\x81\x82\x83\x84\x85\x11");
+ ASSERT_TRUE(GetVarint64Ptr(input.data(), input.data() + input.size(),
+ &result) == NULL);
+}
+
+TEST(Coding, Varint64Truncation) {
+ uint64 large_value = (1ull << 63) + 100ull;
+ string s;
+ PutVarint64(&s, large_value);
+ uint64 result;
+ for (size_t len = 0; len < s.size() - 1; len++) {
+ ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + len, &result) == NULL);
+ }
+ ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + s.size(), &result) != NULL);
+ ASSERT_EQ(large_value, result);
+}
+
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/command_line_flags.cc b/tensorflow/core/lib/core/command_line_flags.cc
new file mode 100644
index 0000000000..0f1072ffaa
--- /dev/null
+++ b/tensorflow/core/lib/core/command_line_flags.cc
@@ -0,0 +1,94 @@
+#include "tensorflow/core/lib/core/command_line_flags.h"
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace tensorflow {
+namespace {
+
+// Templated function to convert a string to target values.
+// Return true if the conversion is successful. Otherwise, return false.
+template <typename T>
+bool StringToValue(const string& content, T* value);
+
+template <>
+bool StringToValue<int32>(const string& content, int* value) {
+ return str_util::NumericParse32(content, value);
+}
+
+// Parse a single argument by linearly searching through the command table.
+// The input format is: --argument=value.
+// Return OK if the argument is used. It store the extracted value into the
+// matching flag.
+// Return NOT_FOUND if the argument is not recognized.
+// Retrun INVALID_ARGUMENT if the command is recognized, but fails to extract
+// its value.
+template <typename T>
+Status ParseArgument(const string& argument) {
+ for (auto& command :
+ internal::CommandLineFlagRegistry<int>::Instance()->commands) {
+ string prefix = strings::StrCat("--", command.name, "=");
+ if (tensorflow::StringPiece(argument).starts_with(prefix)) {
+ string content = argument.substr(prefix.length());
+ if (StringToValue<T>(content, command.value)) {
+ return Status::OK();
+ }
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("Cannot parse integer in: ", argument));
+ }
+ }
+ return Status(error::NOT_FOUND,
+ strings::StrCat("Unknown command: ", argument));
+}
+
+// A specialization for booleans. The input format is:
+// "--argument" or "--noargument".
+// Parse a single argument by linearly searching through the command table.
+// Return OK if the argument is used. The value is stored in the matching flag.
+// Return NOT_FOUND if the argument is not recognized.
+template <>
+Status ParseArgument<bool>(const string& argument) {
+ for (auto& command :
+ internal::CommandLineFlagRegistry<bool>::Instance()->commands) {
+ if (argument == strings::StrCat("--", command.name)) {
+ *command.value = true;
+ return Status::OK();
+ } else if (argument == strings::StrCat("--no", command.name)) {
+ *command.value = false;
+ return Status::OK();
+ }
+ }
+ return Status(error::NOT_FOUND,
+ strings::StrCat("Unknown command: ", argument));
+}
+} // namespace
+
+Status ParseCommandLineFlags(int* argc, char* argv[]) {
+ int unused_argc = 1;
+ for (int index = 1; index < *argc; ++index) {
+ Status s;
+ // Search bool commands.
+ s = ParseArgument<bool>(argv[index]);
+ if (s.ok()) {
+ continue;
+ }
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ }
+ // Search int32 commands.
+ s = ParseArgument<int32>(argv[index]);
+ if (s.ok()) {
+ continue;
+ }
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ }
+ // Pointer swap the unused argument to the front.
+ std::swap(argv[unused_argc++], argv[index]);
+ }
+ *argc = unused_argc;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/command_line_flags.h b/tensorflow/core/lib/core/command_line_flags.h
new file mode 100644
index 0000000000..f1a94c11f9
--- /dev/null
+++ b/tensorflow/core/lib/core/command_line_flags.h
@@ -0,0 +1,60 @@
+#ifndef TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
+#define TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace internal {
+
+template <typename T>
+struct CommandLineFlagRegistry {
+ static CommandLineFlagRegistry* Instance() {
+ static CommandLineFlagRegistry instance_;
+ return &instance_;
+ }
+ struct Command {
+ string name;
+ T* value;
+ string text;
+ };
+ std::vector<Command> commands;
+
+ private:
+ CommandLineFlagRegistry() {}
+ TF_DISALLOW_COPY_AND_ASSIGN(CommandLineFlagRegistry);
+};
+
+template <typename T>
+struct CommandLineFlagRegister {
+ CommandLineFlagRegister(const string& name, T* val, const string& text) {
+ CommandLineFlagRegistry<T>::Instance()->commands.push_back(
+ {name, val, text});
+ }
+};
+
+#define TF_DEFINE_variable(type, name, default_value, text) \
+ type FLAGS_##name = default_value; \
+ namespace TF_flags_internal { \
+ tensorflow::internal::CommandLineFlagRegister<type> \
+ TF_flags_internal_var_##name(#name, &FLAGS_##name, text); \
+ } // namespace TF_flags_internal
+
+} // namespace internal
+
+#define TF_DEFINE_int32(name, default_value, text) \
+ TF_DEFINE_variable(int32, name, default_value, text);
+
+#define TF_DEFINE_bool(name, default_value, text) \
+ TF_DEFINE_variable(bool, name, default_value, text);
+
+// Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv.
+// Returned the number of unused arguments in *argc.
+// Return error Status if the parsing encounters errors.
+// TODO(opensource): switch to a command line argument parser that can be
+// shared with other tests.
+Status ParseCommandLineFlags(int* argc, char* argv[]);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
diff --git a/tensorflow/core/lib/core/error_codes.proto b/tensorflow/core/lib/core/error_codes.proto
new file mode 100644
index 0000000000..6735fd8f88
--- /dev/null
+++ b/tensorflow/core/lib/core/error_codes.proto
@@ -0,0 +1,145 @@
+syntax = "proto3";
+
+package tensorflow.error;
+// option cc_enable_arenas = true;
+
+// The canonical error codes for TensorFlow APIs.
+//
+// Warnings:
+//
+// - Do not change any numeric assignments.
+// - Changes to this list should only be made if there is a compelling
+// need that can't be satisfied in another way. Such changes
+// must be approved by at least two OWNERS.
+//
+// Sometimes multiple error codes may apply. Services should return
+// the most specific error code that applies. For example, prefer
+// OUT_OF_RANGE over FAILED_PRECONDITION if both codes apply.
+// Similarly prefer NOT_FOUND or ALREADY_EXISTS over FAILED_PRECONDITION.
+enum Code {
+ // Not an error; returned on success
+ OK = 0;
+
+ // The operation was cancelled (typically by the caller).
+ CANCELLED = 1;
+
+ // Unknown error. An example of where this error may be returned is
+ // if a Status value received from another address space belongs to
+ // an error-space that is not known in this address space. Also
+ // errors raised by APIs that do not return enough error information
+ // may be converted to this error.
+ UNKNOWN = 2;
+
+ // Client specified an invalid argument. Note that this differs
+ // from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments
+ // that are problematic regardless of the state of the system
+ // (e.g., a malformed file name).
+ INVALID_ARGUMENT = 3;
+
+ // Deadline expired before operation could complete. For operations
+ // that change the state of the system, this error may be returned
+ // even if the operation has completed successfully. For example, a
+ // successful response from a server could have been delayed long
+ // enough for the deadline to expire.
+ DEADLINE_EXCEEDED = 4;
+
+ // Some requested entity (e.g., file or directory) was not found.
+ // For privacy reasons, this code *may* be returned when the client
+ // does not have the access right to the entity.
+ NOT_FOUND = 5;
+
+ // Some entity that we attempted to create (e.g., file or directory)
+ // already exists.
+ ALREADY_EXISTS = 6;
+
+ // The caller does not have permission to execute the specified
+ // operation. PERMISSION_DENIED must not be used for rejections
+ // caused by exhausting some resource (use RESOURCE_EXHAUSTED
+ // instead for those errors). PERMISSION_DENIED must not be
+ // used if the caller can not be identified (use UNAUTHENTICATED
+ // instead for those errors).
+ PERMISSION_DENIED = 7;
+
+ // The request does not have valid authentication credentials for the
+ // operation.
+ UNAUTHENTICATED = 16;
+
+ // Some resource has been exhausted, perhaps a per-user quota, or
+ // perhaps the entire file system is out of space.
+ RESOURCE_EXHAUSTED = 8;
+
+ // Operation was rejected because the system is not in a state
+ // required for the operation's execution. For example, directory
+ // to be deleted may be non-empty, an rmdir operation is applied to
+ // a non-directory, etc.
+ //
+ // A litmus test that may help a service implementor in deciding
+ // between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE:
+ // (a) Use UNAVAILABLE if the client can retry just the failing call.
+ // (b) Use ABORTED if the client should retry at a higher-level
+ // (e.g., restarting a read-modify-write sequence).
+ // (c) Use FAILED_PRECONDITION if the client should not retry until
+ // the system state has been explicitly fixed. E.g., if an "rmdir"
+ // fails because the directory is non-empty, FAILED_PRECONDITION
+ // should be returned since the client should not retry unless
+ // they have first fixed up the directory by deleting files from it.
+ // (d) Use FAILED_PRECONDITION if the client performs conditional
+ // REST Get/Update/Delete on a resource and the resource on the
+ // server does not match the condition. E.g., conflicting
+ // read-modify-write on the same resource.
+ FAILED_PRECONDITION = 9;
+
+ // The operation was aborted, typically due to a concurrency issue
+ // like sequencer check failures, transaction aborts, etc.
+ //
+ // See litmus test above for deciding between FAILED_PRECONDITION,
+ // ABORTED, and UNAVAILABLE.
+ ABORTED = 10;
+
+ // Operation was attempted past the valid range. E.g., seeking or
+ // reading past end of file.
+ //
+ // Unlike INVALID_ARGUMENT, this error indicates a problem that may
+ // be fixed if the system state changes. For example, a 32-bit file
+ // system will generate INVALID_ARGUMENT if asked to read at an
+ // offset that is not in the range [0,2^32-1], but it will generate
+ // OUT_OF_RANGE if asked to read from an offset past the current
+ // file size.
+ //
+ // There is a fair bit of overlap between FAILED_PRECONDITION and
+ // OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the more specific
+ // error) when it applies so that callers who are iterating through
+ // a space can easily look for an OUT_OF_RANGE error to detect when
+ // they are done.
+ OUT_OF_RANGE = 11;
+
+ // Operation is not implemented or not supported/enabled in this service.
+ UNIMPLEMENTED = 12;
+
+ // Internal errors. Means some invariants expected by underlying
+ // system has been broken. If you see one of these errors,
+ // something is very broken.
+ INTERNAL = 13;
+
+ // The service is currently unavailable. This is a most likely a
+ // transient condition and may be corrected by retrying with
+ // a backoff.
+ //
+ // See litmus test above for deciding between FAILED_PRECONDITION,
+ // ABORTED, and UNAVAILABLE.
+ UNAVAILABLE = 14;
+
+ // Unrecoverable data loss or corruption.
+ DATA_LOSS = 15;
+
+ // An extra enum entry to prevent people from writing code that
+ // fails to compile when a new code is added.
+ //
+ // Nobody should ever reference this enumeration entry. In particular,
+ // if you write C++ code that switches on this enumeration, add a default:
+ // case instead of a case that mentions this enumeration entry.
+ //
+ // Nobody should rely on the value (currently 20) listed here. It
+ // may change in the future.
+ DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ = 20;
+}
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
new file mode 100644
index 0000000000..b0badd8c4d
--- /dev/null
+++ b/tensorflow/core/lib/core/errors.h
@@ -0,0 +1,131 @@
+#ifndef TENSORFLOW_LIB_CORE_ERRORS_H_
+#define TENSORFLOW_LIB_CORE_ERRORS_H_
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace errors {
+
+typedef ::tensorflow::error::Code Code;
+
+// Append some context to an error message. Each time we append
+// context put it on a new line, since it is possible for there
+// to be several layers of additional context.
+template <typename... Args>
+void AppendToMessage(::tensorflow::Status* status, Args... args) {
+ *status = ::tensorflow::Status(
+ status->code(),
+ strings::StrCat(status->error_message(), "\n\t", args...));
+}
+
+// For propagating errors when calling a function.
+#define TF_RETURN_IF_ERROR(expr) \
+ do { \
+ const ::tensorflow::Status _status = (expr); \
+ if (TF_PREDICT_FALSE(!_status.ok())) return _status; \
+ } while (0)
+
+#define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \
+ do { \
+ ::tensorflow::Status _status = (expr); \
+ if (TF_PREDICT_FALSE(!_status.ok())) { \
+ ::tensorflow::errors::AppendToMessage(&_status, __VA_ARGS__); \
+ return _status; \
+ } \
+ } while (0)
+
+// Convenience functions for generating and using error status.
+// Example usage:
+// status.Update(errors::InvalidArgument("The ", foo, " isn't right."));
+// if (errors::IsInvalidArgument(status)) { ... }
+// switch (status.code()) { case error::INVALID_ARGUMENT: ... }
+
+#define DECLARE_ERROR(FUNC, CONST) \
+ template <typename... Args> \
+ inline ::tensorflow::Status FUNC(Args... args) { \
+ return ::tensorflow::Status(::tensorflow::error::CONST, \
+ strings::StrCat(args...)); \
+ } \
+ inline bool Is##FUNC(const ::tensorflow::Status& status) { \
+ return status.code() == ::tensorflow::error::CONST; \
+ }
+
+DECLARE_ERROR(Cancelled, CANCELLED)
+DECLARE_ERROR(InvalidArgument, INVALID_ARGUMENT)
+DECLARE_ERROR(NotFound, NOT_FOUND)
+DECLARE_ERROR(AlreadyExists, ALREADY_EXISTS)
+DECLARE_ERROR(ResourceExhausted, RESOURCE_EXHAUSTED)
+DECLARE_ERROR(Unavailable, UNAVAILABLE)
+DECLARE_ERROR(FailedPrecondition, FAILED_PRECONDITION)
+DECLARE_ERROR(OutOfRange, OUT_OF_RANGE)
+DECLARE_ERROR(Unimplemented, UNIMPLEMENTED)
+DECLARE_ERROR(Internal, INTERNAL)
+DECLARE_ERROR(Aborted, ABORTED)
+DECLARE_ERROR(DeadlineExceeded, DEADLINE_EXCEEDED)
+DECLARE_ERROR(DataLoss, DATA_LOSS)
+DECLARE_ERROR(Unknown, UNKNOWN)
+DECLARE_ERROR(PermissionDenied, PERMISSION_DENIED)
+DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED)
+
+#undef DECLARE_ERROR
+
+// The CanonicalCode() for non-errors.
+using ::tensorflow::error::OK;
+
+// Convenience macros for asserting and handling exceptional conditions.
+// Analogous to the CHECK* macros provided by logging.h.
+//
+// Example use:
+// void Compute(OperationContext* context) {
+// OP_REQUIRES(context, context->num_inputs() == 2,
+// errors::InvalidArgument("FooOp requires 2 arguments"));
+// ...
+// Status status = SomeUncertainMethod();
+// OP_REQUIRES_OK(context, status);
+// ...
+// }
+
+#define OP_REQUIRES(CTX, EXP, STATUS) \
+ if (!(EXP)) { \
+ ::tensorflow::Status _s(STATUS); \
+ VLOG(1) << _s; \
+ (CTX)->SetStatus(_s); \
+ return; \
+ }
+
+#define OP_REQUIRES_OK(CTX, STATUS) \
+ do { \
+ ::tensorflow::Status _s(STATUS); \
+ if (!_s.ok()) { \
+ LOG(WARNING) << _s; \
+ (CTX)->SetStatus(_s); \
+ return; \
+ } \
+ } while (0)
+
+#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
+ if (!(EXP)) { \
+ ::tensorflow::Status _s(STATUS); \
+ VLOG(1) << _s; \
+ (CTX)->SetStatus(_s); \
+ (CALLBACK)(); \
+ return; \
+ }
+
+#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
+ do { \
+ ::tensorflow::Status _s(STATUS); \
+ if (!_s.ok()) { \
+ LOG(WARNING) << _s; \
+ (CTX)->SetStatus(_s); \
+ (CALLBACK)(); \
+ return; \
+ } \
+ } while (0)
+
+} // namespace errors
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_ERRORS_H_
diff --git a/tensorflow/core/lib/core/notification.h b/tensorflow/core/lib/core/notification.h
new file mode 100644
index 0000000000..071e24285a
--- /dev/null
+++ b/tensorflow/core/lib/core/notification.h
@@ -0,0 +1,42 @@
+#ifndef TENSORFLOW_UTIL_NOTIFICATION_H_
+#define TENSORFLOW_UTIL_NOTIFICATION_H_
+
+#include <assert.h>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class Notification {
+ public:
+ Notification() : notified_(false) {}
+ ~Notification() {}
+
+ void Notify() {
+ mutex_lock l(mu_);
+ assert(!notified_);
+ notified_ = true;
+ cv_.notify_all();
+ }
+
+ bool HasBeenNotified() {
+ mutex_lock l(mu_);
+ return notified_;
+ }
+
+ void WaitForNotification() {
+ mutex_lock l(mu_);
+ while (!notified_) {
+ cv_.wait(l);
+ }
+ }
+
+ private:
+ mutex mu_;
+ condition_variable cv_;
+ bool notified_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_NOTIFICATION_H_
diff --git a/tensorflow/core/lib/core/notification_test.cc b/tensorflow/core/lib/core/notification_test.cc
new file mode 100644
index 0000000000..a9e8942f05
--- /dev/null
+++ b/tensorflow/core/lib/core/notification_test.cc
@@ -0,0 +1,64 @@
+#include <gtest/gtest.h>
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(NotificationTest, TestSingleNotification) {
+ thread::ThreadPool* thread_pool =
+ new thread::ThreadPool(Env::Default(), "test", 1);
+
+ int counter = 0;
+ Notification start;
+ Notification proceed;
+ thread_pool->Schedule([&start, &proceed, &counter] {
+ start.Notify();
+ proceed.WaitForNotification();
+ ++counter;
+ });
+
+ // Wait for the thread to start
+ start.WaitForNotification();
+
+ // The thread should be waiting for the 'proceed' notification.
+ EXPECT_EQ(0, counter);
+
+ // Unblock the thread
+ proceed.Notify();
+
+ delete thread_pool; // Wait for closure to finish.
+
+ // Verify the counter has been incremented
+ EXPECT_EQ(1, counter);
+}
+
+TEST(NotificationTest, TestMultipleThreadsWaitingOnNotification) {
+ const int num_closures = 4;
+ thread::ThreadPool* thread_pool =
+ new thread::ThreadPool(Env::Default(), "test", num_closures);
+
+ mutex lock;
+ int counter = 0;
+ Notification n;
+
+ for (int i = 0; i < num_closures; ++i) {
+ thread_pool->Schedule([&n, &lock, &counter] {
+ n.WaitForNotification();
+ mutex_lock l(lock);
+ ++counter;
+ });
+ }
+ sleep(1);
+
+ EXPECT_EQ(0, counter);
+
+ n.Notify();
+ delete thread_pool; // Wait for all closures to finish.
+ EXPECT_EQ(4, counter);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/raw_coding.h b/tensorflow/core/lib/core/raw_coding.h
new file mode 100644
index 0000000000..1fe49b75bb
--- /dev/null
+++ b/tensorflow/core/lib/core/raw_coding.h
@@ -0,0 +1,43 @@
+#ifndef TENSORFLOW_LIB_CORE_RAW_CODING_H_
+#define TENSORFLOW_LIB_CORE_RAW_CODING_H_
+
+#include <string.h>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace core {
+
+// Lower-level versions of Get... that read directly from a character buffer
+// without any bounds checking.
+
+inline uint32 DecodeFixed32(const char* ptr) {
+ if (port::kLittleEndian) {
+ // Load the raw bytes
+ uint32 result;
+ memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
+ return result;
+ } else {
+ return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24));
+ }
+}
+
+inline uint64 DecodeFixed64(const char* ptr) {
+ if (port::kLittleEndian) {
+ // Load the raw bytes
+ uint64 result;
+ memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
+ return result;
+ } else {
+ uint64 lo = DecodeFixed32(ptr);
+ uint64 hi = DecodeFixed32(ptr + 4);
+ return (hi << 32) | lo;
+ }
+}
+
+} // namespace core
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_RAW_CODING_H_
diff --git a/tensorflow/core/lib/core/refcount.cc b/tensorflow/core/lib/core/refcount.cc
new file mode 100644
index 0000000000..3ed8c58eb8
--- /dev/null
+++ b/tensorflow/core/lib/core/refcount.cc
@@ -0,0 +1,35 @@
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace core {
+
+RefCounted::RefCounted() : ref_(1) {}
+
+RefCounted::~RefCounted() { DCHECK_EQ(ref_.load(), 0); }
+
+void RefCounted::Ref() const {
+ DCHECK_GE(ref_.load(), 1);
+ ref_.fetch_add(1, std::memory_order_relaxed);
+}
+
+bool RefCounted::Unref() const {
+ DCHECK_GT(ref_.load(), 0);
+ // If ref_==1, this object is owned only by the caller. Bypass a locked op
+ // in that case.
+ if (ref_.load(std::memory_order_acquire) == 1 || ref_.fetch_sub(1) == 1) {
+ // Make DCHECK in ~RefCounted happy
+ DCHECK((ref_.store(0), true));
+ delete this;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool RefCounted::RefCountIsOne() const {
+ return (ref_.load(std::memory_order_acquire) == 1);
+}
+
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/refcount.h b/tensorflow/core/lib/core/refcount.h
new file mode 100644
index 0000000000..f727750f9e
--- /dev/null
+++ b/tensorflow/core/lib/core/refcount.h
@@ -0,0 +1,63 @@
+#ifndef TENSORFLOW_LIB_CORE_REFCOUNT_H_
+#define TENSORFLOW_LIB_CORE_REFCOUNT_H_
+
+#include <atomic>
+
+namespace tensorflow {
+namespace core {
+
+class RefCounted {
+ public:
+ // Initial reference count is one.
+ RefCounted();
+
+ // Increments reference count by one.
+ void Ref() const;
+
+ // Decrements reference count by one. If the count remains
+ // positive, returns false. When the count reaches zero, returns
+ // true and deletes this, in which case the caller must not access
+ // the object afterward.
+ bool Unref() const;
+
+ // Return whether the reference count is one.
+ // If the reference count is used in the conventional way, a
+ // reference count of 1 implies that the current thread owns the
+ // reference and no other thread shares it.
+ // This call performs the test for a reference count of one, and
+ // performs the memory barrier needed for the owning thread
+ // to act on the object, knowing that it has exclusive access to the
+ // object.
+ bool RefCountIsOne() const;
+
+ protected:
+ // Make destructor protected so that RefCounted objects cannot
+ // be instantiated directly. Only subclasses can be instantiated.
+ virtual ~RefCounted();
+
+ private:
+ mutable std::atomic_int_fast32_t ref_;
+
+ RefCounted(const RefCounted&) = delete;
+ void operator=(const RefCounted&) = delete;
+};
+
+// Helper class to unref an object when out-of-scope.
+class ScopedUnref {
+ public:
+ explicit ScopedUnref(RefCounted* o) : obj_(o) {}
+ ~ScopedUnref() {
+ if (obj_) obj_->Unref();
+ }
+
+ private:
+ RefCounted* obj_;
+
+ ScopedUnref(const ScopedUnref&) = delete;
+ void operator=(const ScopedUnref&) = delete;
+};
+
+} // namespace core
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_REFCOUNT_H_
diff --git a/tensorflow/core/lib/core/refcount_test.cc b/tensorflow/core/lib/core/refcount_test.cc
new file mode 100644
index 0000000000..c042be2d61
--- /dev/null
+++ b/tensorflow/core/lib/core/refcount_test.cc
@@ -0,0 +1,92 @@
+#include "tensorflow/core/lib/core/refcount.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace core {
+namespace {
+
+static int constructed = 0;
+static int destroyed = 0;
+
+class MyRef : public RefCounted {
+ public:
+ MyRef() { constructed++; }
+ ~MyRef() override { destroyed++; }
+};
+
+class RefTest : public testing::Test {
+ public:
+ RefTest() {
+ constructed = 0;
+ destroyed = 0;
+ }
+};
+
+TEST_F(RefTest, New) {
+ MyRef* ref = new MyRef;
+ ASSERT_EQ(1, constructed);
+ ASSERT_EQ(0, destroyed);
+ ref->Unref();
+ ASSERT_EQ(1, constructed);
+ ASSERT_EQ(1, destroyed);
+}
+
+TEST_F(RefTest, RefUnref) {
+ MyRef* ref = new MyRef;
+ ASSERT_EQ(1, constructed);
+ ASSERT_EQ(0, destroyed);
+ ref->Ref();
+ ASSERT_EQ(0, destroyed);
+ ref->Unref();
+ ASSERT_EQ(0, destroyed);
+ ref->Unref();
+ ASSERT_EQ(1, destroyed);
+}
+
+TEST_F(RefTest, RefCountOne) {
+ MyRef* ref = new MyRef;
+ ASSERT_TRUE(ref->RefCountIsOne());
+ ref->Unref();
+}
+
+TEST_F(RefTest, RefCountNotOne) {
+ MyRef* ref = new MyRef;
+ ref->Ref();
+ ASSERT_FALSE(ref->RefCountIsOne());
+ ref->Unref();
+ ref->Unref();
+}
+
+TEST_F(RefTest, ConstRefUnref) {
+ const MyRef* cref = new MyRef;
+ ASSERT_EQ(1, constructed);
+ ASSERT_EQ(0, destroyed);
+ cref->Ref();
+ ASSERT_EQ(0, destroyed);
+ cref->Unref();
+ ASSERT_EQ(0, destroyed);
+ cref->Unref();
+ ASSERT_EQ(1, destroyed);
+}
+
+TEST_F(RefTest, ReturnOfUnref) {
+ MyRef* ref = new MyRef;
+ ref->Ref();
+ EXPECT_FALSE(ref->Unref());
+ EXPECT_TRUE(ref->Unref());
+}
+
+TEST_F(RefTest, ScopedUnref) {
+ { ScopedUnref unref(new MyRef); }
+ EXPECT_EQ(destroyed, 1);
+}
+
+TEST_F(RefTest, ScopedUnref_Nullptr) {
+ { ScopedUnref unref(nullptr); }
+ EXPECT_EQ(destroyed, 0);
+}
+
+} // namespace
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/status.cc b/tensorflow/core/lib/core/status.cc
new file mode 100644
index 0000000000..24ce842560
--- /dev/null
+++ b/tensorflow/core/lib/core/status.cc
@@ -0,0 +1,107 @@
+#include "tensorflow/core/public/status.h"
+#include <stdio.h>
+
+namespace tensorflow {
+
+Status::Status(tensorflow::error::Code code, StringPiece msg) {
+ assert(code != tensorflow::error::OK);
+ state_ = new State;
+ state_->code = code;
+ state_->msg = msg.ToString();
+}
+Status::~Status() { delete state_; }
+
+void Status::Update(const Status& new_status) {
+ if (ok()) {
+ *this = new_status;
+ }
+}
+
+void Status::SlowCopyFrom(const State* src) {
+ delete state_;
+ if (src == nullptr) {
+ state_ = nullptr;
+ } else {
+ state_ = new State(*src);
+ }
+}
+
+const string& Status::empty_string() {
+ static string* empty = new string;
+ return *empty;
+}
+
+string Status::ToString() const {
+ if (state_ == NULL) {
+ return "OK";
+ } else {
+ char tmp[30];
+ const char* type;
+ switch (code()) {
+ case tensorflow::error::CANCELLED:
+ type = "Cancelled";
+ break;
+ case tensorflow::error::UNKNOWN:
+ type = "Unknown";
+ break;
+ case tensorflow::error::INVALID_ARGUMENT:
+ type = "Invalid argument";
+ break;
+ case tensorflow::error::DEADLINE_EXCEEDED:
+ type = "Deadline exceeded";
+ break;
+ case tensorflow::error::NOT_FOUND:
+ type = "Not found";
+ break;
+ case tensorflow::error::ALREADY_EXISTS:
+ type = "Already exists";
+ break;
+ case tensorflow::error::PERMISSION_DENIED:
+ type = "Permission denied";
+ break;
+ case tensorflow::error::UNAUTHENTICATED:
+ type = "Unauthenticated";
+ break;
+ case tensorflow::error::RESOURCE_EXHAUSTED:
+ type = "Resource exhausted";
+ break;
+ case tensorflow::error::FAILED_PRECONDITION:
+ type = "Failed precondition";
+ break;
+ case tensorflow::error::ABORTED:
+ type = "Aborted";
+ break;
+ case tensorflow::error::OUT_OF_RANGE:
+ type = "Out of range";
+ break;
+ case tensorflow::error::UNIMPLEMENTED:
+ type = "Unimplemented";
+ break;
+ case tensorflow::error::INTERNAL:
+ type = "Internal";
+ break;
+ case tensorflow::error::UNAVAILABLE:
+ type = "Unavailable";
+ break;
+ case tensorflow::error::DATA_LOSS:
+ type = "Data loss";
+ break;
+ default:
+ snprintf(tmp, sizeof(tmp), "Unknown code(%d)",
+ static_cast<int>(code()));
+ type = tmp;
+ break;
+ }
+ string result(type);
+ result += ": ";
+ result += state_->msg;
+ return result;
+ }
+}
+
+std::ostream& operator<<(std::ostream& os, const Status& x) {
+ os << x.ToString();
+ return os;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/status_test.cc b/tensorflow/core/lib/core/status_test.cc
new file mode 100644
index 0000000000..3ef6b3302a
--- /dev/null
+++ b/tensorflow/core/lib/core/status_test.cc
@@ -0,0 +1,84 @@
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(Status, OK) {
+ EXPECT_EQ(Status::OK().code(), error::OK);
+ EXPECT_EQ(Status::OK().error_message(), "");
+ EXPECT_OK(Status::OK());
+ ASSERT_OK(Status::OK());
+ EXPECT_EQ(Status::OK(), Status());
+ Status s;
+ EXPECT_TRUE(s.ok());
+}
+
+TEST(DeathStatus, CheckOK) {
+ Status status(errors::InvalidArgument("Invalid"));
+ ASSERT_DEATH(TF_CHECK_OK(status), "Invalid");
+}
+
+TEST(Status, Set) {
+ Status status;
+ status = Status(error::CANCELLED, "Error message");
+ EXPECT_EQ(status.code(), error::CANCELLED);
+ EXPECT_EQ(status.error_message(), "Error message");
+}
+
+TEST(Status, Copy) {
+ Status a(errors::InvalidArgument("Invalid"));
+ Status b(a);
+ ASSERT_EQ(a.ToString(), b.ToString());
+}
+
+TEST(Status, Assign) {
+ Status a(errors::InvalidArgument("Invalid"));
+ Status b;
+ b = a;
+ ASSERT_EQ(a.ToString(), b.ToString());
+}
+
+TEST(Status, Update) {
+ Status s;
+ s.Update(Status::OK());
+ ASSERT_TRUE(s.ok());
+ Status a(errors::InvalidArgument("Invalid"));
+ s.Update(a);
+ ASSERT_EQ(s.ToString(), a.ToString());
+ Status b(errors::Internal("Internal"));
+ s.Update(b);
+ ASSERT_EQ(s.ToString(), a.ToString());
+ s.Update(Status::OK());
+ ASSERT_EQ(s.ToString(), a.ToString());
+ ASSERT_FALSE(s.ok());
+}
+
+TEST(Status, EqualsOK) { ASSERT_EQ(Status::OK(), Status()); }
+
+TEST(Status, EqualsSame) {
+ Status a(errors::InvalidArgument("Invalid"));
+ Status b(errors::InvalidArgument("Invalid"));
+ ASSERT_EQ(a, b);
+}
+
+TEST(Status, EqualsCopy) {
+ const Status a(errors::InvalidArgument("Invalid"));
+ const Status b = a;
+ ASSERT_EQ(a, b);
+}
+
+TEST(Status, EqualsDifferentCode) {
+ const Status a(errors::InvalidArgument("message"));
+ const Status b(errors::Internal("message"));
+ ASSERT_NE(a, b);
+}
+
+TEST(Status, EqualsDifferentMessage) {
+ const Status a(errors::InvalidArgument("message"));
+ const Status b(errors::InvalidArgument("another"));
+ ASSERT_NE(a, b);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h
new file mode 100644
index 0000000000..b3b4db429f
--- /dev/null
+++ b/tensorflow/core/lib/core/status_test_util.h
@@ -0,0 +1,20 @@
+#ifndef TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
+#define TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/status.h"
+
+// Macros for testing the results of functions that return util::Status.
+
+#define EXPECT_OK(statement) EXPECT_EQ(::tensorflow::Status::OK(), (statement))
+#define ASSERT_OK(statement) ASSERT_EQ(::tensorflow::Status::OK(), (statement))
+
+// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not
+// provide much value (when they fail, they would just print the OK status
+// which conveys no more information than EXPECT_FALSE(status.ok());
+// If you want to check for particular errors, better alternatives are:
+// EXPECT_EQ(::util::Status(...expected error...), status.StripMessage());
+// EXPECT_THAT(status.ToString(), HasSubstr("expected error"));
+// Also, see testing/lib/util/status_util.h.
+
+#endif // TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc
new file mode 100644
index 0000000000..57c5139f47
--- /dev/null
+++ b/tensorflow/core/lib/core/stringpiece.cc
@@ -0,0 +1,57 @@
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+#include <iostream>
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+
+size_t StringPiece::Hasher::operator()(StringPiece s) const {
+ return Hash64(s.data(), s.size());
+}
+
+std::ostream& operator<<(std::ostream& o, StringPiece piece) {
+ o.write(piece.data(), piece.size());
+ return o;
+}
+
+bool StringPiece::contains(StringPiece s) const {
+ return memmem(data_, size_, s.data_, s.size_) != nullptr;
+}
+
+size_t StringPiece::find(char c, size_t pos) const {
+ if (pos >= size_) {
+ return npos;
+ }
+ const char* result =
+ reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos));
+ return result != NULL ? result - data_ : npos;
+}
+
+// Search range is [0..pos] inclusive. If pos == npos, search everything.
+size_t StringPiece::rfind(char c, size_t pos) const {
+ if (size_ == 0) return npos;
+ for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) {
+ if (*p == c) {
+ return p - data_;
+ }
+ }
+ return npos;
+}
+
+bool StringPiece::Consume(StringPiece x) {
+ if (starts_with(x)) {
+ remove_prefix(x.size_);
+ return true;
+ }
+ return false;
+}
+
+StringPiece StringPiece::substr(size_t pos, size_t n) const {
+ if (pos > size_) pos = size_;
+ if (n > size_ - pos) n = size_ - pos;
+ return StringPiece(data_ + pos, n);
+}
+
+const StringPiece::size_type StringPiece::npos = size_type(-1);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
new file mode 100644
index 0000000000..17d4b294e9
--- /dev/null
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -0,0 +1,159 @@
+// StringPiece is a simple structure containing a pointer into some external
+// storage and a size. The user of a StringPiece must ensure that the slice
+// is not used after the corresponding external storage has been
+// deallocated.
+//
+// Multiple threads can invoke const methods on a StringPiece without
+// external synchronization, but if any of the threads may call a
+// non-const method, all threads accessing the same StringPiece must use
+// external synchronization.
+
+#ifndef TENSORFLOW_LIB_CORE_STRINGPIECE_H_
+#define TENSORFLOW_LIB_CORE_STRINGPIECE_H_
+
+#include <assert.h>
+#include <stddef.h>
+#include <string.h>
+#include <iosfwd>
+#include <string>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class StringPiece {
+ public:
+ typedef size_t size_type;
+
+ // Create an empty slice.
+ StringPiece() : data_(""), size_(0) {}
+
+ // Create a slice that refers to d[0,n-1].
+ StringPiece(const char* d, size_t n) : data_(d), size_(n) {}
+
+ // Create a slice that refers to the contents of "s"
+ StringPiece(const string& s) : data_(s.data()), size_(s.size()) {}
+
+ // Create a slice that refers to s[0,strlen(s)-1]
+ StringPiece(const char* s) : data_(s), size_(strlen(s)) {}
+
+ void set(const void* data, size_t len) {
+ data_ = reinterpret_cast<const char*>(data);
+ size_ = len;
+ }
+
+ // Return a pointer to the beginning of the referenced data
+ const char* data() const { return data_; }
+
+ // Return the length (in bytes) of the referenced data
+ size_t size() const { return size_; }
+
+ // Return true iff the length of the referenced data is zero
+ bool empty() const { return size_ == 0; }
+
+ typedef const char* const_iterator;
+ typedef const char* iterator;
+ iterator begin() const { return data_; }
+ iterator end() const { return data_ + size_; }
+
+ static const size_t npos;
+
+ // Return the ith byte in the referenced data.
+ // REQUIRES: n < size()
+ char operator[](size_t n) const {
+ assert(n < size());
+ return data_[n];
+ }
+
+ // Change this slice to refer to an empty array
+ void clear() {
+ data_ = "";
+ size_ = 0;
+ }
+
+ // Drop the first "n" bytes from this slice.
+ void remove_prefix(size_t n) {
+ assert(n <= size());
+ data_ += n;
+ size_ -= n;
+ }
+
+ void remove_suffix(size_t n) {
+ assert(size_ >= n);
+ size_ -= n;
+ }
+
+ size_t find(char c, size_t pos = 0) const;
+ size_t rfind(char c, size_t pos = npos) const;
+ bool contains(StringPiece s) const;
+
+ // Checks whether StringPiece starts with x and if so advances the beginning
+ // of it to past the match. It's basically a shortcut for starts_with
+ // followed by remove_prefix.
+ bool Consume(StringPiece x);
+
+ StringPiece substr(size_t pos, size_t n = npos) const;
+
+ struct Hasher {
+ size_t operator()(StringPiece arg) const;
+ };
+
+ // Return a string that contains the copy of the referenced data.
+ std::string ToString() const { return std::string(data_, size_); }
+
+ // Three-way comparison. Returns value:
+ // < 0 iff "*this" < "b",
+ // == 0 iff "*this" == "b",
+ // > 0 iff "*this" > "b"
+ int compare(StringPiece b) const;
+
+ // Return true iff "x" is a prefix of "*this"
+ bool starts_with(StringPiece x) const {
+ return ((size_ >= x.size_) && (memcmp(data_, x.data_, x.size_) == 0));
+ }
+ // Return true iff "x" is a suffix of "*this"
+ bool ends_with(StringPiece x) const {
+ return ((size_ >= x.size_) &&
+ (memcmp(data_ + (size_ - x.size_), x.data_, x.size_) == 0));
+ }
+
+ private:
+ const char* data_;
+ size_t size_;
+
+ // Intentionally copyable
+};
+
+inline bool operator==(StringPiece x, StringPiece y) {
+ return ((x.size() == y.size()) &&
+ (memcmp(x.data(), y.data(), x.size()) == 0));
+}
+
+inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
+
+inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; }
+inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; }
+inline bool operator<=(StringPiece x, StringPiece y) {
+ return x.compare(y) <= 0;
+}
+inline bool operator>=(StringPiece x, StringPiece y) {
+ return x.compare(y) >= 0;
+}
+
+inline int StringPiece::compare(StringPiece b) const {
+ const size_t min_len = (size_ < b.size_) ? size_ : b.size_;
+ int r = memcmp(data_, b.data_, min_len);
+ if (r == 0) {
+ if (size_ < b.size_)
+ r = -1;
+ else if (size_ > b.size_)
+ r = +1;
+ }
+ return r;
+}
+
+// allow StringPiece to be logged
+extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_STRINGPIECE_H_
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
new file mode 100644
index 0000000000..e9b84d3102
--- /dev/null
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -0,0 +1,108 @@
+#include "tensorflow/core/lib/core/threadpool.h"
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tracing.h"
+
+namespace tensorflow {
+namespace thread {
+
+struct ThreadPool::Waiter {
+ condition_variable cv;
+ bool ready;
+};
+
+ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
+ : ThreadPool(env, ThreadOptions(), name, num_threads) {}
+
+ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
+ const string& name, int num_threads)
+ : name_(name) {
+ CHECK_GE(num_threads, 1);
+ string name_prefix = "tf_" + name_;
+ for (int i = 0; i < num_threads; i++) {
+ threads_.push_back(env->StartThread(thread_options, name_prefix,
+ [this]() { WorkerLoop(); }));
+ }
+}
+
+ThreadPool::~ThreadPool() {
+ {
+ // Wait for all work to get done.
+ mutex_lock l(mu_);
+
+ // Inform every thread to exit.
+ for (size_t i = 0; i < threads_.size(); ++i) {
+ pending_.push_back({nullptr, 0});
+ }
+
+ // Wakeup all waiters.
+ for (auto w : waiters_) {
+ w->ready = true;
+ w->cv.notify_one();
+ }
+ }
+
+ // Wait for threads to finish.
+ for (auto t : threads_) {
+ delete t;
+ }
+}
+
+bool ThreadPool::HasPendingClosures() const {
+ mutex_lock l(mu_);
+ return pending_.size() != 0;
+}
+
+void ThreadPool::Schedule(std::function<void()> fn) {
+ CHECK(fn != nullptr);
+ uint64 id = 0;
+ if (port::Tracing::IsActive()) {
+ id = port::Tracing::UniqueId();
+ port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
+ id);
+ }
+
+ mutex_lock l(mu_);
+ pending_.push_back({fn, id});
+ if (!waiters_.empty()) {
+ Waiter* w = waiters_.back();
+ waiters_.pop_back();
+ w->ready = true;
+ w->cv.notify_one();
+ }
+}
+
+void ThreadPool::WorkerLoop() {
+ port::Tracing::RegisterCurrentThread(name_.c_str());
+ mutex_lock l(mu_);
+ Waiter w;
+ while (true) {
+ while (pending_.empty()) {
+ // Wait for work to be assigned to me
+ w.ready = false;
+ waiters_.push_back(&w);
+ while (!w.ready) {
+ w.cv.wait(l);
+ }
+ }
+ // Pick up pending work
+ Item item = pending_.front();
+ pending_.pop_front();
+ if (item.fn == nullptr) {
+ break;
+ }
+ mu_.unlock();
+ if (item.id != 0) {
+ port::Tracing::ScopedActivity region(
+ port::Tracing::EventCategory::kRunClosure, item.id);
+ item.fn();
+ } else {
+ item.fn();
+ }
+ mu_.lock();
+ }
+}
+
+} // namespace thread
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
new file mode 100644
index 0000000000..5cf780fa86
--- /dev/null
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -0,0 +1,59 @@
+#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_
+#define TENSORFLOW_LIB_CORE_THREADPOOL_H_
+
+#include <deque>
+#include <functional>
+#include <thread>
+#include <vector>
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace thread {
+
+class ThreadPool {
+ public:
+ // Construct a pool that contains "num_threads" threads with specified "name".
+ // env->StartThread() is used to create individual threads.
+ //
+ // REQUIRES: num_threads > 0
+ ThreadPool(Env* env, const string& name, int num_threads);
+
+ // Construct a pool that contains "num_threads" threads with specified "name".
+ // env->StartThread() is used to create individual threads.
+ //
+ // REQUIRES: num_threads > 0
+ ThreadPool(Env* env, const ThreadOptions& thread_options, const string& name,
+ int num_threads);
+
+ // Wait until all scheduled work has finished and then destroy the
+ // set of threads.
+ virtual ~ThreadPool();
+
+ // Schedule fn() for execution in the pool of threads.
+ virtual void Schedule(std::function<void()> fn);
+
+ virtual bool HasPendingClosures() const;
+
+ private:
+ struct Waiter;
+ struct Item {
+ std::function<void()> fn;
+ uint64 id;
+ };
+
+ void WorkerLoop();
+
+ const string name_;
+ mutable mutex mu_;
+ std::vector<Thread*> threads_; // All threads
+ std::vector<Waiter*> waiters_; // Stack of waiting threads.
+ std::deque<Item> pending_; // Queue of pending work
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool);
+};
+
+} // namespace thread
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_THREADPOOL_H_
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
new file mode 100644
index 0000000000..f4909c445c
--- /dev/null
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -0,0 +1,93 @@
+#include "tensorflow/core/lib/core/threadpool.h"
+
+#include <atomic>
+
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/env.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace thread {
+
+static const int kNumThreads = 30;
+
+TEST(ThreadPool, Empty) {
+ for (int num_threads = 1; num_threads < kNumThreads; num_threads++) {
+ fprintf(stderr, "Testing with %d threads\n", num_threads);
+ ThreadPool pool(Env::Default(), "test", num_threads);
+ }
+}
+
+TEST(ThreadPool, DoWork) {
+ for (int num_threads = 1; num_threads < kNumThreads; num_threads++) {
+ fprintf(stderr, "Testing with %d threads\n", num_threads);
+ const int kWorkItems = 15;
+ bool work[kWorkItems];
+ for (int i = 0; i < kWorkItems; i++) {
+ work[i] = false;
+ }
+ {
+ ThreadPool pool(Env::Default(), "test", num_threads);
+ for (int i = 0; i < kWorkItems; i++) {
+ pool.Schedule([&work, i]() {
+ ASSERT_FALSE(work[i]);
+ work[i] = true;
+ });
+ }
+ }
+ for (int i = 0; i < kWorkItems; i++) {
+ ASSERT_TRUE(work[i]);
+ }
+ }
+}
+
+static void BM_Sequential(int iters) {
+ ThreadPool pool(Env::Default(), "test", kNumThreads);
+ // Decrement count sequentially until 0.
+ int count = iters;
+ mutex done_lock;
+ condition_variable done;
+ bool done_flag = false;
+ std::function<void()> work = [&pool, &count, &done_lock, &done, &done_flag,
+ &work]() {
+ if (count--) {
+ pool.Schedule(work);
+ } else {
+ mutex_lock l(done_lock);
+ done_flag = true;
+ done.notify_all();
+ }
+ };
+ work();
+ mutex_lock l(done_lock);
+ if (!done_flag) {
+ done.wait(l);
+ }
+}
+BENCHMARK(BM_Sequential);
+
+static void BM_Parallel(int iters) {
+ ThreadPool pool(Env::Default(), "test", kNumThreads);
+ // Decrement count concurrently until 0.
+ std::atomic_int_fast32_t count(iters);
+ mutex done_lock;
+ condition_variable done;
+ bool done_flag = false;
+ for (int i = 0; i < iters; ++i) {
+ pool.Schedule([&count, &done_lock, &done, &done_flag]() {
+ if (count.fetch_sub(1) == 1) {
+ mutex_lock l(done_lock);
+ done_flag = true;
+ done.notify_all();
+ }
+ });
+ }
+ mutex_lock l(done_lock);
+ if (!done_flag) {
+ done.wait(l);
+ }
+}
+BENCHMARK(BM_Parallel);
+
+} // namespace thread
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/array_slice.h b/tensorflow/core/lib/gtl/array_slice.h
new file mode 100644
index 0000000000..813fb126e3
--- /dev/null
+++ b/tensorflow/core/lib/gtl/array_slice.h
@@ -0,0 +1,299 @@
+// An ArraySlice<T> represents an immutable array of elements of type
+// T. It has a length "length", and a base pointer "ptr", and the
+// array it represents contains the elements "ptr[0] .. ptr[len-1]".
+// The backing store for the array is *not* owned by the ArraySlice
+// object, and clients must arrange for the backing store to remain
+// live while the ArraySlice object is in use.
+//
+// An ArraySlice<T> is somewhat analogous to a StringPiece, but for
+// array elements of type T.
+//
+// Implicit conversion operations are provided from types such as
+// std::vector<T> and util::gtl::InlinedVector<T, N>. Note that ArraySlice
+// objects constructed from types in this way may be invalidated by
+// any operations that mutate the underlying vector.
+//
+// One common use for ArraySlice is when passing arguments to a
+// routine where you want to be able to accept a variety of array
+// types (e.g. a vector, a util::gtl::InlinedVector, a C-style array,
+// etc.). The usual approach here is to have the client explicitly
+// pass in a pointer and a length, as in:
+//
+// void MyRoutine(const int* elems, int N) {
+// for (int i = 0; i < N; i++) { .. do something with elems[i] .. }
+// }
+//
+// Unfortunately, this leads to ugly and error-prone code at the call site:
+//
+// std::vector<int> my_vector;
+// MyRoutine(vector_as_array(&my_vector), my_vector.size());
+//
+// util::gtl::InlinedVector<int, 4> my_inline_vector;
+// MyRoutine(my_inline_vector.array(), my_inline_vector.size());
+//
+// int my_array[10];
+// MyRoutine(my_array, 10);
+//
+// Instead, you can use an ArraySlice as the argument to the routine:
+//
+// void MyRoutine(ArraySlice<int> a) {
+// for (int i = 0; i < a.size(); i++) { .. do something with a[i] .. }
+// }
+//
+// This makes the call sites cleaner, for the most part:
+//
+// std::vector<int> my_vector;
+// MyRoutine(my_vector);
+//
+// util::gtl::InlinedVector<int, 4> my_inline_vector;
+// MyRoutine(my_inline_vector);
+//
+// int my_array[10];
+// MyRoutine(my_array);
+//
+// int* my_array = new int[10];
+// MyRoutine(gtl::ArraySlice<int>(my_array, 10));
+//
+// MutableArraySlice<T> represents a mutable array of elements, and, like
+// ArraySlice, does not own the backing store. The implicit constructors it
+// provides allow functions not to worry about whether their mutable arguments
+// refer to vectors, arrays, proto2::RepeatedFields, etc.:
+//
+// void MyMutatingRoutine(MutableArraySlice<int> a) {
+// for (int i = 0; i < a.size(); i++) { .. mutate a[i] .. }
+// }
+//
+// std::vector<int> my_vector;
+// MyMutatingRoutine(&my_vector);
+//
+// int my_array[10];
+// MyMutatingRoutine(my_array);
+//
+// int* my_array = new int[10];
+// MyMutatingRoutine(gtl::MutableArraySlice<int>(my_array, 10));
+//
+// MyProto my_proto;
+// for (int i = 0; i < 10; ++i) { my_proto.add_value(i); }
+// MyMutatingRoutine(my_proto.mutable_value());
+
+#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
+#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
+
+#include <initializer_list>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/array_slice_internal.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+namespace gtl {
+
+template <typename T>
+class ArraySlice {
+ private:
+ typedef array_slice_internal::ArraySliceImpl<T> Impl;
+
+ public:
+ typedef T value_type;
+ typedef typename Impl::pointer pointer;
+ typedef typename Impl::const_pointer const_pointer;
+ typedef typename Impl::reference reference;
+ typedef typename Impl::const_reference const_reference;
+ typedef typename Impl::iterator iterator;
+ typedef typename Impl::const_iterator const_iterator;
+ typedef typename Impl::reverse_iterator reverse_iterator;
+ typedef typename Impl::const_reverse_iterator const_reverse_iterator;
+ typedef typename Impl::size_type size_type;
+ typedef typename Impl::difference_type difference_type;
+
+ static const size_type npos = Impl::npos;
+
+ ArraySlice() : impl_(nullptr, 0) {}
+ ArraySlice(const_pointer array, size_type length) : impl_(array, length) {}
+
+ // Implicit conversion constructors
+ ArraySlice(const std::vector<value_type>& v) // NOLINT(runtime/explicit)
+ : impl_(v.data(), v.size()) {}
+
+ template <size_t N>
+ ArraySlice(const value_type (&a)[N]) // NOLINT(runtime/explicit)
+ : impl_(a, N) {}
+
+ template <int N>
+ ArraySlice(const InlinedVector<value_type, N>& v) // NOLINT(runtime/explicit)
+ : impl_(v.array(), v.size()) {}
+
+ // The constructor for any class supplying 'data() const' that returns either
+ // const T* or a less const-qualified version of it, and 'some_integral_type
+ // size() const'. proto2::RepeatedField<T>, string and (since C++11)
+ // std::vector<T,A> and std::array<T, N> are examples of this. See
+ // array_slice_internal.h for details.
+ template <typename V,
+ typename = typename Impl::template EnableIfConvertibleFrom<V>>
+ ArraySlice(const V& v) // NOLINT(runtime/explicit)
+ : impl_(v) {}
+
+ // Implicitly constructs an ArraySlice from an initializer list. This makes it
+ // possible to pass a brace-enclosed initializer list to a function expecting
+ // an ArraySlice:
+ // void Process(ArraySlice<int> x);
+ // Process({1, 2, 3});
+ // The data referenced by the initializer_list must outlive this
+ // ArraySlice. For example, "ArraySlice<int> s={1,2};" and "return
+ // ArraySlice<int>({3,4});" are errors, as the resulting ArraySlice may
+ // reference data that is no longer valid.
+ ArraySlice(std::initializer_list<value_type> v) // NOLINT(runtime/explicit)
+ : impl_(v.begin(), v.size()) {}
+
+ // Substring of another ArraySlice.
+ // pos must be non-negative and <= x.length().
+ // len must be non-negative and will be pinned to at most x.length() - pos.
+ // If len==npos, the substring continues till the end of x.
+ ArraySlice(const ArraySlice& x, size_type pos, size_type len)
+ : impl_(x.impl_, pos, len) {}
+
+ const_pointer data() const { return impl_.data(); }
+ size_type size() const { return impl_.size(); }
+ size_type length() const { return size(); }
+ bool empty() const { return size() == 0; }
+
+ void clear() { impl_.clear(); }
+
+ const_reference operator[](size_type i) const { return impl_[i]; }
+ const_reference at(size_type i) const { return impl_.at(i); }
+ const_reference front() const { return impl_.front(); }
+ const_reference back() const { return impl_.back(); }
+
+ const_iterator begin() const { return impl_.begin(); }
+ const_iterator end() const { return impl_.end(); }
+ const_reverse_iterator rbegin() const { return impl_.rbegin(); }
+ const_reverse_iterator rend() const { return impl_.rend(); }
+
+ void remove_prefix(size_type n) { impl_.remove_prefix(n); }
+ void remove_suffix(size_type n) { impl_.remove_suffix(n); }
+ void pop_back() { remove_suffix(1); }
+ void pop_front() { remove_prefix(1); }
+
+ // These relational operators have the same semantics as the
+ // std::vector<T> relational operators: they do deep (elementwise)
+ // comparisons. Array slices are equal iff their size is the same
+ // and all their elements are equal.
+ bool operator==(ArraySlice<T> other) const { return impl_ == other.impl_; }
+ bool operator!=(ArraySlice<T> other) const { return impl_ != other.impl_; }
+
+ private:
+ Impl impl_;
+};
+
+// Mutable version of ArraySlice, which allows the clients to mutate the
+// underlying data. It is implicitly convertible to ArraySlice since it provides
+// the data() and size() methods with correct signatures. When a
+// MutableArraySlice is created from a pointer to a container (as opposed to raw
+// memory pointer), the pointer must not be null.
+//
+// A note on const-ness: "mutable" here refers to the mutability of the
+// underlying data, not of the slice itself. It is perfectly reasonable to have
+// a variable of type "const MutableArraySlice<T>"; this means that the bounds
+// of the view on the array cannot be changed, but the underlying data in the
+// array still may be modified. This is akin to a "T* const" pointer, as opposed
+// to a "const T*" pointer (corresponding to a non-const ArraySlice<T>).
+template <typename T>
+class MutableArraySlice {
+ private:
+ typedef array_slice_internal::MutableArraySliceImpl<T> Impl;
+
+ public:
+ typedef T value_type;
+ typedef typename Impl::pointer pointer;
+ typedef typename Impl::const_pointer const_pointer;
+ typedef typename Impl::reference reference;
+ typedef typename Impl::const_reference const_reference;
+ typedef typename Impl::iterator iterator;
+ typedef typename Impl::const_iterator const_iterator;
+ typedef typename Impl::reverse_iterator reverse_iterator;
+ typedef typename Impl::const_reverse_iterator const_reverse_iterator;
+ typedef typename Impl::size_type size_type;
+ typedef typename Impl::difference_type difference_type;
+
+ static const size_type npos = Impl::npos;
+
+ MutableArraySlice() : impl_(nullptr, 0) {}
+ MutableArraySlice(pointer array, size_type length) : impl_(array, length) {}
+
+ // Implicit conversion constructors
+ MutableArraySlice(std::vector<value_type>* v) // NOLINT(runtime/explicit)
+ : impl_(v->data(), v->size()) {}
+
+ template <size_t N>
+ MutableArraySlice(value_type (&a)[N]) // NOLINT(runtime/explicit)
+ : impl_(a, N) {}
+
+ template <int N>
+ MutableArraySlice(
+ InlinedVector<value_type, N>* v) // NOLINT(runtime/explicit)
+ : impl_(v->mutable_array(), v->size()) {}
+
+ // The constructor for any class supplying 'T* data()' or 'T* mutable_data()'
+ // (the former is called if both exist), and 'some_integral_type size()
+ // const'. proto2::RepeatedField is an example of this. Also supports string
+ // arguments, when T==char. The appropriate ctor is selected using SFINAE. See
+ // array_slice_internal.h for details.
+ template <typename V,
+ typename = typename Impl::template EnableIfConvertibleFrom<V>>
+ MutableArraySlice(V* v) // NOLINT(runtime/explicit)
+ : impl_(v) {}
+
+ // Substring of another MutableArraySlice.
+ // pos must be non-negative and <= x.length().
+ // len must be non-negative and will be pinned to at most x.length() - pos.
+ // If len==npos, the substring continues till the end of x.
+ MutableArraySlice(const MutableArraySlice& x, size_type pos, size_type len)
+ : impl_(x.impl_, pos, len) {}
+
+ // Accessors.
+ pointer data() const { return impl_.data(); }
+ size_type size() const { return impl_.size(); }
+ size_type length() const { return size(); }
+ bool empty() const { return size() == 0; }
+
+ void clear() { impl_.clear(); }
+
+ reference operator[](size_type i) const { return impl_[i]; }
+ reference at(size_type i) const { return impl_.at(i); }
+ reference front() const { return impl_.front(); }
+ reference back() const { return impl_.back(); }
+
+ iterator begin() const { return impl_.begin(); }
+ iterator end() const { return impl_.end(); }
+ reverse_iterator rbegin() const { return impl_.rbegin(); }
+ reverse_iterator rend() const { return impl_.rend(); }
+
+ void remove_prefix(size_type n) { impl_.remove_prefix(n); }
+ void remove_suffix(size_type n) { impl_.remove_suffix(n); }
+ void pop_back() { remove_suffix(1); }
+ void pop_front() { remove_prefix(1); }
+
+ bool operator==(ArraySlice<T> other) const {
+ return ArraySlice<T>(*this) == other;
+ }
+ bool operator!=(ArraySlice<T> other) const {
+ return ArraySlice<T>(*this) != other;
+ }
+
+ // DEPRECATED(jacobsa): Please use data() instead.
+ pointer mutable_data() const { return impl_.data(); }
+
+ private:
+ Impl impl_;
+};
+
+template <typename T>
+const typename ArraySlice<T>::size_type ArraySlice<T>::npos;
+template <typename T>
+const typename MutableArraySlice<T>::size_type MutableArraySlice<T>::npos;
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_internal.h b/tensorflow/core/lib/gtl/array_slice_internal.h
new file mode 100644
index 0000000000..080f0a38d8
--- /dev/null
+++ b/tensorflow/core/lib/gtl/array_slice_internal.h
@@ -0,0 +1,253 @@
+// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by
+// array_slice.h.
+
+// Helper functions and templates for ArraySlice.
+
+#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
+#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
+
+#include <stddef.h>
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace gtl {
+namespace array_slice_internal {
+
+// Template logic for generic constructors.
+
+// Wrappers whose Get() delegates to the appropriate method of a container, and
+// is defined when this method exists. Delegates to the const method if C is a
+// const type.
+struct Data {
+ template <typename C>
+ static decltype(std::declval<C>().data()) Get(C* v) {
+ return v->data();
+ }
+};
+
+struct MutableData {
+ template <typename C>
+ static decltype(std::declval<C>().mutable_data()) Get(C* v) {
+ return v->mutable_data();
+ }
+};
+
+struct Size {
+ template <typename C>
+ static decltype(std::declval<C>().size()) Get(C* v) {
+ return v->size();
+ }
+};
+
+struct MutableStringData {
+ // Defined only for string.
+ static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); }
+};
+
+// Checks whether M::Get(C*) is defined and has a return type R such that
+// Checker::valid<R>()==true.
+template <typename M, typename Checker, typename C>
+struct HasGetHelper : public M {
+ private:
+ struct None {};
+ // M::Get is selected when it is viable. Get(...) is selected otherwise.
+ using M::Get;
+ static None Get(...);
+
+ public:
+ static constexpr bool HasGet() {
+ using Result = decltype(Get(std::declval<C*>()));
+ return !std::is_same<Result, None>() && Checker::template valid<Result>();
+ }
+};
+
+// Defines HasGet() for a particular method, container, and checker. If
+// HasGet()==true, provides Get() that delegates to the method.
+template <typename M, typename Checker, typename C,
+ bool /*has_get*/ = HasGetHelper<M, Checker, C>::HasGet()>
+struct Wrapper {
+ static constexpr bool HasGet() { return false; }
+};
+
+template <typename M, typename Checker, typename C>
+struct Wrapper<M, Checker, C, true> {
+ static constexpr bool HasGet() { return true; }
+ static decltype(M::Get(std::declval<C*>())) Get(C* v) { return M::Get(v); }
+};
+
+// Type checker for a method returning an integral value.
+struct SizeChecker {
+ template <typename R>
+ static constexpr bool valid() {
+ return std::is_integral<R>::value;
+ }
+};
+
+// Type checker for a method returning either a pointer to T or a less const
+// version of that.
+template <typename T>
+struct DataChecker {
+ // We want to enable conversion from std::vector<T*> to ArraySlice<const T*>
+ // but
+ // disable conversion from std::vector<Derived> to ArraySlice<Base>. Here we
+ // use
+ // the fact that U** is convertible to Q* const* if and only if Q is the same
+ // type or a more cv-qualified version of U.
+ template <typename R>
+ static constexpr bool valid() {
+ return std::is_convertible<R*, T* const*>::value;
+ }
+};
+
+// Aliases to A if A::HasGet()==true, or to B otherwise.
+template <typename A, typename B>
+using FirstWithGet = typename std::conditional<A::HasGet(), A, B>::type;
+
+// Wraps C::data() const, returning a pointer to const data.
+template <typename T, typename C>
+using ContainerData = Wrapper<Data, DataChecker<const T>, const C>;
+
+// Wraps a method returning a pointer to mutable data. Prefers data() over
+// mutable_data(), and handles strings when T==char. If data() returns a pointer
+// to mutable data, it is most likely overloaded, but may also be a single
+// method 'T* C::data() const' in a non-STL-compliant container.
+template <typename T, typename C>
+using ContainerMutableData =
+ FirstWithGet<Wrapper<Data, DataChecker<T>, C>,
+ FirstWithGet<Wrapper<MutableData, DataChecker<T>, C>,
+ Wrapper<MutableStringData, DataChecker<T>, C>>>;
+
+// Wraps C::size() const.
+template <typename C>
+using ContainerSize = Wrapper<Size, SizeChecker, const C>;
+
+// Implementation class for ArraySlice and MutableArraySlice. In the case of
+// ArraySlice, T will be a const type; for MutableArraySlice, T will be a
+// mutable type.
+template <typename T>
+class ArraySliceImplBase {
+ public:
+ typedef T* pointer;
+ typedef const T* const_pointer;
+ typedef T& reference;
+ typedef const T& const_reference;
+ typedef pointer iterator;
+ typedef const_pointer const_iterator;
+ typedef std::reverse_iterator<iterator> reverse_iterator;
+ typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
+ typedef size_t size_type;
+ typedef ptrdiff_t difference_type;
+
+ static const size_type npos = -1;
+
+ ArraySliceImplBase(pointer array, size_type length)
+ : ptr_(array), length_(length) {}
+
+ // Substring of another ArraySlice.
+ // pos must be non-negative and <= x.length().
+ // len must be non-negative and will be pinned to at most x.length() - pos.
+ ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len)
+ : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {}
+
+ // Some of the const methods below return pointers and references to mutable
+ // data. This is only the case in this internal class; ArraySlice and
+ // MutableArraySlice provide deep-constness.
+
+ pointer data() const { return ptr_; }
+ size_type size() const { return length_; }
+
+ void clear() {
+ ptr_ = nullptr;
+ length_ = 0;
+ }
+
+ reference operator[](size_type i) const { return ptr_[i]; }
+ reference at(size_type i) const {
+ DCHECK_LT(i, length_);
+ return ptr_[i];
+ }
+ reference front() const {
+ DCHECK_GT(length_, 0);
+ return ptr_[0];
+ }
+ reference back() const {
+ DCHECK_GT(length_, 0);
+ return ptr_[length_ - 1];
+ }
+
+ void remove_prefix(size_type n) {
+ DCHECK_GE(length_, n);
+ ptr_ += n;
+ length_ -= n;
+ }
+ void remove_suffix(size_type n) {
+ DCHECK_GE(length_, n);
+ length_ -= n;
+ }
+
+ iterator begin() const { return ptr_; }
+ iterator end() const { return ptr_ + length_; }
+ reverse_iterator rbegin() const { return reverse_iterator(end()); }
+ reverse_iterator rend() const { return reverse_iterator(begin()); }
+
+ bool operator==(const ArraySliceImplBase& other) const {
+ if (size() != other.size()) return false;
+ if (data() == other.data()) return true;
+ return std::equal(data(), data() + size(), other.data());
+ }
+ bool operator!=(const ArraySliceImplBase& other) const {
+ return !(*this == other);
+ }
+
+ private:
+ pointer ptr_;
+ size_type length_;
+};
+
+template <typename T>
+class ArraySliceImpl : public ArraySliceImplBase<const T> {
+ public:
+ using ArraySliceImplBase<const T>::ArraySliceImplBase;
+
+ // Defined iff the data and size accessors for the container C have been
+ // defined.
+ template <typename C>
+ using EnableIfConvertibleFrom =
+ typename std::enable_if<ContainerData<T, C>::HasGet() &&
+ ContainerSize<C>::HasGet()>::type;
+
+ // Constructs from a container when EnableIfConvertibleFrom is
+ // defined. std::addressof handles types with overloaded operator&.
+ template <typename C>
+ explicit ArraySliceImpl(const C& v)
+ : ArraySliceImplBase<const T>(ContainerData<T, C>::Get(std::addressof(v)),
+ ContainerSize<C>::Get(std::addressof(v))) {}
+};
+
+template <typename T>
+class MutableArraySliceImpl : public ArraySliceImplBase<T> {
+ public:
+ using ArraySliceImplBase<T>::ArraySliceImplBase;
+
+ template <typename C>
+ using EnableIfConvertibleFrom =
+ typename std::enable_if<ContainerMutableData<T, C>::HasGet() &&
+ ContainerSize<C>::HasGet()>::type;
+
+ template <typename C>
+ explicit MutableArraySliceImpl(C* v)
+ : ArraySliceImplBase<T>(ContainerMutableData<T, C>::Get(v),
+ ContainerSize<C>::Get(v)) {}
+};
+
+} // namespace array_slice_internal
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_test.cc b/tensorflow/core/lib/gtl/array_slice_test.cc
new file mode 100644
index 0000000000..33ee8fc8dd
--- /dev/null
+++ b/tensorflow/core/lib/gtl/array_slice_test.cc
@@ -0,0 +1,646 @@
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+#include <algorithm>
+#include <array>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace gtl {
+namespace {
+
+typedef ArraySlice<int> IntSlice;
+typedef ArraySlice<char> CharSlice;
+typedef MutableArraySlice<int> MutableIntSlice;
+typedef MutableArraySlice<char> MutableCharSlice;
+typedef std::vector<int> IntVec;
+
+// Append 0..len-1 to *v
+template <typename Vector>
+static void Fill(Vector* v, int len, int offset = 0) {
+ for (int i = 0; i < len; i++) {
+ v->push_back(i + offset);
+ }
+}
+
+static void TestHelper(const IntSlice& vorig, const IntVec& vec) {
+ IntSlice other; // To test the assignment return value.
+ IntSlice v = other = vorig;
+ const int len = vec.size();
+ EXPECT_EQ(v.size(), vec.size());
+
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(v[i], vec[i]);
+ EXPECT_EQ(v.at(i), vec[i]);
+ }
+ EXPECT_EQ(v.begin(), gtl::vector_as_array(&vec));
+
+ int counter = 0;
+ for (IntSlice::iterator it = v.begin(); it != v.end(); ++it) {
+ EXPECT_EQ(counter, *it);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ counter = 0;
+ for (IntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
+ EXPECT_EQ(counter, *it);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ if (len > 0) {
+ EXPECT_EQ(0, v.front());
+ EXPECT_EQ(len - 1, v.back());
+ v.pop_back();
+ EXPECT_EQ(len - 1, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(i, v[i]);
+ }
+ if (len > 1) {
+ v.pop_front();
+ EXPECT_EQ(len - 2, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(i + 1, v[i]);
+ }
+ }
+ }
+}
+
+// The element access test that is applicable both when MutableArraySlice is
+// const and when it's not.
+template <class V>
+void MutableTestHelperTemplated(V v, int* ptr, const int len) {
+ CHECK_EQ(v.size(), len);
+
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(ptr + i, &v[i]);
+ EXPECT_EQ(ptr + i, &v.at(i));
+ }
+ EXPECT_EQ(ptr, v.begin());
+ EXPECT_EQ(ptr + len, v.end());
+ EXPECT_EQ(ptr, v.data());
+
+ int counter = 0;
+ for (MutableIntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
+ EXPECT_EQ(ptr + counter, &*it);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ EXPECT_EQ(len, std::distance(v.rbegin(), v.rend()));
+
+ if (len > 0) {
+ EXPECT_EQ(ptr, &v.front());
+ EXPECT_EQ(ptr + len - 1, &v.back());
+ EXPECT_EQ(ptr + len - 1, &*v.rbegin());
+ EXPECT_EQ(ptr, &*(v.rend() - 1));
+ }
+}
+
+static void MutableTestHelper(const MutableIntSlice& vorig, int* ptr,
+ const int len) {
+ // Test the data accessors both when the MutableArraySlice is declared const,
+ // and when it is not.
+ MutableTestHelperTemplated<const MutableIntSlice&>(vorig, ptr, len);
+ MutableTestHelperTemplated<MutableIntSlice>(vorig, ptr, len);
+
+ MutableIntSlice other; // To test the assignment return value.
+ MutableIntSlice v = other = vorig;
+ EXPECT_EQ(ptr, v.mutable_data());
+
+ int counter = 0;
+ for (MutableIntSlice::iterator it = v.begin(); it != v.end(); ++it) {
+ EXPECT_EQ(ptr + counter, &*it);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ if (len > 0) {
+ // Test that elements are assignable.
+ v[0] = 1;
+ v.front() = 2;
+ v.back() = 5;
+ *v.mutable_data() = 4;
+ std::fill(v.begin(), v.end(), 5);
+ std::fill(v.rbegin(), v.rend(), 6);
+ // Test size-changing methods.
+ v.pop_back();
+ EXPECT_EQ(len - 1, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(ptr + i, &v[i]);
+ }
+ if (len > 1) {
+ v.pop_front();
+ EXPECT_EQ(len - 2, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(ptr + i + 1, &v[i]);
+ }
+ }
+ }
+}
+
+template <typename Vector>
+static void TestImplicitConversion(const IntSlice& v, const Vector& vec) {
+ EXPECT_EQ(v.size(), vec.size());
+ for (size_t i = 0; i < v.size(); i++) {
+ EXPECT_EQ(v[i], vec[i]);
+ }
+}
+
+template <typename Vector>
+static void TestImplicitConversion(const CharSlice& v, const Vector& vec) {
+ TestImplicitConversion(IntVec(v.begin(), v.end()), vec);
+}
+
+static void TestImplicitConversion(const MutableIntSlice& v, const int* data,
+ int size) {
+ EXPECT_EQ(size, v.size());
+ for (size_t i = 0; i < v.size(); i++) {
+ EXPECT_EQ(data + i, &v[i]);
+ }
+}
+
+static void TestImplicitConversion(const MutableCharSlice& v, const char* data,
+ int size) {
+ EXPECT_EQ(size, v.size());
+ for (size_t i = 0; i < v.size(); i++) {
+ EXPECT_EQ(data + i, &v[i]);
+ }
+}
+// A struct supplying the data(), mutable_data() and size() methods, just like
+// e.g. proto2::RepeatedField.
+struct RepeatedField {
+ std::vector<int> storage;
+ const int* data() const { return storage.data(); }
+ int* mutable_data() { return storage.data(); }
+ int size() const { return storage.size(); }
+};
+
+// A struct supplying the data() (both mutable and const versions) and
+// size(). It also supplies mutable_data() but we test that data() is selected
+// instead.
+struct ContainerWithOverloads {
+ std::vector<int> storage;
+ std::vector<int> wrong_storage;
+ const int* data() const { return storage.data(); }
+ int* data() { return storage.data(); }
+ // MutableArraySlice should not call mutable_data(), preferring data()
+ // instead.
+ int* mutable_data() { return wrong_storage.data(); }
+ int size() const { return storage.size(); }
+};
+
+// A struct supplying data() and size() methods.
+struct ContainerWithShallowConstData {
+ std::vector<int> storage;
+ int* data() const { return const_cast<int*>(storage.data()); }
+ int size() const { return storage.size(); }
+};
+
+TEST(IntSlice, Simple) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec;
+ Fill(&vec, len);
+ TestHelper(IntSlice(vec), vec);
+ TestHelper(IntSlice(vec.data(), vec.size()), vec);
+ }
+}
+
+TEST(IntSlice, WithPosAndLen) {
+ IntVec vec;
+ Fill(&vec, 20);
+ for (size_t len = 0; len < vec.size(); len++) {
+ IntVec subvec(vec.begin(), vec.begin() + len);
+ TestImplicitConversion(IntSlice(vec, 0, len), subvec);
+ TestImplicitConversion(IntSlice(IntSlice(vec), 0, len), subvec);
+ }
+ EXPECT_EQ(0, IntSlice(vec, 0, 0).size());
+ EXPECT_EQ(0, IntSlice(IntSlice(vec), 0, 0).size());
+ TestImplicitConversion(IntSlice(vec, 0, IntSlice::npos), vec);
+}
+
+TEST(IntSlice, Clear) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec;
+ Fill(&vec, len);
+ IntSlice v(vec);
+ v.clear();
+ EXPECT_EQ(0, v.size());
+ EXPECT_EQ(v.begin(), v.end());
+ }
+}
+
+TEST(IntSlice, Swap) {
+ for (int l1 = 0; l1 < 20; l1++) {
+ for (int l2 = 0; l2 < 20; l2++) {
+ IntVec avec, bvec;
+ Fill(&avec, l1);
+ Fill(&bvec, l2, 100);
+ IntSlice a(avec), b(bvec);
+ using std::swap;
+ swap(a, b);
+ EXPECT_EQ(l1, b.size());
+ EXPECT_EQ(l2, a.size());
+ for (int i = 0; i < l1; i++) {
+ EXPECT_EQ(i, b[i]);
+ }
+ for (int i = 0; i < l2; i++) {
+ EXPECT_EQ(100 + i, a[i]);
+ }
+ }
+ }
+}
+
+TEST(IntSlice, ImplicitConversion) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec;
+ Fill(&vec, len);
+ IntSlice slice;
+ slice = vec;
+ TestImplicitConversion(vec, vec);
+ TestImplicitConversion(slice, vec);
+ TestImplicitConversion(IntSlice(vec.data(), vec.size()), vec);
+ }
+}
+
+TEST(IntSlice, InlinedVectorConversion) {
+ for (int len = 0; len < 20; len++) {
+ InlinedVector<int, 4> inline_vec;
+ for (int i = 0; i < len; i++) {
+ inline_vec.push_back(i);
+ }
+ IntVec vec;
+ Fill(&vec, len);
+ IntSlice v = inline_vec; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(inline_vec, vec);
+ }
+}
+
+TEST(IntSlice, StaticArrayConversion) {
+ int array[20];
+ IntVec vec;
+ Fill(&vec, TF_ARRAYSIZE(array));
+ std::copy(vec.begin(), vec.end(), array);
+ IntSlice v = array; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(array, vec);
+}
+
+TEST(IntSlice, StdArrayConversion) {
+ std::array<int, 20> array;
+ IntVec vec;
+ Fill(&vec, array.size());
+ std::copy(vec.begin(), vec.end(), array.begin());
+
+ // Check assignment.
+ {
+ IntSlice v = array;
+ static_cast<void>(v);
+ }
+
+ // Check sub-slice initialization.
+ {
+ IntSlice v = {array, 10, 15};
+ static_cast<void>(v);
+ }
+
+ TestImplicitConversion(array, vec);
+}
+
+// Values according to the Fill function.
+static const int test_const_array[] = {0, 1, 2};
+
+TEST(IntSlice, ConstStaticArrayConversion) {
+ IntVec vec;
+ Fill(&vec, TF_ARRAYSIZE(test_const_array));
+ IntSlice v = test_const_array; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(test_const_array, vec);
+}
+
+TEST(IntSlice, RepeatedFieldConversion) {
+ RepeatedField repeated_field;
+ IntVec vec;
+ Fill(&vec, 20);
+ repeated_field.storage = vec;
+ IntSlice v = repeated_field; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(repeated_field, vec);
+}
+
+TEST(IntSlice, ContainerWithOverloadsConversion) {
+ ContainerWithOverloads container;
+ Fill(&container.storage, 20);
+ container.wrong_storage.resize(container.size());
+ IntSlice v = container; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(container, container.storage);
+}
+
+TEST(IntSlice, ContainerWithShallowConstDataConversion) {
+ ContainerWithShallowConstData container;
+ Fill(&container.storage, 20);
+ IntSlice v = container; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(container, container.storage);
+}
+
+TEST(IntSlice, MutableIntSliceConversion) {
+ IntVec vec(20);
+ IntSlice slice = MutableIntSlice(&vec);
+ EXPECT_EQ(vec.size(), slice.size());
+ EXPECT_EQ(vec.data(), slice.data());
+}
+
+TEST(IntSlice, Equality) {
+ IntVec vec1(20);
+ IntVec vec2(20);
+ // These two slices are from different vectors, but have the same
+ // size and have the same elements (right now). They should
+ // compare equal.
+ const IntSlice from1(vec1);
+ const IntSlice from2(vec2);
+ EXPECT_EQ(from1, from1);
+ EXPECT_EQ(from1, from2);
+
+ // This verifies that MutableArraySlices can be compared freely with
+ // ArraySlices.
+ const MutableIntSlice mutable_from1(&vec1);
+ const MutableIntSlice mutable_from2(&vec2);
+ EXPECT_EQ(from1, mutable_from1);
+ EXPECT_EQ(mutable_from1, from1);
+ EXPECT_EQ(mutable_from1, mutable_from2);
+ EXPECT_EQ(mutable_from2, mutable_from1);
+
+ // With a different size, the array slices should not be equal.
+ EXPECT_NE(from1, IntSlice(from1, 0, from1.size() - 1));
+
+ // With different contents, the array slices should not be equal.
+ ++vec2.back();
+ EXPECT_NE(from1, from2);
+}
+
+// Compile-asserts that the argument has the expected type.
+template <typename Expected, typename T>
+void CheckType(const T& value) {
+ testing::StaticAssertTypeEq<Expected, T>();
+}
+
+TEST(IntSlice, ExposesContainerTypesAndConsts) {
+ IntSlice slice;
+ const IntSlice const_slice;
+ CheckType<IntSlice::iterator>(slice.begin());
+ CheckType<IntSlice::const_iterator>(const_slice.end());
+ CheckType<IntSlice::const_reverse_iterator>(const_slice.rbegin());
+ CheckType<IntSlice::reverse_iterator>(slice.rend());
+ testing::StaticAssertTypeEq<int, IntSlice::value_type>();
+ testing::StaticAssertTypeEq<const int*, IntSlice::pointer>();
+ testing::StaticAssertTypeEq<const int&, IntSlice::const_reference>();
+ EXPECT_EQ(static_cast<IntSlice::size_type>(-1), IntSlice::npos);
+}
+
+void TestEmpty(IntSlice slice) { ASSERT_TRUE(slice.empty()); }
+
+void TestRange(IntSlice slice, int from, int to) {
+ ASSERT_EQ(to - from + 1, slice.size());
+ for (size_t i = 0; i < slice.size(); ++i) {
+ EXPECT_EQ(from + i, slice[i]);
+ }
+}
+
+TEST(IntSlice, InitializerListConversion) {
+ TestEmpty({});
+ TestRange({1}, 1, 1);
+ TestRange({10, 11, 12, 13}, 10, 13);
+}
+
+TEST(CharSlice, StringConversion) {
+ IntVec vec;
+ Fill(&vec, 20);
+ string str(vec.begin(), vec.end());
+ CharSlice v = str; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(str, vec);
+}
+
+TEST(IntPtrSlice, ConstConversion) {
+ int one = 1;
+ int two = 2;
+ std::vector<int*> vec;
+ vec.push_back(&one);
+ vec.push_back(&two);
+ ArraySlice<const int*> v = vec;
+ ASSERT_EQ(2, v.size());
+ EXPECT_EQ(&one, v[0]);
+ EXPECT_EQ(&two, v[1]);
+}
+
+TEST(MutableIntSlice, Simple) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec(len);
+ MutableTestHelper(MutableIntSlice(&vec), vec.data(), len);
+ MutableTestHelper(MutableIntSlice(vec.data(), vec.size()), vec.data(), len);
+ }
+}
+
+TEST(MutableIntSlice, WithPosAndLen) {
+ IntVec vec(20);
+ for (size_t len = 0; len < vec.size(); len++) {
+ TestImplicitConversion(MutableIntSlice(&vec, 0, len), vec.data(), len);
+ TestImplicitConversion(MutableIntSlice(MutableIntSlice(&vec), 0, len),
+ vec.data(), len);
+ }
+ EXPECT_EQ(0, MutableIntSlice(&vec, 0, 0).size());
+ EXPECT_EQ(0, MutableIntSlice(MutableIntSlice(&vec), 0, 0).size());
+ TestImplicitConversion(MutableIntSlice(&vec, 0, MutableIntSlice::npos),
+ vec.data(), vec.size());
+}
+
+TEST(MutableIntSlice, Clear) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec(len);
+ MutableIntSlice v(&vec);
+ v.clear();
+ EXPECT_EQ(0, v.size());
+ EXPECT_EQ(v.begin(), v.end());
+ }
+}
+
+TEST(MutableIntSlice, Swap) {
+ for (int l1 = 0; l1 < 20; l1++) {
+ for (int l2 = 0; l2 < 20; l2++) {
+ IntVec avec(l1), bvec(l2);
+ MutableIntSlice a(&avec), b(&bvec);
+ using std::swap;
+ swap(a, b);
+ EXPECT_EQ(l1, b.size());
+ EXPECT_EQ(l2, a.size());
+ for (int i = 0; i < l1; i++) {
+ EXPECT_EQ(&avec[i], &b[i]);
+ }
+ for (int i = 0; i < l2; i++) {
+ EXPECT_EQ(&bvec[i], &a[i]);
+ }
+ }
+ }
+}
+
+TEST(MutableIntSlice, ImplicitConversion) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec(len);
+ MutableIntSlice slice;
+ slice = &vec;
+ TestImplicitConversion(&vec, vec.data(), len);
+ TestImplicitConversion(slice, vec.data(), len);
+ TestImplicitConversion(MutableIntSlice(vec.data(), vec.size()), vec.data(),
+ len);
+ }
+}
+
+TEST(MutableIntSlice, InlinedVectorConversion) {
+ for (int len = 0; len < 20; len++) {
+ InlinedVector<int, 4> inline_vec;
+ for (int i = 0; i < len; i++) {
+ inline_vec.push_back(i);
+ }
+ MutableIntSlice v = &inline_vec; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(&inline_vec, inline_vec.array(), inline_vec.size());
+ }
+}
+
+TEST(MutableIntSlice, StaticArrayConversion) {
+ int array[20];
+ MutableIntSlice v = array; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(array, array, TF_ARRAYSIZE(array));
+}
+
+TEST(MutableIntSlice, StdArrayConversion) {
+ std::array<int, 20> array;
+
+ // Check assignment.
+ {
+ MutableIntSlice v = &array;
+ static_cast<void>(v);
+ }
+
+ // Check sub-slice initialization.
+ {
+ MutableIntSlice v = {&array, 10, 15};
+ static_cast<void>(v);
+ }
+
+ TestImplicitConversion(&array, &array[0], array.size());
+}
+
+TEST(MutableIntSlice, RepeatedFieldConversion) {
+ RepeatedField repeated_field;
+ Fill(&repeated_field.storage, 20);
+ MutableIntSlice v = &repeated_field; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(&repeated_field, repeated_field.storage.data(),
+ repeated_field.storage.size());
+}
+
+TEST(MutableIntSlice, ContainerWithOverloadsConversion) {
+ ContainerWithOverloads container;
+ Fill(&container.storage, 20);
+ container.wrong_storage.resize(container.size());
+ MutableIntSlice v = &container; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(&container, container.storage.data(),
+ container.storage.size());
+}
+
+TEST(MutableIntSlice, ContainerWithShallowConstDataConversion) {
+ ContainerWithShallowConstData container;
+ Fill(&container.storage, 20);
+ MutableIntSlice v = &container; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(&container, container.storage.data(),
+ container.storage.size());
+}
+
+TEST(MutableIntSlice, TypedefsAndConstants) {
+ testing::StaticAssertTypeEq<int, MutableIntSlice::value_type>();
+ testing::StaticAssertTypeEq<int*, MutableIntSlice::pointer>();
+ testing::StaticAssertTypeEq<const int*, MutableIntSlice::const_pointer>();
+ testing::StaticAssertTypeEq<int&, MutableIntSlice::reference>();
+ testing::StaticAssertTypeEq<const int&, MutableIntSlice::const_reference>();
+
+ EXPECT_EQ(static_cast<MutableIntSlice::size_type>(-1), MutableIntSlice::npos);
+}
+
+TEST(MutableIntSlice, IteratorsAndReferences) {
+ auto accept_pointer = [](int* x) {};
+ auto accept_reference = [](int& x) {};
+ auto accept_iterator = [](MutableIntSlice::iterator x) {};
+ auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
+
+ int a[1];
+ MutableIntSlice s = a;
+
+ accept_pointer(s.data());
+ accept_pointer(s.mutable_data());
+ accept_iterator(s.begin());
+ accept_iterator(s.end());
+ accept_reverse_iterator(s.rbegin());
+ accept_reverse_iterator(s.rend());
+
+ accept_reference(s[0]);
+ accept_reference(s.at(0));
+ accept_reference(s.front());
+ accept_reference(s.back());
+}
+
+TEST(MutableIntSlice, IteratorsAndReferences_Const) {
+ auto accept_pointer = [](int* x) {};
+ auto accept_reference = [](int& x) {};
+ auto accept_iterator = [](MutableIntSlice::iterator x) {};
+ auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
+
+ int a[1];
+ const MutableIntSlice s = a;
+
+ accept_pointer(s.data());
+ accept_pointer(s.mutable_data());
+ accept_iterator(s.begin());
+ accept_iterator(s.end());
+ accept_reverse_iterator(s.rbegin());
+ accept_reverse_iterator(s.rend());
+
+ accept_reference(s[0]);
+ accept_reference(s.at(0));
+ accept_reference(s.front());
+ accept_reference(s.back());
+}
+
+bool TestMutableOverload(MutableIntSlice slice) { return false; }
+
+bool TestMutableOverload(MutableCharSlice slice) { return true; }
+
+TEST(MutableCharSlice, StringConversion) {
+ for (int len = 0; len < 20; len++) {
+ string str(len, '\0');
+ MutableCharSlice v = &str; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(v, str.data(), str.size());
+ }
+ // Verify that only the correct overload is feasible. Note that this would
+ // fail if the string ctor was declared simply as MutableArraySlice(string*),
+ // since in that case both overloads would be feasible.
+ string str;
+ EXPECT_TRUE(TestMutableOverload(&str));
+}
+
+} // namespace
+} // namespace gtl
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/edit_distance.h b/tensorflow/core/lib/gtl/edit_distance.h
new file mode 100644
index 0000000000..82b6c2299f
--- /dev/null
+++ b/tensorflow/core/lib/gtl/edit_distance.h
@@ -0,0 +1,82 @@
+#ifndef TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_
+#define TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+namespace gtl {
+
+// Calculate the Levenshtein Edit Distance between two contiguous
+// sequences, s and t, of type T.
+//
+// The Levenshtein distance is a symmetric distance defined as the
+// smallest number of insertions, deletions, and substitutions
+// required to convert sequence s to t (and vice versa).
+// Note, this distance does not consider transpositions.
+//
+// For more details and a reference implementation, see:
+// https://en.wikipedia.org/wiki/Levenshtein_distance
+//
+// This implementation has time complexity O(|s|*|t|)
+// and space complexity O(min(|s|, |t|)), where
+// |x| := x.size()
+//
+// A simple call to LevenshteinDistance looks like:
+//
+// int64 dist = LevenshteinDistance("hi", "bye", std::equal_to<char>());
+//
+template <typename T, typename Cmp>
+inline int64 LevenshteinDistance(const gtl::ArraySlice<T>& s,
+ const gtl::ArraySlice<T>& t, const Cmp& cmp) {
+ const int64 s_size = s.size();
+ const int64 t_size = t.size();
+
+ if (s_size == 0) return t_size;
+ if (t_size == 0) return s_size;
+ if (s == t) return 0;
+ if (t_size > s_size) return LevenshteinDistance(t, s, cmp);
+
+ // Create work vectors
+ gtl::InlinedVector<int64, 32> scratch0(t_size + 1);
+ gtl::InlinedVector<int64, 32> scratch1(t_size + 1);
+
+ int64* previous = scratch0.data();
+ int64* current = scratch1.data();
+
+ // Initialize previous row of distances
+ std::iota(scratch0.begin(), scratch0.end(), 0);
+
+ for (int64 i = 0; i < s_size; ++i) {
+ // Swap current and previous rows for next iteration
+ std::swap(previous, current);
+
+ // Calculate current row distances from previous row
+ current[0] = i + 1;
+
+ // Fill in the rest of the row
+ for (int64 j = 0; j < t_size; ++j) {
+ const int64 cost = cmp(s[i], t[j]) ? 0 : 1;
+ current[j + 1] =
+ std::min(current[j] + 1, // deletion cost
+ std::min(previous[j + 1] + 1, // insertion cost
+ previous[j] + cost)); // substitution cost
+ }
+ }
+
+ return current[t_size];
+}
+
+template <typename Container1, typename Container2, typename Cmp>
+inline int64 LevenshteinDistance(const Container1& s, const Container2& t,
+ const Cmp& cmp) {
+ return LevenshteinDistance(
+ gtl::ArraySlice<typename Container1::value_type>(s.data(), s.size()),
+ gtl::ArraySlice<typename Container1::value_type>(t.data(), t.size()),
+ cmp);
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_
diff --git a/tensorflow/core/lib/gtl/edit_distance_test.cc b/tensorflow/core/lib/gtl/edit_distance_test.cc
new file mode 100644
index 0000000000..0526ee0a05
--- /dev/null
+++ b/tensorflow/core/lib/gtl/edit_distance_test.cc
@@ -0,0 +1,125 @@
+#include "tensorflow/core/lib/gtl/edit_distance.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace gtl {
+namespace {
+
+class LevenshteinDistanceTest : public ::testing::Test {
+ protected:
+ std::vector<char> empty_;
+ std::string s1_;
+ std::string s1234_;
+ std::string s567_;
+ std::string kilo_;
+ std::string kilogram_;
+ std::string mother_;
+ std::string grandmother_;
+ std::string lower_;
+ std::string upper_;
+
+ void SetUp() override {
+ s1_ = "1";
+ s1234_ = "1234";
+ s567_ = "567";
+ kilo_ = "kilo";
+ kilogram_ = "kilogram";
+ mother_ = "mother";
+ grandmother_ = "grandmother";
+ lower_ = "lower case";
+ upper_ = "UPPER case";
+ }
+};
+
+TEST_F(LevenshteinDistanceTest, BothEmpty) {
+ ASSERT_EQ(LevenshteinDistance(empty_, empty_, std::equal_to<char>()), 0);
+}
+
+TEST_F(LevenshteinDistanceTest, OneEmpty) {
+ ASSERT_EQ(LevenshteinDistance(s1234_, empty_, std::equal_to<char>()), 4);
+ ASSERT_EQ(LevenshteinDistance(empty_, s567_, std::equal_to<char>()), 3);
+}
+
+TEST_F(LevenshteinDistanceTest, SingleElement) {
+ ASSERT_EQ(LevenshteinDistance(s1234_, s1_, std::equal_to<char>()), 3);
+ ASSERT_EQ(LevenshteinDistance(s1_, s1234_, std::equal_to<char>()), 3);
+}
+
+TEST_F(LevenshteinDistanceTest, Prefix) {
+ ASSERT_EQ(LevenshteinDistance(kilo_, kilogram_, std::equal_to<char>()), 4);
+ ASSERT_EQ(LevenshteinDistance(kilogram_, kilo_, std::equal_to<char>()), 4);
+}
+
+TEST_F(LevenshteinDistanceTest, Suffix) {
+ ASSERT_EQ(LevenshteinDistance(mother_, grandmother_, std::equal_to<char>()),
+ 5);
+ ASSERT_EQ(LevenshteinDistance(grandmother_, mother_, std::equal_to<char>()),
+ 5);
+}
+
+TEST_F(LevenshteinDistanceTest, DifferentComparisons) {
+ ASSERT_EQ(LevenshteinDistance(lower_, upper_, std::equal_to<char>()), 5);
+ ASSERT_EQ(LevenshteinDistance(upper_, lower_, std::equal_to<char>()), 5);
+ ASSERT_EQ(
+ LevenshteinDistance(gtl::ArraySlice<char>(lower_.data(), lower_.size()),
+ gtl::ArraySlice<char>(upper_.data(), upper_.size()),
+ std::equal_to<char>()),
+ 5);
+ auto no_case_cmp = [](char c1, char c2) {
+ return std::tolower(c1) == std::tolower(c2);
+ };
+ ASSERT_EQ(LevenshteinDistance(lower_, upper_, no_case_cmp), 3);
+ ASSERT_EQ(LevenshteinDistance(upper_, lower_, no_case_cmp), 3);
+}
+
+TEST_F(LevenshteinDistanceTest, Vectors) {
+ ASSERT_EQ(
+ LevenshteinDistance(std::string("algorithm"), std::string("altruistic"),
+ std::equal_to<char>()),
+ 6);
+}
+
+static void BM_EditDistanceHelper(int n, int len, bool completely_different) {
+ string a =
+ "The quick brown fox jumped over the lazy dog and on and on and on"
+ " Every good boy deserves fudge. In fact, this is a very long sentence "
+ " w/many bytes..";
+ while (a.size() < static_cast<size_t>(len)) {
+ a = a + a;
+ }
+ string b = a;
+ if (completely_different) {
+ for (size_t i = 0; i < b.size(); i++) {
+ b[i]++;
+ }
+ }
+ while (n-- > 0) {
+ LevenshteinDistance(gtl::ArraySlice<char>(a.data(), len),
+ gtl::ArraySlice<char>(b.data(), len),
+ std::equal_to<char>());
+ }
+}
+
+static void BM_EditDistanceSame(int n, int len) {
+ BM_EditDistanceHelper(n, len, false);
+}
+static void BM_EditDistanceDiff(int n, int len) {
+ BM_EditDistanceHelper(n, len, true);
+}
+
+BENCHMARK(BM_EditDistanceSame)->Arg(5);
+BENCHMARK(BM_EditDistanceSame)->Arg(50);
+BENCHMARK(BM_EditDistanceSame)->Arg(200);
+BENCHMARK(BM_EditDistanceSame)->Arg(1000);
+BENCHMARK(BM_EditDistanceDiff)->Arg(5);
+BENCHMARK(BM_EditDistanceDiff)->Arg(50);
+BENCHMARK(BM_EditDistanceDiff)->Arg(200);
+BENCHMARK(BM_EditDistanceDiff)->Arg(1000);
+
+} // namespace
+} // namespace gtl
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h
new file mode 100644
index 0000000000..c23075129c
--- /dev/null
+++ b/tensorflow/core/lib/gtl/inlined_vector.h
@@ -0,0 +1,839 @@
+// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage
+// for sequences of length <= N are provided inline without requiring
+// any heap allocation. Typically N is very small (e.g., 4) so that
+// sequences that are expected to be short do not require allocations.
+//
+// Only some of the std::vector<> operations are currently implemented.
+// Other operations may be added as needed to facilitate migrating
+// code that uses std::vector<> to InlinedVector<>.
+//
+// NOTE: If you want an inlined version to replace use of a
+// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS>
+// in util/bitmap/inlined_bitvector.h
+//
+// TODO(billydonahue): change size_t to size_type where appropriate.
+
+#ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
+#define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
+
+#include <stddef.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/types.h>
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <type_traits>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/manual_constructor.h"
+
+#include <initializer_list> // NOLINT(build/include_order)
+
+namespace tensorflow {
+namespace gtl {
+
+template <typename T, int N, typename A = std::allocator<T> >
+class InlinedVector {
+ public:
+ typedef A allocator_type;
+ typedef typename allocator_type::value_type value_type;
+ typedef typename allocator_type::pointer pointer;
+ typedef typename allocator_type::const_pointer const_pointer;
+ typedef typename allocator_type::reference reference;
+ typedef typename allocator_type::const_reference const_reference;
+ typedef typename allocator_type::size_type size_type;
+ typedef typename allocator_type::difference_type difference_type;
+ typedef pointer iterator;
+ typedef const_pointer const_iterator;
+
+ // Create an empty vector
+ InlinedVector();
+ explicit InlinedVector(const allocator_type& alloc);
+
+ // Create a vector with n copies of value_type().
+ explicit InlinedVector(size_t n);
+
+ // Create a vector with n copies of elem
+ InlinedVector(size_t n, const value_type& elem,
+ const allocator_type& alloc = allocator_type());
+
+ // Create and initialize with the elements [range_start .. range_end).
+ // The unused enable_if argument restricts this constructor so that it is
+ // elided when value_type is an integral type. This prevents ambiguous
+ // interpretation between a call to this constructor with two integral
+ // arguments and a call to the preceding (n, elem) constructor.
+ template <typename InputIterator>
+ InlinedVector(
+ InputIterator range_start, InputIterator range_end,
+ const allocator_type& alloc = allocator_type(),
+ typename std::enable_if<!std::is_integral<InputIterator>::value>::type* =
+ NULL)
+ : allocator_and_tag_(alloc) {
+ AppendRange(range_start, range_end);
+ }
+
+ InlinedVector(std::initializer_list<value_type> init,
+ const allocator_type& alloc = allocator_type())
+ : allocator_and_tag_(alloc) {
+ AppendRange(init.begin(), init.end());
+ }
+
+ InlinedVector(const InlinedVector& v);
+
+ ~InlinedVector() { clear(); }
+
+ InlinedVector& operator=(const InlinedVector& v) {
+ // Optimized to avoid reallocation.
+ // Prefer reassignment to copy construction for elements.
+ if (size() < v.size()) { // grow
+ reserve(v.size());
+ std::copy(v.begin(), v.begin() + size(), begin());
+ std::copy(v.begin() + size(), v.end(), std::back_inserter(*this));
+ } else { // maybe shrink
+ erase(begin() + v.size(), end());
+ std::copy(v.begin(), v.end(), begin());
+ }
+ return *this;
+ }
+
+ size_t size() const {
+ return allocated() ? allocation().size() : tag().size();
+ }
+
+ bool empty() const { return (size() == 0); }
+
+ // Return number of elements that can be stored in vector
+ // without requiring a reallocation of underlying memory
+ size_t capacity() const { return allocated() ? allocation().capacity() : N; }
+
+ // Return a pointer to the underlying array.
+ // Only result[0,size()-1] are defined.
+ const_pointer data() const {
+ return allocated() ? allocated_space() : inlined_space();
+ }
+ pointer data() { return allocated() ? allocated_space() : inlined_space(); }
+
+ // An older name for the more standard-friendly .data().
+ const_pointer array() const { return data(); }
+ pointer mutable_array() { return data(); }
+
+ // Remove all elements
+ void clear() {
+ size_t s = size();
+ if (allocated()) {
+ DestroyAllocated(allocated_space(), allocated_space() + s);
+ allocation().Dealloc(allocator());
+ } else {
+ DestroyInlined(inlined_space(), inlined_space() + s);
+ }
+ tag() = Tag();
+ }
+
+ // Return the ith element
+ // REQUIRES: 0 <= i < size()
+ const value_type& at(size_t i) const {
+ DCHECK_LT(i, size());
+ return array()[i];
+ }
+ const value_type& operator[](size_t i) const {
+ DCHECK_LT(i, size());
+ return array()[i];
+ }
+
+ // Return a non-const reference to the ith element
+ // REQUIRES: 0 <= i < size()
+ value_type& at(size_t i) {
+ DCHECK_LT(i, size());
+ return mutable_array()[i];
+ }
+ value_type& operator[](size_t i) {
+ DCHECK_LT(i, size());
+ return mutable_array()[i];
+ }
+
+ value_type& back() {
+ DCHECK(!empty());
+ return at(size() - 1);
+ }
+
+ const value_type& back() const {
+ DCHECK(!empty());
+ return at(size() - 1);
+ }
+
+ value_type& front() {
+ DCHECK(!empty());
+ return at(0);
+ }
+
+ const value_type& front() const {
+ DCHECK(!empty());
+ return at(0);
+ }
+
+ // Append t to the vector.
+ // Increases size() by one.
+ // Amortized complexity: O(1)
+ // Worst-case complexity: O(size())
+ void push_back(const value_type& t) {
+ size_t s = size();
+ DCHECK_LE(s, capacity());
+ if (s == capacity()) {
+ return GrowAndPushBack(t);
+ }
+ DCHECK_LT(s, capacity());
+
+ if (allocated()) {
+ ConstructAllocated(allocated_space() + s, t);
+ } else {
+ ConstructInlined(inlined_space() + s, t);
+ }
+
+ set_size_internal(s + 1);
+ }
+
+ void pop_back() {
+ DCHECK(!empty());
+ size_t s = size();
+ if (allocated()) {
+ DestroyAllocated(allocated_space() + s - 1, allocated_space() + s);
+ } else {
+ DestroyInlined(inlined_space() + s - 1, inlined_space() + s);
+ }
+ set_size_internal(s - 1);
+ }
+
+ // Resizes the vector to contain "n" elements.
+ // If "n" is smaller than the initial size, extra elements are destroyed.
+ // If "n" is larger than the initial size, enough copies of "elem"
+ // are appended to increase the size to "n". If "elem" is omitted,
+ // new elements are value-initialized.
+ void resize(size_t n);
+ void resize(size_t n, const value_type& elem);
+
+ iterator begin() { return mutable_array(); }
+ const_iterator begin() const { return array(); }
+
+ iterator end() { return mutable_array() + size(); }
+ const_iterator end() const { return array() + size(); }
+
+ iterator insert(iterator pos, const value_type& v);
+
+ iterator erase(iterator pos) {
+ DCHECK_LT(pos, end());
+ DCHECK_GE(pos, begin());
+ std::copy(pos + 1, end(), pos);
+ pop_back();
+ return pos;
+ }
+
+ iterator erase(iterator first, iterator last);
+
+ // Enlarges the underlying representation so it can hold at least
+ // "n" elements without reallocation.
+ // Does not change size() or the actual contents of the vector.
+ void reserve(size_t n) {
+ if (n > capacity()) {
+ // Make room for new elements
+ EnlargeBy(n - size());
+ }
+ }
+
+ // Swap the contents of *this with other.
+ // REQUIRES: value_type is swappable and copyable.
+ void swap(InlinedVector& other);
+
+ allocator_type get_allocator() const { return allocator(); }
+
+ private:
+ struct AllocatorTraits {
+ typedef typename allocator_type::value_type value_type;
+ typedef typename allocator_type::pointer pointer;
+ typedef typename allocator_type::size_type size_type;
+
+ static void construct(allocator_type& a, // NOLINT(runtime/references)
+ pointer p) {
+ // Tricky: do we support non-copyable types, or support allocators
+ // that do special things with construct()? Non-copyable types are
+ // needed today, so they are more important. When we sort out the
+ // Android NDK C++11 problem, we will be able to use the proper
+ // std::allocator_traits<A>::construct(p, ...).
+ //
+ // a.construct(p, value_type());
+ new (p) value_type();
+ }
+ static void construct(allocator_type& a, // NOLINT(runtime/references)
+ pointer p, const value_type& t) {
+ a.construct(p, t);
+ }
+ static void destroy(allocator_type& a, // NOLINT(runtime/references)
+ pointer p) {
+ a.destroy(p);
+ }
+ static pointer allocate(allocator_type& a, // NOLINT(runtime/references)
+ size_type n) {
+ return a.allocate(n);
+ }
+ static void deallocate(allocator_type& a, // NOLINT(runtime/references)
+ pointer p, size_type n) {
+ a.deallocate(p, n);
+ }
+ };
+
+ // If the vector is inlined, holds the size of the vector.
+ // If the vector is allocated, holds the special value kAllocated,
+ // and the size is stored in the vector's Allocation.
+ class Tag {
+ public:
+ Tag() : size_(0) {}
+ size_t size() const { return size_; }
+ void set_size(size_t n) { size_ = n; }
+ bool allocated() const { return size_ == kAllocated; }
+ void set_allocated() { size_ = kAllocated; }
+
+ private:
+ static const size_t kAllocated = -1;
+ size_t size_;
+ };
+
+ // Derives from allocator_type to use the empty base class optimization.
+ // If the allocator_type is stateless, we can 'store'
+ // our instance of it for free.
+ class AllocatorAndTag : private allocator_type {
+ public:
+ explicit AllocatorAndTag(const allocator_type& a, Tag t = Tag())
+ : allocator_type(a), tag_(t) {}
+ Tag& tag() { return tag_; }
+ const Tag& tag() const { return tag_; }
+ allocator_type& allocator() { return *this; }
+ const allocator_type& allocator() const { return *this; }
+
+ private:
+ Tag tag_;
+ };
+
+ class Allocation {
+ public:
+ Allocation(allocator_type& a, // NOLINT(runtime/references)
+ size_t capacity)
+ : size_(0),
+ capacity_(capacity),
+ buffer_(AllocatorTraits::allocate(a, capacity_)) {}
+
+ void Dealloc(allocator_type& a) { // NOLINT(runtime/references)
+ AllocatorTraits::deallocate(a, buffer(), capacity());
+ }
+
+ size_t size() const { return size_; }
+ void set_size(size_t s) { size_ = s; }
+ size_t capacity() const { return capacity_; }
+ const value_type* buffer() const { return buffer_; }
+ value_type* buffer() { return buffer_; }
+
+ private:
+ size_t size_;
+ size_t capacity_;
+ value_type* buffer_;
+ };
+
+ const Tag& tag() const { return allocator_and_tag_.tag(); }
+ Tag& tag() { return allocator_and_tag_.tag(); }
+
+ Allocation& allocation() { return *rep_.allocation_storage.allocation.get(); }
+ const Allocation& allocation() const {
+ return *rep_.allocation_storage.allocation.get();
+ }
+ void init_allocation(const Allocation& allocation) {
+ rep_.allocation_storage.allocation.Init(allocation);
+ }
+
+ value_type* inlined_space() { return rep_.inlined_storage.inlined[0].get(); }
+ const value_type* inlined_space() const {
+ return rep_.inlined_storage.inlined[0].get();
+ }
+
+ value_type* allocated_space() { return allocation().buffer(); }
+ const value_type* allocated_space() const { return allocation().buffer(); }
+
+ const allocator_type& allocator() const {
+ return allocator_and_tag_.allocator();
+ }
+ allocator_type& allocator() { return allocator_and_tag_.allocator(); }
+
+ bool allocated() const { return tag().allocated(); }
+ void set_allocated() { return tag().set_allocated(); }
+
+ void set_size_internal(size_t n) {
+ if (allocated()) {
+ allocation().set_size(n);
+ } else {
+ tag().set_size(n);
+ }
+ }
+
+ // Enlarge the underlying representation so we can store size_ + delta elems.
+ // The size is not changed, and any newly added memory is not initialized.
+ void EnlargeBy(size_t delta);
+
+ void ResetAllocation(Allocation new_allocation) {
+ if (allocated()) {
+ DestroyAllocated(allocated_space(), allocated_space() + size());
+ DCHECK_EQ(begin(), allocated_space());
+ allocation().Dealloc(allocator());
+ allocation() = new_allocation;
+ } else {
+ DestroyInlined(inlined_space(), inlined_space() + size());
+ init_allocation(new_allocation); // bug: only init once
+ set_allocated();
+ }
+ }
+
+ void GrowAndPushBack(const value_type& t) {
+ DCHECK_EQ(size(), capacity());
+ const size_t s = size();
+
+ Allocation new_allocation(allocator(), 2 * capacity());
+ new_allocation.set_size(s + 1);
+
+ UninitializedCopyAllocated(array(), array() + s, new_allocation.buffer());
+ ConstructAllocated(new_allocation.buffer() + s, t);
+
+ ResetAllocation(new_allocation);
+ }
+
+ void InitAssign(size_t n);
+ void InitAssign(size_t n, const value_type& t);
+
+ void ConstructInlined(pointer p) { new (p) value_type(); }
+
+ void ConstructInlined(pointer p, const value_type& t) {
+ new (p) value_type(t);
+ }
+
+ void ConstructAllocated(pointer p) {
+ AllocatorTraits::construct(allocator(), p);
+ }
+ void ConstructAllocated(pointer p, const value_type& t) {
+ AllocatorTraits::construct(allocator(), p, t);
+ }
+
+ template <typename Iter>
+ void UninitializedCopyInlined(Iter src, Iter src_last, value_type* dst) {
+ std::uninitialized_copy(src, src_last, dst);
+ }
+
+ template <typename Iter>
+ void UninitializedCopyAllocated(Iter src, Iter src_last, value_type* dst) {
+ for (; src != src_last; ++dst, ++src) ConstructAllocated(dst, *src);
+ }
+
+ void UninitializedFillInlined(value_type* dst, value_type* dst_last) {
+ for (; dst != dst_last; ++dst) ConstructInlined(dst);
+ }
+ void UninitializedFillInlined(value_type* dst, value_type* dst_last,
+ const value_type& t) {
+ std::uninitialized_fill(dst, dst_last, t);
+ }
+
+ void UninitializedFillAllocated(value_type* dst, value_type* dst_last) {
+ for (; dst != dst_last; ++dst) ConstructAllocated(dst);
+ }
+ void UninitializedFillAllocated(value_type* dst, value_type* dst_last,
+ const value_type& t) {
+ for (; dst != dst_last; ++dst) ConstructAllocated(dst, t);
+ }
+
+ // Destroy [ptr, ptr_last) in place.
+ void DestroyInlined(value_type* ptr, value_type* ptr_last);
+ void DestroyAllocated(value_type* ptr, value_type* ptr_last);
+
+ template <typename Iter>
+ void AppendRange(Iter first, Iter last, std::input_iterator_tag);
+
+ // Faster path for forward iterators.
+ template <typename Iter>
+ void AppendRange(Iter first, Iter last, std::forward_iterator_tag);
+
+ template <typename Iter>
+ void AppendRange(Iter first, Iter last);
+
+ AllocatorAndTag allocator_and_tag_;
+
+ // Either the inlined or allocated representation
+ union Rep {
+ // Use struct to perform indirection that solves a bizarre compilation
+ // error on Visual Studio (all known versions).
+ struct {
+ tensorflow::ManualConstructor<value_type> inlined[N];
+ } inlined_storage;
+ struct {
+ tensorflow::ManualConstructor<Allocation> allocation;
+ } allocation_storage;
+ } rep_;
+};
+
+template <typename T, int N, typename A>
+const size_t InlinedVector<T, N, A>::Tag::kAllocated;
+
+template <typename T, int N, typename A>
+inline void swap(InlinedVector<T, N, A>& a, InlinedVector<T, N, A>& b) {
+ a.swap(b);
+}
+
+template <typename T, int N, typename A>
+inline bool operator==(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin());
+}
+
+template <typename T, int N, typename A>
+inline bool operator!=(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return !(a == b);
+}
+
+template <typename T, int N, typename A>
+inline bool operator<(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
+}
+
+template <typename T, int N, typename A>
+inline bool operator>(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return b < a;
+}
+
+template <typename T, int N, typename A>
+inline bool operator<=(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return !(b < a);
+}
+
+template <typename T, int N, typename A>
+inline bool operator>=(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return !(a < b);
+}
+
+// ========================================
+// Implementation
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector()
+ : allocator_and_tag_(allocator_type()) {}
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector(const allocator_type& alloc)
+ : allocator_and_tag_(alloc) {}
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector(size_t n)
+ : allocator_and_tag_(allocator_type()) {
+ InitAssign(n);
+}
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector(size_t n, const value_type& elem,
+ const allocator_type& alloc)
+ : allocator_and_tag_(alloc) {
+ InitAssign(n, elem);
+}
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector(const InlinedVector& v)
+ : allocator_and_tag_(v.allocator()) {
+ reserve(v.size());
+ if (allocated()) {
+ UninitializedCopyAllocated(v.begin(), v.end(), allocated_space());
+ } else {
+ UninitializedCopyInlined(v.begin(), v.end(), inlined_space());
+ }
+ set_size_internal(v.size());
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::InitAssign(size_t n, const value_type& t) {
+ if (n > static_cast<size_t>(N)) {
+ Allocation new_allocation(allocator(), n);
+ init_allocation(new_allocation);
+ set_allocated();
+ UninitializedFillAllocated(allocated_space(), allocated_space() + n, t);
+ } else {
+ UninitializedFillInlined(inlined_space(), inlined_space() + n, t);
+ }
+ set_size_internal(n);
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::InitAssign(size_t n) {
+ if (n > static_cast<size_t>(N)) {
+ Allocation new_allocation(allocator(), n);
+ init_allocation(new_allocation);
+ set_allocated();
+ UninitializedFillAllocated(allocated_space(), allocated_space() + n);
+ } else {
+ UninitializedFillInlined(inlined_space(), inlined_space() + n);
+ }
+ set_size_internal(n);
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::resize(size_t n) {
+ size_t s = size();
+ if (n < s) {
+ erase(begin() + n, end());
+ return;
+ }
+ reserve(n);
+ DCHECK_GE(capacity(), n);
+
+ // Fill new space with elements constructed in-place.
+ if (allocated()) {
+ UninitializedFillAllocated(allocated_space() + s, allocated_space() + n);
+ } else {
+ UninitializedFillInlined(inlined_space() + s, inlined_space() + n);
+ }
+ set_size_internal(n);
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::resize(size_t n, const value_type& elem) {
+ size_t s = size();
+ if (n < s) {
+ erase(begin() + n, end());
+ return;
+ }
+ reserve(n);
+ DCHECK_GE(capacity(), n);
+
+ // Fill new space with copies of 'elem'.
+ if (allocated()) {
+ UninitializedFillAllocated(allocated_space() + s, allocated_space() + n,
+ elem);
+ } else {
+ UninitializedFillInlined(inlined_space() + s, inlined_space() + n, elem);
+ }
+ set_size_internal(n);
+}
+
+template <typename T, int N, typename A>
+typename InlinedVector<T, N, A>::iterator InlinedVector<T, N, A>::insert(
+ iterator pos, const value_type& v) {
+ DCHECK_GE(pos, begin());
+ DCHECK_LE(pos, end());
+ if (pos == end()) {
+ push_back(v);
+ return end() - 1;
+ }
+ size_t s = size();
+ size_t idx = std::distance(begin(), pos);
+ if (s == capacity()) {
+ EnlargeBy(1);
+ }
+ CHECK_LT(s, capacity());
+ pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator.
+
+ if (allocated()) {
+ ConstructAllocated(allocated_space() + s, *(allocated_space() + s - 1));
+ std::copy_backward(pos, allocated_space() + s - 1, allocated_space() + s);
+ } else {
+ ConstructInlined(inlined_space() + s, *(inlined_space() + s - 1));
+ std::copy_backward(pos, inlined_space() + s - 1, inlined_space() + s);
+ }
+
+ *pos = v;
+
+ set_size_internal(s + 1);
+ return pos;
+}
+
+template <typename T, int N, typename A>
+typename InlinedVector<T, N, A>::iterator InlinedVector<T, N, A>::erase(
+ iterator first, iterator last) {
+ DCHECK_LE(begin(), first);
+ DCHECK_LE(first, last);
+ DCHECK_LE(last, end());
+
+ size_t s = size();
+ ptrdiff_t erase_gap = std::distance(first, last);
+
+ if (allocated()) {
+ std::copy(last, allocated_space() + s, first);
+ DestroyAllocated(allocated_space() + s - erase_gap, allocated_space() + s);
+ } else {
+ std::copy(last, inlined_space() + s, first);
+ DestroyInlined(inlined_space() + s - erase_gap, inlined_space() + s);
+ }
+
+ set_size_internal(size() - erase_gap);
+
+ return first;
+}
+
+template <typename T, int N, typename A>
+void InlinedVector<T, N, A>::swap(InlinedVector& other) {
+ using std::swap; // Augment ADL with std::swap.
+ if (&other == this) {
+ return;
+ }
+ if (allocated() && other.allocated()) {
+ // Both out of line, so just swap the tag, allocation, and allocator.
+ swap(tag(), other.tag());
+ swap(allocation(), other.allocation());
+ swap(allocator(), other.allocator());
+ return;
+ }
+ if (!allocated() && !other.allocated()) {
+ // Both inlined: swap up to smaller size, then move remaining elements.
+ InlinedVector* a = this;
+ InlinedVector* b = &other;
+ if (size() < other.size()) {
+ swap(a, b);
+ }
+
+ const size_t a_size = a->size();
+ const size_t b_size = b->size();
+ DCHECK_GE(a_size, b_size);
+ // 'a' is larger. Swap the elements up to the smaller array size.
+ std::swap_ranges(a->inlined_space(), a->inlined_space() + b_size,
+ b->inlined_space());
+
+ // Move the remaining elements: A[b_size,a_size) -> B[b_size,a_size)
+ b->UninitializedCopyInlined(a->inlined_space() + b_size,
+ a->inlined_space() + a_size,
+ b->inlined_space() + b_size);
+ a->DestroyInlined(a->inlined_space() + b_size, a->inlined_space() + a_size);
+
+ swap(a->tag(), b->tag());
+ swap(a->allocator(), b->allocator());
+ DCHECK_EQ(b->size(), a_size);
+ DCHECK_EQ(a->size(), b_size);
+ return;
+ }
+ // One is out of line, one is inline.
+ // We first move the elements from the inlined vector into the
+ // inlined space in the other vector. We then put the other vector's
+ // pointer/capacity into the originally inlined vector and swap
+ // the tags.
+ InlinedVector* a = this;
+ InlinedVector* b = &other;
+ if (a->allocated()) {
+ swap(a, b);
+ }
+ DCHECK(!a->allocated());
+ DCHECK(b->allocated());
+ const size_t a_size = a->size();
+ const size_t b_size = b->size();
+
+ // Made Local copies of size(), don't need tag() accurate anymore
+ swap(a->tag(), b->tag());
+
+ // Copy b_allocation out before b's union gets clobbered by inline_space.
+ Allocation b_allocation = b->allocation();
+
+ b->UninitializedCopyInlined(a->inlined_space(), a->inlined_space() + a_size,
+ b->inlined_space());
+ a->DestroyInlined(a->inlined_space(), a->inlined_space() + a_size);
+
+ a->allocation() = b_allocation;
+
+ if (a->allocator() != b->allocator()) {
+ swap(a->allocator(), b->allocator());
+ }
+
+ DCHECK_EQ(b->size(), a_size);
+ DCHECK_EQ(a->size(), b_size);
+}
+
+template <typename T, int N, typename A>
+void InlinedVector<T, N, A>::EnlargeBy(size_t delta) {
+ const size_t s = size();
+ DCHECK_LE(s, capacity());
+
+ size_t target = std::max(static_cast<size_t>(N), s + delta);
+
+ // Compute new capacity by repeatedly doubling current capacity
+ // TODO(psrc): Check and avoid overflow?
+ size_t new_capacity = capacity();
+ while (new_capacity < target) {
+ new_capacity <<= 1;
+ }
+
+ Allocation new_allocation(allocator(), new_capacity);
+ new_allocation.set_size(s);
+
+ UninitializedCopyAllocated(array(), array() + s, new_allocation.buffer());
+
+ ResetAllocation(new_allocation);
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::DestroyInlined(value_type* ptr,
+ value_type* ptr_last) {
+ for (value_type* p = ptr; p != ptr_last; ++p) {
+ p->~value_type();
+ }
+
+// Overwrite unused memory with 0xab so we can catch uninitialized usage.
+// Cast to void* to tell the compiler that we don't care that we might be
+// scribbling on a vtable pointer.
+#ifndef NDEBUG
+ if (ptr != ptr_last) {
+ memset(reinterpret_cast<void*>(ptr), 0xab, sizeof(*ptr) * (ptr_last - ptr));
+ }
+#endif
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::DestroyAllocated(value_type* ptr,
+ value_type* ptr_last) {
+ for (value_type* p = ptr; p != ptr_last; ++p) {
+ AllocatorTraits::destroy(allocator(), p);
+ }
+
+// Overwrite unused memory with 0xab so we can catch uninitialized usage.
+// Cast to void* to tell the compiler that we don't care that we might be
+// scribbling on a vtable pointer.
+#ifndef NDEBUG
+ if (ptr != ptr_last) {
+ memset(reinterpret_cast<void*>(ptr), 0xab, sizeof(*ptr) * (ptr_last - ptr));
+ }
+#endif
+}
+
+template <typename T, int N, typename A>
+template <typename Iter>
+inline void InlinedVector<T, N, A>::AppendRange(Iter first, Iter last,
+ std::input_iterator_tag) {
+ std::copy(first, last, std::back_inserter(*this));
+}
+
+template <typename T, int N, typename A>
+template <typename Iter>
+inline void InlinedVector<T, N, A>::AppendRange(Iter first, Iter last,
+ std::forward_iterator_tag) {
+ typedef typename std::iterator_traits<Iter>::difference_type Length;
+ Length length = std::distance(first, last);
+ reserve(size() + length);
+ if (allocated()) {
+ UninitializedCopyAllocated(first, last, allocated_space() + size());
+ } else {
+ UninitializedCopyInlined(first, last, inlined_space() + size());
+ }
+ set_size_internal(size() + length);
+}
+
+template <typename T, int N, typename A>
+template <typename Iter>
+inline void InlinedVector<T, N, A>::AppendRange(Iter first, Iter last) {
+ typedef typename std::iterator_traits<Iter>::iterator_category IterTag;
+ AppendRange(first, last, IterTag());
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc
new file mode 100644
index 0000000000..ec5fe1eaa8
--- /dev/null
+++ b/tensorflow/core/lib/gtl/inlined_vector_test.cc
@@ -0,0 +1,905 @@
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+#include <list>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+typedef tensorflow::gtl::InlinedVector<int, 8> IntVec;
+
+// A type that counts number of live occurrences of the type
+static int64 instances = 0;
+class Instance {
+ public:
+ int value_;
+ explicit Instance(int x) : value_(x) { instances++; }
+ Instance(const Instance& x) : value_(x.value_) { instances++; }
+ ~Instance() { instances--; }
+
+ friend inline void swap(Instance& a, Instance& b) {
+ using std::swap;
+ swap(a.value_, b.value_);
+ }
+
+ friend std::ostream& operator<<(std::ostream& o, const Instance& v) {
+ return o << "[value:" << v.value_ << "]";
+ }
+};
+
+typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec;
+
+// A simple reference counted class to make sure that the proper elements are
+// destroyed in the erase(begin, end) test.
+class RefCounted {
+ public:
+ RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); }
+
+ RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) {
+ VLOG(5) << "[RefCounted: copy"
+ << " from count @" << v.count_ << "]";
+ Ref();
+ }
+
+ ~RefCounted() {
+ Unref();
+ count_ = NULL;
+ }
+
+ friend void swap(RefCounted& a, RefCounted& b) {
+ using std::swap;
+ swap(a.value_, b.value_);
+ swap(a.count_, b.count_);
+ }
+
+ RefCounted& operator=(RefCounted v) {
+ using std::swap;
+ swap(*this, v);
+ return *this;
+ }
+
+ void Ref() const {
+ CHECK(count_ != NULL);
+ ++(*count_);
+ VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]";
+ }
+
+ void Unref() const {
+ --(*count_);
+ CHECK_GE(*count_, 0);
+ VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]";
+ }
+
+ int count() const { return *count_; }
+
+ friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) {
+ return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]";
+ }
+
+ int value_;
+ int* count_;
+};
+
+typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec;
+
+// A class with a vtable pointer
+class Dynamic {
+ public:
+ virtual ~Dynamic() {}
+
+ friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) {
+ return o << "[Dynamic]";
+ }
+};
+
+typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec;
+
+// Append 0..len-1 to *v
+static void Fill(IntVec* v, int len, int offset = 0) {
+ for (int i = 0; i < len; i++) {
+ v->push_back(i + offset);
+ }
+}
+
+static IntVec Fill(int len, int offset = 0) {
+ IntVec v;
+ Fill(&v, len, offset);
+ return v;
+}
+
+TEST(IntVec, SimpleOps) {
+ for (int len = 0; len < 20; len++) {
+ IntVec v;
+ const IntVec& cv = v; // const alias
+
+ Fill(&v, len);
+ EXPECT_EQ(len, v.size());
+ EXPECT_LE(len, v.capacity());
+
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(i, v[i]);
+ }
+ EXPECT_EQ(v.begin(), v.array());
+ EXPECT_EQ(v.begin(), v.mutable_array());
+
+ EXPECT_EQ(v.begin(), v.data());
+ EXPECT_EQ(cv.begin(), cv.data());
+
+ int counter = 0;
+ for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) {
+ EXPECT_EQ(counter, *iter);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ counter = 0;
+ for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) {
+ EXPECT_EQ(counter, *iter);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ if (len > 0) {
+ EXPECT_EQ(0, v.front());
+ EXPECT_EQ(len - 1, v.back());
+ v.pop_back();
+ EXPECT_EQ(len - 1, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(i, v[i]);
+ }
+ }
+ }
+}
+
+TEST(IntVec, Erase) {
+ for (int len = 1; len < 20; len++) {
+ for (int i = 0; i < len; ++i) {
+ IntVec v;
+ Fill(&v, len);
+ v.erase(v.begin() + i);
+ EXPECT_EQ(len - 1, v.size());
+ for (int j = 0; j < i; ++j) {
+ EXPECT_EQ(j, v[j]);
+ }
+ for (int j = i; j < len - 1; ++j) {
+ EXPECT_EQ(j + 1, v[j]);
+ }
+ }
+ }
+}
+
+// At the end of this test loop, the elements between [erase_begin, erase_end)
+// should have reference counts == 0, and all others elements should have
+// reference counts == 1.
+TEST(RefCountedVec, EraseBeginEnd) {
+ for (int len = 1; len < 20; ++len) {
+ for (int erase_begin = 0; erase_begin < len; ++erase_begin) {
+ for (int erase_end = erase_begin; erase_end <= len; ++erase_end) {
+ std::vector<int> counts(len, 0);
+ RefCountedVec v;
+ for (int i = 0; i < len; ++i) {
+ v.push_back(RefCounted(i, &counts[i]));
+ }
+
+ int erase_len = erase_end - erase_begin;
+
+ v.erase(v.begin() + erase_begin, v.begin() + erase_end);
+
+ EXPECT_EQ(len - erase_len, v.size());
+
+ // Check the elements before the first element erased.
+ for (int i = 0; i < erase_begin; ++i) {
+ EXPECT_EQ(i, v[i].value_);
+ }
+
+ // Check the elements after the first element erased.
+ for (size_t i = erase_begin; i < v.size(); ++i) {
+ EXPECT_EQ(i + erase_len, v[i].value_);
+ }
+
+ // Check that the elements at the beginning are preserved.
+ for (int i = 0; i < erase_begin; ++i) {
+ EXPECT_EQ(1, counts[i]);
+ }
+
+ // Check that the erased elements are destroyed
+ for (int i = erase_begin; i < erase_end; ++i) {
+ EXPECT_EQ(0, counts[i]);
+ }
+
+ // Check that the elements at the end are preserved.
+ for (int i = erase_end; i < len; ++i) {
+ EXPECT_EQ(1, counts[i]);
+ }
+ }
+ }
+ }
+}
+
+struct NoDefaultCtor {
+ explicit NoDefaultCtor(int /* x */) {}
+};
+struct NoCopy {
+ NoCopy() {}
+ NoCopy(const NoCopy& /* x */) = delete;
+};
+struct NoAssign {
+ NoAssign() {}
+ NoAssign& operator=(const NoAssign& /* x */) = delete;
+};
+TEST(InlinedVectorTest, NoDefaultCtor) {
+ tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2));
+ (void)v;
+}
+TEST(InlinedVectorTest, NoCopy) {
+ tensorflow::gtl::InlinedVector<NoCopy, 1> v(10);
+ (void)v;
+}
+TEST(InlinedVectorTest, NoAssign) {
+ tensorflow::gtl::InlinedVector<NoAssign, 1> v(10);
+ (void)v;
+}
+
+TEST(IntVec, Insert) {
+ for (int len = 0; len < 20; len++) {
+ for (int pos = 0; pos <= len; pos++) {
+ IntVec v;
+ Fill(&v, len);
+ v.insert(v.begin() + pos, 9999);
+ EXPECT_EQ(v.size(), len + 1);
+ for (int i = 0; i < pos; i++) {
+ EXPECT_EQ(v[i], i);
+ }
+ EXPECT_EQ(v[pos], 9999);
+ for (size_t i = pos + 1; i < v.size(); i++) {
+ EXPECT_EQ(v[i], i - 1);
+ }
+ }
+ }
+}
+
+TEST(RefCountedVec, InsertConstructorDestructor) {
+ // Make sure the proper construction/destruction happen during insert
+ // operations.
+ for (int len = 0; len < 20; len++) {
+ SCOPED_TRACE(len);
+ for (int pos = 0; pos <= len; pos++) {
+ SCOPED_TRACE(pos);
+ std::vector<int> counts(len, 0);
+ RefCountedVec v;
+ for (int i = 0; i < len; ++i) {
+ SCOPED_TRACE(i);
+ v.push_back(RefCounted(i, &counts[i]));
+ }
+
+ for (auto elem : counts) {
+ EXPECT_EQ(1, elem);
+ }
+
+ int inserted_count = 0;
+ RefCounted insert_element(9999, &inserted_count);
+ EXPECT_EQ(1, inserted_count);
+ v.insert(v.begin() + pos, insert_element);
+ EXPECT_EQ(2, inserted_count);
+ // Check that the elements at the end are preserved.
+ for (auto elem : counts) {
+ EXPECT_EQ(1, elem);
+ }
+ EXPECT_EQ(2, inserted_count);
+ }
+ }
+}
+
+TEST(IntVec, Resize) {
+ for (int len = 0; len < 20; len++) {
+ IntVec v;
+ Fill(&v, len);
+
+ // Try resizing up and down by k elements
+ static const int kResizeElem = 1000000;
+ for (int k = 0; k < 10; k++) {
+ // Enlarging resize
+ v.resize(len + k, kResizeElem);
+ EXPECT_EQ(len + k, v.size());
+ EXPECT_LE(len + k, v.capacity());
+ for (int i = 0; i < len + k; i++) {
+ if (i < len) {
+ EXPECT_EQ(i, v[i]);
+ } else {
+ EXPECT_EQ(kResizeElem, v[i]);
+ }
+ }
+
+ // Shrinking resize
+ v.resize(len, kResizeElem);
+ EXPECT_EQ(len, v.size());
+ EXPECT_LE(len, v.capacity());
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(i, v[i]);
+ }
+ }
+ }
+}
+
+TEST(IntVec, InitWithLength) {
+ for (int len = 0; len < 20; len++) {
+ IntVec v(len, 7);
+ EXPECT_EQ(len, v.size());
+ EXPECT_LE(len, v.capacity());
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(7, v[i]);
+ }
+ }
+}
+
+TEST(IntVec, CopyConstructorAndAssignment) {
+ for (int len = 0; len < 20; len++) {
+ IntVec v;
+ Fill(&v, len);
+ EXPECT_EQ(len, v.size());
+ EXPECT_LE(len, v.capacity());
+
+ IntVec v2(v);
+ EXPECT_EQ(v, v2);
+
+ for (int start_len = 0; start_len < 20; start_len++) {
+ IntVec v3;
+ Fill(&v3, start_len, 99); // Add dummy elements that should go away
+ v3 = v;
+ EXPECT_EQ(v, v3);
+ }
+ }
+}
+
+TEST(OverheadTest, Storage) {
+ // Check for size overhead.
+ // In particular, ensure that std::allocator doesn't cost anything to store.
+ // The union should be absorbing some of the allocation bookkeeping overhead
+ // in the larger vectors, leaving only the size_ field as overhead.
+ using tensorflow::gtl::InlinedVector;
+ EXPECT_EQ(3 * sizeof(int*),
+ sizeof(InlinedVector<int*, 1>) - 1 * sizeof(int*));
+ EXPECT_EQ(2 * sizeof(int*),
+ sizeof(InlinedVector<int*, 2>) - 2 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 3>) - 3 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 4>) - 4 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 5>) - 5 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 6>) - 6 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 7>) - 7 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 8>) - 8 * sizeof(int*));
+}
+
+TEST(IntVec, Clear) {
+ for (int len = 0; len < 20; len++) {
+ SCOPED_TRACE(len);
+ IntVec v;
+ Fill(&v, len);
+ v.clear();
+ EXPECT_EQ(0, v.size());
+ EXPECT_EQ(v.begin(), v.end());
+ }
+}
+
+TEST(IntVec, Reserve) {
+ for (size_t len = 0; len < 20; len++) {
+ IntVec v;
+ Fill(&v, len);
+
+ for (size_t newlen = 0; newlen < 100; newlen++) {
+ const int* start_rep = v.array();
+ v.reserve(newlen);
+ const int* final_rep = v.array();
+ if (newlen <= len) {
+ EXPECT_EQ(start_rep, final_rep);
+ }
+ EXPECT_LE(newlen, v.capacity());
+
+ // Filling up to newlen should not change rep
+ while (v.size() < newlen) {
+ v.push_back(0);
+ }
+ EXPECT_EQ(final_rep, v.array());
+ }
+ }
+}
+
+template <typename T>
+static std::vector<typename T::value_type> Vec(const T& src) {
+ std::vector<typename T::value_type> result;
+ for (const auto& elem : src) {
+ result.push_back(elem);
+ }
+ return result;
+}
+
+TEST(IntVec, SelfRefPushBack) {
+ std::vector<string> std_v;
+ tensorflow::gtl::InlinedVector<string, 4> v;
+ const string s = "A very long string to ensure heap.";
+ std_v.push_back(s);
+ v.push_back(s);
+ for (int i = 0; i < 20; ++i) {
+ EXPECT_EQ(std_v, Vec(v));
+
+ v.push_back(v.back());
+ std_v.push_back(std_v.back());
+ }
+ EXPECT_EQ(std_v, Vec(v));
+}
+
+TEST(IntVec, Swap) {
+ for (int l1 = 0; l1 < 20; l1++) {
+ SCOPED_TRACE(l1);
+ for (int l2 = 0; l2 < 20; l2++) {
+ SCOPED_TRACE(l2);
+ IntVec a = Fill(l1, 0);
+ IntVec b = Fill(l2, 100);
+ {
+ using std::swap;
+ swap(a, b);
+ }
+ EXPECT_EQ(l1, b.size());
+ EXPECT_EQ(l2, a.size());
+ for (int i = 0; i < l1; i++) {
+ SCOPED_TRACE(i);
+ EXPECT_EQ(i, b[i]);
+ }
+ for (int i = 0; i < l2; i++) {
+ SCOPED_TRACE(i);
+ EXPECT_EQ(100 + i, a[i]);
+ }
+ }
+ }
+}
+
+TEST(InstanceVec, Swap) {
+ for (int l1 = 0; l1 < 20; l1++) {
+ for (int l2 = 0; l2 < 20; l2++) {
+ InstanceVec a, b;
+ for (int i = 0; i < l1; i++) a.push_back(Instance(i));
+ for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i));
+ EXPECT_EQ(l1 + l2, instances);
+ {
+ using std::swap;
+ swap(a, b);
+ }
+ EXPECT_EQ(l1 + l2, instances);
+ EXPECT_EQ(l1, b.size());
+ EXPECT_EQ(l2, a.size());
+ for (int i = 0; i < l1; i++) {
+ EXPECT_EQ(i, b[i].value_);
+ }
+ for (int i = 0; i < l2; i++) {
+ EXPECT_EQ(100 + i, a[i].value_);
+ }
+ }
+ }
+}
+
+TEST(IntVec, EqualAndNotEqual) {
+ IntVec a, b;
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a != b);
+
+ a.push_back(3);
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a != b);
+
+ b.push_back(3);
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a != b);
+
+ b.push_back(7);
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a != b);
+
+ a.push_back(6);
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a != b);
+
+ a.clear();
+ b.clear();
+ for (int i = 0; i < 100; i++) {
+ a.push_back(i);
+ b.push_back(i);
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a != b);
+
+ b[i] = b[i] + 1;
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a != b);
+
+ b[i] = b[i] - 1; // Back to before
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a != b);
+ }
+}
+
+TEST(IntVec, RelationalOps) {
+ IntVec a, b;
+ EXPECT_FALSE(a < b);
+ EXPECT_FALSE(b < a);
+ EXPECT_FALSE(a > b);
+ EXPECT_FALSE(b > a);
+ EXPECT_TRUE(a <= b);
+ EXPECT_TRUE(b <= a);
+ EXPECT_TRUE(a >= b);
+ EXPECT_TRUE(b >= a);
+ b.push_back(3);
+ EXPECT_TRUE(a < b);
+ EXPECT_FALSE(b < a);
+ EXPECT_FALSE(a > b);
+ EXPECT_TRUE(b > a);
+ EXPECT_TRUE(a <= b);
+ EXPECT_FALSE(b <= a);
+ EXPECT_FALSE(a >= b);
+ EXPECT_TRUE(b >= a);
+}
+
+TEST(InstanceVec, CountConstructorsDestructors) {
+ const int start = instances;
+ for (int len = 0; len < 20; len++) {
+ InstanceVec v;
+ for (int i = 0; i < len; i++) {
+ v.push_back(Instance(i));
+ }
+ EXPECT_EQ(start + len, instances);
+
+ { // Copy constructor should create 'len' more instances.
+ InstanceVec v_copy(v);
+ EXPECT_EQ(start + len + len, instances);
+ }
+ EXPECT_EQ(start + len, instances);
+
+ // Enlarging resize() must construct some objects
+ v.resize(len + 10, Instance(100));
+ EXPECT_EQ(start + len + 10, instances);
+
+ // Shrinking resize() must destroy some objects
+ v.resize(len, Instance(100));
+ EXPECT_EQ(start + len, instances);
+
+ // reserve() must not increase the number of initialized objects
+ v.reserve(len + 1000);
+ EXPECT_EQ(start + len, instances);
+
+ // pop_back() and erase() must destroy one object
+ if (len > 0) {
+ v.pop_back();
+ EXPECT_EQ(start + len - 1, instances);
+ if (!v.empty()) {
+ v.erase(v.begin());
+ EXPECT_EQ(start + len - 2, instances);
+ }
+ }
+ }
+ EXPECT_EQ(start, instances);
+}
+
+TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) {
+ const int start = instances;
+ for (int len = 0; len < 20; len++) {
+ for (int longorshort = 0; longorshort <= 1; ++longorshort) {
+ InstanceVec longer, shorter;
+ for (int i = 0; i < len; i++) {
+ longer.push_back(Instance(i));
+ shorter.push_back(Instance(i));
+ }
+ longer.push_back(Instance(len));
+ EXPECT_EQ(start + len + len + 1, instances);
+
+ if (longorshort) {
+ shorter = longer;
+ EXPECT_EQ(start + (len + 1) + (len + 1), instances);
+ } else {
+ longer = shorter;
+ EXPECT_EQ(start + len + len, instances);
+ }
+ }
+ }
+ EXPECT_EQ(start, instances);
+}
+
+TEST(RangedConstructor, SimpleType) {
+ std::vector<int> source_v = {4, 5, 6};
+ // First try to fit in inline backing
+ tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end());
+ EXPECT_EQ(3, v.size());
+ EXPECT_EQ(4, v.capacity()); // Indication that we're still on inlined storage
+ EXPECT_EQ(4, v[0]);
+ EXPECT_EQ(5, v[1]);
+ EXPECT_EQ(6, v[2]);
+
+ // Now, force a re-allocate
+ tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(),
+ source_v.end());
+ EXPECT_EQ(3, realloc_v.size());
+ EXPECT_LT(2, realloc_v.capacity());
+ EXPECT_EQ(4, realloc_v[0]);
+ EXPECT_EQ(5, realloc_v[1]);
+ EXPECT_EQ(6, realloc_v[2]);
+}
+
+TEST(RangedConstructor, ComplexType) {
+ // We also use a list here to pass a different flavor of iterator (e.g. not
+ // random-access).
+ std::list<Instance> source_v = {Instance(0)};
+
+ // First try to fit in inline backing
+ tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(),
+ source_v.end());
+ EXPECT_EQ(1, v.size());
+ EXPECT_EQ(1, v.capacity()); // Indication that we're still on inlined storage
+ EXPECT_EQ(0, v[0].value_);
+
+ std::list<Instance> source_v2 = {Instance(0), Instance(1)};
+ // Now, force a re-allocate
+ tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(),
+ source_v2.end());
+ EXPECT_EQ(2, realloc_v.size());
+ EXPECT_LT(1, realloc_v.capacity());
+ EXPECT_EQ(0, realloc_v[0].value_);
+ EXPECT_EQ(1, realloc_v[1].value_);
+}
+
+TEST(RangedConstructor, ElementsAreConstructed) {
+ std::vector<string> source_v = {"cat", "dog"};
+
+ // Force expansion and re-allocation of v. Ensures that when the vector is
+ // expanded that new elements are constructed.
+ tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end());
+ EXPECT_EQ("cat", v[0]);
+ EXPECT_EQ("dog", v[1]);
+}
+
+TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) {
+ auto vec = tensorflow::gtl::InlinedVector<int, 4>{4, 5, 6};
+ EXPECT_EQ(3, vec.size());
+ EXPECT_EQ(4, vec.capacity());
+ EXPECT_EQ(4, vec[0]);
+ EXPECT_EQ(5, vec[1]);
+ EXPECT_EQ(6, vec[2]);
+}
+
+TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) {
+ auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6};
+ EXPECT_EQ(3, vec.size());
+ EXPECT_LE(3, vec.capacity());
+ EXPECT_EQ(4, vec[0]);
+ EXPECT_EQ(5, vec[1]);
+ EXPECT_EQ(6, vec[2]);
+}
+
+TEST(InitializerListConstructor, DisparateTypesInList) {
+ EXPECT_EQ((std::vector<int>{-7, 8}),
+ Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL}));
+
+ EXPECT_EQ(
+ (std::vector<string>{"foo", "bar"}),
+ Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")}));
+}
+
+TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) {
+ auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)};
+ EXPECT_EQ(1, vec.size());
+ EXPECT_EQ(1, vec.capacity());
+ EXPECT_EQ(0, vec[0].value_);
+}
+
+TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) {
+ auto vec =
+ tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)};
+ EXPECT_EQ(2, vec.size());
+ EXPECT_LE(2, vec.capacity());
+ EXPECT_EQ(0, vec[0].value_);
+ EXPECT_EQ(1, vec[1].value_);
+}
+
+TEST(DynamicVec, DynamicVecCompiles) {
+ DynamicVec v;
+ (void)v;
+}
+
+#ifdef INLINED_VECTOR_HAS_ALLOC
+TEST(AllocatorSupportTest, Constructors) {
+ typedef STLCountingAllocator<int> MyAlloc;
+ typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec;
+ const int ia[] = {0, 1, 2, 3, 4, 5, 6, 7};
+ int64 allocated = 0;
+ MyAlloc alloc(&allocated);
+ { AllocVec TF_ATTRIBUTE_UNUSED v; }
+ { AllocVec TF_ATTRIBUTE_UNUSED v(alloc); }
+ { AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + arraysize(ia), alloc); }
+#ifdef LANG_CXX11
+ { AllocVec TF_ATTRIBUTE_UNUSED v({1, 2, 3}, alloc); }
+#endif // LANG_CXX11
+}
+
+TEST(AllocatorSupportTest, CountAllocations) {
+ typedef STLCountingAllocator<int> MyAlloc;
+ typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec;
+ const int ia[] = {0, 1, 2, 3, 4, 5, 6, 7};
+ int64 allocated = 0;
+ MyAlloc alloc(&allocated);
+ {
+ AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + 4, alloc);
+ EXPECT_THAT(allocated, 0);
+ }
+ EXPECT_THAT(allocated, 0);
+ {
+ AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + arraysize(ia), alloc);
+ EXPECT_THAT(allocated, v.size() * sizeof(int));
+ }
+ EXPECT_THAT(allocated, 0);
+}
+
+TEST(AllocatorSupportTest, SwapBothAllocated) {
+ typedef STLCountingAllocator<int> MyAlloc;
+ typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec;
+ int64 allocated1 = 0;
+ int64 allocated2 = 0;
+ {
+ const std::vector<int> ia1 = {0, 1, 2, 3, 4, 5, 6, 7};
+ const std::vector<int> ia2 = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ MyAlloc a1(&allocated1);
+ MyAlloc a2(&allocated2);
+ AllocVec v1(ia1.data(), ia1.data() + ia1.size(), a1);
+ AllocVec v2(ia2.data(), ia2.data() + ia2.size(), a2);
+ EXPECT_LT(v1.capacity(), v2.capacity());
+ EXPECT_THAT(allocated1, v1.capacity() * sizeof(int));
+ EXPECT_THAT(allocated2, v2.capacity() * sizeof(int));
+ v1.swap(v2);
+ EXPECT_EQ(ia2, Vec(v1));
+ EXPECT_EQ(ia1, Vec(v2));
+ EXPECT_THAT(allocated1, v2.capacity() * sizeof(int));
+ EXPECT_THAT(allocated2, v1.capacity() * sizeof(int));
+ }
+ EXPECT_THAT(allocated1, 0);
+ EXPECT_THAT(allocated2, 0);
+}
+
+TEST(AllocatorSupportTest, SwapOneAllocated) {
+ typedef STLCountingAllocator<int> MyAlloc;
+ typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec;
+ int64 allocated1 = 0;
+ int64 allocated2 = 0;
+ {
+ const std::vector<int> ia1 = {0, 1, 2, 3, 4, 5, 6, 7};
+ const std::vector<int> ia2 = {0, 1, 2, 3};
+ MyAlloc a1(&allocated1);
+ MyAlloc a2(&allocated2);
+ AllocVec v1(ia1.data(), ia1.data() + ia1.size(), a1);
+ AllocVec v2(ia2.data(), ia2.data() + ia2.size(), a2);
+ EXPECT_THAT(allocated1, v1.capacity() * sizeof(int));
+ EXPECT_THAT(allocated2, 0);
+ v1.swap(v2);
+ EXPECT_EQ(ia2, Vec(v1));
+ EXPECT_EQ(ia1, Vec(v2));
+ EXPECT_THAT(allocated1, v2.capacity() * sizeof(int));
+ EXPECT_THAT(allocated2, 0);
+ EXPECT_TRUE(v2.get_allocator() == a1);
+ EXPECT_TRUE(v1.get_allocator() == a2);
+ }
+ EXPECT_THAT(allocated1, 0);
+ EXPECT_THAT(allocated2, 0);
+}
+#endif // INLINED_VECTOR_HAS_ALLOC
+
+static void BM_InlinedVectorFill(int iters, int len) {
+ for (int i = 0; i < iters; i++) {
+ IntVec v;
+ for (int j = 0; j < len; j++) {
+ v.push_back(j);
+ }
+ }
+ testing::BytesProcessed((static_cast<int64>(iters) * len) * sizeof(int));
+}
+BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024);
+
+static void BM_InlinedVectorFillRange(int iters, int len) {
+ std::unique_ptr<int[]> ia(new int[len]);
+ for (int j = 0; j < len; j++) {
+ ia[j] = j;
+ }
+ for (int i = 0; i < iters; i++) {
+ IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len);
+ }
+ testing::BytesProcessed((static_cast<int64>(iters) * len) * sizeof(int));
+}
+BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024);
+
+static void BM_StdVectorFill(int iters, int len) {
+ for (int i = 0; i < iters; i++) {
+ std::vector<int> v;
+ for (int j = 0; j < len; j++) {
+ v.push_back(j);
+ }
+ }
+ testing::BytesProcessed((static_cast<int64>(iters) * len) * sizeof(int));
+}
+BENCHMARK(BM_StdVectorFill)->Range(0, 1024);
+
+namespace {
+struct Buffer { // some arbitrary structure for benchmarking.
+ char* base;
+ int length;
+ int capacity;
+ void* user_data;
+};
+} // anonymous namespace
+
+static void BM_InlinedVectorTenAssignments(int iters, int len) {
+ typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec;
+
+ BufferVec src;
+ src.resize(len);
+
+ iters *= 10;
+ BufferVec dst;
+ for (int i = 0; i < iters; i++) {
+ dst = src;
+ }
+}
+BENCHMARK(BM_InlinedVectorTenAssignments)
+ ->Arg(0)
+ ->Arg(1)
+ ->Arg(2)
+ ->Arg(3)
+ ->Arg(4)
+ ->Arg(20);
+
+static void BM_CreateFromInitializerList(int iters) {
+ for (; iters > 0; iters--) {
+ tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3};
+ (void)x[0];
+ }
+}
+BENCHMARK(BM_CreateFromInitializerList);
+
+namespace {
+
+struct LargeSwappable {
+ LargeSwappable() : d_(1024, 17) {}
+ ~LargeSwappable() {}
+ LargeSwappable(const LargeSwappable& o) : d_(o.d_) {}
+
+ friend void swap(LargeSwappable& a, LargeSwappable& b) {
+ using std::swap;
+ swap(a.d_, b.d_);
+ }
+
+ LargeSwappable& operator=(LargeSwappable o) {
+ using std::swap;
+ swap(*this, o);
+ return *this;
+ }
+
+ std::vector<int> d_;
+};
+
+} // namespace
+
+static void BM_LargeSwappableElements(int iters, int len) {
+ typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec;
+ Vec a(len);
+ Vec b;
+ while (--iters >= 0) {
+ using std::swap;
+ swap(a, b);
+ }
+}
+BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/int_type.h b/tensorflow/core/lib/gtl/int_type.h
new file mode 100644
index 0000000000..d3fcb08d38
--- /dev/null
+++ b/tensorflow/core/lib/gtl/int_type.h
@@ -0,0 +1,343 @@
+// #status: LEGACY
+// #category: Miscellaneous
+// #summary: Integral types; prefer util/intops/strong_int.h
+// #bugs: Infrastructure > C++ Library Team > util
+//
+// IntType is a simple template class mechanism for defining "logical"
+// integer-like class types that support many of the same functionalities
+// as native integer types, but which prevent assignment, construction, and
+// other operations from other similar integer-like types. Essentially, the
+// template class IntType<IntTypeName, ValueType> (where ValueType assumes
+// valid scalar types such as int, uint, int32, etc) has the additional
+// property that it cannot be assigned to or constructed from other IntTypes
+// or native integer types of equal or implicitly convertible type.
+//
+// The class is useful for preventing mingling of integer variables with
+// different logical roles or units. Unfortunately, C++ provides relatively
+// good type-safety for user-defined classes but not for integer types. It is
+// essentially up to the user to use nice variable names and comments to prevent
+// accidental mismatches, such as confusing a user-index with a group-index or a
+// time-in-milliseconds with a time-in-seconds. The use of typedefs are limited
+// in that regard as they do not enforce type-safety.
+//
+// USAGE -----------------------------------------------------------------------
+//
+// DEFINE_INT_TYPE(IntTypeName, ValueType);
+//
+// where:
+// IntTypeName: is the desired (unique) name for the "logical" integer type.
+// ValueType: is one of the integral types as defined by base::is_integral
+// (see base/type_traits.h).
+//
+// DISALLOWED OPERATIONS / TYPE-SAFETY ENFORCEMENT -----------------------------
+//
+// Consider these definitions and variable declarations:
+// DEFINE_INT_TYPE(GlobalDocID, int64);
+// DEFINE_INT_TYPE(LocalDocID, int64);
+// GlobalDocID global;
+// LocalDocID local;
+//
+// The class IntType prevents:
+//
+// 1) Assignments of other IntTypes with different IntTypeNames.
+//
+// global = local; <-- Fails to compile!
+// local = global; <-- Fails to compile!
+//
+// 2) Explicit/implicit conversion from an IntType to another IntType.
+//
+// LocalDocID l(global); <-- Fails to compile!
+// LocalDocID l = global; <-- Fails to compile!
+//
+// void GetGlobalDoc(GlobalDocID global) { }
+// GetGlobalDoc(global); <-- Compiles fine, types match!
+// GetGlobalDoc(local); <-- Fails to compile!
+//
+// 3) Implicit conversion from an IntType to a native integer type.
+//
+// void GetGlobalDoc(int64 global) { ...
+// GetGlobalDoc(global); <-- Fails to compile!
+// GetGlobalDoc(local); <-- Fails to compile!
+//
+// void GetLocalDoc(int32 local) { ...
+// GetLocalDoc(global); <-- Fails to compile!
+// GetLocalDoc(local); <-- Fails to compile!
+//
+//
+// SUPPORTED OPERATIONS --------------------------------------------------------
+//
+// The following operators are supported: unary: ++ (both prefix and postfix),
+// +, -, ! (logical not), ~ (one's complement); comparison: ==, !=, <, <=, >,
+// >=; numerical: +, -, *, /; assignment: =, +=, -=, /=, *=; stream: <<. Each
+// operator allows the same IntTypeName and the ValueType to be used on
+// both left- and right-hand sides.
+//
+// It also supports an accessor value() returning the stored value as ValueType,
+// and a templatized accessor value<T>() method that serves as syntactic sugar
+// for static_cast<T>(var.value()). These accessors are useful when assigning
+// the stored value into protocol buffer fields and using it as printf args.
+//
+// The class also defines a hash functor that allows the IntType to be used
+// as key to hashable containers such as std::unordered_map and
+// std::unordered_set.
+//
+// We suggest using the IntTypeIndexedContainer wrapper around FixedArray and
+// STL vector (see int-type-indexed-container.h) if an IntType is intended to
+// be used as an index into these containers. These wrappers are indexed in a
+// type-safe manner using IntTypes to ensure type-safety.
+//
+// NB: this implementation does not attempt to abide by or enforce dimensional
+// analysis on these scalar types.
+//
+// EXAMPLES --------------------------------------------------------------------
+//
+// DEFINE_INT_TYPE(GlobalDocID, int64);
+// GlobalDocID global = 3;
+// cout << global; <-- Prints 3 to stdout.
+//
+// for (GlobalDocID i(0); i < global; ++i) {
+// cout << i;
+// } <-- Print(ln)s 0 1 2 to stdout
+//
+// DEFINE_INT_TYPE(LocalDocID, int64);
+// LocalDocID local;
+// cout << local; <-- Prints 0 to stdout it default
+// initializes the value to 0.
+//
+// local = 5;
+// local *= 2;
+// LocalDocID l(local);
+// cout << l + local; <-- Prints 20 to stdout.
+//
+// GenericSearchRequest request;
+// request.set_doc_id(global.value()); <-- Uses value() to extract the value
+// from the IntType class.
+//
+// REMARKS ---------------------------------------------------------------------
+//
+// The following bad usage is permissible although discouraged. Essentially, it
+// involves using the value*() accessors to extract the native integer type out
+// of the IntType class. Keep in mind that the primary reason for the IntType
+// class is to prevent *accidental* mingling of similar logical integer types --
+// and not type casting from one type to another.
+//
+// DEFINE_INT_TYPE(GlobalDocID, int64);
+// DEFINE_INT_TYPE(LocalDocID, int64);
+// GlobalDocID global;
+// LocalDocID local;
+//
+// global = local.value(); <-- Compiles fine.
+//
+// void GetGlobalDoc(GlobalDocID global) { ...
+// GetGlobalDoc(local.value()); <-- Compiles fine.
+//
+// void GetGlobalDoc(int64 global) { ...
+// GetGlobalDoc(local.value()); <-- Compiles fine.
+
+#ifndef TENSORFLOW_LIB_GTL_INT_TYPE_H_
+#define TENSORFLOW_LIB_GTL_INT_TYPE_H_
+
+#include <stddef.h>
+#include <functional>
+#include <iosfwd>
+#include <ostream> // NOLINT
+#include <unordered_map>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace gtl {
+
+template <typename IntTypeName, typename _ValueType>
+class IntType;
+
+// Defines the IntType using value_type and typedefs it to int_type_name.
+// The struct int_type_name ## _tag_ trickery is needed to ensure that a new
+// type is created per int_type_name.
+#define TF_LIB_GTL_DEFINE_INT_TYPE(int_type_name, value_type) \
+ struct int_type_name##_tag_ {}; \
+ typedef ::tensorflow::gtl::IntType<int_type_name##_tag_, value_type> \
+ int_type_name;
+
+// Holds an integer value (of type ValueType) and behaves as a ValueType by
+// exposing assignment, unary, comparison, and arithmetic operators.
+//
+// The template parameter IntTypeName defines the name for the int type and must
+// be unique within a binary (the convenient DEFINE_INT_TYPE macro at the end of
+// the file generates a unique IntTypeName). The parameter ValueType defines
+// the integer type value (see supported list above).
+//
+// This class is NOT thread-safe.
+template <typename IntTypeName, typename _ValueType>
+class IntType {
+ public:
+ typedef _ValueType ValueType; // for non-member operators
+ typedef IntType<IntTypeName, ValueType> ThisType; // Syntactic sugar.
+
+ // Note that this may change from time to time without notice.
+ struct Hasher {
+ size_t operator()(const IntType& arg) const {
+ return static_cast<size_t>(arg.value());
+ }
+ };
+
+ public:
+ // Default c'tor initializing value_ to 0.
+ constexpr IntType() : value_(0) {}
+ // C'tor explicitly initializing from a ValueType.
+ constexpr explicit IntType(ValueType value) : value_(value) {}
+
+ // IntType uses the default copy constructor, destructor and assign operator.
+ // The defaults are sufficient and omitting them allows the compiler to add
+ // the move constructor/assignment.
+
+ // -- ACCESSORS --------------------------------------------------------------
+ // The class provides a value() accessor returning the stored ValueType value_
+ // as well as a templatized accessor that is just a syntactic sugar for
+ // static_cast<T>(var.value());
+ constexpr ValueType value() const { return value_; }
+
+ template <typename ValType>
+ constexpr ValType value() const {
+ return static_cast<ValType>(value_);
+ }
+
+ // -- UNARY OPERATORS --------------------------------------------------------
+ ThisType& operator++() { // prefix ++
+ ++value_;
+ return *this;
+ }
+ const ThisType operator++(int v) { // postfix ++
+ ThisType temp(*this);
+ ++value_;
+ return temp;
+ }
+ ThisType& operator--() { // prefix --
+ --value_;
+ return *this;
+ }
+ const ThisType operator--(int v) { // postfix --
+ ThisType temp(*this);
+ --value_;
+ return temp;
+ }
+
+ constexpr bool operator!() const { return value_ == 0; }
+ constexpr const ThisType operator+() const { return ThisType(value_); }
+ constexpr const ThisType operator-() const { return ThisType(-value_); }
+ constexpr const ThisType operator~() const { return ThisType(~value_); }
+
+// -- ASSIGNMENT OPERATORS ---------------------------------------------------
+// We support the following assignment operators: =, +=, -=, *=, /=, <<=, >>=
+// and %= for both ThisType and ValueType.
+#define INT_TYPE_ASSIGNMENT_OP(op) \
+ ThisType& operator op(const ThisType& arg_value) { \
+ value_ op arg_value.value(); \
+ return *this; \
+ } \
+ ThisType& operator op(ValueType arg_value) { \
+ value_ op arg_value; \
+ return *this; \
+ }
+ INT_TYPE_ASSIGNMENT_OP(+= );
+ INT_TYPE_ASSIGNMENT_OP(-= );
+ INT_TYPE_ASSIGNMENT_OP(*= );
+ INT_TYPE_ASSIGNMENT_OP(/= );
+ INT_TYPE_ASSIGNMENT_OP(<<= ); // NOLINT
+ INT_TYPE_ASSIGNMENT_OP(>>= ); // NOLINT
+ INT_TYPE_ASSIGNMENT_OP(%= );
+#undef INT_TYPE_ASSIGNMENT_OP
+
+ ThisType& operator=(ValueType arg_value) {
+ value_ = arg_value;
+ return *this;
+ }
+
+ private:
+ // The integer value of type ValueType.
+ ValueType value_;
+
+ static_assert(std::is_integral<ValueType>::value, "invalid integer type");
+} TF_PACKED;
+
+// -- NON-MEMBER STREAM OPERATORS ----------------------------------------------
+// We provide the << operator, primarily for logging purposes. Currently, there
+// seems to be no need for an >> operator.
+template <typename IntTypeName, typename ValueType>
+std::ostream& operator<<(std::ostream& os, // NOLINT
+ IntType<IntTypeName, ValueType> arg) {
+ return os << arg.value();
+}
+
+// -- NON-MEMBER ARITHMETIC OPERATORS ------------------------------------------
+// We support only the +, -, *, and / operators with the same IntType and
+// ValueType types. The reason is to allow simple manipulation on these IDs
+// when used as indices in vectors and arrays.
+//
+// NB: Although it is possible to do IntType * IntType and IntType / IntType,
+// it is probably non-sensical from a dimensionality analysis perspective.
+#define INT_TYPE_ARITHMETIC_OP(op) \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr IntType<IntTypeName, ValueType> operator op( \
+ IntType<IntTypeName, ValueType> id_1, \
+ IntType<IntTypeName, ValueType> id_2) { \
+ return IntType<IntTypeName, ValueType>(id_1.value() op id_2.value()); \
+ } \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr IntType<IntTypeName, ValueType> operator op( \
+ IntType<IntTypeName, ValueType> id, \
+ typename IntType<IntTypeName, ValueType>::ValueType arg_val) { \
+ return IntType<IntTypeName, ValueType>(id.value() op arg_val); \
+ } \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr IntType<IntTypeName, ValueType> operator op( \
+ typename IntType<IntTypeName, ValueType>::ValueType arg_val, \
+ IntType<IntTypeName, ValueType> id) { \
+ return IntType<IntTypeName, ValueType>(arg_val op id.value()); \
+ }
+INT_TYPE_ARITHMETIC_OP(+);
+INT_TYPE_ARITHMETIC_OP(-);
+INT_TYPE_ARITHMETIC_OP(*);
+INT_TYPE_ARITHMETIC_OP(/ );
+INT_TYPE_ARITHMETIC_OP(<< ); // NOLINT
+INT_TYPE_ARITHMETIC_OP(>> ); // NOLINT
+INT_TYPE_ARITHMETIC_OP(% );
+#undef INT_TYPE_ARITHMETIC_OP
+
+// -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------
+// Static inline comparison operators. We allow all comparison operators among
+// the following types (OP \in [==, !=, <, <=, >, >=]:
+// IntType<IntTypeName, ValueType> OP IntType<IntTypeName, ValueType>
+// IntType<IntTypeName, ValueType> OP ValueType
+// ValueType OP IntType<IntTypeName, ValueType>
+#define INT_TYPE_COMPARISON_OP(op) \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr bool operator op( \
+ IntType<IntTypeName, ValueType> id_1, \
+ IntType<IntTypeName, ValueType> id_2) { \
+ return id_1.value() op id_2.value(); \
+ } \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr bool operator op( \
+ IntType<IntTypeName, ValueType> id, \
+ typename IntType<IntTypeName, ValueType>::ValueType val) { \
+ return id.value() op val; \
+ } \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr bool operator op( \
+ typename IntType<IntTypeName, ValueType>::ValueType val, \
+ IntType<IntTypeName, ValueType> id) { \
+ return val op id.value(); \
+ }
+INT_TYPE_COMPARISON_OP(== ); // NOLINT
+INT_TYPE_COMPARISON_OP(!= ); // NOLINT
+INT_TYPE_COMPARISON_OP(< ); // NOLINT
+INT_TYPE_COMPARISON_OP(<= ); // NOLINT
+INT_TYPE_COMPARISON_OP(> ); // NOLINT
+INT_TYPE_COMPARISON_OP(>= ); // NOLINT
+#undef INT_TYPE_COMPARISON_OP
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_INT_TYPE_H_
diff --git a/tensorflow/core/lib/gtl/int_type_test.cc b/tensorflow/core/lib/gtl/int_type_test.cc
new file mode 100644
index 0000000000..694886d345
--- /dev/null
+++ b/tensorflow/core/lib/gtl/int_type_test.cc
@@ -0,0 +1,282 @@
+// Unit test cases for IntType.
+
+#include <memory>
+#include <unordered_map>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/int_type.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TF_LIB_GTL_DEFINE_INT_TYPE(Int8_IT, int8);
+TF_LIB_GTL_DEFINE_INT_TYPE(UInt8_IT, uint8);
+TF_LIB_GTL_DEFINE_INT_TYPE(Int16_IT, int16);
+TF_LIB_GTL_DEFINE_INT_TYPE(UInt16_IT, uint16);
+TF_LIB_GTL_DEFINE_INT_TYPE(Int32_IT, int32);
+TF_LIB_GTL_DEFINE_INT_TYPE(Int64_IT, int64);
+TF_LIB_GTL_DEFINE_INT_TYPE(UInt32_IT, uint32);
+TF_LIB_GTL_DEFINE_INT_TYPE(UInt64_IT, uint64);
+TF_LIB_GTL_DEFINE_INT_TYPE(Long_IT, long); // NOLINT
+
+template <typename IntType_Type>
+class IntTypeTest : public ::testing::Test {
+ public:
+ typedef IntType_Type T;
+};
+
+// All tests below will be executed on all supported IntTypes.
+typedef ::testing::Types<Int8_IT, UInt8_IT, Int16_IT, UInt16_IT, Int32_IT,
+ Int64_IT, UInt64_IT, Long_IT> SupportedIntTypes;
+
+TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes);
+
+TYPED_TEST(IntTypeTest, TestInitialization) {
+ constexpr typename TestFixture::T a;
+ constexpr typename TestFixture::T b(1);
+ constexpr typename TestFixture::T c(b);
+ EXPECT_EQ(0, a); // default initialization to 0
+ EXPECT_EQ(1, b);
+ EXPECT_EQ(1, c);
+}
+
+TYPED_TEST(IntTypeTest, TestOperators) {
+ typename TestFixture::T a(0);
+ typename TestFixture::T b(1);
+ typename TestFixture::T c(2);
+ constexpr typename TestFixture::T d(3);
+ constexpr typename TestFixture::T e(4);
+
+ // On all EXPECT_EQ below, we use the accessor value() as to not invoke the
+ // comparison operators which must themselves be tested.
+
+ // -- UNARY OPERATORS --------------------------------------------------------
+ EXPECT_EQ(0, (a++).value());
+ EXPECT_EQ(2, (++a).value());
+ EXPECT_EQ(2, (a--).value());
+ EXPECT_EQ(0, (--a).value());
+
+ EXPECT_EQ(true, !a);
+ EXPECT_EQ(false, !b);
+ static_assert(!d == false, "Unary operator! failed");
+
+ EXPECT_EQ(a.value(), +a);
+ static_assert(+d == d.value(), "Unary operator+ failed");
+ EXPECT_EQ(-a.value(), -a);
+ static_assert(-d == -d.value(), "Unary operator- failed");
+ EXPECT_EQ(~a.value(), ~a); // ~zero
+ EXPECT_EQ(~b.value(), ~b); // ~non-zero
+ static_assert(~d == ~d.value(), "Unary operator~ failed");
+
+ // -- ASSIGNMENT OPERATORS ---------------------------------------------------
+ // We test all assignment operators using IntType and constant as arguments.
+ // We also test the return from the operators.
+ // From same IntType
+ c = a = b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ // From constant
+ c = b = 2;
+ EXPECT_EQ(2, b.value());
+ EXPECT_EQ(2, c.value());
+ // From same IntType
+ c = a += b;
+ EXPECT_EQ(3, a.value());
+ EXPECT_EQ(3, c.value());
+ c = a -= b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a *= b;
+ EXPECT_EQ(2, a.value());
+ EXPECT_EQ(2, c.value());
+ c = a /= b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a <<= b;
+ EXPECT_EQ(4, a.value());
+ EXPECT_EQ(4, c.value());
+ c = a >>= b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a %= b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ // From constant
+ c = a += 2;
+ EXPECT_EQ(3, a.value());
+ EXPECT_EQ(3, c.value());
+ c = a -= 2;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a *= 2;
+ EXPECT_EQ(2, a.value());
+ EXPECT_EQ(2, c.value());
+ c = a /= 2;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a <<= 2;
+ EXPECT_EQ(4, a.value());
+ EXPECT_EQ(4, c.value());
+ c = a >>= 2;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a %= 2;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+
+ // -- COMPARISON OPERATORS ---------------------------------------------------
+ a = 0;
+ b = 1;
+
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a == 0); // NOLINT
+ EXPECT_FALSE(1 == a); // NOLINT
+ static_assert(d == d, "operator== failed");
+ static_assert(d == 3, "operator== failed");
+ static_assert(3 == d, "operator== failed");
+ EXPECT_TRUE(a != b);
+ EXPECT_TRUE(a != 1); // NOLINT
+ EXPECT_FALSE(0 != a); // NOLINT
+ static_assert(d != e, "operator!= failed");
+ static_assert(d != 4, "operator!= failed");
+ static_assert(4 != d, "operator!= failed");
+ EXPECT_TRUE(a < b);
+ EXPECT_TRUE(a < 1); // NOLINT
+ EXPECT_FALSE(0 < a); // NOLINT
+ static_assert(d < e, "operator< failed");
+ static_assert(d < 4, "operator< failed");
+ static_assert(3 < e, "operator< failed");
+ EXPECT_TRUE(a <= b);
+ EXPECT_TRUE(a <= 1); // NOLINT
+ EXPECT_TRUE(0 <= a); // NOLINT
+ static_assert(d <= e, "operator<= failed");
+ static_assert(d <= 4, "operator<= failed");
+ static_assert(3 <= e, "operator<= failed");
+ EXPECT_FALSE(a > b);
+ EXPECT_FALSE(a > 1); // NOLINT
+ EXPECT_FALSE(0 > a); // NOLINT
+ static_assert(e > d, "operator> failed");
+ static_assert(e > 3, "operator> failed");
+ static_assert(4 > d, "operator> failed");
+ EXPECT_FALSE(a >= b);
+ EXPECT_FALSE(a >= 1); // NOLINT
+ EXPECT_TRUE(0 >= a); // NOLINT
+ static_assert(e >= d, "operator>= failed");
+ static_assert(e >= 3, "operator>= failed");
+ static_assert(4 >= d, "operator>= failed");
+
+ // -- BINARY OPERATORS -------------------------------------------------------
+ a = 1;
+ b = 3;
+ EXPECT_EQ(4, (a + b).value());
+ EXPECT_EQ(4, (a + 3).value());
+ EXPECT_EQ(4, (1 + b).value());
+ static_assert((d + e).value() == 7, "Binary operator+ failed");
+ static_assert((d + 4).value() == 7, "Binary operator+ failed");
+ static_assert((3 + e).value() == 7, "Binary operator+ failed");
+ EXPECT_EQ(2, (b - a).value());
+ EXPECT_EQ(2, (b - 1).value());
+ EXPECT_EQ(2, (3 - a).value());
+ static_assert((e - d).value() == 1, "Binary operator- failed");
+ static_assert((e - 3).value() == 1, "Binary operator- failed");
+ static_assert((4 - d).value() == 1, "Binary operator- failed");
+ EXPECT_EQ(3, (a * b).value());
+ EXPECT_EQ(3, (a * 3).value());
+ EXPECT_EQ(3, (1 * b).value());
+ static_assert((d * e).value() == 12, "Binary operator* failed");
+ static_assert((d * 4).value() == 12, "Binary operator* failed");
+ static_assert((3 * e).value() == 12, "Binary operator* failed");
+ EXPECT_EQ(0, (a / b).value());
+ EXPECT_EQ(0, (a / 3).value());
+ EXPECT_EQ(0, (1 / b).value());
+ static_assert((d / e).value() == 0, "Binary operator/ failed");
+ static_assert((d / 4).value() == 0, "Binary operator/ failed");
+ static_assert((3 / e).value() == 0, "Binary operator/ failed");
+ EXPECT_EQ(8, (a << b).value());
+ EXPECT_EQ(8, (a << 3).value());
+ EXPECT_EQ(8, (1 << b).value());
+ static_assert((d << e).value() == 48, "Binary operator<< failed");
+ static_assert((d << 4).value() == 48, "Binary operator<< failed");
+ static_assert((3 << e).value() == 48, "Binary operator<< failed");
+ b = 8;
+ EXPECT_EQ(4, (b >> a).value());
+ EXPECT_EQ(4, (b >> 1).value());
+ EXPECT_EQ(4, (8 >> a).value());
+ static_assert((d >> e).value() == 0, "Binary operator>> failed");
+ static_assert((d >> 4).value() == 0, "Binary operator>> failed");
+ static_assert((3 >> e).value() == 0, "Binary operator>> failed");
+ b = 3;
+ a = 2;
+ EXPECT_EQ(1, (b % a).value());
+ EXPECT_EQ(1, (b % 2).value());
+ EXPECT_EQ(1, (3 % a).value());
+ static_assert((e % d).value() == 1, "Binary operator% failed");
+ static_assert((e % 3).value() == 1, "Binary operator% failed");
+ static_assert((4 % d).value() == 1, "Binary operator% failed");
+}
+
+TYPED_TEST(IntTypeTest, TestHashFunctor) {
+ std::unordered_map<typename TestFixture::T, char,
+ typename TestFixture::T::Hasher> map;
+ typename TestFixture::T a(0);
+ map[a] = 'c';
+ EXPECT_EQ('c', map[a]);
+ map[++a] = 'o';
+ EXPECT_EQ('o', map[a]);
+
+ typename TestFixture::T b(a);
+ EXPECT_EQ(typename TestFixture::T::Hasher()(a),
+ typename TestFixture::T::Hasher()(b));
+}
+
+// Tests the use of the templatized value accessor that performs static_casts.
+// We use -1 to force casting in unsigned integers.
+TYPED_TEST(IntTypeTest, TestValueAccessor) {
+ constexpr typename TestFixture::T::ValueType i = -1;
+ constexpr typename TestFixture::T int_type(i);
+ EXPECT_EQ(i, int_type.value());
+ static_assert(int_type.value() == i, "value() failed");
+ // The use of the keyword 'template' (suggested by Clang) is only necessary
+ // as this code is part of a template class. Weird syntax though. Good news
+ // is that only int_type.value<int>() is needed in most code.
+ EXPECT_EQ(static_cast<int>(i), int_type.template value<int>());
+ EXPECT_EQ(static_cast<int8>(i), int_type.template value<int8>());
+ EXPECT_EQ(static_cast<int16>(i), int_type.template value<int16>());
+ EXPECT_EQ(static_cast<int32>(i), int_type.template value<int32>());
+ EXPECT_EQ(static_cast<uint32>(i), int_type.template value<uint32>());
+ EXPECT_EQ(static_cast<int64>(i), int_type.template value<int64>());
+ EXPECT_EQ(static_cast<uint64>(i), int_type.template value<uint64>());
+ EXPECT_EQ(static_cast<long>(i), int_type.template value<long>()); // NOLINT
+ static_assert(int_type.template value<int>() == static_cast<int>(i),
+ "value<Value>() failed");
+}
+
+TYPED_TEST(IntTypeTest, TestMove) {
+ // Check that the int types have move constructor/assignment.
+ // We do this by composing a struct with an int type and a unique_ptr. This
+ // struct can't be copied due to the unique_ptr, so it must be moved.
+ // If this compiles, it means that the int types have move operators.
+ struct NotCopyable {
+ typename TestFixture::T inttype;
+ std::unique_ptr<int> ptr;
+
+ static NotCopyable Make(int i) {
+ NotCopyable f;
+ f.inttype = typename TestFixture::T(i);
+ f.ptr.reset(new int(i));
+ return f;
+ }
+ };
+
+ // Test move constructor.
+ NotCopyable foo = NotCopyable::Make(123);
+ EXPECT_EQ(123, foo.inttype);
+ EXPECT_EQ(123, *foo.ptr);
+
+ // Test move assignment.
+ foo = NotCopyable::Make(321);
+ EXPECT_EQ(321, foo.inttype);
+ EXPECT_EQ(321, *foo.ptr);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/iterator_range.h b/tensorflow/core/lib/gtl/iterator_range.h
new file mode 100644
index 0000000000..baec85c40a
--- /dev/null
+++ b/tensorflow/core/lib/gtl/iterator_range.h
@@ -0,0 +1,49 @@
+// This provides a very simple, boring adaptor for a begin and end iterator
+// into a range type. This should be used to build range views that work well
+// with range based for loops and range based constructors.
+//
+// Note that code here follows more standards-based coding conventions as it
+// is mirroring proposed interfaces for standardization.
+//
+// Converted from chandlerc@'s code to Google style by joshl@.
+
+#ifndef TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_
+#define TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_
+
+#include <utility>
+
+namespace tensorflow {
+namespace gtl {
+
+// A range adaptor for a pair of iterators.
+//
+// This just wraps two iterators into a range-compatible interface. Nothing
+// fancy at all.
+template <typename IteratorT>
+class iterator_range {
+ public:
+ iterator_range() : begin_iterator_(), end_iterator_() {}
+ iterator_range(IteratorT begin_iterator, IteratorT end_iterator)
+ : begin_iterator_(std::move(begin_iterator)),
+ end_iterator_(std::move(end_iterator)) {}
+
+ IteratorT begin() const { return begin_iterator_; }
+ IteratorT end() const { return end_iterator_; }
+
+ private:
+ IteratorT begin_iterator_, end_iterator_;
+};
+
+// Convenience function for iterating over sub-ranges.
+//
+// This provides a bit of syntactic sugar to make using sub-ranges
+// in for loops a bit easier. Analogous to std::make_pair().
+template <class T>
+iterator_range<T> make_range(T x, T y) {
+ return iterator_range<T>(std::move(x), std::move(y));
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_
diff --git a/tensorflow/core/lib/gtl/iterator_range_test.cc b/tensorflow/core/lib/gtl/iterator_range_test.cc
new file mode 100644
index 0000000000..328be4ecbc
--- /dev/null
+++ b/tensorflow/core/lib/gtl/iterator_range_test.cc
@@ -0,0 +1,60 @@
+#include "tensorflow/core/lib/gtl/iterator_range.h"
+
+#include <vector>
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace gtl {
+namespace {
+
+TEST(IteratorRange, WholeVector) {
+ std::vector<int> v = {2, 3, 5, 7, 11, 13};
+ iterator_range<std::vector<int>::iterator> range(v.begin(), v.end());
+ int index = 0;
+ for (int prime : range) {
+ ASSERT_LT(index, v.size());
+ EXPECT_EQ(v[index], prime);
+ ++index;
+ }
+ EXPECT_EQ(v.size(), index);
+}
+
+TEST(IteratorRange, VectorMakeRange) {
+ std::vector<int> v = {2, 3, 5, 7, 11, 13};
+ auto range = make_range(v.begin(), v.end());
+ int index = 0;
+ for (int prime : range) {
+ ASSERT_LT(index, v.size());
+ EXPECT_EQ(v[index], prime);
+ ++index;
+ }
+ EXPECT_EQ(v.size(), index);
+}
+
+TEST(IteratorRange, PartArray) {
+ int v[] = {2, 3, 5, 7, 11, 13};
+ iterator_range<int*> range(&v[1], &v[4]); // 3, 5, 7
+ int index = 1;
+ for (int prime : range) {
+ ASSERT_LT(index, TF_ARRAYSIZE(v));
+ EXPECT_EQ(v[index], prime);
+ ++index;
+ }
+ EXPECT_EQ(4, index);
+}
+
+TEST(IteratorRange, ArrayMakeRange) {
+ int v[] = {2, 3, 5, 7, 11, 13};
+ auto range = make_range(&v[1], &v[4]); // 3, 5, 7
+ int index = 1;
+ for (int prime : range) {
+ ASSERT_LT(index, TF_ARRAYSIZE(v));
+ EXPECT_EQ(v[index], prime);
+ ++index;
+ }
+ EXPECT_EQ(4, index);
+}
+} // namespace
+} // namespace gtl
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/manual_constructor.h b/tensorflow/core/lib/gtl/manual_constructor.h
new file mode 100644
index 0000000000..39f029ed4a
--- /dev/null
+++ b/tensorflow/core/lib/gtl/manual_constructor.h
@@ -0,0 +1,230 @@
+// ManualConstructor statically-allocates space in which to store some
+// object, but does not initialize it. You can then call the constructor
+// and destructor for the object yourself as you see fit. This is useful
+// for memory management optimizations, where you want to initialize and
+// destroy an object multiple times but only allocate it once.
+//
+// (When I say ManualConstructor statically allocates space, I mean that
+// the ManualConstructor object itself is forced to be the right size.)
+
+#ifndef TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_
+#define TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_
+
+#include <stddef.h>
+#include <new>
+#include <utility>
+
+#include "tensorflow/core/platform/port.h" // For aligned_malloc/aligned_free
+
+namespace tensorflow {
+namespace gtl {
+namespace internal {
+
+//
+// Provides a char array with the exact same alignment as another type. The
+// first parameter must be a complete type, the second parameter is how many
+// of that type to provide space for.
+//
+// TF_LIB_GTL_ALIGNED_CHAR_ARRAY(struct stat, 16) storage_;
+//
+// Because MSVC and older GCCs require that the argument to their alignment
+// construct to be a literal constant integer, we use a template instantiated
+// at all the possible powers of two.
+#ifndef SWIG
+template <int alignment, int size>
+struct AlignType {};
+template <int size>
+struct AlignType<0, size> {
+ typedef char result[size];
+};
+#if defined(COMPILER_MSVC)
+#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __declspec(align(X))
+#define TF_LIB_GTL_ALIGN_OF(T) __alignof(T)
+#elif defined(COMPILER_GCC3) || __GNUC__ >= 3 || defined(__APPLE__) || \
+ defined(COMPILER_ICC) || defined(OS_NACL) || defined(__clang__)
+#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __attribute__((aligned(X)))
+#define TF_LIB_GTL_ALIGN_OF(T) __alignof__(T)
+#endif
+
+#if defined(TF_LIB_GTL_ALIGN_ATTRIBUTE)
+
+#define TF_LIB_GTL_ALIGNTYPE_TEMPLATE(X) \
+ template <int size> \
+ struct AlignType<X, size> { \
+ typedef TF_LIB_GTL_ALIGN_ATTRIBUTE(X) char result[size]; \
+ }
+
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(16);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(32);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(64);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(128);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(256);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(512);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1024);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2048);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4096);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8192);
+// Any larger and MSVC++ will complain.
+
+#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \
+ typename tensorflow::gtl::internal::AlignType<TF_LIB_GTL_ALIGN_OF(T), \
+ sizeof(T) * Size>::result
+
+#undef TF_LIB_GTL_ALIGNTYPE_TEMPLATE
+#undef TF_LIB_GTL_ALIGN_ATTRIBUTE
+
+#else // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE)
+#error "You must define TF_LIB_GTL_ALIGNED_CHAR_ARRAY for your compiler."
+#endif // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE)
+
+#else // !SWIG
+
+// SWIG can't represent alignment and doesn't care about alignment on data
+// members (it works fine without it).
+template <typename Size>
+struct AlignType {
+ typedef char result[Size];
+};
+#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \
+ tensorflow::gtl::internal::AlignType<Size * sizeof(T)>::result
+
+// Enough to parse with SWIG, will never be used by running code.
+#define TF_LIB_GTL_ALIGN_OF(Type) 16
+
+#endif // !SWIG
+
+} // namespace internal
+} // namespace gtl
+
+template <typename Type>
+class ManualConstructor {
+ public:
+ // No constructor or destructor because one of the most useful uses of
+ // this class is as part of a union, and members of a union cannot have
+ // constructors or destructors. And, anyway, the whole point of this
+ // class is to bypass these.
+
+ // Support users creating arrays of ManualConstructor<>s. This ensures that
+ // the array itself has the correct alignment.
+ static void* operator new[](size_t size) {
+ return port::aligned_malloc(size, TF_LIB_GTL_ALIGN_OF(Type));
+ }
+ static void operator delete[](void* mem) { port::aligned_free(mem); }
+
+ inline Type* get() { return reinterpret_cast<Type*>(space_); }
+ inline const Type* get() const {
+ return reinterpret_cast<const Type*>(space_);
+ }
+
+ inline Type* operator->() { return get(); }
+ inline const Type* operator->() const { return get(); }
+
+ inline Type& operator*() { return *get(); }
+ inline const Type& operator*() const { return *get(); }
+
+ inline void Init() { new (space_) Type; }
+
+// Init() constructs the Type instance using the given arguments
+// (which are forwarded to Type's constructor). In C++11, Init() can
+// take any number of arguments of any type, and forwards them perfectly.
+// On pre-C++11 platforms, it can take up to 11 arguments, and may not be
+// able to forward certain kinds of arguments.
+//
+// Note that Init() with no arguments performs default-initialization,
+// not zero-initialization (i.e it behaves the same as "new Type;", not
+// "new Type();"), so it will leave non-class types uninitialized.
+#ifdef LANG_CXX11
+ template <typename... Ts>
+ inline void Init(Ts&&... args) { // NOLINT
+ new (space_) Type(std::forward<Ts>(args)...); // NOLINT
+ }
+#else // !defined(LANG_CXX11)
+ template <typename T1>
+ inline void Init(const T1& p1) {
+ new (space_) Type(p1);
+ }
+
+ template <typename T1, typename T2>
+ inline void Init(const T1& p1, const T2& p2) {
+ new (space_) Type(p1, p2);
+ }
+
+ template <typename T1, typename T2, typename T3>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3) {
+ new (space_) Type(p1, p2, p3);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4) {
+ new (space_) Type(p1, p2, p3, p4);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5) {
+ new (space_) Type(p1, p2, p3, p4, p5);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7, const T8& p8) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7, const T8& p8,
+ const T9& p9) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9, typename T10>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7, const T8& p8,
+ const T9& p9, const T10& p10) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9, typename T10,
+ typename T11>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7, const T8& p8,
+ const T9& p9, const T10& p10, const T11& p11) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11);
+ }
+#endif // LANG_CXX11
+
+ inline void Destroy() { get()->~Type(); }
+
+ private:
+ TF_LIB_GTL_ALIGNED_CHAR_ARRAY(Type, 1) space_;
+};
+
+#undef TF_LIB_GTL_ALIGNED_CHAR_ARRAY
+#undef TF_LIB_GTL_ALIGN_OF
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_
diff --git a/tensorflow/core/lib/gtl/manual_constructor_test.cc b/tensorflow/core/lib/gtl/manual_constructor_test.cc
new file mode 100644
index 0000000000..a929591be2
--- /dev/null
+++ b/tensorflow/core/lib/gtl/manual_constructor_test.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/core/lib/gtl/manual_constructor.h"
+
+#include <stdint.h>
+
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+static int constructor_count_ = 0;
+
+template <int kSize>
+struct TestN {
+ TestN() { ++constructor_count_; }
+ ~TestN() { --constructor_count_; }
+ char a[kSize];
+};
+
+typedef TestN<1> Test1;
+typedef TestN<2> Test2;
+typedef TestN<3> Test3;
+typedef TestN<4> Test4;
+typedef TestN<5> Test5;
+typedef TestN<9> Test9;
+typedef TestN<15> Test15;
+
+} // namespace
+
+namespace {
+
+TEST(ManualConstructorTest, Sizeof) {
+ CHECK_EQ(sizeof(ManualConstructor<Test1>), sizeof(Test1));
+ CHECK_EQ(sizeof(ManualConstructor<Test2>), sizeof(Test2));
+ CHECK_EQ(sizeof(ManualConstructor<Test3>), sizeof(Test3));
+ CHECK_EQ(sizeof(ManualConstructor<Test4>), sizeof(Test4));
+ CHECK_EQ(sizeof(ManualConstructor<Test5>), sizeof(Test5));
+ CHECK_EQ(sizeof(ManualConstructor<Test9>), sizeof(Test9));
+ CHECK_EQ(sizeof(ManualConstructor<Test15>), sizeof(Test15));
+
+ CHECK_EQ(constructor_count_, 0);
+ ManualConstructor<Test1> mt[4];
+ CHECK_EQ(sizeof(mt), 4);
+ CHECK_EQ(constructor_count_, 0);
+ mt[0].Init();
+ CHECK_EQ(constructor_count_, 1);
+ mt[0].Destroy();
+}
+
+TEST(ManualConstructorTest, Alignment) {
+ // We want to make sure that ManualConstructor aligns its memory properly
+ // on a word barrier. Otherwise, it might be unexpectedly slow, since
+ // memory access will be unaligned.
+
+ struct {
+ char a;
+ ManualConstructor<void*> b;
+ } test1;
+ struct {
+ char a;
+ void* b;
+ } control1;
+
+ // TODO(bww): Make these tests more direct with C++11 alignment_of<T>::value.
+ EXPECT_EQ(reinterpret_cast<char*>(test1.b.get()) - &test1.a,
+ reinterpret_cast<char*>(&control1.b) - &control1.a);
+ EXPECT_EQ(reinterpret_cast<intptr_t>(test1.b.get()) % sizeof(control1.b), 0);
+
+ struct {
+ char a;
+ ManualConstructor<long double> b;
+ } test2;
+ struct {
+ char a;
+ long double b;
+ } control2;
+
+ EXPECT_EQ(reinterpret_cast<char*>(test2.b.get()) - &test2.a,
+ reinterpret_cast<char*>(&control2.b) - &control2.a);
+#ifdef ARCH_K8
+ EXPECT_EQ(reinterpret_cast<intptr_t>(test2.b.get()) % 16, 0);
+#endif
+#ifdef ARCH_PIII
+ EXPECT_EQ(reinterpret_cast<intptr_t>(test2.b.get()) % 4, 0);
+#endif
+}
+
+TEST(ManualConstructorTest, DefaultInitialize) {
+ struct X {
+ X() : x(123) {}
+ int x;
+ };
+ union {
+ ManualConstructor<X> x;
+ ManualConstructor<int> y;
+ } u;
+ *u.y = -1;
+ u.x.Init(); // should default-initialize u.x
+ EXPECT_EQ(123, u.x->x);
+}
+
+TEST(ManualConstructorTest, ZeroInitializePOD) {
+ union {
+ ManualConstructor<int> x;
+ ManualConstructor<int> y;
+ } u;
+ *u.y = -1;
+ u.x.Init(); // should not zero-initialize u.x
+ EXPECT_EQ(-1, *u.y);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/map_util.h b/tensorflow/core/lib/gtl/map_util.h
new file mode 100644
index 0000000000..c953de57c7
--- /dev/null
+++ b/tensorflow/core/lib/gtl/map_util.h
@@ -0,0 +1,123 @@
+// This file provides utility functions for use with STL map-like data
+// structures, such as std::map and hash_map. Some functions will also work with
+// sets, such as ContainsKey().
+
+#ifndef TENSORFLOW_LIB_GTL_MAP_UTIL_H_
+#define TENSORFLOW_LIB_GTL_MAP_UTIL_H_
+
+#include <stddef.h>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tensorflow {
+namespace gtl {
+
+// Returns a pointer to the const value associated with the given key if it
+// exists, or NULL otherwise.
+template <class Collection>
+const typename Collection::value_type::second_type* FindOrNull(
+ const Collection& collection,
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::const_iterator it = collection.find(key);
+ if (it == collection.end()) {
+ return 0;
+ }
+ return &it->second;
+}
+
+// Same as above but returns a pointer to the non-const value.
+template <class Collection>
+typename Collection::value_type::second_type* FindOrNull(
+ Collection& collection, // NOLINT
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::iterator it = collection.find(key);
+ if (it == collection.end()) {
+ return 0;
+ }
+ return &it->second;
+}
+
+// Returns the pointer value associated with the given key. If none is found,
+// NULL is returned. The function is designed to be used with a map of keys to
+// pointers.
+//
+// This function does not distinguish between a missing key and a key mapped
+// to a NULL value.
+template <class Collection>
+typename Collection::value_type::second_type FindPtrOrNull(
+ const Collection& collection,
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::const_iterator it = collection.find(key);
+ if (it == collection.end()) {
+ return typename Collection::value_type::second_type();
+ }
+ return it->second;
+}
+
+// Returns a const reference to the value associated with the given key if it
+// exists, otherwise returns a const reference to the provided default value.
+//
+// WARNING: If a temporary object is passed as the default "value,"
+// this function will return a reference to that temporary object,
+// which will be destroyed at the end of the statement. A common
+// example: if you have a map with string values, and you pass a char*
+// as the default "value," either use the returned value immediately
+// or store it in a string (not string&).
+template <class Collection>
+const typename Collection::value_type::second_type& FindWithDefault(
+ const Collection& collection,
+ const typename Collection::value_type::first_type& key,
+ const typename Collection::value_type::second_type& value) {
+ typename Collection::const_iterator it = collection.find(key);
+ if (it == collection.end()) {
+ return value;
+ }
+ return it->second;
+}
+
+// Inserts the given key and value into the given collection if and only if the
+// given key did NOT already exist in the collection. If the key previously
+// existed in the collection, the value is not changed. Returns true if the
+// key-value pair was inserted; returns false if the key was already present.
+template <class Collection>
+bool InsertIfNotPresent(Collection* const collection,
+ const typename Collection::value_type& vt) {
+ return collection->insert(vt).second;
+}
+
+// Same as above except the key and value are passed separately.
+template <class Collection>
+bool InsertIfNotPresent(
+ Collection* const collection,
+ const typename Collection::value_type::first_type& key,
+ const typename Collection::value_type::second_type& value) {
+ return InsertIfNotPresent(collection,
+ typename Collection::value_type(key, value));
+}
+
+// Looks up a given key and value pair in a collection and inserts the key-value
+// pair if it's not already present. Returns a reference to the value associated
+// with the key.
+template <class Collection>
+typename Collection::value_type::second_type& LookupOrInsert(
+ Collection* const collection, const typename Collection::value_type& vt) {
+ return collection->insert(vt).first->second;
+}
+
+// Same as above except the key-value are passed separately.
+template <class Collection>
+typename Collection::value_type::second_type& LookupOrInsert(
+ Collection* const collection,
+ const typename Collection::value_type::first_type& key,
+ const typename Collection::value_type::second_type& value) {
+ return LookupOrInsert(collection,
+ typename Collection::value_type(key, value));
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_MAP_UTIL_H_
diff --git a/tensorflow/core/lib/gtl/map_util_test.cc b/tensorflow/core/lib/gtl/map_util_test.cc
new file mode 100644
index 0000000000..356f987337
--- /dev/null
+++ b/tensorflow/core/lib/gtl/map_util_test.cc
@@ -0,0 +1,47 @@
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+#include <map>
+#include <set>
+#include <string>
+#include "tensorflow/core/platform/port.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(MapUtil, Find) {
+ typedef std::map<string, string> Map;
+ Map m;
+
+ // Check that I can use a type that's implicitly convertible to the
+ // key or value type, such as const char* -> string.
+ EXPECT_EQ("", gtl::FindWithDefault(m, "foo", ""));
+ m["foo"] = "bar";
+ EXPECT_EQ("bar", gtl::FindWithDefault(m, "foo", ""));
+ EXPECT_EQ("bar", *gtl::FindOrNull(m, "foo"));
+ string str;
+ EXPECT_TRUE(m.count("foo") > 0);
+ EXPECT_EQ(m["foo"], "bar");
+}
+
+TEST(MapUtil, LookupOrInsert) {
+ typedef std::map<string, string> Map;
+ Map m;
+
+ // Check that I can use a type that's implicitly convertible to the
+ // key or value type, such as const char* -> string.
+ EXPECT_EQ("xyz", gtl::LookupOrInsert(&m, "foo", "xyz"));
+ EXPECT_EQ("xyz", gtl::LookupOrInsert(&m, "foo", "abc"));
+}
+
+TEST(MapUtil, InsertIfNotPresent) {
+ // Set operations
+ typedef std::set<int> Set;
+ Set s;
+ EXPECT_TRUE(gtl::InsertIfNotPresent(&s, 0));
+ EXPECT_EQ(s.count(0), 1);
+ EXPECT_FALSE(gtl::InsertIfNotPresent(&s, 0));
+ EXPECT_EQ(s.count(0), 1);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/stl_util.h b/tensorflow/core/lib/gtl/stl_util.h
new file mode 100644
index 0000000000..83abcd6b55
--- /dev/null
+++ b/tensorflow/core/lib/gtl/stl_util.h
@@ -0,0 +1,130 @@
+// This file provides utility functions for use with STL
+
+#ifndef TENSORFLOW_LIB_GTL_STL_UTIL_H_
+#define TENSORFLOW_LIB_GTL_STL_UTIL_H_
+
+#include <stddef.h>
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tensorflow {
+namespace gtl {
+
+// Returns a mutable char* pointing to a string's internal buffer, which may not
+// be null-terminated. Returns NULL for an empty string. If not non-null,
+// writing through this pointer will modify the string.
+//
+// string_as_array(&str)[i] is valid for 0 <= i < str.size() until the
+// next call to a string method that invalidates iterators.
+//
+// In C++11 you may simply use &str[0] to get a mutable char*.
+//
+// Prior to C++11, there was no standard-blessed way of getting a mutable
+// reference to a string's internal buffer. The requirement that string be
+// contiguous is officially part of the C++11 standard [string.require]/5.
+// According to Matt Austern, this should already work on all current C++98
+// implementations.
+inline char* string_as_array(string* str) {
+ return str->empty() ? NULL : &*str->begin();
+}
+
+// Returns the T* array for the given vector, or NULL if the vector was empty.
+//
+// Note: If you know the array will never be empty, you can use &*v.begin()
+// directly, but that is may dump core if v is empty. This function is the most
+// efficient code that will work, taking into account how our STL is actually
+// implemented. THIS IS NON-PORTABLE CODE, so use this function instead of
+// repeating the nonportable code everywhere. If our STL implementation changes,
+// we will need to change this as well.
+template <typename T, typename Allocator>
+inline T* vector_as_array(std::vector<T, Allocator>* v) {
+#if defined NDEBUG && !defined _GLIBCXX_DEBUG
+ return &*v->begin();
+#else
+ return v->empty() ? NULL : &*v->begin();
+#endif
+}
+// vector_as_array overload for const std::vector<>.
+template <typename T, typename Allocator>
+inline const T* vector_as_array(const std::vector<T, Allocator>* v) {
+#if defined NDEBUG && !defined _GLIBCXX_DEBUG
+ return &*v->begin();
+#else
+ return v->empty() ? NULL : &*v->begin();
+#endif
+}
+
+// Like str->resize(new_size), except any new characters added to "*str" as a
+// result of resizing may be left uninitialized, rather than being filled with
+// '0' bytes. Typically used when code is then going to overwrite the backing
+// store of the string with known data. Uses a Google extension to ::string.
+inline void STLStringResizeUninitialized(string* s, size_t new_size) {
+#if __google_stl_resize_uninitialized_string
+ s->resize_uninitialized(new_size);
+#else
+ s->resize(new_size);
+#endif
+}
+
+// Calls delete (non-array version) on the SECOND item (pointer) in each pair in
+// the range [begin, end).
+//
+// Note: If you're calling this on an entire container, you probably want to
+// call STLDeleteValues(&container) instead, or use ValueDeleter.
+template <typename ForwardIterator>
+void STLDeleteContainerPairSecondPointers(ForwardIterator begin,
+ ForwardIterator end) {
+ while (begin != end) {
+ ForwardIterator temp = begin;
+ ++begin;
+ delete temp->second;
+ }
+}
+
+// Deletes all the elements in an STL container and clears the container. This
+// function is suitable for use with a vector, set, hash_set, or any other STL
+// container which defines sensible begin(), end(), and clear() methods.
+//
+// If container is NULL, this function is a no-op.
+template <typename T>
+void STLDeleteElements(T* container) {
+ if (!container) return;
+ auto it = container->begin();
+ while (it != container->end()) {
+ auto temp = it;
+ ++it;
+ delete *temp;
+ }
+ container->clear();
+}
+
+// Given an STL container consisting of (key, value) pairs, STLDeleteValues
+// deletes all the "value" components and clears the container. Does nothing in
+// the case it's given a NULL pointer.
+template <typename T>
+void STLDeleteValues(T* container) {
+ if (!container) return;
+ auto it = container->begin();
+ while (it != container->end()) {
+ auto temp = it;
+ ++it;
+ delete temp->second;
+ }
+ container->clear();
+}
+
+// Sorts and removes duplicates from a sequence container.
+template <typename T>
+inline void STLSortAndRemoveDuplicates(T* v) {
+ std::sort(v->begin(), v->end());
+ v->erase(std::unique(v->begin(), v->end()), v->end());
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_STL_UTIL_H_
diff --git a/tensorflow/core/lib/gtl/top_n.h b/tensorflow/core/lib/gtl/top_n.h
new file mode 100644
index 0000000000..b95b998c21
--- /dev/null
+++ b/tensorflow/core/lib/gtl/top_n.h
@@ -0,0 +1,324 @@
+// This simple class finds the top n elements of an incrementally provided set
+// of elements which you push one at a time. If the number of elements exceeds
+// n, the lowest elements are incrementally dropped. At the end you get
+// a vector of the top elements sorted in descending order (through Extract() or
+// ExtractNondestructive()), or a vector of the top elements but not sorted
+// (through ExtractUnsorted() or ExtractUnsortedNondestructive()).
+//
+// The value n is specified in the constructor. If there are p elements pushed
+// altogether:
+// The total storage requirements are O(min(n, p)) elements
+// The running time is O(p * log(min(n, p))) comparisons
+// If n is a constant, the total storage required is a constant and the running
+// time is linear in p.
+//
+// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p)
+// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements,
+// discarding the lowest n elements whenever the buffer is full using a linear-
+// time median algorithm. This may have better performance when the input
+// sequence is partially sorted.
+//
+// NOTE(zhifengc): This class should be redesigned to avoid reallocating a
+// vector for each Extract.
+
+#ifndef TENSORFLOW_LIB_GTL_TOP_N_H_
+#define TENSORFLOW_LIB_GTL_TOP_N_H_
+
+#include <stddef.h>
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace gtl {
+
+// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate,
+// not the more commonly used "less" predicate.
+//
+// If you use a "less" predicate here, the TopN will pick out the bottom N
+// elements out of the ones passed to it, and it will return them sorted in
+// ascending order.
+//
+// TopN is rule-of-zero copyable and movable if its members are.
+template <class T, class Cmp = std::greater<T> >
+class TopN {
+ public:
+ // The TopN is in one of the three states:
+ //
+ // o UNORDERED: this is the state an instance is originally in,
+ // where the elements are completely orderless.
+ //
+ // o BOTTOM_KNOWN: in this state, we keep the invariant that there
+ // is at least one element in it, and the lowest element is at
+ // position 0. The elements in other positions remain
+ // unsorted. This state is reached if the state was originally
+ // UNORDERED and a peek_bottom() function call is invoked.
+ //
+ // o HEAP_SORTED: in this state, the array is kept as a heap and
+ // there are exactly (limit_+1) elements in the array. This
+ // state is reached when at least (limit_+1) elements are
+ // pushed in.
+ //
+ // The state transition graph is at follows:
+ //
+ // peek_bottom() (limit_+1) elements
+ // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED
+ // | ^
+ // | (limit_+1) elements |
+ // +-----------------------------------------------------------+
+
+ enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED };
+ using UnsortedIterator = typename std::vector<T>::const_iterator;
+
+ // 'limit' is the maximum number of top results to return.
+ explicit TopN(size_t limit) : TopN(limit, Cmp()) {}
+ TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {}
+
+ size_t limit() const { return limit_; }
+
+ // Number of elements currently held by this TopN object. This
+ // will be no greater than 'limit' passed to the constructor.
+ size_t size() const { return std::min(elements_.size(), limit_); }
+
+ bool empty() const { return size() == 0; }
+
+ // If you know how many elements you will push at the time you create the
+ // TopN object, you can call reserve to preallocate the memory that TopN
+ // will need to process all 'n' pushes. Calling this method is optional.
+ void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); }
+
+ // Push 'v'. If the maximum number of elements was exceeded, drop the
+ // lowest element and return it in 'dropped' (if given). If the maximum is not
+ // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or
+ // nullptr, in which case it is not filled in.
+ // Requires: T is CopyAssignable, Swappable
+ void push(const T &v) { push(v, nullptr); }
+ void push(const T &v, T *dropped) { PushInternal(v, dropped); }
+
+ // Move overloads of push.
+ // Requires: T is MoveAssignable, Swappable
+ void push(T &&v) { // NOLINT(build/c++11)
+ push(std::move(v), nullptr);
+ }
+ void push(T &&v, T *dropped) { // NOLINT(build/c++11)
+ PushInternal(std::move(v), dropped);
+ }
+
+ // Peeks the bottom result without calling Extract()
+ const T &peek_bottom();
+
+ // Extract the elements as a vector sorted in descending order. The caller
+ // assumes ownership of the vector and must delete it when done. This is a
+ // destructive operation. The only method that can be called immediately
+ // after Extract() is Reset().
+ std::vector<T> *Extract();
+
+ // Similar to Extract(), but makes no guarantees the elements are in sorted
+ // order. As with Extract(), the caller assumes ownership of the vector and
+ // must delete it when done. This is a destructive operation. The only
+ // method that can be called immediately after ExtractUnsorted() is Reset().
+ std::vector<T> *ExtractUnsorted();
+
+ // A non-destructive version of Extract(). Copy the elements in a new vector
+ // sorted in descending order and return it. The caller assumes ownership of
+ // the new vector and must delete it when done. After calling
+ // ExtractNondestructive(), the caller can continue to push() new elements.
+ std::vector<T> *ExtractNondestructive() const;
+
+ // A non-destructive version of Extract(). Copy the elements to a given
+ // vector sorted in descending order. After calling
+ // ExtractNondestructive(), the caller can continue to push() new elements.
+ // Note:
+ // 1. The given argument must to be allocated.
+ // 2. Any data contained in the vector prior to the call will be deleted
+ // from it. After the call the vector will contain only the elements
+ // from the data structure.
+ void ExtractNondestructive(std::vector<T> *output) const;
+
+ // A non-destructive version of ExtractUnsorted(). Copy the elements in a new
+ // vector and return it, with no guarantees the elements are in sorted order.
+ // The caller assumes ownership of the new vector and must delete it when
+ // done. After calling ExtractUnsortedNondestructive(), the caller can
+ // continue to push() new elements.
+ std::vector<T> *ExtractUnsortedNondestructive() const;
+
+ // A non-destructive version of ExtractUnsorted(). Copy the elements into
+ // a given vector, with no guarantees the elements are in sorted order.
+ // After calling ExtractUnsortedNondestructive(), the caller can continue
+ // to push() new elements.
+ // Note:
+ // 1. The given argument must to be allocated.
+ // 2. Any data contained in the vector prior to the call will be deleted
+ // from it. After the call the vector will contain only the elements
+ // from the data structure.
+ void ExtractUnsortedNondestructive(std::vector<T> *output) const;
+
+ // Return an iterator to the beginning (end) of the container,
+ // with no guarantees about the order of iteration. These iterators are
+ // invalidated by mutation of the data structure.
+ UnsortedIterator unsorted_begin() const { return elements_.begin(); }
+ UnsortedIterator unsorted_end() const { return elements_.begin() + size(); }
+
+ // Accessor for comparator template argument.
+ Cmp *comparator() { return &cmp_; }
+
+ // This removes all elements. If Extract() or ExtractUnsorted() have been
+ // called, this will put it back in an empty but useable state.
+ void Reset();
+
+ private:
+ template <typename U>
+ void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11)
+
+ // elements_ can be in one of two states:
+ // elements_.size() <= limit_: elements_ is an unsorted vector of elements
+ // pushed so far.
+ // elements_.size() > limit_: The last element of elements_ is unused;
+ // the other elements of elements_ are an stl heap whose size is exactly
+ // limit_. In this case elements_.size() is exactly one greater than
+ // limit_, but don't use "elements_.size() == limit_ + 1" to check for
+ // that because you'll get a false positive if limit_ == size_t(-1).
+ std::vector<T> elements_;
+ size_t limit_; // Maximum number of elements to find
+ Cmp cmp_; // Greater-than comparison function
+ State state_ = UNORDERED;
+};
+
+// ----------------------------------------------------------------------
+// Implementations of non-inline functions
+
+template <class T, class Cmp>
+template <typename U>
+void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11)
+ if (limit_ == 0) {
+ if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11)
+ return;
+ }
+ if (state_ != HEAP_SORTED) {
+ elements_.push_back(std::forward<U>(v)); // NOLINT(build/c++11)
+ if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) {
+ // Easy case: we just pushed the new element back
+ } else {
+ // To maintain the BOTTOM_KNOWN state, we need to make sure that
+ // the element at position 0 is always the smallest. So we put
+ // the new element at position 0 and push the original bottom
+ // element in the back.
+ // Warning: this code is subtle.
+ using std::swap;
+ swap(elements_.front(), elements_.back());
+ }
+ if (elements_.size() == limit_ + 1) {
+ // Transition from unsorted vector to a heap.
+ std::make_heap(elements_.begin(), elements_.end(), cmp_);
+ if (dropped) *dropped = std::move(elements_.front());
+ std::pop_heap(elements_.begin(), elements_.end(), cmp_);
+ state_ = HEAP_SORTED;
+ }
+ } else {
+ // Only insert the new element if it is greater than the least element.
+ if (cmp_(v, elements_.front())) {
+ elements_.back() = std::forward<U>(v); // NOLINT(build/c++11)
+ std::push_heap(elements_.begin(), elements_.end(), cmp_);
+ if (dropped) *dropped = std::move(elements_.front());
+ std::pop_heap(elements_.begin(), elements_.end(), cmp_);
+ } else {
+ if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11)
+ }
+ }
+}
+
+template <class T, class Cmp>
+const T &TopN<T, Cmp>::peek_bottom() {
+ CHECK(!empty());
+ if (state_ == UNORDERED) {
+ // We need to do a linear scan to find out the bottom element
+ int min_candidate = 0;
+ for (size_t i = 1; i < elements_.size(); ++i) {
+ if (cmp_(elements_[min_candidate], elements_[i])) {
+ min_candidate = i;
+ }
+ }
+ // By swapping the element at position 0 and the minimal
+ // element, we transition to the BOTTOM_KNOWN state
+ if (min_candidate != 0) {
+ using std::swap;
+ swap(elements_[0], elements_[min_candidate]);
+ }
+ state_ = BOTTOM_KNOWN;
+ }
+ return elements_.front();
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::Extract() {
+ auto out = new std::vector<T>;
+ out->swap(elements_);
+ if (state_ != HEAP_SORTED) {
+ std::sort(out->begin(), out->end(), cmp_);
+ } else {
+ out->pop_back();
+ std::sort_heap(out->begin(), out->end(), cmp_);
+ }
+ return out;
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::ExtractUnsorted() {
+ auto out = new std::vector<T>;
+ out->swap(elements_);
+ if (state_ == HEAP_SORTED) {
+ // Remove the limit_+1'th element.
+ out->pop_back();
+ }
+ return out;
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::ExtractNondestructive() const {
+ auto out = new std::vector<T>;
+ ExtractNondestructive(out);
+ return out;
+}
+
+template <class T, class Cmp>
+void TopN<T, Cmp>::ExtractNondestructive(std::vector<T> *output) const {
+ CHECK(output);
+ *output = elements_;
+ if (state_ != HEAP_SORTED) {
+ std::sort(output->begin(), output->end(), cmp_);
+ } else {
+ output->pop_back();
+ std::sort_heap(output->begin(), output->end(), cmp_);
+ }
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::ExtractUnsortedNondestructive() const {
+ auto elements = new std::vector<T>;
+ ExtractUnsortedNondestructive(elements);
+ return elements;
+}
+
+template <class T, class Cmp>
+void TopN<T, Cmp>::ExtractUnsortedNondestructive(std::vector<T> *output) const {
+ CHECK(output);
+ *output = elements_;
+ if (state_ == HEAP_SORTED) {
+ // Remove the limit_+1'th element.
+ output->pop_back();
+ }
+}
+
+template <class T, class Cmp>
+void TopN<T, Cmp>::Reset() {
+ elements_.clear();
+ state_ = UNORDERED;
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_TOP_N_H_
diff --git a/tensorflow/core/lib/gtl/top_n_test.cc b/tensorflow/core/lib/gtl/top_n_test.cc
new file mode 100644
index 0000000000..1812a1bd3f
--- /dev/null
+++ b/tensorflow/core/lib/gtl/top_n_test.cc
@@ -0,0 +1,249 @@
+// Unit test for TopN.
+
+#include "tensorflow/core/lib/gtl/top_n.h"
+
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace {
+
+using tensorflow::gtl::TopN;
+using tensorflow::random::PhiloxRandom;
+using tensorflow::random::SimplePhilox;
+using tensorflow::string;
+
+// Move the contents from an owned raw pointer, returning by value.
+// Objects are easier to manage by value.
+template <class T>
+T ConsumeRawPtr(T *p) {
+ T tmp = std::move(*p);
+ delete p;
+ return tmp;
+}
+
+template <class Cmp>
+void TestIntTopNHelper(size_t limit, size_t n_elements, const Cmp &cmp,
+ SimplePhilox *random, bool test_peek,
+ bool test_extract_unsorted) {
+ LOG(INFO) << "Testing limit=" << limit << ", n_elements=" << n_elements
+ << ", test_peek=" << test_peek
+ << ", test_extract_unsorted=" << test_extract_unsorted;
+ TopN<int, Cmp> top(limit, cmp);
+ std::vector<int> shadow(n_elements);
+ for (int i = 0; i != n_elements; ++i) shadow[i] = random->Uniform(limit);
+ for (int e : shadow) top.push(e);
+ std::sort(shadow.begin(), shadow.end(), cmp);
+ size_t top_size = std::min(limit, n_elements);
+ EXPECT_EQ(top_size, top.size());
+ if (test_peek && top_size != 0) {
+ EXPECT_EQ(shadow[top_size - 1], top.peek_bottom());
+ }
+ std::vector<int> v;
+ if (test_extract_unsorted) {
+ v = ConsumeRawPtr(top.ExtractUnsorted());
+ std::sort(v.begin(), v.end(), cmp);
+ } else {
+ v = ConsumeRawPtr(top.Extract());
+ }
+ EXPECT_EQ(top_size, v.size());
+ for (int i = 0; i != top_size; ++i) {
+ VLOG(1) << "Top element " << v[i];
+ EXPECT_EQ(shadow[i], v[i]);
+ }
+}
+
+template <class Cmp>
+void TestIntTopN(size_t limit, size_t n_elements, const Cmp &cmp,
+ SimplePhilox *random) {
+ // Test peek_bottom() and Extract()
+ TestIntTopNHelper(limit, n_elements, cmp, random, true, false);
+ // Test Extract()
+ TestIntTopNHelper(limit, n_elements, cmp, random, false, false);
+ // Test peek_bottom() and ExtractUnsorted()
+ TestIntTopNHelper(limit, n_elements, cmp, random, true, true);
+ // Test ExtractUnsorted()
+ TestIntTopNHelper(limit, n_elements, cmp, random, false, true);
+}
+
+TEST(TopNTest, Misc) {
+ PhiloxRandom philox(1, 1);
+ SimplePhilox random(&philox);
+
+ TestIntTopN(0, 5, std::greater<int>(), &random);
+ TestIntTopN(32, 0, std::greater<int>(), &random);
+ TestIntTopN(6, 6, std::greater<int>(), &random);
+ TestIntTopN(6, 6, std::less<int>(), &random);
+ TestIntTopN(1000, 999, std::greater<int>(), &random);
+ TestIntTopN(1000, 1000, std::greater<int>(), &random);
+ TestIntTopN(1000, 1001, std::greater<int>(), &random);
+ TestIntTopN(2300, 28393, std::less<int>(), &random);
+ TestIntTopN(30, 100, std::greater<int>(), &random);
+ TestIntTopN(100, 30, std::less<int>(), &random);
+ TestIntTopN(size_t(-1), 3, std::greater<int>(), &random);
+ TestIntTopN(size_t(-1), 0, std::greater<int>(), &random);
+ TestIntTopN(0, 5, std::greater<int>(), &random);
+}
+
+TEST(TopNTest, String) {
+ LOG(INFO) << "Testing strings";
+
+ TopN<string> top(3);
+ EXPECT_TRUE(top.empty());
+ top.push("abracadabra");
+ top.push("waldemar");
+ EXPECT_EQ(2, top.size());
+ EXPECT_EQ("abracadabra", top.peek_bottom());
+ top.push("");
+ EXPECT_EQ(3, top.size());
+ EXPECT_EQ("", top.peek_bottom());
+ top.push("top");
+ EXPECT_EQ(3, top.size());
+ EXPECT_EQ("abracadabra", top.peek_bottom());
+ top.push("Google");
+ top.push("test");
+ EXPECT_EQ(3, top.size());
+ EXPECT_EQ("test", top.peek_bottom());
+ TopN<string> top2(top);
+ TopN<string> top3(5);
+ top3 = top;
+ EXPECT_EQ("test", top3.peek_bottom());
+ {
+ std::vector<string> s = ConsumeRawPtr(top.Extract());
+ EXPECT_EQ(s[0], "waldemar");
+ EXPECT_EQ(s[1], "top");
+ EXPECT_EQ(s[2], "test");
+ }
+
+ top2.push("zero");
+ EXPECT_EQ(top2.peek_bottom(), "top");
+
+ {
+ std::vector<string> s = ConsumeRawPtr(top2.Extract());
+ EXPECT_EQ(s[0], "zero");
+ EXPECT_EQ(s[1], "waldemar");
+ EXPECT_EQ(s[2], "top");
+ }
+ {
+ std::vector<string> s = ConsumeRawPtr(top3.Extract());
+ EXPECT_EQ(s[0], "waldemar");
+ EXPECT_EQ(s[1], "top");
+ EXPECT_EQ(s[2], "test");
+ }
+
+ TopN<string> top4(3);
+ // Run this test twice to check Reset():
+ for (int i = 0; i < 2; ++i) {
+ top4.push("abcd");
+ top4.push("ijkl");
+ top4.push("efgh");
+ top4.push("mnop");
+ std::vector<string> s = ConsumeRawPtr(top4.Extract());
+ EXPECT_EQ(s[0], "mnop");
+ EXPECT_EQ(s[1], "ijkl");
+ EXPECT_EQ(s[2], "efgh");
+ top4.Reset();
+ }
+}
+
+// Test that pointers aren't leaked from a TopN if we use the 2-argument version
+// of push().
+TEST(TopNTest, Ptr) {
+ LOG(INFO) << "Testing 2-argument push()";
+ TopN<string *> topn(3);
+ for (int i = 0; i < 8; ++i) {
+ string *dropped = NULL;
+ topn.push(new string(std::to_string(i)), &dropped);
+ delete dropped;
+ }
+
+ for (int i = 8; i > 0; --i) {
+ string *dropped = NULL;
+ topn.push(new string(std::to_string(i)), &dropped);
+ delete dropped;
+ }
+
+ std::vector<string *> extract = ConsumeRawPtr(topn.Extract());
+ tensorflow::gtl::STLDeleteElements(&extract);
+}
+
+struct PointeeGreater {
+ template <typename T>
+ bool operator()(const T &a, const T &b) const {
+ return *a > *b;
+ }
+};
+
+TEST(TopNTest, MoveOnly) {
+ using StrPtr = std::unique_ptr<string>;
+ TopN<StrPtr, PointeeGreater> topn(3);
+ for (int i = 0; i < 8; ++i) topn.push(StrPtr(new string(std::to_string(i))));
+ for (int i = 8; i > 0; --i) topn.push(StrPtr(new string(std::to_string(i))));
+
+ std::vector<StrPtr> extract = ConsumeRawPtr(topn.Extract());
+ EXPECT_EQ(extract.size(), 3);
+ EXPECT_EQ(*(extract[0]), "8");
+ EXPECT_EQ(*(extract[1]), "7");
+ EXPECT_EQ(*(extract[2]), "7");
+}
+
+// Test that Nondestructive extracts do not need a Reset() afterwards,
+// and that pointers aren't leaked from a TopN after calling them.
+TEST(TopNTest, Nondestructive) {
+ LOG(INFO) << "Testing Nondestructive extracts";
+ TopN<int> top4(4);
+ for (int i = 0; i < 8; ++i) {
+ top4.push(i);
+ std::vector<int> v = ConsumeRawPtr(top4.ExtractNondestructive());
+ EXPECT_EQ(std::min(i + 1, 4), v.size());
+ for (size_t j = 0; j < v.size(); ++j) EXPECT_EQ(i - j, v[j]);
+ }
+
+ TopN<int> top3(3);
+ for (int i = 0; i < 8; ++i) {
+ top3.push(i);
+ std::vector<int> v = ConsumeRawPtr(top3.ExtractUnsortedNondestructive());
+ std::sort(v.begin(), v.end(), std::greater<int>());
+ EXPECT_EQ(std::min(i + 1, 3), v.size());
+ for (size_t j = 0; j < v.size(); ++j) EXPECT_EQ(i - j, v[j]);
+ }
+}
+
+struct ForbiddenCmp {
+ bool operator()(int lhs, int rhs) const {
+ LOG(FATAL) << "ForbiddenCmp called " << lhs << " " << rhs;
+ }
+};
+
+TEST(TopNTest, ZeroLimit) {
+ TopN<int, ForbiddenCmp> top(0);
+ top.push(1);
+ top.push(2);
+
+ int dropped = -1;
+ top.push(1, &dropped);
+ top.push(2, &dropped);
+
+ std::vector<int> v;
+ top.ExtractNondestructive(&v);
+ EXPECT_EQ(0, v.size());
+}
+
+TEST(TopNTest, Iteration) {
+ TopN<int> top(4);
+ for (int i = 0; i < 8; ++i) top.push(i);
+ std::vector<int> actual(top.unsorted_begin(), top.unsorted_end());
+ // Check that we have 4,5,6,7 as the top 4 (in some order, so we sort)
+ sort(actual.begin(), actual.end());
+ EXPECT_EQ(actual.size(), 4);
+ EXPECT_EQ(actual[0], 4);
+ EXPECT_EQ(actual[1], 5);
+ EXPECT_EQ(actual[2], 6);
+ EXPECT_EQ(actual[3], 7);
+}
+} // namespace
diff --git a/tensorflow/core/lib/hash/crc32c.cc b/tensorflow/core/lib/hash/crc32c.cc
new file mode 100644
index 0000000000..3bef1cf78d
--- /dev/null
+++ b/tensorflow/core/lib/hash/crc32c.cc
@@ -0,0 +1,244 @@
+// A portable implementation of crc32c, optimized to handle
+// four bytes at a time.
+
+#include "tensorflow/core/lib/hash/crc32c.h"
+
+#include <stdint.h>
+#include "tensorflow/core/lib/core/coding.h"
+
+namespace tensorflow {
+namespace crc32c {
+
+static const uint32 table0_[256] = {
+ 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c,
+ 0x26a1e7e8, 0xd4ca64eb, 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b,
+ 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 0x105ec76f, 0xe235446c,
+ 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384,
+ 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc,
+ 0xbc267848, 0x4e4dfb4b, 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a,
+ 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 0xaa64d611, 0x580f5512,
+ 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa,
+ 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad,
+ 0x1642ae59, 0xe4292d5a, 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a,
+ 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 0x417b1dbc, 0xb3109ebf,
+ 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957,
+ 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f,
+ 0xed03a29b, 0x1f682198, 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927,
+ 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 0xdbfc821c, 0x2997011f,
+ 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7,
+ 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e,
+ 0x4767748a, 0xb50cf789, 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859,
+ 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 0x7198540d, 0x83f3d70e,
+ 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6,
+ 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de,
+ 0xdde0eb2a, 0x2f8b6829, 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c,
+ 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 0x082f63b7, 0xfa44e0b4,
+ 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c,
+ 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b,
+ 0xb4091bff, 0x466298fc, 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c,
+ 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 0xa24bb5a6, 0x502036a5,
+ 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d,
+ 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975,
+ 0x0e330a81, 0xfc588982, 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d,
+ 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 0x38cc2a06, 0xcaa7a905,
+ 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed,
+ 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8,
+ 0xe52cc12c, 0x1747422f, 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff,
+ 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 0xd3d3e1ab, 0x21b862a8,
+ 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540,
+ 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78,
+ 0x7fab5e8c, 0x8dc0dd8f, 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee,
+ 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 0x69e9f0d5, 0x9b8273d6,
+ 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e,
+ 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69,
+ 0xd5cf889d, 0x27a40b9e, 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e,
+ 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351};
+static const uint32 table1_[256] = {
+ 0x00000000, 0x13a29877, 0x274530ee, 0x34e7a899, 0x4e8a61dc, 0x5d28f9ab,
+ 0x69cf5132, 0x7a6dc945, 0x9d14c3b8, 0x8eb65bcf, 0xba51f356, 0xa9f36b21,
+ 0xd39ea264, 0xc03c3a13, 0xf4db928a, 0xe7790afd, 0x3fc5f181, 0x2c6769f6,
+ 0x1880c16f, 0x0b225918, 0x714f905d, 0x62ed082a, 0x560aa0b3, 0x45a838c4,
+ 0xa2d13239, 0xb173aa4e, 0x859402d7, 0x96369aa0, 0xec5b53e5, 0xfff9cb92,
+ 0xcb1e630b, 0xd8bcfb7c, 0x7f8be302, 0x6c297b75, 0x58ced3ec, 0x4b6c4b9b,
+ 0x310182de, 0x22a31aa9, 0x1644b230, 0x05e62a47, 0xe29f20ba, 0xf13db8cd,
+ 0xc5da1054, 0xd6788823, 0xac154166, 0xbfb7d911, 0x8b507188, 0x98f2e9ff,
+ 0x404e1283, 0x53ec8af4, 0x670b226d, 0x74a9ba1a, 0x0ec4735f, 0x1d66eb28,
+ 0x298143b1, 0x3a23dbc6, 0xdd5ad13b, 0xcef8494c, 0xfa1fe1d5, 0xe9bd79a2,
+ 0x93d0b0e7, 0x80722890, 0xb4958009, 0xa737187e, 0xff17c604, 0xecb55e73,
+ 0xd852f6ea, 0xcbf06e9d, 0xb19da7d8, 0xa23f3faf, 0x96d89736, 0x857a0f41,
+ 0x620305bc, 0x71a19dcb, 0x45463552, 0x56e4ad25, 0x2c896460, 0x3f2bfc17,
+ 0x0bcc548e, 0x186eccf9, 0xc0d23785, 0xd370aff2, 0xe797076b, 0xf4359f1c,
+ 0x8e585659, 0x9dface2e, 0xa91d66b7, 0xbabffec0, 0x5dc6f43d, 0x4e646c4a,
+ 0x7a83c4d3, 0x69215ca4, 0x134c95e1, 0x00ee0d96, 0x3409a50f, 0x27ab3d78,
+ 0x809c2506, 0x933ebd71, 0xa7d915e8, 0xb47b8d9f, 0xce1644da, 0xddb4dcad,
+ 0xe9537434, 0xfaf1ec43, 0x1d88e6be, 0x0e2a7ec9, 0x3acdd650, 0x296f4e27,
+ 0x53028762, 0x40a01f15, 0x7447b78c, 0x67e52ffb, 0xbf59d487, 0xacfb4cf0,
+ 0x981ce469, 0x8bbe7c1e, 0xf1d3b55b, 0xe2712d2c, 0xd69685b5, 0xc5341dc2,
+ 0x224d173f, 0x31ef8f48, 0x050827d1, 0x16aabfa6, 0x6cc776e3, 0x7f65ee94,
+ 0x4b82460d, 0x5820de7a, 0xfbc3faf9, 0xe861628e, 0xdc86ca17, 0xcf245260,
+ 0xb5499b25, 0xa6eb0352, 0x920cabcb, 0x81ae33bc, 0x66d73941, 0x7575a136,
+ 0x419209af, 0x523091d8, 0x285d589d, 0x3bffc0ea, 0x0f186873, 0x1cbaf004,
+ 0xc4060b78, 0xd7a4930f, 0xe3433b96, 0xf0e1a3e1, 0x8a8c6aa4, 0x992ef2d3,
+ 0xadc95a4a, 0xbe6bc23d, 0x5912c8c0, 0x4ab050b7, 0x7e57f82e, 0x6df56059,
+ 0x1798a91c, 0x043a316b, 0x30dd99f2, 0x237f0185, 0x844819fb, 0x97ea818c,
+ 0xa30d2915, 0xb0afb162, 0xcac27827, 0xd960e050, 0xed8748c9, 0xfe25d0be,
+ 0x195cda43, 0x0afe4234, 0x3e19eaad, 0x2dbb72da, 0x57d6bb9f, 0x447423e8,
+ 0x70938b71, 0x63311306, 0xbb8de87a, 0xa82f700d, 0x9cc8d894, 0x8f6a40e3,
+ 0xf50789a6, 0xe6a511d1, 0xd242b948, 0xc1e0213f, 0x26992bc2, 0x353bb3b5,
+ 0x01dc1b2c, 0x127e835b, 0x68134a1e, 0x7bb1d269, 0x4f567af0, 0x5cf4e287,
+ 0x04d43cfd, 0x1776a48a, 0x23910c13, 0x30339464, 0x4a5e5d21, 0x59fcc556,
+ 0x6d1b6dcf, 0x7eb9f5b8, 0x99c0ff45, 0x8a626732, 0xbe85cfab, 0xad2757dc,
+ 0xd74a9e99, 0xc4e806ee, 0xf00fae77, 0xe3ad3600, 0x3b11cd7c, 0x28b3550b,
+ 0x1c54fd92, 0x0ff665e5, 0x759baca0, 0x663934d7, 0x52de9c4e, 0x417c0439,
+ 0xa6050ec4, 0xb5a796b3, 0x81403e2a, 0x92e2a65d, 0xe88f6f18, 0xfb2df76f,
+ 0xcfca5ff6, 0xdc68c781, 0x7b5fdfff, 0x68fd4788, 0x5c1aef11, 0x4fb87766,
+ 0x35d5be23, 0x26772654, 0x12908ecd, 0x013216ba, 0xe64b1c47, 0xf5e98430,
+ 0xc10e2ca9, 0xd2acb4de, 0xa8c17d9b, 0xbb63e5ec, 0x8f844d75, 0x9c26d502,
+ 0x449a2e7e, 0x5738b609, 0x63df1e90, 0x707d86e7, 0x0a104fa2, 0x19b2d7d5,
+ 0x2d557f4c, 0x3ef7e73b, 0xd98eedc6, 0xca2c75b1, 0xfecbdd28, 0xed69455f,
+ 0x97048c1a, 0x84a6146d, 0xb041bcf4, 0xa3e32483};
+static const uint32 table2_[256] = {
+ 0x00000000, 0xa541927e, 0x4f6f520d, 0xea2ec073, 0x9edea41a, 0x3b9f3664,
+ 0xd1b1f617, 0x74f06469, 0x38513ec5, 0x9d10acbb, 0x773e6cc8, 0xd27ffeb6,
+ 0xa68f9adf, 0x03ce08a1, 0xe9e0c8d2, 0x4ca15aac, 0x70a27d8a, 0xd5e3eff4,
+ 0x3fcd2f87, 0x9a8cbdf9, 0xee7cd990, 0x4b3d4bee, 0xa1138b9d, 0x045219e3,
+ 0x48f3434f, 0xedb2d131, 0x079c1142, 0xa2dd833c, 0xd62de755, 0x736c752b,
+ 0x9942b558, 0x3c032726, 0xe144fb14, 0x4405696a, 0xae2ba919, 0x0b6a3b67,
+ 0x7f9a5f0e, 0xdadbcd70, 0x30f50d03, 0x95b49f7d, 0xd915c5d1, 0x7c5457af,
+ 0x967a97dc, 0x333b05a2, 0x47cb61cb, 0xe28af3b5, 0x08a433c6, 0xade5a1b8,
+ 0x91e6869e, 0x34a714e0, 0xde89d493, 0x7bc846ed, 0x0f382284, 0xaa79b0fa,
+ 0x40577089, 0xe516e2f7, 0xa9b7b85b, 0x0cf62a25, 0xe6d8ea56, 0x43997828,
+ 0x37691c41, 0x92288e3f, 0x78064e4c, 0xdd47dc32, 0xc76580d9, 0x622412a7,
+ 0x880ad2d4, 0x2d4b40aa, 0x59bb24c3, 0xfcfab6bd, 0x16d476ce, 0xb395e4b0,
+ 0xff34be1c, 0x5a752c62, 0xb05bec11, 0x151a7e6f, 0x61ea1a06, 0xc4ab8878,
+ 0x2e85480b, 0x8bc4da75, 0xb7c7fd53, 0x12866f2d, 0xf8a8af5e, 0x5de93d20,
+ 0x29195949, 0x8c58cb37, 0x66760b44, 0xc337993a, 0x8f96c396, 0x2ad751e8,
+ 0xc0f9919b, 0x65b803e5, 0x1148678c, 0xb409f5f2, 0x5e273581, 0xfb66a7ff,
+ 0x26217bcd, 0x8360e9b3, 0x694e29c0, 0xcc0fbbbe, 0xb8ffdfd7, 0x1dbe4da9,
+ 0xf7908dda, 0x52d11fa4, 0x1e704508, 0xbb31d776, 0x511f1705, 0xf45e857b,
+ 0x80aee112, 0x25ef736c, 0xcfc1b31f, 0x6a802161, 0x56830647, 0xf3c29439,
+ 0x19ec544a, 0xbcadc634, 0xc85da25d, 0x6d1c3023, 0x8732f050, 0x2273622e,
+ 0x6ed23882, 0xcb93aafc, 0x21bd6a8f, 0x84fcf8f1, 0xf00c9c98, 0x554d0ee6,
+ 0xbf63ce95, 0x1a225ceb, 0x8b277743, 0x2e66e53d, 0xc448254e, 0x6109b730,
+ 0x15f9d359, 0xb0b84127, 0x5a968154, 0xffd7132a, 0xb3764986, 0x1637dbf8,
+ 0xfc191b8b, 0x595889f5, 0x2da8ed9c, 0x88e97fe2, 0x62c7bf91, 0xc7862def,
+ 0xfb850ac9, 0x5ec498b7, 0xb4ea58c4, 0x11abcaba, 0x655baed3, 0xc01a3cad,
+ 0x2a34fcde, 0x8f756ea0, 0xc3d4340c, 0x6695a672, 0x8cbb6601, 0x29faf47f,
+ 0x5d0a9016, 0xf84b0268, 0x1265c21b, 0xb7245065, 0x6a638c57, 0xcf221e29,
+ 0x250cde5a, 0x804d4c24, 0xf4bd284d, 0x51fcba33, 0xbbd27a40, 0x1e93e83e,
+ 0x5232b292, 0xf77320ec, 0x1d5de09f, 0xb81c72e1, 0xccec1688, 0x69ad84f6,
+ 0x83834485, 0x26c2d6fb, 0x1ac1f1dd, 0xbf8063a3, 0x55aea3d0, 0xf0ef31ae,
+ 0x841f55c7, 0x215ec7b9, 0xcb7007ca, 0x6e3195b4, 0x2290cf18, 0x87d15d66,
+ 0x6dff9d15, 0xc8be0f6b, 0xbc4e6b02, 0x190ff97c, 0xf321390f, 0x5660ab71,
+ 0x4c42f79a, 0xe90365e4, 0x032da597, 0xa66c37e9, 0xd29c5380, 0x77ddc1fe,
+ 0x9df3018d, 0x38b293f3, 0x7413c95f, 0xd1525b21, 0x3b7c9b52, 0x9e3d092c,
+ 0xeacd6d45, 0x4f8cff3b, 0xa5a23f48, 0x00e3ad36, 0x3ce08a10, 0x99a1186e,
+ 0x738fd81d, 0xd6ce4a63, 0xa23e2e0a, 0x077fbc74, 0xed517c07, 0x4810ee79,
+ 0x04b1b4d5, 0xa1f026ab, 0x4bdee6d8, 0xee9f74a6, 0x9a6f10cf, 0x3f2e82b1,
+ 0xd50042c2, 0x7041d0bc, 0xad060c8e, 0x08479ef0, 0xe2695e83, 0x4728ccfd,
+ 0x33d8a894, 0x96993aea, 0x7cb7fa99, 0xd9f668e7, 0x9557324b, 0x3016a035,
+ 0xda386046, 0x7f79f238, 0x0b899651, 0xaec8042f, 0x44e6c45c, 0xe1a75622,
+ 0xdda47104, 0x78e5e37a, 0x92cb2309, 0x378ab177, 0x437ad51e, 0xe63b4760,
+ 0x0c158713, 0xa954156d, 0xe5f54fc1, 0x40b4ddbf, 0xaa9a1dcc, 0x0fdb8fb2,
+ 0x7b2bebdb, 0xde6a79a5, 0x3444b9d6, 0x91052ba8};
+static const uint32 table3_[256] = {
+ 0x00000000, 0xdd45aab8, 0xbf672381, 0x62228939, 0x7b2231f3, 0xa6679b4b,
+ 0xc4451272, 0x1900b8ca, 0xf64463e6, 0x2b01c95e, 0x49234067, 0x9466eadf,
+ 0x8d665215, 0x5023f8ad, 0x32017194, 0xef44db2c, 0xe964b13d, 0x34211b85,
+ 0x560392bc, 0x8b463804, 0x924680ce, 0x4f032a76, 0x2d21a34f, 0xf06409f7,
+ 0x1f20d2db, 0xc2657863, 0xa047f15a, 0x7d025be2, 0x6402e328, 0xb9474990,
+ 0xdb65c0a9, 0x06206a11, 0xd725148b, 0x0a60be33, 0x6842370a, 0xb5079db2,
+ 0xac072578, 0x71428fc0, 0x136006f9, 0xce25ac41, 0x2161776d, 0xfc24ddd5,
+ 0x9e0654ec, 0x4343fe54, 0x5a43469e, 0x8706ec26, 0xe524651f, 0x3861cfa7,
+ 0x3e41a5b6, 0xe3040f0e, 0x81268637, 0x5c632c8f, 0x45639445, 0x98263efd,
+ 0xfa04b7c4, 0x27411d7c, 0xc805c650, 0x15406ce8, 0x7762e5d1, 0xaa274f69,
+ 0xb327f7a3, 0x6e625d1b, 0x0c40d422, 0xd1057e9a, 0xaba65fe7, 0x76e3f55f,
+ 0x14c17c66, 0xc984d6de, 0xd0846e14, 0x0dc1c4ac, 0x6fe34d95, 0xb2a6e72d,
+ 0x5de23c01, 0x80a796b9, 0xe2851f80, 0x3fc0b538, 0x26c00df2, 0xfb85a74a,
+ 0x99a72e73, 0x44e284cb, 0x42c2eeda, 0x9f874462, 0xfda5cd5b, 0x20e067e3,
+ 0x39e0df29, 0xe4a57591, 0x8687fca8, 0x5bc25610, 0xb4868d3c, 0x69c32784,
+ 0x0be1aebd, 0xd6a40405, 0xcfa4bccf, 0x12e11677, 0x70c39f4e, 0xad8635f6,
+ 0x7c834b6c, 0xa1c6e1d4, 0xc3e468ed, 0x1ea1c255, 0x07a17a9f, 0xdae4d027,
+ 0xb8c6591e, 0x6583f3a6, 0x8ac7288a, 0x57828232, 0x35a00b0b, 0xe8e5a1b3,
+ 0xf1e51979, 0x2ca0b3c1, 0x4e823af8, 0x93c79040, 0x95e7fa51, 0x48a250e9,
+ 0x2a80d9d0, 0xf7c57368, 0xeec5cba2, 0x3380611a, 0x51a2e823, 0x8ce7429b,
+ 0x63a399b7, 0xbee6330f, 0xdcc4ba36, 0x0181108e, 0x1881a844, 0xc5c402fc,
+ 0xa7e68bc5, 0x7aa3217d, 0x52a0c93f, 0x8fe56387, 0xedc7eabe, 0x30824006,
+ 0x2982f8cc, 0xf4c75274, 0x96e5db4d, 0x4ba071f5, 0xa4e4aad9, 0x79a10061,
+ 0x1b838958, 0xc6c623e0, 0xdfc69b2a, 0x02833192, 0x60a1b8ab, 0xbde41213,
+ 0xbbc47802, 0x6681d2ba, 0x04a35b83, 0xd9e6f13b, 0xc0e649f1, 0x1da3e349,
+ 0x7f816a70, 0xa2c4c0c8, 0x4d801be4, 0x90c5b15c, 0xf2e73865, 0x2fa292dd,
+ 0x36a22a17, 0xebe780af, 0x89c50996, 0x5480a32e, 0x8585ddb4, 0x58c0770c,
+ 0x3ae2fe35, 0xe7a7548d, 0xfea7ec47, 0x23e246ff, 0x41c0cfc6, 0x9c85657e,
+ 0x73c1be52, 0xae8414ea, 0xcca69dd3, 0x11e3376b, 0x08e38fa1, 0xd5a62519,
+ 0xb784ac20, 0x6ac10698, 0x6ce16c89, 0xb1a4c631, 0xd3864f08, 0x0ec3e5b0,
+ 0x17c35d7a, 0xca86f7c2, 0xa8a47efb, 0x75e1d443, 0x9aa50f6f, 0x47e0a5d7,
+ 0x25c22cee, 0xf8878656, 0xe1873e9c, 0x3cc29424, 0x5ee01d1d, 0x83a5b7a5,
+ 0xf90696d8, 0x24433c60, 0x4661b559, 0x9b241fe1, 0x8224a72b, 0x5f610d93,
+ 0x3d4384aa, 0xe0062e12, 0x0f42f53e, 0xd2075f86, 0xb025d6bf, 0x6d607c07,
+ 0x7460c4cd, 0xa9256e75, 0xcb07e74c, 0x16424df4, 0x106227e5, 0xcd278d5d,
+ 0xaf050464, 0x7240aedc, 0x6b401616, 0xb605bcae, 0xd4273597, 0x09629f2f,
+ 0xe6264403, 0x3b63eebb, 0x59416782, 0x8404cd3a, 0x9d0475f0, 0x4041df48,
+ 0x22635671, 0xff26fcc9, 0x2e238253, 0xf36628eb, 0x9144a1d2, 0x4c010b6a,
+ 0x5501b3a0, 0x88441918, 0xea669021, 0x37233a99, 0xd867e1b5, 0x05224b0d,
+ 0x6700c234, 0xba45688c, 0xa345d046, 0x7e007afe, 0x1c22f3c7, 0xc167597f,
+ 0xc747336e, 0x1a0299d6, 0x782010ef, 0xa565ba57, 0xbc65029d, 0x6120a825,
+ 0x0302211c, 0xde478ba4, 0x31035088, 0xec46fa30, 0x8e647309, 0x5321d9b1,
+ 0x4a21617b, 0x9764cbc3, 0xf54642fa, 0x2803e842};
+
+// Used to fetch a naturally-aligned 32-bit word in little endian byte-order
+static inline uint32_t LE_LOAD32(const uint8_t *p) {
+ return core::DecodeFixed32(reinterpret_cast<const char *>(p));
+}
+
+uint32 Extend(uint32 crc, const char *buf, size_t size) {
+ const uint8 *p = reinterpret_cast<const uint8 *>(buf);
+ const uint8 *e = p + size;
+ uint32 l = crc ^ 0xffffffffu;
+
+#define STEP1 \
+ do { \
+ int c = (l & 0xff) ^ *p++; \
+ l = table0_[c] ^ (l >> 8); \
+ } while (0)
+
+#define STEP4 \
+ do { \
+ uint32 c = l ^ LE_LOAD32(p); \
+ p += 4; \
+ l = table3_[c & 0xff] ^ table2_[(c >> 8) & 0xff] ^ \
+ table1_[(c >> 16) & 0xff] ^ table0_[c >> 24]; \
+ } while (0)
+
+ // Point x at first 4-byte aligned byte in string. This might be
+ // just past the end of the string.
+ const uintptr_t pval = reinterpret_cast<uintptr_t>(p);
+ const uint8 *x = reinterpret_cast<const uint8 *>(((pval + 3) >> 2) << 2);
+ if (x <= e) {
+ // Process bytes until finished or p is 4-byte aligned
+ while (p != x) {
+ STEP1;
+ }
+ }
+ // Process bytes 16 at a time
+ while ((e - p) >= 16) {
+ STEP4;
+ STEP4;
+ STEP4;
+ STEP4;
+ }
+ // Process bytes 4 at a time
+ while ((e - p) >= 4) {
+ STEP4;
+ }
+ // Process the last few bytes
+ while (p != e) {
+ STEP1;
+ }
+#undef STEP4
+#undef STEP1
+ return l ^ 0xffffffffu;
+}
+
+} // namespace crc32c
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/hash/crc32c.h b/tensorflow/core/lib/hash/crc32c.h
new file mode 100644
index 0000000000..f728b6f5e7
--- /dev/null
+++ b/tensorflow/core/lib/hash/crc32c.h
@@ -0,0 +1,39 @@
+#ifndef TENSORFLOW_LIB_HASH_CRC32C_H_
+#define TENSORFLOW_LIB_HASH_CRC32C_H_
+
+#include <stddef.h>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace crc32c {
+
+// Return the crc32c of concat(A, data[0,n-1]) where init_crc is the
+// crc32c of some string A. Extend() is often used to maintain the
+// crc32c of a stream of data.
+extern uint32 Extend(uint32 init_crc, const char* data, size_t n);
+
+// Return the crc32c of data[0,n-1]
+inline uint32 Value(const char* data, size_t n) { return Extend(0, data, n); }
+
+static const uint32 kMaskDelta = 0xa282ead8ul;
+
+// Return a masked representation of crc.
+//
+// Motivation: it is problematic to compute the CRC of a string that
+// contains embedded CRCs. Therefore we recommend that CRCs stored
+// somewhere (e.g., in files) should be masked before being stored.
+inline uint32 Mask(uint32 crc) {
+ // Rotate right by 15 bits and add a constant.
+ return ((crc >> 15) | (crc << 17)) + kMaskDelta;
+}
+
+// Return the crc whose masked representation is masked_crc.
+inline uint32 Unmask(uint32 masked_crc) {
+ uint32 rot = masked_crc - kMaskDelta;
+ return ((rot >> 17) | (rot << 15));
+}
+
+} // namespace crc32c
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_HASH_CRC32C_H_
diff --git a/tensorflow/core/lib/hash/crc32c_test.cc b/tensorflow/core/lib/hash/crc32c_test.cc
new file mode 100644
index 0000000000..54aced3186
--- /dev/null
+++ b/tensorflow/core/lib/hash/crc32c_test.cc
@@ -0,0 +1,51 @@
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace crc32c {
+
+TEST(CRC, StandardResults) {
+ // From rfc3720 section B.4.
+ char buf[32];
+
+ memset(buf, 0, sizeof(buf));
+ ASSERT_EQ(0x8a9136aa, Value(buf, sizeof(buf)));
+
+ memset(buf, 0xff, sizeof(buf));
+ ASSERT_EQ(0x62a8ab43, Value(buf, sizeof(buf)));
+
+ for (int i = 0; i < 32; i++) {
+ buf[i] = i;
+ }
+ ASSERT_EQ(0x46dd794e, Value(buf, sizeof(buf)));
+
+ for (int i = 0; i < 32; i++) {
+ buf[i] = 31 - i;
+ }
+ ASSERT_EQ(0x113fdb5c, Value(buf, sizeof(buf)));
+
+ unsigned char data[48] = {
+ 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00,
+ 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ };
+ ASSERT_EQ(0xd9963a56, Value(reinterpret_cast<char*>(data), sizeof(data)));
+}
+
+TEST(CRC, Values) { ASSERT_NE(Value("a", 1), Value("foo", 3)); }
+
+TEST(CRC, Extend) {
+ ASSERT_EQ(Value("hello world", 11), Extend(Value("hello ", 6), "world", 5));
+}
+
+TEST(CRC, Mask) {
+ uint32 crc = Value("foo", 3);
+ ASSERT_NE(crc, Mask(crc));
+ ASSERT_NE(crc, Mask(Mask(crc)));
+ ASSERT_EQ(crc, Unmask(Mask(crc)));
+ ASSERT_EQ(crc, Unmask(Unmask(Mask(Mask(crc)))));
+}
+
+} // namespace crc32c
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/hash/hash.cc b/tensorflow/core/lib/hash/hash.cc
new file mode 100644
index 0000000000..075d252412
--- /dev/null
+++ b/tensorflow/core/lib/hash/hash.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/core/lib/hash/hash.h"
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/raw_coding.h"
+
+#include <string.h>
+
+namespace tensorflow {
+
+// 0xff is in case char is signed.
+static inline uint32 ByteAs32(char c) { return static_cast<uint32>(c) & 0xff; }
+static inline uint64 ByteAs64(char c) { return static_cast<uint64>(c) & 0xff; }
+
+uint32 Hash32(const char* data, size_t n, uint32 seed) {
+ // 'm' and 'r' are mixing constants generated offline.
+ // They're not really 'magic', they just happen to work well.
+
+ const uint32 m = 0x5bd1e995;
+ const int r = 24;
+
+ // Initialize the hash to a 'random' value
+ uint32 h = seed ^ n;
+
+ // Mix 4 bytes at a time into the hash
+ while (n >= 4) {
+ uint32 k = core::DecodeFixed32(data);
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h *= m;
+ h ^= k;
+
+ data += 4;
+ n -= 4;
+ }
+
+ // Handle the last few bytes of the input array
+
+ switch (n) {
+ case 3:
+ h ^= ByteAs32(data[2]) << 16;
+ TF_FALLTHROUGH_INTENDED;
+ case 2:
+ h ^= ByteAs32(data[1]) << 8;
+ TF_FALLTHROUGH_INTENDED;
+ case 1:
+ h ^= ByteAs32(data[0]);
+ h *= m;
+ }
+
+ // Do a few final mixes of the hash to ensure the last few
+ // bytes are well-incorporated.
+
+ h ^= h >> 13;
+ h *= m;
+ h ^= h >> 15;
+
+ return h;
+}
+
+uint64 Hash64(const char* data, size_t n, uint64 seed) {
+ const uint64 m = 0xc6a4a7935bd1e995;
+ const int r = 47;
+
+ uint64 h = seed ^ (n * m);
+
+ while (n >= 8) {
+ uint64 k = core::DecodeFixed64(data);
+ data += 8;
+ n -= 8;
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h ^= k;
+ h *= m;
+ }
+
+ switch (n) {
+ case 7:
+ h ^= ByteAs64(data[6]) << 48;
+ TF_FALLTHROUGH_INTENDED;
+ case 6:
+ h ^= ByteAs64(data[5]) << 40;
+ TF_FALLTHROUGH_INTENDED;
+ case 5:
+ h ^= ByteAs64(data[4]) << 32;
+ TF_FALLTHROUGH_INTENDED;
+ case 4:
+ h ^= ByteAs64(data[3]) << 24;
+ TF_FALLTHROUGH_INTENDED;
+ case 3:
+ h ^= ByteAs64(data[2]) << 16;
+ TF_FALLTHROUGH_INTENDED;
+ case 2:
+ h ^= ByteAs64(data[1]) << 8;
+ TF_FALLTHROUGH_INTENDED;
+ case 1:
+ h ^= ByteAs64(data[0]);
+ h *= m;
+ }
+
+ h ^= h >> r;
+ h *= m;
+ h ^= h >> r;
+
+ return h;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h
new file mode 100644
index 0000000000..af56218fed
--- /dev/null
+++ b/tensorflow/core/lib/hash/hash.h
@@ -0,0 +1,28 @@
+// Simple hash functions used for internal data structures
+
+#ifndef TENSORFLOW_LIB_HASH_HASH_H_
+#define TENSORFLOW_LIB_HASH_HASH_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <string>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+extern uint32 Hash32(const char* data, size_t n, uint32 seed);
+extern uint64 Hash64(const char* data, size_t n, uint64 seed);
+
+inline uint64 Hash64(const char* data, size_t n) {
+ return Hash64(data, n, 0xDECAFCAFFE);
+}
+
+inline uint64 Hash64(const string& str) {
+ return Hash64(str.data(), str.size());
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_HASH_HASH_H_
diff --git a/tensorflow/core/lib/hash/hash_test.cc b/tensorflow/core/lib/hash/hash_test.cc
new file mode 100644
index 0000000000..9d3b970f3b
--- /dev/null
+++ b/tensorflow/core/lib/hash/hash_test.cc
@@ -0,0 +1,64 @@
+#include <vector>
+
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(Hash, SignedUnsignedIssue) {
+ const unsigned char d1[1] = {0x62};
+ const unsigned char d2[2] = {0xc3, 0x97};
+ const unsigned char d3[3] = {0xe2, 0x99, 0xa5};
+ const unsigned char d4[4] = {0xe1, 0x80, 0xb9, 0x32};
+ const unsigned char d5[48] = {
+ 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00,
+ 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ };
+
+ struct Case {
+ uint32 hash32;
+ uint64 hash64;
+ const unsigned char* data;
+ size_t size;
+ uint32 seed;
+ };
+
+ for (Case c : std::vector<Case>{
+ {0x471a8188u, 0x4c61ea3eeda4cb87ull, nullptr, 0, 0xbc9f1d34},
+ {0xd615eba5u, 0x091309f7ef916c8aull, d1, sizeof(d1), 0xbc9f1d34},
+ {0x0c3cccdau, 0xa815bcdf1d1af01cull, d2, sizeof(d2), 0xbc9f1d34},
+ {0x3ba37e0eu, 0x02167564e4d06430ull, d3, sizeof(d3), 0xbc9f1d34},
+ {0x16174eb3u, 0x8f7ed82ffc21071full, d4, sizeof(d4), 0xbc9f1d34},
+ {0x98b1926cu, 0xce196580c97aff1eull, d5, sizeof(d5), 0x12345678},
+ }) {
+ EXPECT_EQ(c.hash32,
+ Hash32(reinterpret_cast<const char*>(c.data), c.size, c.seed));
+ EXPECT_EQ(c.hash64,
+ Hash64(reinterpret_cast<const char*>(c.data), c.size, c.seed));
+
+ // Check hashes with inputs aligned differently.
+ for (int align = 1; align <= 7; align++) {
+ std::string input(align, 'x');
+ input.append(reinterpret_cast<const char*>(c.data), c.size);
+ EXPECT_EQ(c.hash32, Hash32(&input[align], c.size, c.seed));
+ EXPECT_EQ(c.hash64, Hash64(&input[align], c.size, c.seed));
+ }
+ }
+}
+
+static void BM_Hash32(int iters, int len) {
+ std::string input(len, 'x');
+ uint32 h = 0;
+ for (int i = 0; i < iters; i++) {
+ h = Hash32(input.data(), len, 1);
+ }
+ testing::BytesProcessed(static_cast<int64>(iters) * len);
+ VLOG(1) << h;
+}
+BENCHMARK(BM_Hash32)->Range(1, 1024);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/histogram/histogram.cc b/tensorflow/core/lib/histogram/histogram.cc
new file mode 100644
index 0000000000..4c29d687b7
--- /dev/null
+++ b/tensorflow/core/lib/histogram/histogram.cc
@@ -0,0 +1,247 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include <float.h>
+#include <math.h>
+#include "tensorflow/core/framework/summary.pb.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+namespace tensorflow {
+namespace histogram {
+
+static std::vector<double>* InitDefaultBucketsInner() {
+ std::vector<double> buckets;
+ std::vector<double> neg_buckets;
+ // Make buckets whose range grows by 10% starting at 1.0e-12 up to 1.0e20
+ double v = 1.0e-12;
+ while (v < 1.0e20) {
+ buckets.push_back(v);
+ neg_buckets.push_back(-v);
+ v *= 1.1;
+ }
+ buckets.push_back(DBL_MAX);
+ neg_buckets.push_back(-DBL_MAX);
+ std::reverse(neg_buckets.begin(), neg_buckets.end());
+ std::vector<double>* result = new std::vector<double>;
+ result->insert(result->end(), neg_buckets.begin(), neg_buckets.end());
+ result->push_back(0.0);
+ result->insert(result->end(), buckets.begin(), buckets.end());
+ return result;
+}
+
+static gtl::ArraySlice<double> InitDefaultBuckets() {
+ static std::vector<double>* default_bucket_limits = InitDefaultBucketsInner();
+ return *default_bucket_limits;
+}
+
+Histogram::Histogram() : bucket_limits_(InitDefaultBuckets()) { Clear(); }
+
+// Create a histogram with a custom set of bucket limits,
+// specified in "custom_buckets[0..custom_buckets.size()-1]"
+Histogram::Histogram(gtl::ArraySlice<double> custom_bucket_limits)
+ : custom_bucket_limits_(custom_bucket_limits.begin(),
+ custom_bucket_limits.end()),
+ bucket_limits_(custom_bucket_limits_) {
+#ifndef NDEBUG
+ DCHECK_GT(bucket_limits_.size(), 0);
+ // Verify that the bucket boundaries are strictly increasing
+ for (size_t i = 1; i < bucket_limits_.size(); i++) {
+ DCHECK_GT(bucket_limits_[i], bucket_limits_[i - 1]);
+ }
+#endif
+ Clear();
+}
+
+bool Histogram::DecodeFromProto(const HistogramProto& proto) {
+ if ((proto.bucket_size() != proto.bucket_limit_size()) ||
+ (proto.bucket_size() == 0)) {
+ return false;
+ }
+ min_ = proto.min();
+ max_ = proto.max();
+ num_ = proto.num();
+ sum_ = proto.sum();
+ sum_squares_ = proto.sum_squares();
+ custom_bucket_limits_.clear();
+ custom_bucket_limits_.insert(custom_bucket_limits_.end(),
+ proto.bucket_limit().begin(),
+ proto.bucket_limit().end());
+ bucket_limits_ = custom_bucket_limits_;
+ buckets_.clear();
+ buckets_.insert(buckets_.end(), proto.bucket().begin(), proto.bucket().end());
+ return true;
+}
+
+void Histogram::Clear() {
+ min_ = bucket_limits_[bucket_limits_.size() - 1];
+ max_ = -DBL_MAX;
+ num_ = 0;
+ sum_ = 0;
+ sum_squares_ = 0;
+ buckets_.resize(bucket_limits_.size());
+ for (size_t i = 0; i < bucket_limits_.size(); i++) {
+ buckets_[i] = 0;
+ }
+}
+
+void Histogram::Add(double value) {
+ int b =
+ std::upper_bound(bucket_limits_.begin(), bucket_limits_.end(), value) -
+ bucket_limits_.begin();
+
+ buckets_[b] += 1.0;
+ if (min_ > value) min_ = value;
+ if (max_ < value) max_ = value;
+ num_++;
+ sum_ += value;
+ sum_squares_ += (value * value);
+}
+
+double Histogram::Median() const { return Percentile(50.0); }
+
+double Histogram::Percentile(double p) const {
+ if (num_ == 0.0) return 0.0;
+ double threshold = num_ * (p / 100.0);
+ double sum = 0;
+ for (size_t b = 0; b < buckets_.size(); b++) {
+ sum += buckets_[b];
+ if (sum >= threshold) {
+ // Scale linearly within this bucket
+ double left_point = (b == 0) ? min_ : bucket_limits_[b - 1];
+ double right_point = bucket_limits_[b];
+ double left_sum = sum - buckets_[b];
+ double right_sum = sum;
+ double pos = (threshold - left_sum) / (right_sum - left_sum);
+ double r = left_point + (right_point - left_point) * pos;
+ if (r < min_) r = min_;
+ if (r > max_) r = max_;
+ return r;
+ }
+ }
+ return max_;
+}
+
+double Histogram::Average() const {
+ if (num_ == 0.0) return 0;
+ return sum_ / num_;
+}
+
+double Histogram::StandardDeviation() const {
+ if (num_ == 0.0) return 0;
+ double variance = (sum_squares_ * num_ - sum_ * sum_) / (num_ * num_);
+ return sqrt(variance);
+}
+
+std::string Histogram::ToString() const {
+ std::string r;
+ char buf[200];
+ snprintf(buf, sizeof(buf), "Count: %.0f Average: %.4f StdDev: %.2f\n", num_,
+ Average(), StandardDeviation());
+ r.append(buf);
+ snprintf(buf, sizeof(buf), "Min: %.4f Median: %.4f Max: %.4f\n",
+ (num_ == 0.0 ? 0.0 : min_), Median(), max_);
+ r.append(buf);
+ r.append("------------------------------------------------------\n");
+ const double mult = num_ > 0 ? 100.0 / num_ : 0.0;
+ double sum = 0;
+ for (size_t b = 0; b < buckets_.size(); b++) {
+ if (buckets_[b] <= 0.0) continue;
+ sum += buckets_[b];
+ snprintf(buf, sizeof(buf), "[ %10.2g, %10.2g ) %7.0f %7.3f%% %7.3f%% ",
+ ((b == 0) ? -DBL_MAX : bucket_limits_[b - 1]), // left
+ bucket_limits_[b], // right
+ buckets_[b], // count
+ mult * buckets_[b], // percentage
+ mult * sum); // cum percentage
+ r.append(buf);
+
+ // Add hash marks based on percentage; 20 marks for 100%.
+ int marks = static_cast<int>(20 * (buckets_[b] / num_) + 0.5);
+ r.append(marks, '#');
+ r.push_back('\n');
+ }
+ return r;
+}
+
+void Histogram::EncodeToProto(HistogramProto* proto,
+ bool preserve_zero_buckets) const {
+ proto->Clear();
+ proto->set_min(min_);
+ proto->set_max(max_);
+ proto->set_num(num_);
+ proto->set_sum(sum_);
+ proto->set_sum_squares(sum_squares_);
+ for (size_t i = 0; i < buckets_.size();) {
+ double end = bucket_limits_[i];
+ double count = buckets_[i];
+ i++;
+ if (!preserve_zero_buckets && count <= 0.0) {
+ // Find run of empty buckets and collapse them into one
+ while (i < buckets_.size() && buckets_[i] <= 0.0) {
+ end = bucket_limits_[i];
+ count = buckets_[i];
+ i++;
+ }
+ }
+ proto->add_bucket_limit(end);
+ proto->add_bucket(count);
+ }
+ if (proto->bucket_size() == 0.0) {
+ // It's easier when we restore if we always have at least one bucket entry
+ proto->add_bucket_limit(DBL_MAX);
+ proto->add_bucket(0.0);
+ }
+}
+
+// ThreadSafeHistogram implementation.
+bool ThreadSafeHistogram::DecodeFromProto(const HistogramProto& proto) {
+ mutex_lock l(mu_);
+ return histogram_.DecodeFromProto(proto);
+}
+
+void ThreadSafeHistogram::Clear() {
+ mutex_lock l(mu_);
+ histogram_.Clear();
+}
+
+void ThreadSafeHistogram::Add(double value) {
+ mutex_lock l(mu_);
+ histogram_.Add(value);
+}
+
+void ThreadSafeHistogram::EncodeToProto(HistogramProto* proto,
+ bool preserve_zero_buckets) const {
+ mutex_lock l(mu_);
+ histogram_.EncodeToProto(proto, preserve_zero_buckets);
+}
+
+double ThreadSafeHistogram::Median() const {
+ mutex_lock l(mu_);
+ return histogram_.Median();
+}
+
+double ThreadSafeHistogram::Percentile(double p) const {
+ mutex_lock l(mu_);
+ return histogram_.Percentile(p);
+}
+
+double ThreadSafeHistogram::Average() const {
+ mutex_lock l(mu_);
+ return histogram_.Average();
+}
+
+double ThreadSafeHistogram::StandardDeviation() const {
+ mutex_lock l(mu_);
+ return histogram_.StandardDeviation();
+}
+
+std::string ThreadSafeHistogram::ToString() const {
+ mutex_lock l(mu_);
+ return histogram_.ToString();
+}
+
+} // namespace histogram
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/histogram/histogram.h b/tensorflow/core/lib/histogram/histogram.h
new file mode 100644
index 0000000000..9b655f3acb
--- /dev/null
+++ b/tensorflow/core/lib/histogram/histogram.h
@@ -0,0 +1,119 @@
+#ifndef TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
+#define TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
+
+#include <string>
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+class HistogramProto;
+
+namespace histogram {
+
+class Histogram {
+ public:
+ // Create a histogram with a default set of bucket boundaries.
+ // Buckets near zero cover very small ranges (e.g. 10^-12), and each
+ // bucket range grows by ~10% as we head away from zero. The
+ // buckets cover the range from -DBL_MAX to DBL_MAX.
+ Histogram();
+
+ // Create a histogram with a custom set of bucket boundaries,
+ // specified in "custom_bucket_limits[0..custom_bucket_limits.size()-1]"
+ // REQUIRES: custom_bucket_limits[i] values are monotonically increasing.
+ // REQUIRES: custom_bucket_limits is not empty()
+ explicit Histogram(gtl::ArraySlice<double> custom_bucket_limits);
+
+ // Restore the state of a histogram that was previously encoded
+ // via Histogram::EncodeToProto. Note that only the bucket boundaries
+ // generated by EncodeToProto will be restored.
+ bool DecodeFromProto(const HistogramProto& proto);
+
+ ~Histogram() {}
+
+ void Clear();
+ void Add(double value);
+
+ // Save the current state of the histogram to "*proto". If
+ // "preserve_zero_buckets" is false, only non-zero bucket values and
+ // ranges are saved, and the bucket boundaries of zero-valued buckets
+ // are lost.
+ void EncodeToProto(HistogramProto* proto, bool preserve_zero_buckets) const;
+
+ // Return the median of the values in the histogram
+ double Median() const;
+
+ // Return the "p"th percentile [0.0..100.0] of the values in the
+ // distribution
+ double Percentile(double p) const;
+
+ // Return the average value of the distribution
+ double Average() const;
+
+ // Return the standard deviation of values in the distribution
+ double StandardDeviation() const;
+
+ // Returns a multi-line human-readable string representing the histogram
+ // contents. Example output:
+ // Count: 4 Average: 251.7475 StdDev: 432.02
+ // Min: -3.0000 Median: 5.0000 Max: 1000.0000
+ // ------------------------------------------------------
+ // [ -5, 0 ) 1 25.000% 25.000% #####
+ // [ 0, 5 ) 1 25.000% 50.000% #####
+ // [ 5, 10 ) 1 25.000% 75.000% #####
+ // [ 1000, 10000 ) 1 25.000% 100.000% #####
+ std::string ToString() const;
+
+ private:
+ double min_;
+ double max_;
+ double num_;
+ double sum_;
+ double sum_squares_;
+
+ std::vector<double> custom_bucket_limits_;
+ gtl::ArraySlice<double> bucket_limits_;
+ std::vector<double> buckets_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Histogram);
+};
+
+// Wrapper around a Histogram object that is thread safe.
+//
+// All methods hold a lock while delegating to a Histogram object owned by the
+// ThreadSafeHistogram instance.
+//
+// See Histogram for documentation of the methods.
+class ThreadSafeHistogram {
+ public:
+ ThreadSafeHistogram() {}
+ explicit ThreadSafeHistogram(gtl::ArraySlice<double> custom_bucket_limits)
+ : histogram_(custom_bucket_limits) {}
+ bool DecodeFromProto(const HistogramProto& proto);
+
+ ~ThreadSafeHistogram() {}
+
+ void Clear();
+
+ // TODO(mdevin): It might be a good idea to provide a AddN(<many values>)
+ // method to avoid grabbing/releasing the lock when adding many values.
+ void Add(double value);
+
+ void EncodeToProto(HistogramProto* proto, bool preserve_zero_buckets) const;
+ double Median() const;
+ double Percentile(double p) const;
+ double Average() const;
+ double StandardDeviation() const;
+ std::string ToString() const;
+
+ private:
+ mutable mutex mu_;
+ Histogram histogram_ GUARDED_BY(mu_);
+};
+
+} // namespace histogram
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
diff --git a/tensorflow/core/lib/histogram/histogram_test.cc b/tensorflow/core/lib/histogram/histogram_test.cc
new file mode 100644
index 0000000000..ede44fe85b
--- /dev/null
+++ b/tensorflow/core/lib/histogram/histogram_test.cc
@@ -0,0 +1,112 @@
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include <float.h>
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace histogram {
+
+static void Validate(const Histogram& h) {
+ string s1 = h.ToString();
+ LOG(ERROR) << s1;
+
+ HistogramProto proto_with_zeroes;
+ h.EncodeToProto(&proto_with_zeroes, true);
+ Histogram h2;
+ EXPECT_TRUE(h2.DecodeFromProto(proto_with_zeroes));
+ string s2 = h2.ToString();
+ LOG(ERROR) << s2;
+
+ EXPECT_EQ(s1, s2);
+
+ HistogramProto proto_no_zeroes;
+ h.EncodeToProto(&proto_no_zeroes, false);
+ LOG(ERROR) << proto_no_zeroes.DebugString();
+ Histogram h3;
+ EXPECT_TRUE(h3.DecodeFromProto(proto_no_zeroes));
+ string s3 = h3.ToString();
+ LOG(ERROR) << s3;
+
+ EXPECT_EQ(s1, s3);
+}
+
+TEST(Histogram, Empty) {
+ Histogram h;
+ Validate(h);
+}
+
+TEST(Histogram, SingleValue) {
+ Histogram h;
+ h.Add(-3.0);
+ Validate(h);
+}
+
+TEST(Histogram, CustomBuckets) {
+ Histogram h({-10, -5, 0, 5, 10, 100, 1000, 10000, DBL_MAX});
+ h.Add(-3.0);
+ h.Add(4.99);
+ h.Add(5.0);
+ h.Add(1000.0);
+ Validate(h);
+}
+
+TEST(Histogram, Percentile) {
+ Histogram h({0, 10, 100, DBL_MAX});
+ h.Add(-2);
+ h.Add(-2);
+ h.Add(0);
+ double median = h.Percentile(50.0);
+ EXPECT_EQ(median, -0.5);
+}
+
+TEST(Histogram, Basic) {
+ Histogram h;
+ for (int i = 0; i < 100; i++) {
+ h.Add(i);
+ }
+ for (int i = 1000; i < 100000; i += 1000) {
+ h.Add(i);
+ }
+ Validate(h);
+}
+
+TEST(ThreadSafeHistogram, Basic) {
+ // Fill a normal histogram.
+ Histogram h;
+ for (int i = 0; i < 100; i++) {
+ h.Add(i);
+ }
+
+ // Fill a thread-safe histogram with the same values.
+ ThreadSafeHistogram tsh;
+ for (int i = 0; i < 100; i++) {
+ tsh.Add(i);
+ }
+
+ for (int i = 0; i < 2; ++i) {
+ bool preserve_zero_buckets = (i == 0);
+ HistogramProto h_proto;
+ h.EncodeToProto(&h_proto, preserve_zero_buckets);
+ HistogramProto tsh_proto;
+ tsh.EncodeToProto(&tsh_proto, preserve_zero_buckets);
+
+ // Let's decode from the proto of the other histogram type.
+ Histogram h2;
+ EXPECT_TRUE(h2.DecodeFromProto(tsh_proto));
+ ThreadSafeHistogram tsh2;
+ EXPECT_TRUE(tsh2.DecodeFromProto(h_proto));
+
+ // Now let's reencode and check they match.
+ EXPECT_EQ(h2.ToString(), tsh2.ToString());
+ }
+
+ EXPECT_EQ(h.Median(), tsh.Median());
+ EXPECT_EQ(h.Percentile(40.0), tsh.Percentile(40.0));
+ EXPECT_EQ(h.Average(), tsh.Average());
+ EXPECT_EQ(h.StandardDeviation(), tsh.StandardDeviation());
+ EXPECT_EQ(h.ToString(), tsh.ToString());
+}
+
+} // namespace histogram
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/block.cc b/tensorflow/core/lib/io/block.cc
new file mode 100644
index 0000000000..1ddaa2eb78
--- /dev/null
+++ b/tensorflow/core/lib/io/block.cc
@@ -0,0 +1,236 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// Decodes the blocks generated by block_builder.cc.
+
+#include "tensorflow/core/lib/io/block.h"
+
+#include <vector>
+#include <algorithm>
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace table {
+
+inline uint32 Block::NumRestarts() const {
+ assert(size_ >= sizeof(uint32));
+ return core::DecodeFixed32(data_ + size_ - sizeof(uint32));
+}
+
+Block::Block(const BlockContents& contents)
+ : data_(contents.data.data()),
+ size_(contents.data.size()),
+ owned_(contents.heap_allocated) {
+ if (size_ < sizeof(uint32)) {
+ size_ = 0; // Error marker
+ } else {
+ size_t max_restarts_allowed = (size_ - sizeof(uint32)) / sizeof(uint32);
+ if (NumRestarts() > max_restarts_allowed) {
+ // The size is too small for NumRestarts()
+ size_ = 0;
+ } else {
+ restart_offset_ = size_ - (1 + NumRestarts()) * sizeof(uint32);
+ }
+ }
+}
+
+Block::~Block() {
+ if (owned_) {
+ delete[] data_;
+ }
+}
+
+// Helper routine: decode the next block entry starting at "p",
+// storing the number of shared key bytes, non_shared key bytes,
+// and the length of the value in "*shared", "*non_shared", and
+// "*value_length", respectively. Will not dereference past "limit".
+//
+// If any errors are detected, returns NULL. Otherwise, returns a
+// pointer to the key delta (just past the three decoded values).
+static inline const char* DecodeEntry(const char* p, const char* limit,
+ uint32* shared, uint32* non_shared,
+ uint32* value_length) {
+ if (limit - p < 3) return NULL;
+ *shared = reinterpret_cast<const unsigned char*>(p)[0];
+ *non_shared = reinterpret_cast<const unsigned char*>(p)[1];
+ *value_length = reinterpret_cast<const unsigned char*>(p)[2];
+ if ((*shared | *non_shared | *value_length) < 128) {
+ // Fast path: all three values are encoded in one byte each
+ p += 3;
+ } else {
+ if ((p = core::GetVarint32Ptr(p, limit, shared)) == NULL) return NULL;
+ if ((p = core::GetVarint32Ptr(p, limit, non_shared)) == NULL) return NULL;
+ if ((p = core::GetVarint32Ptr(p, limit, value_length)) == NULL) return NULL;
+ }
+
+ if (static_cast<uint32>(limit - p) < (*non_shared + *value_length)) {
+ return NULL;
+ }
+ return p;
+}
+
+class Block::Iter : public Iterator {
+ private:
+ const char* const data_; // underlying block contents
+ uint32 const restarts_; // Offset of restart array (list of fixed32)
+ uint32 const num_restarts_; // Number of uint32 entries in restart array
+
+ // current_ is offset in data_ of current entry. >= restarts_ if !Valid
+ uint32 current_;
+ uint32 restart_index_; // Index of restart block in which current_ falls
+ string key_;
+ StringPiece value_;
+ Status status_;
+
+ inline int Compare(const StringPiece& a, const StringPiece& b) const {
+ return a.compare(b);
+ }
+
+ // Return the offset in data_ just past the end of the current entry.
+ inline uint32 NextEntryOffset() const {
+ return (value_.data() + value_.size()) - data_;
+ }
+
+ uint32 GetRestartPoint(uint32 index) {
+ assert(index < num_restarts_);
+ return core::DecodeFixed32(data_ + restarts_ + index * sizeof(uint32));
+ }
+
+ void SeekToRestartPoint(uint32 index) {
+ key_.clear();
+ restart_index_ = index;
+ // current_ will be fixed by ParseNextKey();
+
+ // ParseNextKey() starts at the end of value_, so set value_ accordingly
+ uint32 offset = GetRestartPoint(index);
+ value_ = StringPiece(data_ + offset, 0);
+ }
+
+ public:
+ Iter(const char* data, uint32 restarts, uint32 num_restarts)
+ : data_(data),
+ restarts_(restarts),
+ num_restarts_(num_restarts),
+ current_(restarts_),
+ restart_index_(num_restarts_) {
+ assert(num_restarts_ > 0);
+ }
+
+ virtual bool Valid() const { return current_ < restarts_; }
+ virtual Status status() const { return status_; }
+ virtual StringPiece key() const {
+ assert(Valid());
+ return key_;
+ }
+ virtual StringPiece value() const {
+ assert(Valid());
+ return value_;
+ }
+
+ virtual void Next() {
+ assert(Valid());
+ ParseNextKey();
+ }
+
+ virtual void Seek(const StringPiece& target) {
+ // Binary search in restart array to find the last restart point
+ // with a key < target
+ uint32 left = 0;
+ uint32 right = num_restarts_ - 1;
+ while (left < right) {
+ uint32 mid = (left + right + 1) / 2;
+ uint32 region_offset = GetRestartPoint(mid);
+ uint32 shared, non_shared, value_length;
+ const char* key_ptr =
+ DecodeEntry(data_ + region_offset, data_ + restarts_, &shared,
+ &non_shared, &value_length);
+ if (key_ptr == NULL || (shared != 0)) {
+ CorruptionError();
+ return;
+ }
+ StringPiece mid_key(key_ptr, non_shared);
+ if (Compare(mid_key, target) < 0) {
+ // Key at "mid" is smaller than "target". Therefore all
+ // blocks before "mid" are uninteresting.
+ left = mid;
+ } else {
+ // Key at "mid" is >= "target". Therefore all blocks at or
+ // after "mid" are uninteresting.
+ right = mid - 1;
+ }
+ }
+
+ // Linear search (within restart block) for first key >= target
+ SeekToRestartPoint(left);
+ while (true) {
+ if (!ParseNextKey()) {
+ return;
+ }
+ if (Compare(key_, target) >= 0) {
+ return;
+ }
+ }
+ }
+
+ virtual void SeekToFirst() {
+ SeekToRestartPoint(0);
+ ParseNextKey();
+ }
+
+ private:
+ void CorruptionError() {
+ current_ = restarts_;
+ restart_index_ = num_restarts_;
+ status_ = errors::DataLoss("bad entry in block");
+ key_.clear();
+ value_.clear();
+ }
+
+ bool ParseNextKey() {
+ current_ = NextEntryOffset();
+ const char* p = data_ + current_;
+ const char* limit = data_ + restarts_; // Restarts come right after data
+ if (p >= limit) {
+ // No more entries to return. Mark as invalid.
+ current_ = restarts_;
+ restart_index_ = num_restarts_;
+ return false;
+ }
+
+ // Decode next entry
+ uint32 shared, non_shared, value_length;
+ p = DecodeEntry(p, limit, &shared, &non_shared, &value_length);
+ if (p == NULL || key_.size() < shared) {
+ CorruptionError();
+ return false;
+ } else {
+ key_.resize(shared);
+ key_.append(p, non_shared);
+ value_ = StringPiece(p + non_shared, value_length);
+ while (restart_index_ + 1 < num_restarts_ &&
+ GetRestartPoint(restart_index_ + 1) < current_) {
+ ++restart_index_;
+ }
+ return true;
+ }
+ }
+};
+
+Iterator* Block::NewIterator() {
+ if (size_ < sizeof(uint32)) {
+ return NewErrorIterator(errors::DataLoss("bad block contents"));
+ }
+ const uint32 num_restarts = NumRestarts();
+ if (num_restarts == 0) {
+ return NewEmptyIterator();
+ } else {
+ return new Iter(data_, restart_offset_, num_restarts);
+ }
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/block.h b/tensorflow/core/lib/io/block.h
new file mode 100644
index 0000000000..bf53245b8d
--- /dev/null
+++ b/tensorflow/core/lib/io/block.h
@@ -0,0 +1,45 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_BLOCK_H_
+#define TENSORFLOW_LIB_IO_BLOCK_H_
+
+#include <stddef.h>
+#include <stdint.h>
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+namespace table {
+
+struct BlockContents;
+
+class Block {
+ public:
+ // Initialize the block with the specified contents.
+ explicit Block(const BlockContents& contents);
+
+ ~Block();
+
+ size_t size() const { return size_; }
+ Iterator* NewIterator();
+
+ private:
+ uint32 NumRestarts() const;
+
+ const char* data_;
+ size_t size_;
+ uint32 restart_offset_; // Offset in data_ of restart array
+ bool owned_; // Block owns data_[]
+
+ // No copying allowed
+ Block(const Block&);
+ void operator=(const Block&);
+
+ class Iter;
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_BLOCK_H_
diff --git a/tensorflow/core/lib/io/block_builder.cc b/tensorflow/core/lib/io/block_builder.cc
new file mode 100644
index 0000000000..d94048d744
--- /dev/null
+++ b/tensorflow/core/lib/io/block_builder.cc
@@ -0,0 +1,107 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// BlockBuilder generates blocks where keys are prefix-compressed:
+//
+// When we store a key, we drop the prefix shared with the previous
+// string. This helps reduce the space requirement significantly.
+// Furthermore, once every K keys, we do not apply the prefix
+// compression and store the entire key. We call this a "restart
+// point". The tail end of the block stores the offsets of all of the
+// restart points, and can be used to do a binary search when looking
+// for a particular key. Values are stored as-is (without compression)
+// immediately following the corresponding key.
+//
+// An entry for a particular key-value pair has the form:
+// shared_bytes: varint32
+// unshared_bytes: varint32
+// value_length: varint32
+// key_delta: char[unshared_bytes]
+// value: char[value_length]
+// shared_bytes == 0 for restart points.
+//
+// The trailer of the block has the form:
+// restarts: uint32[num_restarts]
+// num_restarts: uint32
+// restarts[i] contains the offset within the block of the ith restart point.
+
+#include "tensorflow/core/lib/io/block_builder.h"
+
+#include <algorithm>
+#include <assert.h>
+#include "tensorflow/core/lib/io/table_builder.h"
+#include "tensorflow/core/lib/core/coding.h"
+
+namespace tensorflow {
+namespace table {
+
+BlockBuilder::BlockBuilder(const Options* options)
+ : options_(options), restarts_(), counter_(0), finished_(false) {
+ assert(options->block_restart_interval >= 1);
+ restarts_.push_back(0); // First restart point is at offset 0
+}
+
+void BlockBuilder::Reset() {
+ buffer_.clear();
+ restarts_.clear();
+ restarts_.push_back(0); // First restart point is at offset 0
+ counter_ = 0;
+ finished_ = false;
+ last_key_.clear();
+}
+
+size_t BlockBuilder::CurrentSizeEstimate() const {
+ return (buffer_.size() + // Raw data buffer
+ restarts_.size() * sizeof(uint32) + // Restart array
+ sizeof(uint32)); // Restart array length
+}
+
+StringPiece BlockBuilder::Finish() {
+ // Append restart array
+ for (size_t i = 0; i < restarts_.size(); i++) {
+ core::PutFixed32(&buffer_, restarts_[i]);
+ }
+ core::PutFixed32(&buffer_, restarts_.size());
+ finished_ = true;
+ return StringPiece(buffer_);
+}
+
+void BlockBuilder::Add(const StringPiece& key, const StringPiece& value) {
+ StringPiece last_key_piece(last_key_);
+ assert(!finished_);
+ assert(counter_ <= options_->block_restart_interval);
+ assert(buffer_.empty() // No values yet?
+ || key.compare(last_key_piece) > 0);
+ size_t shared = 0;
+ if (counter_ < options_->block_restart_interval) {
+ // See how much sharing to do with previous string
+ const size_t min_length = std::min(last_key_piece.size(), key.size());
+ while ((shared < min_length) && (last_key_piece[shared] == key[shared])) {
+ shared++;
+ }
+ } else {
+ // Restart compression
+ restarts_.push_back(buffer_.size());
+ counter_ = 0;
+ }
+ const size_t non_shared = key.size() - shared;
+
+ // Add "<shared><non_shared><value_size>" to buffer_
+ core::PutVarint32(&buffer_, shared);
+ core::PutVarint32(&buffer_, non_shared);
+ core::PutVarint32(&buffer_, value.size());
+
+ // Add string delta to buffer_ followed by value
+ buffer_.append(key.data() + shared, non_shared);
+ buffer_.append(value.data(), value.size());
+
+ // Update state
+ last_key_.resize(shared);
+ last_key_.append(key.data() + shared, non_shared);
+ assert(StringPiece(last_key_) == key);
+ counter_++;
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h
new file mode 100644
index 0000000000..e07a647805
--- /dev/null
+++ b/tensorflow/core/lib/io/block_builder.h
@@ -0,0 +1,57 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_
+#define TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_
+
+#include <vector>
+
+#include <stdint.h>
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace table {
+
+struct Options;
+
+class BlockBuilder {
+ public:
+ explicit BlockBuilder(const Options* options);
+
+ // Reset the contents as if the BlockBuilder was just constructed.
+ void Reset();
+
+ // REQUIRES: Finish() has not been called since the last call to Reset().
+ // REQUIRES: key is larger than any previously added key
+ void Add(const StringPiece& key, const StringPiece& value);
+
+ // Finish building the block and return a slice that refers to the
+ // block contents. The returned slice will remain valid for the
+ // lifetime of this builder or until Reset() is called.
+ StringPiece Finish();
+
+ // Returns an estimate of the current (uncompressed) size of the block
+ // we are building.
+ size_t CurrentSizeEstimate() const;
+
+ // Return true iff no entries have been added since the last Reset()
+ bool empty() const { return buffer_.empty(); }
+
+ private:
+ const Options* options_;
+ string buffer_; // Destination buffer
+ std::vector<uint32> restarts_; // Restart points
+ int counter_; // Number of entries emitted since restart
+ bool finished_; // Has Finish() been called?
+ string last_key_;
+
+ // No copying allowed
+ BlockBuilder(const BlockBuilder&);
+ void operator=(const BlockBuilder&);
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_
diff --git a/tensorflow/core/lib/io/format.cc b/tensorflow/core/lib/io/format.cc
new file mode 100644
index 0000000000..259cfc13dc
--- /dev/null
+++ b/tensorflow/core/lib/io/format.cc
@@ -0,0 +1,148 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/format.h"
+
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/io/block.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace table {
+
+void BlockHandle::EncodeTo(string* dst) const {
+ // Sanity check that all fields have been set
+ assert(offset_ != ~static_cast<uint64>(0));
+ assert(size_ != ~static_cast<uint64>(0));
+ core::PutVarint64(dst, offset_);
+ core::PutVarint64(dst, size_);
+}
+
+Status BlockHandle::DecodeFrom(StringPiece* input) {
+ if (core::GetVarint64(input, &offset_) && core::GetVarint64(input, &size_)) {
+ return Status::OK();
+ } else {
+ return errors::DataLoss("bad block handle");
+ }
+}
+
+void Footer::EncodeTo(string* dst) const {
+#ifndef NDEBUG
+ const size_t original_size = dst->size();
+#endif
+ metaindex_handle_.EncodeTo(dst);
+ index_handle_.EncodeTo(dst);
+ dst->resize(2 * BlockHandle::kMaxEncodedLength); // Padding
+ core::PutFixed32(dst, static_cast<uint32>(kTableMagicNumber & 0xffffffffu));
+ core::PutFixed32(dst, static_cast<uint32>(kTableMagicNumber >> 32));
+ assert(dst->size() == original_size + kEncodedLength);
+}
+
+Status Footer::DecodeFrom(StringPiece* input) {
+ const char* magic_ptr = input->data() + kEncodedLength - 8;
+ const uint32 magic_lo = core::DecodeFixed32(magic_ptr);
+ const uint32 magic_hi = core::DecodeFixed32(magic_ptr + 4);
+ const uint64 magic =
+ ((static_cast<uint64>(magic_hi) << 32) | (static_cast<uint64>(magic_lo)));
+ if (magic != kTableMagicNumber) {
+ return errors::DataLoss("not an sstable (bad magic number)");
+ }
+
+ Status result = metaindex_handle_.DecodeFrom(input);
+ if (result.ok()) {
+ result = index_handle_.DecodeFrom(input);
+ }
+ if (result.ok()) {
+ // We skip over any leftover data (just padding for now) in "input"
+ const char* end = magic_ptr + 8;
+ *input = StringPiece(end, input->data() + input->size() - end);
+ }
+ return result;
+}
+
+Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle,
+ BlockContents* result) {
+ result->data = StringPiece();
+ result->cachable = false;
+ result->heap_allocated = false;
+
+ // Read the block contents as well as the type/crc footer.
+ // See table_builder.cc for the code that built this structure.
+ size_t n = static_cast<size_t>(handle.size());
+ char* buf = new char[n + kBlockTrailerSize];
+ StringPiece contents;
+ Status s =
+ file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf);
+ if (!s.ok()) {
+ delete[] buf;
+ return s;
+ }
+ if (contents.size() != n + kBlockTrailerSize) {
+ delete[] buf;
+ return errors::DataLoss("truncated block read");
+ }
+
+ // Check the crc of the type and the block contents
+ const char* data = contents.data(); // Pointer to where Read put the data
+ // This checksum verification is optional. We leave it on for now
+ const bool verify_checksum = true;
+ if (verify_checksum) {
+ const uint32 crc = crc32c::Unmask(core::DecodeFixed32(data + n + 1));
+ const uint32 actual = crc32c::Value(data, n + 1);
+ if (actual != crc) {
+ delete[] buf;
+ s = errors::DataLoss("block checksum mismatch");
+ return s;
+ }
+ }
+
+ switch (data[n]) {
+ case kNoCompression:
+ if (data != buf) {
+ // File implementation gave us pointer to some other data.
+ // Use it directly under the assumption that it will be live
+ // while the file is open.
+ delete[] buf;
+ result->data = StringPiece(data, n);
+ result->heap_allocated = false;
+ result->cachable = false; // Do not double-cache
+ } else {
+ result->data = StringPiece(buf, n);
+ result->heap_allocated = true;
+ result->cachable = true;
+ }
+
+ // Ok
+ break;
+ case kSnappyCompression: {
+ size_t ulength = 0;
+ if (!port::Snappy_GetUncompressedLength(data, n, &ulength)) {
+ delete[] buf;
+ return errors::DataLoss("corrupted compressed block contents");
+ }
+ char* ubuf = new char[ulength];
+ if (!port::Snappy_Uncompress(data, n, ubuf)) {
+ delete[] buf;
+ delete[] ubuf;
+ return errors::DataLoss("corrupted compressed block contents");
+ }
+ delete[] buf;
+ result->data = StringPiece(ubuf, ulength);
+ result->heap_allocated = true;
+ result->cachable = true;
+ break;
+ }
+ default:
+ delete[] buf;
+ return errors::DataLoss("bad block type");
+ }
+
+ return Status::OK();
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/format.h b/tensorflow/core/lib/io/format.h
new file mode 100644
index 0000000000..3121c41bb8
--- /dev/null
+++ b/tensorflow/core/lib/io/format.h
@@ -0,0 +1,99 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_FORMAT_H_
+#define TENSORFLOW_LIB_IO_FORMAT_H_
+
+#include <string>
+#include <stdint.h>
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/io/table_builder.h"
+
+namespace tensorflow {
+class RandomAccessFile;
+namespace table {
+
+class Block;
+
+// BlockHandle is a pointer to the extent of a file that stores a data
+// block or a meta block.
+class BlockHandle {
+ public:
+ BlockHandle();
+
+ // The offset of the block in the file.
+ uint64 offset() const { return offset_; }
+ void set_offset(uint64 offset) { offset_ = offset; }
+
+ // The size of the stored block
+ uint64 size() const { return size_; }
+ void set_size(uint64 size) { size_ = size; }
+
+ void EncodeTo(string* dst) const;
+ Status DecodeFrom(StringPiece* input);
+
+ // Maximum encoding length of a BlockHandle
+ enum { kMaxEncodedLength = 10 + 10 };
+
+ private:
+ uint64 offset_;
+ uint64 size_;
+};
+
+// Footer encapsulates the fixed information stored at the tail
+// end of every table file.
+class Footer {
+ public:
+ Footer() {}
+
+ // The block handle for the metaindex block of the table
+ const BlockHandle& metaindex_handle() const { return metaindex_handle_; }
+ void set_metaindex_handle(const BlockHandle& h) { metaindex_handle_ = h; }
+
+ // The block handle for the index block of the table
+ const BlockHandle& index_handle() const { return index_handle_; }
+ void set_index_handle(const BlockHandle& h) { index_handle_ = h; }
+
+ void EncodeTo(string* dst) const;
+ Status DecodeFrom(StringPiece* input);
+
+ // Encoded length of a Footer. Note that the serialization of a
+ // Footer will always occupy exactly this many bytes. It consists
+ // of two block handles and a magic number.
+ enum { kEncodedLength = 2 * BlockHandle::kMaxEncodedLength + 8 };
+
+ private:
+ BlockHandle metaindex_handle_;
+ BlockHandle index_handle_;
+};
+
+// kTableMagicNumber was picked by running
+// echo http://code.google.com/p/leveldb/ | sha1sum
+// and taking the leading 64 bits.
+static const uint64 kTableMagicNumber = 0xdb4775248b80fb57ull;
+
+// 1-byte type + 32-bit crc
+static const size_t kBlockTrailerSize = 5;
+
+struct BlockContents {
+ StringPiece data; // Actual contents of data
+ bool cachable; // True iff data can be cached
+ bool heap_allocated; // True iff caller should delete[] data.data()
+};
+
+// Read the block identified by "handle" from "file". On failure
+// return non-OK. On success fill *result and return OK.
+extern Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle,
+ BlockContents* result);
+
+// Implementation details follow. Clients should ignore,
+
+inline BlockHandle::BlockHandle()
+ : offset_(~static_cast<uint64>(0)), size_(~static_cast<uint64>(0)) {}
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_FORMAT_H_
diff --git a/tensorflow/core/lib/io/inputbuffer.cc b/tensorflow/core/lib/io/inputbuffer.cc
new file mode 100644
index 0000000000..8fa245a546
--- /dev/null
+++ b/tensorflow/core/lib/io/inputbuffer.cc
@@ -0,0 +1,112 @@
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace io {
+
+InputBuffer::InputBuffer(RandomAccessFile* file, size_t buffer_bytes)
+ : file_(file),
+ file_pos_(0),
+ size_(buffer_bytes),
+ buf_(new char[size_]),
+ pos_(buf_),
+ limit_(buf_) {}
+
+InputBuffer::~InputBuffer() {
+ delete file_;
+ delete[] buf_;
+}
+
+Status InputBuffer::FillBuffer() {
+ StringPiece data;
+ Status s = file_->Read(file_pos_, size_, &data, buf_);
+ if (data.data() != buf_) {
+ memmove(buf_, data.data(), data.size());
+ }
+ pos_ = buf_;
+ limit_ = pos_ + data.size();
+ file_pos_ += data.size();
+ return s;
+}
+
+Status InputBuffer::ReadLine(string* result) {
+ result->clear();
+ int i;
+ Status s;
+ for (i = 0;; i++) {
+ if (pos_ == limit_) {
+ // Get more data into buffer
+ s = FillBuffer();
+ if (limit_ == buf_) {
+ break;
+ }
+ }
+ char c = *pos_++;
+ if (c == '\n') {
+ // We don't append the '\n' to *result
+ return Status::OK();
+ }
+ *result += c;
+ }
+ if (errors::IsOutOfRange(s) && !result->empty()) {
+ return Status::OK();
+ }
+ return s;
+}
+
+Status InputBuffer::ReadNBytes(int64 bytes_to_read, string* result) {
+ result->clear();
+ if (bytes_to_read < 0) {
+ return errors::InvalidArgument("Can't read a negative number of bytes: ",
+ bytes_to_read);
+ }
+ result->reserve(bytes_to_read);
+ Status s;
+ while (result->size() < static_cast<size_t>(bytes_to_read)) {
+ if (pos_ == limit_) {
+ // Get more data into buffer
+ s = FillBuffer();
+ if (limit_ == buf_) {
+ break;
+ }
+ }
+ const int64 bytes_to_copy =
+ std::min<int64>(limit_ - pos_, bytes_to_read - result->size());
+ result->insert(result->size(), pos_, bytes_to_copy);
+ pos_ += bytes_to_copy;
+ }
+ if (errors::IsOutOfRange(s) &&
+ (result->size() == static_cast<size_t>(bytes_to_read))) {
+ return Status::OK();
+ }
+ return s;
+}
+
+Status InputBuffer::SkipNBytes(int64 bytes_to_skip) {
+ if (bytes_to_skip < 0) {
+ return errors::InvalidArgument("Can only skip forward, not ",
+ bytes_to_skip);
+ }
+ int64 bytes_skipped = 0;
+ Status s;
+ while (bytes_skipped < bytes_to_skip) {
+ if (pos_ == limit_) {
+ // Get more data into buffer
+ s = FillBuffer();
+ if (limit_ == buf_) {
+ break;
+ }
+ }
+ const int64 bytes_to_advance =
+ std::min<int64>(limit_ - pos_, bytes_to_skip - bytes_skipped);
+ bytes_skipped += bytes_to_advance;
+ pos_ += bytes_to_advance;
+ }
+ if (errors::IsOutOfRange(s) && bytes_skipped == bytes_to_skip) {
+ return Status::OK();
+ }
+ return s;
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/inputbuffer.h b/tensorflow/core/lib/io/inputbuffer.h
new file mode 100644
index 0000000000..6879f30567
--- /dev/null
+++ b/tensorflow/core/lib/io/inputbuffer.h
@@ -0,0 +1,62 @@
+#ifndef TENSORFLOW_LIB_IO_INPUTBUFFER_H_
+#define TENSORFLOW_LIB_IO_INPUTBUFFER_H_
+
+#include <string>
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace io {
+
+// An InputBuffer provides a buffer on top of a RandomAccessFile.
+// A given instance of an InputBuffer is NOT safe for concurrent use
+// by multiple threads
+class InputBuffer {
+ public:
+ // Create an InputBuffer for "file" with a buffer size of
+ // "buffer_bytes" bytes. Takes ownership of "file" and will
+ // delete it when the InputBuffer is destroyed.
+ InputBuffer(RandomAccessFile* file, size_t buffer_bytes);
+ ~InputBuffer();
+
+ // Read one text line of data into "*result" until end-of-file or a
+ // \n is read. (The \n is not included in the result.) Overwrites
+ // any existing data in *result.
+ //
+ // If successful, returns OK. If we are already at the end of the
+ // file, we return an OUT_OF_RANGE error. Otherwise, we return
+ // some other non-OK status.
+ Status ReadLine(string* result);
+
+ // Reads bytes_to_read bytes into *result, overwriting *result.
+ //
+ // If successful, returns OK. If we there are not enough bytes to
+ // read before the end of the file, we return an OUT_OF_RANGE error.
+ // Otherwise, we return some other non-OK status.
+ Status ReadNBytes(int64 bytes_to_read, string* result);
+
+ // Like ReadNBytes() without returning the bytes read.
+ Status SkipNBytes(int64 bytes_to_skip);
+
+ // Returns the position in the file.
+ int64 Tell() const { return file_pos_ - (limit_ - pos_); }
+
+ private:
+ Status FillBuffer();
+
+ RandomAccessFile* file_; // Owned
+ int64 file_pos_; // Next position to read from in "file_"
+ size_t size_; // Size of "buf_"
+ char* buf_; // The buffer itself
+ // [pos_,limit_) hold the "limit_ - pos_" bytes just before "file_pos_"
+ char* pos_; // Current position in "buf"
+ char* limit_; // Just past end of valid data in "buf"
+
+ TF_DISALLOW_COPY_AND_ASSIGN(InputBuffer);
+};
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_INPUTBUFFER_H_
diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc
new file mode 100644
index 0000000000..34094f018c
--- /dev/null
+++ b/tensorflow/core/lib/io/inputbuffer_test.cc
@@ -0,0 +1,174 @@
+#include "tensorflow/core/lib/io/inputbuffer.h"
+
+#include "tensorflow/core/public/env.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+static std::vector<int> BufferSizes() {
+ return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 65536};
+}
+
+TEST(InputBuffer, ReadLine_Empty) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string line;
+ io::InputBuffer in(file, buf_size);
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ }
+}
+
+TEST(InputBuffer, ReadLine1) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "line one\nline two\nline three\n");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string line;
+ io::InputBuffer in(file, buf_size);
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line one");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line two");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line three");
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ // A second call should also return end of file
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ }
+}
+
+TEST(InputBuffer, ReadLine_NoTrailingNewLine) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "line one\nline two\nline three");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string line;
+ io::InputBuffer in(file, buf_size);
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line one");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line two");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line three");
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ // A second call should also return end of file
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ }
+}
+
+TEST(InputBuffer, ReadLine_EmptyLines) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "line one\n\n\nline two\nline three");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string line;
+ io::InputBuffer in(file, buf_size);
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line one");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line two");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line three");
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ // A second call should also return end of file
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ }
+}
+
+TEST(InputBuffer, ReadNBytes) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "0123456789");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string read;
+ io::InputBuffer in(file, buf_size);
+ EXPECT_EQ(0, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(3, &read));
+ EXPECT_EQ(read, "012");
+ EXPECT_EQ(3, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(0, &read));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(3, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(4, &read));
+ EXPECT_EQ(read, "3456");
+ EXPECT_EQ(7, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(0, &read));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(7, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read)));
+ EXPECT_EQ(read, "789");
+ EXPECT_EQ(10, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read)));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(10, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(0, &read));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(10, in.Tell());
+ }
+}
+
+TEST(InputBuffer, SkipNBytes) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "0123456789");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string read;
+ io::InputBuffer in(file, buf_size);
+ EXPECT_EQ(0, in.Tell());
+ TF_CHECK_OK(in.SkipNBytes(3));
+ EXPECT_EQ(3, in.Tell());
+ TF_CHECK_OK(in.SkipNBytes(0));
+ EXPECT_EQ(3, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(2, &read));
+ EXPECT_EQ(read, "34");
+ EXPECT_EQ(5, in.Tell());
+ TF_CHECK_OK(in.SkipNBytes(0));
+ EXPECT_EQ(5, in.Tell());
+ TF_CHECK_OK(in.SkipNBytes(2));
+ EXPECT_EQ(7, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(1, &read));
+ EXPECT_EQ(read, "7");
+ EXPECT_EQ(8, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5)));
+ EXPECT_EQ(10, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5)));
+ EXPECT_EQ(10, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read)));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(10, in.Tell());
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/iterator.cc b/tensorflow/core/lib/io/iterator.cc
new file mode 100644
index 0000000000..878e93a911
--- /dev/null
+++ b/tensorflow/core/lib/io/iterator.cc
@@ -0,0 +1,72 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+namespace table {
+
+Iterator::Iterator() {
+ cleanup_.function = NULL;
+ cleanup_.next = NULL;
+}
+
+Iterator::~Iterator() {
+ if (cleanup_.function != NULL) {
+ (*cleanup_.function)(cleanup_.arg1, cleanup_.arg2);
+ for (Cleanup* c = cleanup_.next; c != NULL;) {
+ (*c->function)(c->arg1, c->arg2);
+ Cleanup* next = c->next;
+ delete c;
+ c = next;
+ }
+ }
+}
+
+void Iterator::RegisterCleanup(CleanupFunction func, void* arg1, void* arg2) {
+ assert(func != NULL);
+ Cleanup* c;
+ if (cleanup_.function == NULL) {
+ c = &cleanup_;
+ } else {
+ c = new Cleanup;
+ c->next = cleanup_.next;
+ cleanup_.next = c;
+ }
+ c->function = func;
+ c->arg1 = arg1;
+ c->arg2 = arg2;
+}
+
+namespace {
+class EmptyIterator : public Iterator {
+ public:
+ EmptyIterator(const Status& s) : status_(s) {}
+ virtual bool Valid() const { return false; }
+ virtual void Seek(const StringPiece& target) {}
+ virtual void SeekToFirst() {}
+ virtual void Next() { assert(false); }
+ StringPiece key() const {
+ assert(false);
+ return StringPiece();
+ }
+ StringPiece value() const {
+ assert(false);
+ return StringPiece();
+ }
+ virtual Status status() const { return status_; }
+
+ private:
+ Status status_;
+};
+} // namespace
+
+Iterator* NewEmptyIterator() { return new EmptyIterator(Status::OK()); }
+
+Iterator* NewErrorIterator(const Status& status) {
+ return new EmptyIterator(status);
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/iterator.h b/tensorflow/core/lib/io/iterator.h
new file mode 100644
index 0000000000..603a2f95fe
--- /dev/null
+++ b/tensorflow/core/lib/io/iterator.h
@@ -0,0 +1,93 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// An iterator yields a sequence of key/value pairs from a source.
+// The following class defines the interface. Multiple implementations
+// are provided by this library. In particular, iterators are provided
+// to access the contents of a Table or a DB.
+//
+// Multiple threads can invoke const methods on an Iterator without
+// external synchronization, but if any of the threads may call a
+// non-const method, all threads accessing the same Iterator must use
+// external synchronization.
+
+#ifndef TENSORFLOW_LIB_IO_ITERATOR_H_
+#define TENSORFLOW_LIB_IO_ITERATOR_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace table {
+
+class Iterator {
+ public:
+ Iterator();
+ virtual ~Iterator();
+
+ // An iterator is either positioned at a key/value pair, or
+ // not valid. This method returns true iff the iterator is valid.
+ virtual bool Valid() const = 0;
+
+ // Position at the first key in the source. The iterator is Valid()
+ // after this call iff the source is not empty.
+ virtual void SeekToFirst() = 0;
+
+ // Position at the first key in the source that is at or past target.
+ // The iterator is Valid() after this call iff the source contains
+ // an entry that comes at or past target.
+ virtual void Seek(const StringPiece& target) = 0;
+
+ // Moves to the next entry in the source. After this call, Valid() is
+ // true iff the iterator was not positioned at the last entry in the source.
+ // REQUIRES: Valid()
+ virtual void Next() = 0;
+
+ // Return the key for the current entry. The underlying storage for
+ // the returned slice is valid only until the next modification of
+ // the iterator.
+ // REQUIRES: Valid()
+ virtual StringPiece key() const = 0;
+
+ // Return the value for the current entry. The underlying storage for
+ // the returned slice is valid only until the next modification of
+ // the iterator.
+ // REQUIRES: Valid()
+ virtual StringPiece value() const = 0;
+
+ // If an error has occurred, return it. Else return an ok status.
+ virtual Status status() const = 0;
+
+ // Clients are allowed to register function/arg1/arg2 triples that
+ // will be invoked when this iterator is destroyed.
+ //
+ // Note that unlike all of the preceding methods, this method is
+ // not abstract and therefore clients should not override it.
+ typedef void (*CleanupFunction)(void* arg1, void* arg2);
+ void RegisterCleanup(CleanupFunction function, void* arg1, void* arg2);
+
+ private:
+ struct Cleanup {
+ CleanupFunction function;
+ void* arg1;
+ void* arg2;
+ Cleanup* next;
+ };
+ Cleanup cleanup_;
+
+ // No copying allowed
+ Iterator(const Iterator&);
+ void operator=(const Iterator&);
+};
+
+// Return an empty iterator (yields nothing).
+extern Iterator* NewEmptyIterator();
+
+// Return an empty iterator with the specified status.
+extern Iterator* NewErrorIterator(const Status& status);
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_ITERATOR_H_
diff --git a/tensorflow/core/lib/io/match.cc b/tensorflow/core/lib/io/match.cc
new file mode 100644
index 0000000000..1563642d0b
--- /dev/null
+++ b/tensorflow/core/lib/io/match.cc
@@ -0,0 +1,31 @@
+#include "tensorflow/core/lib/io/match.h"
+#include <fnmatch.h>
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace io {
+
+Status GetMatchingFiles(Env* env, const string& pattern,
+ std::vector<string>* results) {
+ results->clear();
+ std::vector<string> all_files;
+ string dir = Dirname(pattern).ToString();
+ if (dir.empty()) dir = ".";
+ string basename_pattern = Basename(pattern).ToString();
+ Status s = env->GetChildren(dir, &all_files);
+ if (!s.ok()) {
+ return s;
+ }
+ for (const auto& f : all_files) {
+ int flags = 0;
+ if (fnmatch(basename_pattern.c_str(), Basename(f).ToString().c_str(),
+ flags) == 0) {
+ results->push_back(JoinPath(dir, f));
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/match.h b/tensorflow/core/lib/io/match.h
new file mode 100644
index 0000000000..fd194178e7
--- /dev/null
+++ b/tensorflow/core/lib/io/match.h
@@ -0,0 +1,24 @@
+#ifndef TENSORFLOW_LIB_IO_MATCH_H_
+#define TENSORFLOW_LIB_IO_MATCH_H_
+
+#include <vector>
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+class Env;
+namespace io {
+
+// Given a pattern, return the set of files that match the pattern.
+// Note that this routine only supports wildcard characters in the
+// basename portion of the pattern, not in the directory portion. If
+// successful, return Status::OK and store the matching files in
+// "*results". Otherwise, return a non-OK status.
+Status GetMatchingFiles(Env* env, const string& pattern,
+ std::vector<string>* results);
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_MATCH_H_
diff --git a/tensorflow/core/lib/io/match_test.cc b/tensorflow/core/lib/io/match_test.cc
new file mode 100644
index 0000000000..aaa56e4e7e
--- /dev/null
+++ b/tensorflow/core/lib/io/match_test.cc
@@ -0,0 +1,51 @@
+#include <algorithm>
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/match.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/env.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace io {
+
+static string Match(Env* env, const string& suffix_pattern) {
+ std::vector<string> results;
+ Status s = GetMatchingFiles(env, JoinPath(testing::TmpDir(), suffix_pattern),
+ &results);
+ if (!s.ok()) {
+ return s.ToString();
+ } else {
+ string r;
+ std::sort(results.begin(), results.end());
+ for (size_t i = 0; i < results.size(); i++) {
+ strings::StrAppend(&r, (i > 0) ? "," : "", Basename(results[i]));
+ }
+ return r;
+ }
+}
+TEST(GetMatchingFiles, Simple) {
+ Env* env = Env::Default();
+ EXPECT_EQ(Match(env, "thereisnosuchfile"), "");
+ EXPECT_EQ(Match(env, "thereisnosuchfile*"), "");
+
+ // Populate a few files
+ EXPECT_OK(WriteStringToFile(Env::Default(),
+ JoinPath(testing::TmpDir(), "match-00"), ""));
+ EXPECT_OK(WriteStringToFile(Env::Default(),
+ JoinPath(testing::TmpDir(), "match-0a"), ""));
+ EXPECT_OK(WriteStringToFile(Env::Default(),
+ JoinPath(testing::TmpDir(), "match-01"), ""));
+ EXPECT_OK(WriteStringToFile(Env::Default(),
+ JoinPath(testing::TmpDir(), "match-aaa"), ""));
+
+ EXPECT_EQ(Match(env, "match-*"), "match-00,match-01,match-0a,match-aaa");
+ EXPECT_EQ(Match(env, "match-0[0-9]"), "match-00,match-01");
+ EXPECT_EQ(Match(env, "match-?[0-9]"), "match-00,match-01");
+ EXPECT_EQ(Match(env, "match-?a*"), "match-0a,match-aaa");
+ EXPECT_EQ(Match(env, "match-??"), "match-00,match-01,match-0a");
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc
new file mode 100644
index 0000000000..1359ded0f0
--- /dev/null
+++ b/tensorflow/core/lib/io/path.cc
@@ -0,0 +1,92 @@
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace io {
+
+string JoinPath(StringPiece part1, StringPiece part2) {
+ string result;
+
+ StringPiece paths[2] = {part1, part2};
+ for (StringPiece path : paths) {
+ if (path.empty()) continue;
+
+ if (result.empty()) {
+ result = path.ToString();
+ continue;
+ }
+
+ if (result[result.size() - 1] == '/') {
+ if (IsAbsolutePath(path)) {
+ strings::StrAppend(&result, path.substr(1));
+ } else {
+ strings::StrAppend(&result, path);
+ }
+ } else {
+ if (IsAbsolutePath(path)) {
+ strings::StrAppend(&result, path);
+ } else {
+ strings::StrAppend(&result, "/", path);
+ }
+ }
+ }
+
+ return result;
+}
+
+namespace internal {
+
+// Return the parts of the path, split on the final "/". If there is no
+// "/" in the path, the first part of the output is empty and the second
+// is the input. If the only "/" in the path is the first character, it is
+// the first part of the output.
+std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) {
+ auto pos = path.rfind('/');
+
+ // Handle the case with no '/' in 'path'.
+ if (pos == StringPiece::npos)
+ return std::make_pair(StringPiece(path.data(), 0), path);
+
+ // Handle the case with a single leading '/' in 'path'.
+ if (pos == 0)
+ return std::make_pair(StringPiece(path.data(), 1),
+ StringPiece(path.data() + 1, path.size() - 1));
+
+ return std::make_pair(
+ StringPiece(path.data(), pos),
+ StringPiece(path.data() + pos + 1, path.size() - (pos + 1)));
+}
+
+// Return the parts of the basename of path, split on the final ".".
+// If there is no "." in the basename or "." is the final character in the
+// basename, the second value will be empty.
+std::pair<StringPiece, StringPiece> SplitBasename(StringPiece path) {
+ path = Basename(path);
+
+ auto pos = path.rfind('.');
+ if (pos == StringPiece::npos)
+ return std::make_pair(path, StringPiece(path.data() + path.size(), 0));
+ return std::make_pair(
+ StringPiece(path.data(), pos),
+ StringPiece(path.data() + pos + 1, path.size() - (pos + 1)));
+}
+} // namespace internal
+
+bool IsAbsolutePath(StringPiece path) {
+ return !path.empty() && path[0] == '/';
+}
+
+StringPiece Dirname(StringPiece path) {
+ return internal::SplitPath(path).first;
+}
+
+StringPiece Basename(StringPiece path) {
+ return internal::SplitPath(path).second;
+}
+
+StringPiece Extension(StringPiece path) {
+ return internal::SplitBasename(path).second;
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h
new file mode 100644
index 0000000000..01483f1702
--- /dev/null
+++ b/tensorflow/core/lib/io/path.h
@@ -0,0 +1,47 @@
+#ifndef TENSORFLOW_LIB_IO_PATH_H_
+#define TENSORFLOW_LIB_IO_PATH_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+class StringPiece;
+namespace io {
+
+// Utility routines for processing filenames
+
+// Join multiple paths together, without introducing unnecessary path
+// separators.
+// For example:
+//
+// Arguments | JoinPath
+// ---------------------------+----------
+// '/foo', 'bar' | /foo/bar
+// '/foo/', 'bar' | /foo/bar
+// '/foo', '/bar' | /foo/bar
+//
+// Usage:
+// string path = io::JoinPath("/mydir", filename);
+// string path = io::JoinPath(FLAGS_test_srcdir, filename);
+string JoinPath(StringPiece part1, StringPiece part2);
+
+// Return true if path is absolute.
+bool IsAbsolutePath(StringPiece path);
+
+// Returns the part of the path before the final "/". If there is a single
+// leading "/" in the path, the result will be the leading "/". If there is
+// no "/" in the path, the result is the empty prefix of the input.
+StringPiece Dirname(StringPiece path);
+
+// Returns the part of the path after the final "/". If there is no
+// "/" in the path, the result is the same as the input.
+StringPiece Basename(StringPiece path);
+
+// Returns the part of the basename of path after the final ".". If
+// there is no "." in the basename, the result is empty.
+StringPiece Extension(StringPiece path);
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_PATH_H_
diff --git a/tensorflow/core/lib/io/path_test.cc b/tensorflow/core/lib/io/path_test.cc
new file mode 100644
index 0000000000..b670e44f1f
--- /dev/null
+++ b/tensorflow/core/lib/io/path_test.cc
@@ -0,0 +1,65 @@
+#include "tensorflow/core/lib/io/path.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace io {
+
+TEST(PathTest, JoinPath) {
+ EXPECT_EQ("/foo/bar", JoinPath("/foo", "bar"));
+ EXPECT_EQ("foo/bar", JoinPath("foo", "bar"));
+ EXPECT_EQ("foo/bar", JoinPath("foo", "/bar"));
+ EXPECT_EQ("/foo/bar", JoinPath("/foo", "/bar"));
+
+ EXPECT_EQ("/bar", JoinPath("", "/bar"));
+ EXPECT_EQ("bar", JoinPath("", "bar"));
+ EXPECT_EQ("/foo", JoinPath("/foo", ""));
+
+ EXPECT_EQ("/foo/bar/baz/blah/blink/biz",
+ JoinPath("/foo/bar/baz/", "/blah/blink/biz"));
+}
+
+TEST(PathTest, IsAbsolutePath) {
+ EXPECT_FALSE(IsAbsolutePath(""));
+ EXPECT_FALSE(IsAbsolutePath("../foo"));
+ EXPECT_FALSE(IsAbsolutePath("foo"));
+ EXPECT_FALSE(IsAbsolutePath("./foo"));
+ EXPECT_FALSE(IsAbsolutePath("foo/bar/baz/"));
+ EXPECT_TRUE(IsAbsolutePath("/foo"));
+ EXPECT_TRUE(IsAbsolutePath("/foo/bar/../baz"));
+}
+
+TEST(PathTest, Dirname) {
+ EXPECT_EQ("/hello", Dirname("/hello/"));
+ EXPECT_EQ("/", Dirname("/hello"));
+ EXPECT_EQ("hello", Dirname("hello/world"));
+ EXPECT_EQ("hello", Dirname("hello/"));
+ EXPECT_EQ("", Dirname("world"));
+ EXPECT_EQ("/", Dirname("/"));
+ EXPECT_EQ("", Dirname(""));
+}
+
+TEST(PathTest, Basename) {
+ EXPECT_EQ("", Basename("/hello/"));
+ EXPECT_EQ("hello", Basename("/hello"));
+ EXPECT_EQ("world", Basename("hello/world"));
+ EXPECT_EQ("", Basename("hello/"));
+ EXPECT_EQ("world", Basename("world"));
+ EXPECT_EQ("", Basename("/"));
+ EXPECT_EQ("", Basename(""));
+}
+
+TEST(PathTest, Extension) {
+ EXPECT_EQ("gif", Extension("foo.gif"));
+ EXPECT_EQ("", Extension("foo."));
+ EXPECT_EQ("", Extension(""));
+ EXPECT_EQ("", Extension("/"));
+ EXPECT_EQ("", Extension("foo"));
+ EXPECT_EQ("", Extension("foo/"));
+ EXPECT_EQ("gif", Extension("/a/path/to/foo.gif"));
+ EXPECT_EQ("html", Extension("/a/path.bar/to/foo.html"));
+ EXPECT_EQ("", Extension("/a/path.bar/to/foo"));
+ EXPECT_EQ("baz", Extension("/a/path.bar/to/foo.bar.baz"));
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
new file mode 100644
index 0000000000..2f0fabff63
--- /dev/null
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -0,0 +1,80 @@
+#include "tensorflow/core/lib/io/record_reader.h"
+
+#include <limits.h>
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace io {
+
+RecordReader::RecordReader(RandomAccessFile* file) : src_(file) {}
+
+RecordReader::~RecordReader() {}
+
+// Read n+4 bytes from file, verify that checksum of first n bytes is
+// stored in the last 4 bytes and store the first n bytes in *result.
+// May use *storage as backing store.
+static Status ReadChecksummed(RandomAccessFile* file, uint64 offset,
+ size_t n, StringPiece* result,
+ string* storage) {
+ if (n >= SIZE_MAX - sizeof(uint32)) {
+ return errors::DataLoss("record size too large");
+ }
+
+ const size_t expected = n + sizeof(uint32);
+ storage->resize(expected);
+ StringPiece data;
+ Status s = file->Read(offset, expected, &data, &(*storage)[0]);
+ if (!s.ok()) {
+ return s;
+ }
+ if (data.size() != expected) {
+ if (data.size() == 0) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
+ }
+ uint32 masked_crc = core::DecodeFixed32(data.data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(data.data(), n);
+ return Status::OK();
+}
+
+Status RecordReader::ReadRecord(uint64* offset, string* record) {
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
+ // Read length
+ StringPiece lbuf;
+ Status s = ReadChecksummed(src_, *offset, sizeof(uint64), &lbuf, record);
+ if (!s.ok()) {
+ return s;
+ }
+ const uint64 length = core::DecodeFixed64(lbuf.data());
+
+ // Read data
+ StringPiece data;
+ s = ReadChecksummed(src_, *offset + kHeaderSize, length, &data, record);
+ if (!s.ok()) {
+ if (errors::IsOutOfRange(s)) {
+ s = errors::DataLoss("truncated record at ", *offset);
+ }
+ return s;
+ }
+ if (record->data() != data.data()) {
+ // RandomAccessFile placed the data in some other location.
+ memmove(&(*record)[0], data.data(), data.size());
+ }
+
+ record->resize(data.size());
+ *offset += kHeaderSize + length + kFooterSize;
+ return Status::OK();
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
new file mode 100644
index 0000000000..a8c1b0dd5d
--- /dev/null
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -0,0 +1,36 @@
+#ifndef TENSORFLOW_LIB_IO_RECORD_READER_H_
+#define TENSORFLOW_LIB_IO_RECORD_READER_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class RandomAccessFile;
+
+namespace io {
+
+class RecordReader {
+ public:
+ // Create a reader that will return log records from "*file".
+ // "*file" must remain live while this Reader is in use.
+ explicit RecordReader(RandomAccessFile* file);
+
+ ~RecordReader();
+
+ // Read the record at "*offset" into *record and update *offset to
+ // point to the offset of the next record. Returns OK on success,
+ // OUT_OF_RANGE for end of file, or something else for an error.
+ Status ReadRecord(uint64* offset, string* record);
+
+ private:
+ RandomAccessFile* src_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
+};
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_RECORD_READER_H_
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
new file mode 100644
index 0000000000..3d7f1509ab
--- /dev/null
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -0,0 +1,42 @@
+#include "tensorflow/core/lib/io/record_writer.h"
+
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+
+namespace tensorflow {
+namespace io {
+
+RecordWriter::RecordWriter(WritableFile* dest) : dest_(dest) {}
+
+RecordWriter::~RecordWriter() {}
+
+static uint32 MaskedCrc(const char* data, size_t n) {
+ return crc32c::Mask(crc32c::Value(data, n));
+}
+
+Status RecordWriter::WriteRecord(StringPiece data) {
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ char header[sizeof(uint64) + sizeof(uint32)];
+ core::EncodeFixed64(header + 0, data.size());
+ core::EncodeFixed32(header + sizeof(uint64),
+ MaskedCrc(header, sizeof(uint64)));
+ Status s = dest_->Append(StringPiece(header, sizeof(header)));
+ if (!s.ok()) {
+ return s;
+ }
+ s = dest_->Append(data);
+ if (!s.ok()) {
+ return s;
+ }
+ char footer[sizeof(uint32)];
+ core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size()));
+ return dest_->Append(StringPiece(footer, sizeof(footer)));
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
new file mode 100644
index 0000000000..c7af00e5ae
--- /dev/null
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -0,0 +1,34 @@
+#ifndef TENSORFLOW_LIB_IO_RECORD_WRITER_H_
+#define TENSORFLOW_LIB_IO_RECORD_WRITER_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class WritableFile;
+
+namespace io {
+
+class RecordWriter {
+ public:
+ // Create a writer that will append data to "*dest".
+ // "*dest" must be initially empty.
+ // "*dest" must remain live while this Writer is in use.
+ explicit RecordWriter(WritableFile* dest);
+
+ ~RecordWriter();
+
+ Status WriteRecord(StringPiece slice);
+
+ private:
+ WritableFile* const dest_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter);
+};
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_RECORD_WRITER_H_
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc
new file mode 100644
index 0000000000..3e9c816443
--- /dev/null
+++ b/tensorflow/core/lib/io/recordio_test.cc
@@ -0,0 +1,245 @@
+#include "tensorflow/core/lib/io/record_reader.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace io {
+
+// Construct a string of the specified length made out of the supplied
+// partial string.
+static string BigString(const string& partial_string, size_t n) {
+ string result;
+ while (result.size() < n) {
+ result.append(partial_string);
+ }
+ result.resize(n);
+ return result;
+}
+
+// Construct a string from a number
+static string NumberString(int n) {
+ char buf[50];
+ snprintf(buf, sizeof(buf), "%d.", n);
+ return string(buf);
+}
+
+// Return a skewed potentially long string
+static string RandomSkewedString(int i, random::SimplePhilox* rnd) {
+ return BigString(NumberString(i), rnd->Skewed(17));
+}
+
+class RecordioTest : public testing::Test {
+ private:
+ class StringDest : public WritableFile {
+ public:
+ string contents_;
+
+ Status Close() override { return Status::OK(); }
+ Status Flush() override { return Status::OK(); }
+ Status Sync() override { return Status::OK(); }
+ Status Append(const StringPiece& slice) override {
+ contents_.append(slice.data(), slice.size());
+ return Status::OK();
+ }
+ };
+
+ class StringSource : public RandomAccessFile {
+ public:
+ StringPiece contents_;
+ mutable bool force_error_;
+ mutable bool returned_partial_;
+ StringSource() : force_error_(false), returned_partial_(false) {}
+
+ Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const override {
+ EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error";
+
+ if (force_error_) {
+ force_error_ = false;
+ returned_partial_ = true;
+ return errors::DataLoss("read error");
+ }
+
+ if (offset >= contents_.size()) {
+ return errors::OutOfRange("end of file");
+ }
+
+ if (contents_.size() < offset + n) {
+ n = contents_.size() - offset;
+ returned_partial_ = true;
+ }
+ *result = StringPiece(contents_.data() + offset, n);
+ return Status::OK();
+ }
+ };
+
+ StringDest dest_;
+ StringSource source_;
+ bool reading_;
+ uint64 readpos_;
+ RecordWriter* writer_;
+ RecordReader* reader_;
+
+ public:
+ RecordioTest()
+ : reading_(false),
+ readpos_(0),
+ writer_(new RecordWriter(&dest_)),
+ reader_(new RecordReader(&source_)) {}
+
+ ~RecordioTest() override {
+ delete writer_;
+ delete reader_;
+ }
+
+ void Write(const string& msg) {
+ ASSERT_TRUE(!reading_) << "Write() after starting to read";
+ ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
+ }
+
+ size_t WrittenBytes() const { return dest_.contents_.size(); }
+
+ string Read() {
+ if (!reading_) {
+ reading_ = true;
+ source_.contents_ = StringPiece(dest_.contents_);
+ }
+ string record;
+ Status s = reader_->ReadRecord(&readpos_, &record);
+ if (s.ok()) {
+ return record;
+ } else if (errors::IsOutOfRange(s)) {
+ return "EOF";
+ } else {
+ return s.ToString();
+ }
+ }
+
+ void IncrementByte(int offset, int delta) {
+ dest_.contents_[offset] += delta;
+ }
+
+ void SetByte(int offset, char new_byte) {
+ dest_.contents_[offset] = new_byte;
+ }
+
+ void ShrinkSize(int bytes) {
+ dest_.contents_.resize(dest_.contents_.size() - bytes);
+ }
+
+ void FixChecksum(int header_offset, int len) {
+ // Compute crc of type/len/data
+ uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len);
+ crc = crc32c::Mask(crc);
+ core::EncodeFixed32(&dest_.contents_[header_offset], crc);
+ }
+
+ void ForceError() { source_.force_error_ = true; }
+
+ void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
+
+ void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) {
+ Write("foo");
+ Write("bar");
+ Write(BigString("x", 10000));
+ reading_ = true;
+ source_.contents_ = StringPiece(dest_.contents_);
+ uint64 offset = WrittenBytes() + offset_past_end;
+ string record;
+ Status s = reader_->ReadRecord(&offset, &record);
+ ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
+ }
+};
+
+TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); }
+
+TEST_F(RecordioTest, ReadWrite) {
+ Write("foo");
+ Write("bar");
+ Write("");
+ Write("xxxx");
+ ASSERT_EQ("foo", Read());
+ ASSERT_EQ("bar", Read());
+ ASSERT_EQ("", Read());
+ ASSERT_EQ("xxxx", Read());
+ ASSERT_EQ("EOF", Read());
+ ASSERT_EQ("EOF", Read()); // Make sure reads at eof work
+}
+
+TEST_F(RecordioTest, ManyRecords) {
+ for (int i = 0; i < 100000; i++) {
+ Write(NumberString(i));
+ }
+ for (int i = 0; i < 100000; i++) {
+ ASSERT_EQ(NumberString(i), Read());
+ }
+ ASSERT_EQ("EOF", Read());
+}
+
+TEST_F(RecordioTest, RandomRead) {
+ const int N = 500;
+ {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int i = 0; i < N; i++) {
+ Write(RandomSkewedString(i, &rnd));
+ }
+ }
+ {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int i = 0; i < N; i++) {
+ ASSERT_EQ(RandomSkewedString(i, &rnd), Read());
+ }
+ }
+ ASSERT_EQ("EOF", Read());
+}
+
+// Tests of all the error paths in log_reader.cc follow:
+static void AssertHasSubstr(StringPiece s, StringPiece expected) {
+ EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain "
+ << expected;
+}
+
+TEST_F(RecordioTest, ReadError) {
+ Write("foo");
+ ForceError();
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptLength) {
+ Write("foo");
+ IncrementByte(6, 100);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptLengthCrc) {
+ Write("foo");
+ IncrementByte(10, 100);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptData) {
+ Write("foo");
+ IncrementByte(14, 10);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptDataCrc) {
+ Write("foo");
+ IncrementByte(WrittenBytes() - 1, 10);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); }
+
+TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/table.cc b/tensorflow/core/lib/io/table.cc
new file mode 100644
index 0000000000..769d7e72a5
--- /dev/null
+++ b/tensorflow/core/lib/io/table.cc
@@ -0,0 +1,169 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/table.h"
+
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/block.h"
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/lib/io/two_level_iterator.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace table {
+
+struct Table::Rep {
+ ~Rep() { delete index_block; }
+
+ Options options;
+ Status status;
+ RandomAccessFile* file;
+ // XXX uint64 cache_id;
+
+ BlockHandle metaindex_handle; // Handle to metaindex_block: saved from footer
+ Block* index_block;
+};
+
+Status Table::Open(const Options& options, RandomAccessFile* file,
+ uint64 size, Table** table) {
+ *table = NULL;
+ if (size < Footer::kEncodedLength) {
+ return errors::DataLoss("file is too short to be an sstable");
+ }
+
+ char footer_space[Footer::kEncodedLength];
+ StringPiece footer_input;
+ Status s =
+ file->Read(size - Footer::kEncodedLength, Footer::kEncodedLength,
+ &footer_input, footer_space);
+ if (!s.ok()) return s;
+
+ Footer footer;
+ s = footer.DecodeFrom(&footer_input);
+ if (!s.ok()) return s;
+
+ // Read the index block
+ BlockContents contents;
+ Block* index_block = NULL;
+ if (s.ok()) {
+ s = ReadBlock(file, footer.index_handle(), &contents);
+ if (s.ok()) {
+ index_block = new Block(contents);
+ }
+ }
+
+ if (s.ok()) {
+ // We've successfully read the footer and the index block: we're
+ // ready to serve requests.
+ Rep* rep = new Table::Rep;
+ rep->options = options;
+ rep->file = file;
+ rep->metaindex_handle = footer.metaindex_handle();
+ rep->index_block = index_block;
+ // XXX rep->cache_id = (options.block_cache ?
+ // options.block_cache->NewId() : 0);
+ *table = new Table(rep);
+ } else {
+ if (index_block) delete index_block;
+ }
+
+ return s;
+}
+
+Table::~Table() { delete rep_; }
+
+static void DeleteBlock(void* arg, void* ignored) {
+ delete reinterpret_cast<Block*>(arg);
+}
+
+// Convert an index iterator value (i.e., an encoded BlockHandle)
+// into an iterator over the contents of the corresponding block.
+Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) {
+ Table* table = reinterpret_cast<Table*>(arg);
+ // Cache* block_cache = table->rep_->options.block_cache;
+ Block* block = NULL;
+ // Cache::Handle* cache_handle = NULL;
+
+ BlockHandle handle;
+ StringPiece input = index_value;
+ Status s = handle.DecodeFrom(&input);
+ // We intentionally allow extra stuff in index_value so that we
+ // can add more features in the future.
+
+ if (s.ok()) {
+ BlockContents contents;
+ s = ReadBlock(table->rep_->file, handle, &contents);
+ if (s.ok()) {
+ block = new Block(contents);
+ }
+ }
+
+ Iterator* iter;
+ if (block != NULL) {
+ iter = block->NewIterator();
+ iter->RegisterCleanup(&DeleteBlock, block, NULL);
+ } else {
+ iter = NewErrorIterator(s);
+ }
+ return iter;
+}
+
+Iterator* Table::NewIterator() const {
+ return NewTwoLevelIterator(rep_->index_block->NewIterator(),
+ &Table::BlockReader, const_cast<Table*>(this));
+}
+
+Status Table::InternalGet(const StringPiece& k, void* arg,
+ void (*saver)(void*, const StringPiece&,
+ const StringPiece&)) {
+ Status s;
+ Iterator* iiter = rep_->index_block->NewIterator();
+ iiter->Seek(k);
+ if (iiter->Valid()) {
+ BlockHandle handle;
+ Iterator* block_iter = BlockReader(this, iiter->value());
+ block_iter->Seek(k);
+ if (block_iter->Valid()) {
+ (*saver)(arg, block_iter->key(), block_iter->value());
+ }
+ s = block_iter->status();
+ delete block_iter;
+ }
+ if (s.ok()) {
+ s = iiter->status();
+ }
+ delete iiter;
+ return s;
+}
+
+uint64 Table::ApproximateOffsetOf(const StringPiece& key) const {
+ Iterator* index_iter = rep_->index_block->NewIterator();
+ index_iter->Seek(key);
+ uint64 result;
+ if (index_iter->Valid()) {
+ BlockHandle handle;
+ StringPiece input = index_iter->value();
+ Status s = handle.DecodeFrom(&input);
+ if (s.ok()) {
+ result = handle.offset();
+ } else {
+ // Strange: we can't decode the block handle in the index block.
+ // We'll just return the offset of the metaindex block, which is
+ // close to the whole file size for this case.
+ result = rep_->metaindex_handle.offset();
+ }
+ } else {
+ // key is past the last key in the file. Approximate the offset
+ // by returning the offset of the metaindex block (which is
+ // right near the end of the file).
+ result = rep_->metaindex_handle.offset();
+ }
+ delete index_iter;
+ return result;
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h
new file mode 100644
index 0000000000..230dded2d4
--- /dev/null
+++ b/tensorflow/core/lib/io/table.h
@@ -0,0 +1,76 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_TABLE_H_
+#define TENSORFLOW_LIB_IO_TABLE_H_
+
+#include <stdint.h>
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+class RandomAccessFile;
+
+namespace table {
+
+class Block;
+class BlockHandle;
+class Footer;
+struct Options;
+
+// A Table is a sorted map from strings to strings. Tables are
+// immutable and persistent. A Table may be safely accessed from
+// multiple threads without external synchronization.
+class Table {
+ public:
+ // Attempt to open the table that is stored in bytes [0..file_size)
+ // of "file", and read the metadata entries necessary to allow
+ // retrieving data from the table.
+ //
+ // If successful, returns ok and sets "*table" to the newly opened
+ // table. The client should delete "*table" when no longer needed.
+ // If there was an error while initializing the table, sets "*table"
+ // to NULL and returns a non-ok status. Does not take ownership of
+ // "*file", but the client must ensure that "file" remains live
+ // for the duration of the returned table's lifetime.
+ static Status Open(const Options& options, RandomAccessFile* file,
+ uint64 file_size, Table** table);
+
+ ~Table();
+
+ // Returns a new iterator over the table contents.
+ // The result of NewIterator() is initially invalid (caller must
+ // call one of the Seek methods on the iterator before using it).
+ Iterator* NewIterator() const;
+
+ // Given a key, return an approximate byte offset in the file where
+ // the data for that key begins (or would begin if the key were
+ // present in the file). The returned value is in terms of file
+ // bytes, and so includes effects like compression of the underlying data.
+ // E.g., the approximate offset of the last key in the table will
+ // be close to the file length.
+ uint64 ApproximateOffsetOf(const StringPiece& key) const;
+
+ private:
+ struct Rep;
+ Rep* rep_;
+
+ explicit Table(Rep* rep) { rep_ = rep; }
+ static Iterator* BlockReader(void*, const StringPiece&);
+
+ // Calls (*handle_result)(arg, ...) with the entry found after a call
+ // to Seek(key). May not make such a call if filter policy says
+ // that key is not present.
+ Status InternalGet(const StringPiece& key, void* arg,
+ void (*handle_result)(void* arg, const StringPiece& k,
+ const StringPiece& v));
+
+ // No copying allowed
+ Table(const Table&);
+ void operator=(const Table&);
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_TABLE_H_
diff --git a/tensorflow/core/lib/io/table_builder.cc b/tensorflow/core/lib/io/table_builder.cc
new file mode 100644
index 0000000000..b786888b30
--- /dev/null
+++ b/tensorflow/core/lib/io/table_builder.cc
@@ -0,0 +1,263 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/table_builder.h"
+
+#include <assert.h>
+#include "tensorflow/core/lib/io/block_builder.h"
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace table {
+
+namespace {
+
+void FindShortestSeparator(string* start, const StringPiece& limit) {
+ // Find length of common prefix
+ size_t min_length = std::min(start->size(), limit.size());
+ size_t diff_index = 0;
+ while ((diff_index < min_length) &&
+ ((*start)[diff_index] == limit[diff_index])) {
+ diff_index++;
+ }
+
+ if (diff_index >= min_length) {
+ // Do not shorten if one string is a prefix of the other
+ } else {
+ uint8 diff_byte = static_cast<uint8>((*start)[diff_index]);
+ if (diff_byte < static_cast<uint8>(0xff) &&
+ diff_byte + 1 < static_cast<uint8>(limit[diff_index])) {
+ (*start)[diff_index]++;
+ start->resize(diff_index + 1);
+ assert(StringPiece(*start).compare(limit) < 0);
+ }
+ }
+}
+
+void FindShortSuccessor(string* key) {
+ // Find first character that can be incremented
+ size_t n = key->size();
+ for (size_t i = 0; i < n; i++) {
+ const uint8 byte = (*key)[i];
+ if (byte != static_cast<uint8>(0xff)) {
+ (*key)[i] = byte + 1;
+ key->resize(i + 1);
+ return;
+ }
+ }
+ // *key is a run of 0xffs. Leave it alone.
+}
+} // namespace
+
+struct TableBuilder::Rep {
+ Options options;
+ Options index_block_options;
+ WritableFile* file;
+ uint64 offset;
+ Status status;
+ BlockBuilder data_block;
+ BlockBuilder index_block;
+ string last_key;
+ int64 num_entries;
+ bool closed; // Either Finish() or Abandon() has been called.
+
+ // We do not emit the index entry for a block until we have seen the
+ // first key for the next data block. This allows us to use shorter
+ // keys in the index block. For example, consider a block boundary
+ // between the keys "the quick brown fox" and "the who". We can use
+ // "the r" as the key for the index block entry since it is >= all
+ // entries in the first block and < all entries in subsequent
+ // blocks.
+ //
+ // Invariant: r->pending_index_entry is true only if data_block is empty.
+ bool pending_index_entry;
+ BlockHandle pending_handle; // Handle to add to index block
+
+ string compressed_output;
+
+ Rep(const Options& opt, WritableFile* f)
+ : options(opt),
+ index_block_options(opt),
+ file(f),
+ offset(0),
+ data_block(&options),
+ index_block(&index_block_options),
+ num_entries(0),
+ closed(false),
+ pending_index_entry(false) {
+ index_block_options.block_restart_interval = 1;
+ }
+};
+
+TableBuilder::TableBuilder(const Options& options, WritableFile* file)
+ : rep_(new Rep(options, file)) {}
+
+TableBuilder::~TableBuilder() {
+ assert(rep_->closed); // Catch errors where caller forgot to call Finish()
+ delete rep_;
+}
+
+void TableBuilder::Add(const StringPiece& key, const StringPiece& value) {
+ Rep* r = rep_;
+ assert(!r->closed);
+ if (!ok()) return;
+ if (r->num_entries > 0) {
+ assert(key.compare(StringPiece(r->last_key)) > 0);
+ // See if this key+value would make our current block overly large. If
+ // so, emit the current block before adding this key/value
+ const int kOverlyLargeBlockRatio = 2;
+ const size_t this_entry_bytes = key.size() + value.size();
+ if (this_entry_bytes >= kOverlyLargeBlockRatio * r->options.block_size) {
+ Flush();
+ }
+ }
+
+ if (r->pending_index_entry) {
+ assert(r->data_block.empty());
+ FindShortestSeparator(&r->last_key, key);
+ string handle_encoding;
+ r->pending_handle.EncodeTo(&handle_encoding);
+ r->index_block.Add(r->last_key, StringPiece(handle_encoding));
+ r->pending_index_entry = false;
+ }
+
+ r->last_key.assign(key.data(), key.size());
+ r->num_entries++;
+ r->data_block.Add(key, value);
+
+ const size_t estimated_block_size = r->data_block.CurrentSizeEstimate();
+ if (estimated_block_size >= r->options.block_size) {
+ Flush();
+ }
+}
+
+void TableBuilder::Flush() {
+ Rep* r = rep_;
+ assert(!r->closed);
+ if (!ok()) return;
+ if (r->data_block.empty()) return;
+ assert(!r->pending_index_entry);
+ WriteBlock(&r->data_block, &r->pending_handle);
+ if (ok()) {
+ r->pending_index_entry = true;
+ r->status = r->file->Flush();
+ }
+}
+
+void TableBuilder::WriteBlock(BlockBuilder* block, BlockHandle* handle) {
+ // File format contains a sequence of blocks where each block has:
+ // block_data: uint8[n]
+ // type: uint8
+ // crc: uint32
+ assert(ok());
+ Rep* r = rep_;
+ StringPiece raw = block->Finish();
+
+ StringPiece block_contents;
+ CompressionType type = r->options.compression;
+ // TODO(postrelease): Support more compression options: zlib?
+ switch (type) {
+ case kNoCompression:
+ block_contents = raw;
+ break;
+
+ case kSnappyCompression: {
+ string* compressed = &r->compressed_output;
+ if (port::Snappy_Compress(raw.data(), raw.size(), compressed) &&
+ compressed->size() < raw.size() - (raw.size() / 8u)) {
+ block_contents = *compressed;
+ } else {
+ // Snappy not supported, or compressed less than 12.5%, so just
+ // store uncompressed form
+ block_contents = raw;
+ type = kNoCompression;
+ }
+ break;
+ }
+ }
+ WriteRawBlock(block_contents, type, handle);
+ r->compressed_output.clear();
+ block->Reset();
+}
+
+void TableBuilder::WriteRawBlock(const StringPiece& block_contents,
+ CompressionType type, BlockHandle* handle) {
+ Rep* r = rep_;
+ handle->set_offset(r->offset);
+ handle->set_size(block_contents.size());
+ r->status = r->file->Append(block_contents);
+ if (r->status.ok()) {
+ char trailer[kBlockTrailerSize];
+ trailer[0] = type;
+ uint32 crc = crc32c::Value(block_contents.data(), block_contents.size());
+ crc = crc32c::Extend(crc, trailer, 1); // Extend crc to cover block type
+ core::EncodeFixed32(trailer + 1, crc32c::Mask(crc));
+ r->status = r->file->Append(StringPiece(trailer, kBlockTrailerSize));
+ if (r->status.ok()) {
+ r->offset += block_contents.size() + kBlockTrailerSize;
+ }
+ }
+}
+
+Status TableBuilder::status() const { return rep_->status; }
+
+Status TableBuilder::Finish() {
+ Rep* r = rep_;
+ Flush();
+ assert(!r->closed);
+ r->closed = true;
+
+ BlockHandle metaindex_block_handle, index_block_handle;
+
+ // Write metaindex block
+ if (ok()) {
+ BlockBuilder meta_index_block(&r->options);
+ // TODO(postrelease): Add stats and other meta blocks
+ WriteBlock(&meta_index_block, &metaindex_block_handle);
+ }
+
+ // Write index block
+ if (ok()) {
+ if (r->pending_index_entry) {
+ FindShortSuccessor(&r->last_key);
+ string handle_encoding;
+ r->pending_handle.EncodeTo(&handle_encoding);
+ r->index_block.Add(r->last_key, StringPiece(handle_encoding));
+ r->pending_index_entry = false;
+ }
+ WriteBlock(&r->index_block, &index_block_handle);
+ }
+
+ // Write footer
+ if (ok()) {
+ Footer footer;
+ footer.set_metaindex_handle(metaindex_block_handle);
+ footer.set_index_handle(index_block_handle);
+ string footer_encoding;
+ footer.EncodeTo(&footer_encoding);
+ r->status = r->file->Append(footer_encoding);
+ if (r->status.ok()) {
+ r->offset += footer_encoding.size();
+ }
+ }
+ return r->status;
+}
+
+void TableBuilder::Abandon() {
+ Rep* r = rep_;
+ assert(!r->closed);
+ r->closed = true;
+}
+
+uint64 TableBuilder::NumEntries() const { return rep_->num_entries; }
+
+uint64 TableBuilder::FileSize() const { return rep_->offset; }
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h
new file mode 100644
index 0000000000..cebf4d8e0c
--- /dev/null
+++ b/tensorflow/core/lib/io/table_builder.h
@@ -0,0 +1,87 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// TableBuilder provides the interface used to build a Table
+// (an immutable and sorted map from keys to values).
+//
+// Multiple threads can invoke const methods on a TableBuilder without
+// external synchronization, but if any of the threads may call a
+// non-const method, all threads accessing the same TableBuilder must use
+// external synchronization.
+
+#ifndef TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
+#define TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
+
+#include <stdint.h>
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+class WritableFile;
+namespace table {
+
+class BlockBuilder;
+class BlockHandle;
+
+class TableBuilder {
+ public:
+ // Create a builder that will store the contents of the table it is
+ // building in *file. Does not close the file. It is up to the
+ // caller to close the file after calling Finish().
+ TableBuilder(const Options& options, WritableFile* file);
+
+ // REQUIRES: Either Finish() or Abandon() has been called.
+ ~TableBuilder();
+
+ // Add key,value to the table being constructed.
+ // REQUIRES: key is after any previously added key in lexicographic order.
+ // REQUIRES: Finish(), Abandon() have not been called
+ void Add(const StringPiece& key, const StringPiece& value);
+
+ // Advanced operation: flush any buffered key/value pairs to file.
+ // Can be used to ensure that two adjacent entries never live in
+ // the same data block. Most clients should not need to use this method.
+ // REQUIRES: Finish(), Abandon() have not been called
+ void Flush();
+
+ // Return non-ok iff some error has been detected.
+ Status status() const;
+
+ // Finish building the table. Stops using the file passed to the
+ // constructor after this function returns.
+ // REQUIRES: Finish(), Abandon() have not been called
+ Status Finish();
+
+ // Indicate that the contents of this builder should be abandoned. Stops
+ // using the file passed to the constructor after this function returns.
+ // If the caller is not going to call Finish(), it must call Abandon()
+ // before destroying this builder.
+ // REQUIRES: Finish(), Abandon() have not been called
+ void Abandon();
+
+ // Number of calls to Add() so far.
+ uint64 NumEntries() const;
+
+ // Size of the file generated so far. If invoked after a successful
+ // Finish() call, returns the size of the final generated file.
+ uint64 FileSize() const;
+
+ private:
+ bool ok() const { return status().ok(); }
+ void WriteBlock(BlockBuilder* block, BlockHandle* handle);
+ void WriteRawBlock(const StringPiece& data, CompressionType,
+ BlockHandle* handle);
+
+ struct Rep;
+ Rep* rep_;
+
+ // No copying allowed
+ TableBuilder(const TableBuilder&);
+ void operator=(const TableBuilder&);
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
diff --git a/tensorflow/core/lib/io/table_format.txt b/tensorflow/core/lib/io/table_format.txt
new file mode 100644
index 0000000000..7edb9fb121
--- /dev/null
+++ b/tensorflow/core/lib/io/table_format.txt
@@ -0,0 +1,8 @@
+File format
+===========
+
+The table format is heavily based on the table format for the LevelDB
+open source key/value store, with the exception that our tables
+do not support "filter" meta blocks (Bloom Filters). See:
+
+https://code.google.com/p/leveldb/source/browse/doc/table_format.txt
diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h
new file mode 100644
index 0000000000..45b061b03b
--- /dev/null
+++ b/tensorflow/core/lib/io/table_options.h
@@ -0,0 +1,53 @@
+#ifndef TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
+#define TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
+
+#include <stddef.h>
+
+namespace tensorflow {
+namespace table {
+
+// DB contents are stored in a set of blocks, each of which holds a
+// sequence of key,value pairs. Each block may be compressed before
+// being stored in a file. The following enum describes which
+// compression method (if any) is used to compress a block.
+enum CompressionType {
+ // NOTE: do not change the values of existing entries, as these are
+ // part of the persistent format on disk.
+ kNoCompression = 0x0,
+ kSnappyCompression = 0x1
+};
+
+// Options to control the behavior of a table (passed to Table::Open)
+struct Options {
+ // Approximate size of user data packed per block. Note that the
+ // block size specified here corresponds to uncompressed data. The
+ // actual size of the unit read from disk may be smaller if
+ // compression is enabled. This parameter can be changed dynamically.
+ size_t block_size = 262144;
+
+ // Number of keys between restart points for delta encoding of keys.
+ // This parameter can be changed dynamically. Most clients should
+ // leave this parameter alone.
+ int block_restart_interval = 16;
+
+ // Compress blocks using the specified compression algorithm. This
+ // parameter can be changed dynamically.
+ //
+ // Default: kSnappyCompression, which gives lightweight but fast
+ // compression.
+ //
+ // Typical speeds of kSnappyCompression on an Intel(R) Core(TM)2 2.4GHz:
+ // ~200-500MB/s compression
+ // ~400-800MB/s decompression
+ // Note that these speeds are significantly faster than most
+ // persistent storage speeds, and therefore it is typically never
+ // worth switching to kNoCompression. Even if the input data is
+ // incompressible, the kSnappyCompression implementation will
+ // efficiently detect that and will switch to uncompressed mode.
+ CompressionType compression = kSnappyCompression;
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc
new file mode 100644
index 0000000000..66e90ac64e
--- /dev/null
+++ b/tensorflow/core/lib/io/table_test.cc
@@ -0,0 +1,601 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/table.h"
+
+#include <map>
+#include <string>
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/block.h"
+#include "tensorflow/core/lib/io/block_builder.h"
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/io/iterator.h"
+#include "tensorflow/core/lib/io/table_builder.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace table {
+
+namespace test {
+static StringPiece RandomString(random::SimplePhilox* rnd, int len,
+ string* dst) {
+ dst->resize(len);
+ for (int i = 0; i < len; i++) {
+ (*dst)[i] = static_cast<char>(' ' + rnd->Uniform(95)); // ' ' .. '~'
+ }
+ return StringPiece(*dst);
+}
+static string RandomKey(random::SimplePhilox* rnd, int len) {
+ // Make sure to generate a wide variety of characters so we
+ // test the boundary conditions for short-key optimizations.
+ static const char kTestChars[] = {'\0', '\1', 'a', 'b', 'c',
+ 'd', 'e', '\xfd', '\xfe', '\xff'};
+ string result;
+ for (int i = 0; i < len; i++) {
+ result += kTestChars[rnd->Uniform(sizeof(kTestChars))];
+ }
+ return result;
+}
+static StringPiece CompressibleString(random::SimplePhilox* rnd,
+ double compressed_fraction, size_t len,
+ string* dst) {
+ int raw = static_cast<int>(len * compressed_fraction);
+ if (raw < 1) raw = 1;
+ string raw_data;
+ RandomString(rnd, raw, &raw_data);
+
+ // Duplicate the random data until we have filled "len" bytes
+ dst->clear();
+ while (dst->size() < len) {
+ dst->append(raw_data);
+ }
+ dst->resize(len);
+ return StringPiece(*dst);
+}
+}
+
+static void Increment(string* key) { key->push_back('\0'); }
+
+// An STL comparator that compares two StringPieces
+namespace {
+struct STLLessThan {
+ STLLessThan() {}
+ bool operator()(const string& a, const string& b) const {
+ return StringPiece(a).compare(StringPiece(b)) < 0;
+ }
+};
+} // namespace
+
+class StringSink : public WritableFile {
+ public:
+ ~StringSink() {}
+
+ const string& contents() const { return contents_; }
+
+ virtual Status Close() { return Status::OK(); }
+ virtual Status Flush() { return Status::OK(); }
+ virtual Status Sync() { return Status::OK(); }
+
+ virtual Status Append(const StringPiece& data) {
+ contents_.append(data.data(), data.size());
+ return Status::OK();
+ }
+
+ private:
+ string contents_;
+};
+
+class StringSource : public RandomAccessFile {
+ public:
+ StringSource(const StringPiece& contents)
+ : contents_(contents.data(), contents.size()), bytes_read_(0) {}
+
+ virtual ~StringSource() {}
+
+ uint64 Size() const { return contents_.size(); }
+
+ virtual Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const {
+ if (offset > contents_.size()) {
+ return errors::InvalidArgument("invalid Read offset");
+ }
+ if (offset + n > contents_.size()) {
+ n = contents_.size() - offset;
+ }
+ memcpy(scratch, &contents_[offset], n);
+ *result = StringPiece(scratch, n);
+ bytes_read_ += n;
+ return Status::OK();
+ }
+
+ uint64 BytesRead() const { return bytes_read_; }
+
+ private:
+ string contents_;
+ mutable uint64 bytes_read_;
+};
+
+typedef std::map<string, string, STLLessThan> KVMap;
+
+// Helper class for tests to unify the interface between
+// BlockBuilder/TableBuilder and Block/Table.
+class Constructor {
+ public:
+ explicit Constructor() : data_(STLLessThan()) {}
+ virtual ~Constructor() {}
+
+ void Add(const string& key, const StringPiece& value) {
+ data_[key] = value.ToString();
+ }
+
+ // Finish constructing the data structure with all the keys that have
+ // been added so far. Returns the keys in sorted order in "*keys"
+ // and stores the key/value pairs in "*kvmap"
+ void Finish(const Options& options, std::vector<string>* keys, KVMap* kvmap) {
+ *kvmap = data_;
+ keys->clear();
+ for (KVMap::const_iterator it = data_.begin(); it != data_.end(); ++it) {
+ keys->push_back(it->first);
+ }
+ data_.clear();
+ Status s = FinishImpl(options, *kvmap);
+ ASSERT_TRUE(s.ok()) << s.ToString();
+ }
+
+ // Construct the data structure from the data in "data"
+ virtual Status FinishImpl(const Options& options, const KVMap& data) = 0;
+
+ virtual Iterator* NewIterator() const = 0;
+
+ virtual const KVMap& data() { return data_; }
+
+ private:
+ KVMap data_;
+};
+
+class BlockConstructor : public Constructor {
+ public:
+ BlockConstructor() : block_(NULL) {}
+ ~BlockConstructor() { delete block_; }
+ virtual Status FinishImpl(const Options& options, const KVMap& data) {
+ delete block_;
+ block_ = NULL;
+ BlockBuilder builder(&options);
+
+ for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) {
+ builder.Add(it->first, it->second);
+ }
+ // Open the block
+ data_ = builder.Finish().ToString();
+ BlockContents contents;
+ contents.data = data_;
+ contents.cachable = false;
+ contents.heap_allocated = false;
+ block_ = new Block(contents);
+ return Status::OK();
+ }
+ virtual Iterator* NewIterator() const { return block_->NewIterator(); }
+
+ private:
+ string data_;
+ Block* block_;
+};
+
+class TableConstructor : public Constructor {
+ public:
+ TableConstructor() : source_(NULL), table_(NULL) {}
+ ~TableConstructor() { Reset(); }
+ virtual Status FinishImpl(const Options& options, const KVMap& data) {
+ Reset();
+ StringSink sink;
+ TableBuilder builder(options, &sink);
+
+ for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) {
+ builder.Add(it->first, it->second);
+ TF_CHECK_OK(builder.status());
+ }
+ Status s = builder.Finish();
+ TF_CHECK_OK(s) << s.ToString();
+
+ CHECK_EQ(sink.contents().size(), builder.FileSize());
+
+ // Open the table
+ source_ = new StringSource(sink.contents());
+ Options table_options;
+ return Table::Open(table_options, source_, sink.contents().size(), &table_);
+ }
+
+ virtual Iterator* NewIterator() const { return table_->NewIterator(); }
+
+ uint64 ApproximateOffsetOf(const StringPiece& key) const {
+ return table_->ApproximateOffsetOf(key);
+ }
+
+ uint64 BytesRead() const { return source_->BytesRead(); }
+
+ private:
+ void Reset() {
+ delete table_;
+ delete source_;
+ table_ = NULL;
+ source_ = NULL;
+ }
+
+ StringSource* source_;
+ Table* table_;
+};
+
+enum TestType { TABLE_TEST, BLOCK_TEST };
+
+struct TestArgs {
+ TestType type;
+ int restart_interval;
+};
+
+static const TestArgs kTestArgList[] = {
+ {TABLE_TEST, 16}, {TABLE_TEST, 1}, {TABLE_TEST, 1024},
+ {BLOCK_TEST, 16}, {BLOCK_TEST, 1}, {BLOCK_TEST, 1024},
+};
+static const int kNumTestArgs = sizeof(kTestArgList) / sizeof(kTestArgList[0]);
+
+class Harness : public ::testing::Test {
+ public:
+ Harness() : constructor_(NULL) {}
+
+ void Init(const TestArgs& args) {
+ delete constructor_;
+ constructor_ = NULL;
+ options_ = Options();
+
+ options_.block_restart_interval = args.restart_interval;
+ // Use shorter block size for tests to exercise block boundary
+ // conditions more.
+ options_.block_size = 256;
+ switch (args.type) {
+ case TABLE_TEST:
+ constructor_ = new TableConstructor();
+ break;
+ case BLOCK_TEST:
+ constructor_ = new BlockConstructor();
+ break;
+ }
+ }
+
+ ~Harness() { delete constructor_; }
+
+ void Add(const string& key, const string& value) {
+ constructor_->Add(key, value);
+ }
+
+ void Test(random::SimplePhilox* rnd) {
+ std::vector<string> keys;
+ KVMap data;
+ constructor_->Finish(options_, &keys, &data);
+
+ TestForwardScan(keys, data);
+ TestRandomAccess(rnd, keys, data);
+ }
+
+ void TestForwardScan(const std::vector<string>& keys, const KVMap& data) {
+ Iterator* iter = constructor_->NewIterator();
+ ASSERT_TRUE(!iter->Valid());
+ iter->SeekToFirst();
+ for (KVMap::const_iterator model_iter = data.begin();
+ model_iter != data.end(); ++model_iter) {
+ ASSERT_EQ(ToString(data, model_iter), ToString(iter));
+ iter->Next();
+ }
+ ASSERT_TRUE(!iter->Valid());
+ delete iter;
+ }
+
+ void TestRandomAccess(random::SimplePhilox* rnd,
+ const std::vector<string>& keys, const KVMap& data) {
+ static const bool kVerbose = false;
+ Iterator* iter = constructor_->NewIterator();
+ ASSERT_TRUE(!iter->Valid());
+ KVMap::const_iterator model_iter = data.begin();
+ if (kVerbose) fprintf(stderr, "---\n");
+ for (int i = 0; i < 200; i++) {
+ const int toss = rnd->Uniform(3);
+ switch (toss) {
+ case 0: {
+ if (iter->Valid()) {
+ if (kVerbose) fprintf(stderr, "Next\n");
+ iter->Next();
+ ++model_iter;
+ ASSERT_EQ(ToString(data, model_iter), ToString(iter));
+ }
+ break;
+ }
+
+ case 1: {
+ if (kVerbose) fprintf(stderr, "SeekToFirst\n");
+ iter->SeekToFirst();
+ model_iter = data.begin();
+ ASSERT_EQ(ToString(data, model_iter), ToString(iter));
+ break;
+ }
+
+ case 2: {
+ string key = PickRandomKey(rnd, keys);
+ model_iter = data.lower_bound(key);
+ if (kVerbose)
+ fprintf(stderr, "Seek '%s'\n", str_util::CEscape(key).c_str());
+ iter->Seek(StringPiece(key));
+ ASSERT_EQ(ToString(data, model_iter), ToString(iter));
+ break;
+ }
+ }
+ }
+ delete iter;
+ }
+
+ string ToString(const KVMap& data, const KVMap::const_iterator& it) {
+ if (it == data.end()) {
+ return "END";
+ } else {
+ return "'" + it->first + "->" + it->second + "'";
+ }
+ }
+
+ string ToString(const KVMap& data, const KVMap::const_reverse_iterator& it) {
+ if (it == data.rend()) {
+ return "END";
+ } else {
+ return "'" + it->first + "->" + it->second + "'";
+ }
+ }
+
+ string ToString(const Iterator* it) {
+ if (!it->Valid()) {
+ return "END";
+ } else {
+ return "'" + it->key().ToString() + "->" + it->value().ToString() + "'";
+ }
+ }
+
+ string PickRandomKey(random::SimplePhilox* rnd,
+ const std::vector<string>& keys) {
+ if (keys.empty()) {
+ return "foo";
+ } else {
+ const int index = rnd->Uniform(keys.size());
+ string result = keys[index];
+ switch (rnd->Uniform(3)) {
+ case 0:
+ // Return an existing key
+ break;
+ case 1: {
+ // Attempt to return something smaller than an existing key
+ if (result.size() > 0 && result[result.size() - 1] > '\0') {
+ result[result.size() - 1]--;
+ }
+ break;
+ }
+ case 2: {
+ // Return something larger than an existing key
+ Increment(&result);
+ break;
+ }
+ }
+ return result;
+ }
+ }
+
+ private:
+ Options options_;
+ Constructor* constructor_;
+};
+
+// Test empty table/block.
+TEST_F(Harness, Empty) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 1, 17);
+ random::SimplePhilox rnd(&philox);
+ Test(&rnd);
+ }
+}
+
+// Special test for a block with no restart entries. The C++ leveldb
+// code never generates such blocks, but the Java version of leveldb
+// seems to.
+TEST_F(Harness, ZeroRestartPointsInBlock) {
+ char data[sizeof(uint32)];
+ memset(data, 0, sizeof(data));
+ BlockContents contents;
+ contents.data = StringPiece(data, sizeof(data));
+ contents.cachable = false;
+ contents.heap_allocated = false;
+ Block block(contents);
+ Iterator* iter = block.NewIterator();
+ iter->SeekToFirst();
+ ASSERT_TRUE(!iter->Valid());
+ iter->Seek("foo");
+ ASSERT_TRUE(!iter->Valid());
+ delete iter;
+}
+
+// Test the empty key
+TEST_F(Harness, SimpleEmptyKey) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 1, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("", "v");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, SimpleSingle) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 2, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("abc", "v");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, SimpleMulti) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 3, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("abc", "v");
+ Add("abcd", "v");
+ Add("ac", "v2");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, SimpleMultiBigValues) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 3, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("ainitial", "tiny");
+ Add("anext", string(10000000, 'a'));
+ Add("anext2", string(10000000, 'b'));
+ Add("azz", "tiny");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, SimpleSpecialKey) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 4, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("\xff\xff", "v3");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, Randomized) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 5, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int num_entries = 0; num_entries < 2000;
+ num_entries += (num_entries < 50 ? 1 : 200)) {
+ if ((num_entries % 10) == 0) {
+ fprintf(stderr, "case %d of %d: num_entries = %d\n", (i + 1),
+ int(kNumTestArgs), num_entries);
+ }
+ for (int e = 0; e < num_entries; e++) {
+ string v;
+ Add(test::RandomKey(&rnd, rnd.Skewed(4)),
+ test::RandomString(&rnd, rnd.Skewed(5), &v).ToString());
+ }
+ Test(&rnd);
+ }
+ }
+}
+
+static bool Between(uint64 val, uint64 low, uint64 high) {
+ bool result = (val >= low) && (val <= high);
+ if (!result) {
+ fprintf(stderr, "Value %llu is not in range [%llu, %llu]\n",
+ (unsigned long long)(val), (unsigned long long)(low),
+ (unsigned long long)(high));
+ }
+ return result;
+}
+
+class TableTest {};
+
+TEST(TableTest, ApproximateOffsetOfPlain) {
+ TableConstructor c;
+ c.Add("k01", "hello");
+ c.Add("k02", "hello2");
+ c.Add("k03", string(10000, 'x'));
+ c.Add("k04", string(200000, 'x'));
+ c.Add("k05", string(300000, 'x'));
+ c.Add("k06", "hello3");
+ c.Add("k07", string(100000, 'x'));
+ std::vector<string> keys;
+ KVMap kvmap;
+ Options options;
+ options.block_size = 1024;
+ options.compression = kNoCompression;
+ c.Finish(options, &keys, &kvmap);
+
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01a"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 10, 500));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 10000, 11000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04a"), 210000, 211000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k05"), 210000, 211000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k06"), 510000, 511000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k07"), 510000, 511000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 610000, 612000));
+}
+
+static bool SnappyCompressionSupported() {
+ string out;
+ StringPiece in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
+ return port::Snappy_Compress(in.data(), in.size(), &out);
+}
+
+TEST(TableTest, ApproximateOffsetOfCompressed) {
+ if (!SnappyCompressionSupported()) {
+ fprintf(stderr, "skipping compression tests\n");
+ return;
+ }
+
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ TableConstructor c;
+ string tmp;
+ c.Add("k01", "hello");
+ c.Add("k02", test::CompressibleString(&rnd, 0.25, 10000, &tmp));
+ c.Add("k03", "hello3");
+ c.Add("k04", test::CompressibleString(&rnd, 0.25, 10000, &tmp));
+ std::vector<string> keys;
+ KVMap kvmap;
+ Options options;
+ options.block_size = 1024;
+ options.compression = kSnappyCompression;
+ c.Finish(options, &keys, &kvmap);
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 10, 100));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 2000, 3000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 2000, 3000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 4000, 6000));
+}
+
+TEST(TableTest, SeekToFirstKeyDoesNotReadTooMuch) {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ string tmp;
+ TableConstructor c;
+ c.Add("k01", "firstvalue");
+ c.Add("k03", test::CompressibleString(&rnd, 0.25, 1000000, &tmp));
+ c.Add("k04", "abc");
+ std::vector<string> keys;
+ KVMap kvmap;
+ Options options;
+ options.block_size = 1024;
+ options.compression = kNoCompression;
+ c.Finish(options, &keys, &kvmap);
+
+ Iterator* iter = c.NewIterator();
+ iter->Seek("k01");
+ delete iter;
+ // Make sure we don't read the big second block when just trying to
+ // retrieve the data in the first key
+ EXPECT_LT(c.BytesRead(), 200);
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/two_level_iterator.cc b/tensorflow/core/lib/io/two_level_iterator.cc
new file mode 100644
index 0000000000..409baade6d
--- /dev/null
+++ b/tensorflow/core/lib/io/two_level_iterator.cc
@@ -0,0 +1,148 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/two_level_iterator.h"
+
+#include "tensorflow/core/lib/io/table.h"
+#include "tensorflow/core/lib/io/block.h"
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+namespace table {
+
+namespace {
+
+typedef Iterator* (*BlockFunction)(void*, const StringPiece&);
+
+class TwoLevelIterator : public Iterator {
+ public:
+ TwoLevelIterator(Iterator* index_iter, BlockFunction block_function,
+ void* arg);
+
+ virtual ~TwoLevelIterator();
+
+ virtual void Seek(const StringPiece& target);
+ virtual void SeekToFirst();
+ virtual void Next();
+
+ virtual bool Valid() const {
+ return (data_iter_ == nullptr) ? false : data_iter_->Valid();
+ }
+ virtual StringPiece key() const {
+ assert(Valid());
+ return data_iter_->key();
+ }
+ virtual StringPiece value() const {
+ assert(Valid());
+ return data_iter_->value();
+ }
+ virtual Status status() const {
+ // It'd be nice if status() returned a const Status& instead of a
+ // Status
+ if (!index_iter_->status().ok()) {
+ return index_iter_->status();
+ } else if (data_iter_ != NULL && !data_iter_->status().ok()) {
+ return data_iter_->status();
+ } else {
+ return status_;
+ }
+ }
+
+ private:
+ void SaveError(const Status& s) {
+ if (status_.ok() && !s.ok()) status_ = s;
+ }
+ void SkipEmptyDataBlocksForward();
+ void SetDataIterator(Iterator* data_iter);
+ void InitDataBlock();
+
+ BlockFunction block_function_;
+ void* arg_;
+ Status status_;
+ Iterator* index_iter_;
+ Iterator* data_iter_; // May be NULL
+ // If data_iter_ is non-NULL, then "data_block_handle_" holds the
+ // "index_value" passed to block_function_ to create the data_iter_.
+ string data_block_handle_;
+};
+
+TwoLevelIterator::TwoLevelIterator(Iterator* index_iter,
+ BlockFunction block_function, void* arg)
+ : block_function_(block_function),
+ arg_(arg),
+ index_iter_(index_iter),
+ data_iter_(NULL) {}
+
+TwoLevelIterator::~TwoLevelIterator() {
+ delete index_iter_;
+ delete data_iter_;
+}
+
+void TwoLevelIterator::Seek(const StringPiece& target) {
+ index_iter_->Seek(target);
+ InitDataBlock();
+ if (data_iter_ != NULL) data_iter_->Seek(target);
+ SkipEmptyDataBlocksForward();
+}
+
+void TwoLevelIterator::SeekToFirst() {
+ index_iter_->SeekToFirst();
+ InitDataBlock();
+ if (data_iter_ != NULL) data_iter_->SeekToFirst();
+ SkipEmptyDataBlocksForward();
+}
+
+void TwoLevelIterator::Next() {
+ assert(Valid());
+ data_iter_->Next();
+ SkipEmptyDataBlocksForward();
+}
+
+void TwoLevelIterator::SkipEmptyDataBlocksForward() {
+ while (data_iter_ == NULL || !data_iter_->Valid()) {
+ // Move to next block
+ if (!index_iter_->Valid()) {
+ SetDataIterator(NULL);
+ return;
+ }
+ index_iter_->Next();
+ InitDataBlock();
+ if (data_iter_ != NULL) data_iter_->SeekToFirst();
+ }
+}
+
+void TwoLevelIterator::SetDataIterator(Iterator* data_iter) {
+ if (data_iter_ != NULL) {
+ SaveError(data_iter_->status());
+ delete data_iter_;
+ }
+ data_iter_ = data_iter;
+}
+
+void TwoLevelIterator::InitDataBlock() {
+ if (!index_iter_->Valid()) {
+ SetDataIterator(NULL);
+ } else {
+ StringPiece handle = index_iter_->value();
+ if (data_iter_ != NULL && handle.compare(data_block_handle_) == 0) {
+ // data_iter_ is already constructed with this iterator, so
+ // no need to change anything
+ } else {
+ Iterator* iter = (*block_function_)(arg_, handle);
+ data_block_handle_.assign(handle.data(), handle.size());
+ SetDataIterator(iter);
+ }
+ }
+}
+
+} // namespace
+
+Iterator* NewTwoLevelIterator(Iterator* index_iter,
+ BlockFunction block_function, void* arg) {
+ return new TwoLevelIterator(index_iter, block_function, arg);
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/two_level_iterator.h b/tensorflow/core/lib/io/two_level_iterator.h
new file mode 100644
index 0000000000..1cc5d2f921
--- /dev/null
+++ b/tensorflow/core/lib/io/two_level_iterator.h
@@ -0,0 +1,30 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_
+#define TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_
+
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+namespace table {
+
+// Return a new two level iterator. A two-level iterator contains an
+// index iterator whose values point to a sequence of blocks where
+// each block is itself a sequence of key,value pairs. The returned
+// two-level iterator yields the concatenation of all key/value pairs
+// in the sequence of blocks. Takes ownership of "index_iter" and
+// will delete it when no longer needed.
+//
+// Uses a supplied function to convert an index_iter value into
+// an iterator over the contents of the corresponding block.
+extern Iterator* NewTwoLevelIterator(
+ Iterator* index_iter,
+ Iterator* (*block_function)(void* arg, const StringPiece& index_value),
+ void* arg);
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_
diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.cc b/tensorflow/core/lib/jpeg/jpeg_handle.cc
new file mode 100644
index 0000000000..4521be0afb
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_handle.cc
@@ -0,0 +1,162 @@
+// This file implements a memory destination for libjpeg
+// The design is very similar to jdatadst.c in libjpeg
+// These functions are not meant to be used directly, see jpeg_mem.h instead.
+// We are filling out stubs required by jpeglib, those stubs are private to
+// the implementation, we are just making available JPGMemSrc, JPGMemDest
+
+#include "tensorflow/core/lib/jpeg/jpeg_handle.h"
+
+#include <setjmp.h>
+#include <stddef.h>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace jpeg {
+
+void CatchError(j_common_ptr cinfo) {
+ (*cinfo->err->output_message)(cinfo);
+ jmp_buf *jpeg_jmpbuf = reinterpret_cast<jmp_buf *>(cinfo->client_data);
+ jpeg_destroy(cinfo);
+ longjmp(*jpeg_jmpbuf, 1);
+}
+
+// *****************************************************************************
+// *****************************************************************************
+// *****************************************************************************
+// Destination functions
+
+// -----------------------------------------------------------------------------
+void MemInitDestination(j_compress_ptr cinfo) {
+ MemDestMgr *dest = reinterpret_cast<MemDestMgr *>(cinfo->dest);
+ VLOG(1) << "Initializing buffer=" << dest->bufsize << " bytes";
+ dest->pub.next_output_byte = dest->buffer;
+ dest->pub.free_in_buffer = dest->bufsize;
+ dest->datacount = 0;
+ if (dest->dest) {
+ dest->dest->clear();
+ }
+}
+
+// -----------------------------------------------------------------------------
+boolean MemEmptyOutputBuffer(j_compress_ptr cinfo) {
+ MemDestMgr *dest = reinterpret_cast<MemDestMgr *>(cinfo->dest);
+ VLOG(1) << "Writing " << dest->bufsize << " bytes";
+ if (dest->dest) {
+ dest->dest->append(reinterpret_cast<char *>(dest->buffer), dest->bufsize);
+ }
+ dest->pub.next_output_byte = dest->buffer;
+ dest->pub.free_in_buffer = dest->bufsize;
+ return TRUE;
+}
+
+// -----------------------------------------------------------------------------
+void MemTermDestination(j_compress_ptr cinfo) {
+ MemDestMgr *dest = reinterpret_cast<MemDestMgr *>(cinfo->dest);
+ VLOG(1) << "Writing " << dest->bufsize - dest->pub.free_in_buffer << " bytes";
+ if (dest->dest) {
+ dest->dest->append(reinterpret_cast<char *>(dest->buffer),
+ dest->bufsize - dest->pub.free_in_buffer);
+ VLOG(1) << "Total size= " << dest->dest->size();
+ }
+ dest->datacount = dest->bufsize - dest->pub.free_in_buffer;
+}
+
+// -----------------------------------------------------------------------------
+void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize) {
+ SetDest(cinfo, buffer, bufsize, NULL);
+}
+
+// -----------------------------------------------------------------------------
+void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
+ string *destination) {
+ MemDestMgr *dest;
+ if (cinfo->dest == NULL) {
+ cinfo->dest = reinterpret_cast<struct jpeg_destination_mgr *>(
+ (*cinfo->mem->alloc_small)(reinterpret_cast<j_common_ptr>(cinfo),
+ JPOOL_PERMANENT, sizeof(MemDestMgr)));
+ }
+
+ dest = reinterpret_cast<MemDestMgr *>(cinfo->dest);
+ dest->bufsize = bufsize;
+ dest->buffer = static_cast<JOCTET *>(buffer);
+ dest->dest = destination;
+ dest->pub.init_destination = MemInitDestination;
+ dest->pub.empty_output_buffer = MemEmptyOutputBuffer;
+ dest->pub.term_destination = MemTermDestination;
+}
+
+// *****************************************************************************
+// *****************************************************************************
+// *****************************************************************************
+// Source functions
+
+// -----------------------------------------------------------------------------
+void MemInitSource(j_decompress_ptr cinfo) {
+ MemSourceMgr *src = reinterpret_cast<MemSourceMgr *>(cinfo->src);
+ src->pub.next_input_byte = src->data;
+ src->pub.bytes_in_buffer = src->datasize;
+}
+
+// -----------------------------------------------------------------------------
+// We emulate the same error-handling as fill_input_buffer() from jdatasrc.c,
+// for coherency's sake.
+boolean MemFillInputBuffer(j_decompress_ptr cinfo) {
+ static const JOCTET kEOIBuffer[2] = {0xff, JPEG_EOI};
+ MemSourceMgr *src = reinterpret_cast<MemSourceMgr *>(cinfo->src);
+ if (src->pub.bytes_in_buffer == 0 && src->pub.next_input_byte == src->data) {
+ // empty file -> treated as an error.
+ ERREXIT(cinfo, JERR_INPUT_EMPTY);
+ return FALSE;
+ } else if (src->pub.bytes_in_buffer) {
+ // if there's still some data left, it's probably corrupted
+ return src->try_recover_truncated_jpeg ? TRUE : FALSE;
+ } else if (src->pub.next_input_byte != kEOIBuffer &&
+ src->try_recover_truncated_jpeg) {
+ // In an attempt to recover truncated files, we insert a fake EOI
+ WARNMS(cinfo, JWRN_JPEG_EOF);
+ src->pub.next_input_byte = kEOIBuffer;
+ src->pub.bytes_in_buffer = 2;
+ return TRUE;
+ } else {
+ // We already inserted a fake EOI and it wasn't enough, so this time
+ // it's really an error.
+ ERREXIT(cinfo, JERR_FILE_READ);
+ return FALSE;
+ }
+}
+
+// -----------------------------------------------------------------------------
+void MemTermSource(j_decompress_ptr cinfo) {}
+
+// -----------------------------------------------------------------------------
+void MemSkipInputData(j_decompress_ptr cinfo, long jump) {
+ MemSourceMgr *src = reinterpret_cast<MemSourceMgr *>(cinfo->src);
+ src->pub.bytes_in_buffer -= jump;
+ src->pub.next_input_byte += jump;
+}
+
+// -----------------------------------------------------------------------------
+void SetSrc(j_decompress_ptr cinfo, const void *data,
+ unsigned long int datasize, bool try_recover_truncated_jpeg) {
+ MemSourceMgr *src;
+
+ cinfo->src = reinterpret_cast<struct jpeg_source_mgr *>(
+ (*cinfo->mem->alloc_small)(reinterpret_cast<j_common_ptr>(cinfo),
+ JPOOL_PERMANENT, sizeof(MemSourceMgr)));
+
+ src = reinterpret_cast<MemSourceMgr *>(cinfo->src);
+ src->pub.init_source = MemInitSource;
+ src->pub.fill_input_buffer = MemFillInputBuffer;
+ src->pub.skip_input_data = MemSkipInputData;
+ src->pub.resync_to_restart = jpeg_resync_to_restart;
+ src->pub.term_source = MemTermSource;
+ src->data = reinterpret_cast<const unsigned char *>(data);
+ src->datasize = datasize;
+ src->pub.bytes_in_buffer = 0;
+ src->pub.next_input_byte = NULL;
+ src->try_recover_truncated_jpeg = try_recover_truncated_jpeg;
+}
+
+} // namespace jpeg
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.h b/tensorflow/core/lib/jpeg/jpeg_handle.h
new file mode 100644
index 0000000000..58f7f6f666
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_handle.h
@@ -0,0 +1,51 @@
+// This file declares the functions and structures for memory I/O with libjpeg
+// These functions are not meant to be used directly, see jpeg_mem.h isntead.
+
+#ifndef TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
+#define TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
+
+extern "C" {
+#include "external/jpeg_archive/jpeg-9a/jinclude.h"
+#include "external/jpeg_archive/jpeg-9a/jpeglib.h"
+#include "external/jpeg_archive/jpeg-9a/jerror.h"
+#include "external/jpeg_archive/jpeg-9a/transupp.h" // for rotations
+}
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace jpeg {
+
+// Handler for fatal JPEG library errors: clean up & return
+void CatchError(j_common_ptr cinfo);
+
+typedef struct {
+ struct jpeg_destination_mgr pub;
+ JOCTET *buffer;
+ int bufsize;
+ int datacount;
+ string *dest;
+} MemDestMgr;
+
+typedef struct {
+ struct jpeg_source_mgr pub;
+ const unsigned char *data;
+ unsigned long int datasize;
+ bool try_recover_truncated_jpeg;
+} MemSourceMgr;
+
+void SetSrc(j_decompress_ptr cinfo, const void *data,
+ unsigned long int datasize, bool try_recover_truncated_jpeg);
+
+// JPEG destination: we will store all the data in a buffer "buffer" of total
+// size "bufsize", if the buffer overflows, we will be in trouble.
+void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize);
+// Same as above, except that buffer is only used as a temporary structure and
+// is emptied into "destination" as soon as it fills up.
+void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
+ string *destination);
+
+} // namespace jpeg
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc
new file mode 100644
index 0000000000..556f13e388
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc
@@ -0,0 +1,557 @@
+// This file defines functions to compress and uncompress JPEG data
+// to and from memory, as well as some direct manipulations of JPEG string
+
+#include "tensorflow/core/lib/jpeg/jpeg_mem.h"
+
+#include <setjmp.h>
+#include <string.h>
+#include <algorithm>
+#include <memory>
+#include <string>
+
+#include "tensorflow/core/lib/jpeg/jpeg_handle.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace jpeg {
+
+// -----------------------------------------------------------------------------
+// Decompression
+
+namespace {
+
+enum JPEGErrors {
+ JPEGERRORS_OK,
+ JPEGERRORS_UNEXPECTED_END_OF_DATA,
+ JPEGERRORS_BAD_PARAM
+};
+
+// Prevent bad compiler behaviour in ASAN mode by wrapping most of the
+// arguments in a struct struct.
+class FewerArgsForCompiler {
+ public:
+ FewerArgsForCompiler(int datasize, const UncompressFlags& flags, int* nwarn,
+ std::function<uint8*(int, int, int)> allocate_output)
+ : datasize_(datasize),
+ flags_(flags),
+ pnwarn_(nwarn),
+ allocate_output_(allocate_output),
+ fraction_read_(0.),
+ height_(0),
+ stride_(0) {
+ if (pnwarn_ != nullptr) *pnwarn_ = 0;
+ }
+
+ const int datasize_;
+ const UncompressFlags flags_;
+ int* const pnwarn_;
+ std::function<uint8*(int, int, int)> allocate_output_;
+ float fraction_read_; // fraction of scanline lines successfully read
+ int height_;
+ int stride_;
+};
+
+uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) {
+ // unpack the argball
+ const int datasize = argball->datasize_;
+ const auto& flags = argball->flags_;
+ const int ratio = flags.ratio;
+ int components = flags.components;
+ int stride = flags.stride; // may be 0
+ int* const nwarn = argball->pnwarn_; // may be NULL
+
+ // can't decode if the ratio is not recognized by libjpeg
+ if ((ratio != 1) && (ratio != 2) && (ratio != 4) && (ratio != 8)) {
+ return nullptr;
+ }
+
+ // if empty image, return
+ if (datasize == 0 || srcdata == NULL) return nullptr;
+
+ // Declare temporary buffer pointer here so that we can free on error paths
+ JSAMPLE* tempdata = nullptr;
+
+ // Initialize libjpeg structures to have a memory source
+ // Modify the usual jpeg error manager to catch fatal errors.
+ JPEGErrors error = JPEGERRORS_OK;
+ struct jpeg_decompress_struct cinfo;
+ struct jpeg_error_mgr jerr;
+ cinfo.err = jpeg_std_error(&jerr);
+ jmp_buf jpeg_jmpbuf;
+ cinfo.client_data = &jpeg_jmpbuf;
+ jerr.error_exit = CatchError;
+ if (setjmp(jpeg_jmpbuf)) {
+ return nullptr;
+ }
+
+ jpeg_create_decompress(&cinfo);
+ SetSrc(&cinfo, srcdata, datasize, flags.try_recover_truncated_jpeg);
+ jpeg_read_header(&cinfo, TRUE);
+
+ // Set components automatically if desired
+ if (components == 0) components = cinfo.num_components;
+
+ // set grayscale and ratio parameters
+ switch (components) {
+ case 1:
+ cinfo.out_color_space = JCS_GRAYSCALE;
+ break;
+ case 3:
+ case 4:
+ if (cinfo.jpeg_color_space == JCS_CMYK ||
+ cinfo.jpeg_color_space == JCS_YCCK) {
+ // always use cmyk for output in a 4 channel jpeg. libjpeg has a builtin
+ // decoder.
+ cinfo.out_color_space = JCS_CMYK;
+ } else {
+ cinfo.out_color_space = JCS_RGB;
+ }
+ break;
+ default:
+ LOG(ERROR) << " Invalid components value " << components << std::endl;
+ jpeg_destroy_decompress(&cinfo);
+ return nullptr;
+ }
+ cinfo.do_fancy_upsampling = boolean(flags.fancy_upscaling);
+ cinfo.scale_num = 1;
+ cinfo.scale_denom = ratio;
+ // Activating this has a quality/speed trade-off implication:
+ // cinfo.dct_method = JDCT_IFAST;
+
+ jpeg_start_decompress(&cinfo);
+
+ // check for compatible stride
+ const int min_stride = cinfo.output_width * components * sizeof(JSAMPLE);
+ if (stride == 0) {
+ stride = min_stride;
+ } else if (stride < min_stride) {
+ LOG(ERROR) << "Incompatible stride: " << stride << " < " << min_stride;
+ jpeg_destroy_decompress(&cinfo);
+ return nullptr;
+ }
+
+ // Remember stride and height for use in Uncompress
+ argball->height_ = cinfo.output_height;
+ argball->stride_ = stride;
+
+ uint8* const dstdata = argball->allocate_output_(
+ cinfo.output_width, cinfo.output_height, components);
+ if (dstdata == nullptr) {
+ jpeg_destroy_decompress(&cinfo);
+ return nullptr;
+ }
+ JSAMPLE* output_line = static_cast<JSAMPLE*>(dstdata);
+
+ // Temporary buffer used for CMYK -> RGB conversion.
+ const bool use_cmyk = (cinfo.out_color_space == JCS_CMYK);
+ tempdata = use_cmyk ? new JSAMPLE[cinfo.output_width * 4] : NULL;
+
+ // If there is an error reading a line, this aborts the reading.
+ // Save the fraction of the image that has been read.
+ argball->fraction_read_ = 1.0;
+ while (cinfo.output_scanline < cinfo.output_height) {
+ int num_lines_read = 0;
+ if (cinfo.out_color_space == JCS_CMYK) {
+ num_lines_read = jpeg_read_scanlines(&cinfo, &tempdata, 1);
+ // Convert CMYK to RGB
+ for (size_t i = 0; i < cinfo.output_width; ++i) {
+ int c = tempdata[4 * i + 0];
+ int m = tempdata[4 * i + 1];
+ int y = tempdata[4 * i + 2];
+ int k = tempdata[4 * i + 3];
+ int r, g, b;
+ if (cinfo.saw_Adobe_marker) {
+ r = (k * c) / 255;
+ g = (k * m) / 255;
+ b = (k * y) / 255;
+ } else {
+ r = (255 - k) * (255 - c) / 255;
+ g = (255 - k) * (255 - m) / 255;
+ b = (255 - k) * (255 - y) / 255;
+ }
+ output_line[3 * i + 0] = r;
+ output_line[3 * i + 1] = g;
+ output_line[3 * i + 2] = b;
+ }
+ } else {
+ num_lines_read = jpeg_read_scanlines(&cinfo, &output_line, 1);
+ }
+ // Handle error cases
+ if (num_lines_read == 0) {
+ LOG(ERROR) << "Premature end of JPEG data. Stopped at line "
+ << cinfo.output_scanline << "/" << cinfo.output_height;
+ if (!flags.try_recover_truncated_jpeg) {
+ argball->fraction_read_ =
+ static_cast<float>(cinfo.output_scanline) / cinfo.output_height;
+ error = JPEGERRORS_UNEXPECTED_END_OF_DATA;
+ } else {
+ for (size_t line = cinfo.output_scanline; line < cinfo.output_height;
+ ++line) {
+ if (line == 0) {
+ // If even the first line is missing, fill with black color
+ memset(output_line, 0, min_stride);
+ } else {
+ // else, just replicate the line above.
+ memcpy(output_line, output_line - stride, min_stride);
+ }
+ output_line += stride;
+ }
+ argball->fraction_read_ = 1.0; // consider all lines as read
+ // prevent error-on-exit in libjpeg:
+ cinfo.output_scanline = cinfo.output_height;
+ }
+ break;
+ }
+ DCHECK_EQ(num_lines_read, 1);
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_line, min_stride);
+ output_line += stride;
+ }
+ delete[] tempdata;
+
+ // Convert the RGB data to RGBA, with alpha set to 0xFF to indicate
+ // opacity.
+ // RGBRGBRGB... --> RGBARGBARGBA...
+ if (components == 4) {
+ // Start on the last line.
+ JSAMPLE* scanlineptr =
+ static_cast<JSAMPLE*>(dstdata + (cinfo.output_height - 1) * stride);
+ const JSAMPLE kOpaque = -1; // All ones appropriate for JSAMPLE.
+ const int right_rgb = (cinfo.output_width - 1) * 3;
+ const int right_rgba = (cinfo.output_width - 1) * 4;
+
+ for (int y = cinfo.output_height; y-- > 0;) {
+ // We do all the transformations in place, going backwards for each row.
+ const JSAMPLE* rgb_pixel = scanlineptr + right_rgb;
+ JSAMPLE* rgba_pixel = scanlineptr + right_rgba;
+ scanlineptr -= stride;
+ for (int x = cinfo.output_width; x-- > 0;
+ rgba_pixel -= 4, rgb_pixel -= 3) {
+ // We copy the 3 bytes at rgb_pixel into the 4 bytes at rgba_pixel
+ // The "a" channel is set to be opaque.
+ rgba_pixel[3] = kOpaque;
+ rgba_pixel[2] = rgb_pixel[2];
+ rgba_pixel[1] = rgb_pixel[1];
+ rgba_pixel[0] = rgb_pixel[0];
+ }
+ }
+ }
+
+ switch (components) {
+ case 1:
+ if (cinfo.output_components != 1) {
+ error = JPEGERRORS_BAD_PARAM;
+ }
+ break;
+ case 3:
+ case 4:
+ if (cinfo.out_color_space == JCS_CMYK) {
+ if (cinfo.output_components != 4) {
+ error = JPEGERRORS_BAD_PARAM;
+ }
+ } else {
+ if (cinfo.output_components != 3) {
+ error = JPEGERRORS_BAD_PARAM;
+ }
+ }
+ break;
+ default:
+ // will never happen, should be catched by the previous switch
+ LOG(ERROR) << "Invalid components value " << components << std::endl;
+ jpeg_destroy_decompress(&cinfo);
+ return nullptr;
+ }
+
+ // save number of warnings if requested
+ if (nwarn != nullptr) {
+ *nwarn = cinfo.err->num_warnings;
+ }
+
+ // Handle errors in JPEG
+ switch (error) {
+ case JPEGERRORS_OK:
+ jpeg_finish_decompress(&cinfo);
+ break;
+ case JPEGERRORS_UNEXPECTED_END_OF_DATA:
+ case JPEGERRORS_BAD_PARAM:
+ jpeg_abort(reinterpret_cast<j_common_ptr>(&cinfo));
+ break;
+ default:
+ LOG(ERROR) << "Unhandled case " << error;
+ break;
+ }
+ jpeg_destroy_decompress(&cinfo);
+
+ return dstdata;
+}
+
+} // anonymous namespace
+
+// -----------------------------------------------------------------------------
+// We do the apparently silly thing of packing 5 of the arguments
+// into a structure that is then passed to another routine
+// that does all the work. The reason is that we want to catch
+// fatal JPEG library errors with setjmp/longjmp, and g++ and
+// associated libraries aren't good enough to guarantee that 7
+// parameters won't get clobbered by the longjmp. So we help
+// it out a little.
+uint8* Uncompress(const void* srcdata, int datasize,
+ const UncompressFlags& flags, int* nwarn,
+ std::function<uint8*(int, int, int)> allocate_output) {
+ FewerArgsForCompiler argball(datasize, flags, nwarn, allocate_output);
+ uint8* const dstdata = UncompressLow(srcdata, &argball);
+ const float fraction_read = argball.fraction_read_;
+ if (dstdata == NULL ||
+ fraction_read < std::min(1.0f, flags.min_acceptable_fraction)) {
+ // Major failure, none or too-partial read returned; get out
+ return NULL;
+ }
+
+ // If there was an error in reading the jpeg data,
+ // set the unread pixels to black
+ if (fraction_read < 1.0) {
+ const int first_bad_line =
+ static_cast<int>(fraction_read * argball.height_);
+ uint8* start = dstdata + first_bad_line * argball.stride_;
+ const int nbytes = (argball.height_ - first_bad_line) * argball.stride_;
+ memset(static_cast<void*>(start), 0, nbytes);
+ }
+
+ return dstdata;
+}
+
+uint8* Uncompress(const void* srcdata, int datasize,
+ const UncompressFlags& flags, int* pwidth, int* pheight,
+ int* pcomponents, int* nwarn) {
+ uint8* buffer = NULL;
+ uint8* result =
+ Uncompress(srcdata, datasize, flags, nwarn,
+ [=, &buffer](int width, int height, int components) {
+ if (pwidth != nullptr) *pwidth = width;
+ if (pheight != nullptr) *pheight = height;
+ if (pcomponents != nullptr) *pcomponents = components;
+ buffer = new uint8[height * width * components];
+ return buffer;
+ });
+ if (!result) delete[] buffer;
+ return result;
+}
+
+// ----------------------------------------------------------------------------
+// Computes image information from jpeg header.
+// Returns true on success; false on failure.
+bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height,
+ int* components) {
+ // Init in case of failure
+ if (width) *width = 0;
+ if (height) *height = 0;
+ if (components) *components = 0;
+
+ // If empty image, return
+ if (datasize == 0 || srcdata == NULL) return false;
+
+ // Initialize libjpeg structures to have a memory source
+ // Modify the usual jpeg error manager to catch fatal errors.
+ struct jpeg_decompress_struct cinfo;
+ struct jpeg_error_mgr jerr;
+ jmp_buf jpeg_jmpbuf;
+ cinfo.err = jpeg_std_error(&jerr);
+ cinfo.client_data = &jpeg_jmpbuf;
+ jerr.error_exit = CatchError;
+ if (setjmp(jpeg_jmpbuf)) {
+ return false;
+ }
+
+ // set up, read header, set image parameters, save size
+ jpeg_create_decompress(&cinfo);
+ SetSrc(&cinfo, srcdata, datasize, false);
+
+ jpeg_read_header(&cinfo, TRUE);
+ jpeg_start_decompress(&cinfo); // required to transfer image size to cinfo
+ if (width) *width = cinfo.output_width;
+ if (height) *height = cinfo.output_height;
+ if (components) *components = cinfo.output_components;
+
+ jpeg_destroy_decompress(&cinfo);
+
+ return true;
+}
+
+// -----------------------------------------------------------------------------
+// Compression
+
+namespace {
+bool CompressInternal(const uint8* srcdata, int width, int height,
+ const CompressFlags& flags, string* output) {
+ output->clear();
+ const int components = (static_cast<int>(flags.format) & 0xff);
+ int in_stride = flags.stride;
+ if (in_stride == 0) {
+ in_stride = width * (static_cast<int>(flags.format) & 0xff);
+ } else if (in_stride < width * components) {
+ LOG(ERROR) << "Incompatible input stride";
+ return false;
+ }
+
+ JOCTET* buffer = 0;
+
+ // NOTE: for broader use xmp_metadata should be made a unicode string
+ CHECK(srcdata != nullptr);
+ CHECK(output != nullptr);
+ // This struct contains the JPEG compression parameters and pointers to
+ // working space
+ struct jpeg_compress_struct cinfo;
+ // This struct represents a JPEG error handler.
+ struct jpeg_error_mgr jerr;
+ jmp_buf jpeg_jmpbuf; // recovery point in case of error
+
+ // Step 1: allocate and initialize JPEG compression object
+ // Use the usual jpeg error manager.
+ cinfo.err = jpeg_std_error(&jerr);
+ cinfo.client_data = &jpeg_jmpbuf;
+ jerr.error_exit = CatchError;
+ if (setjmp(jpeg_jmpbuf)) {
+ output->clear();
+ delete[] buffer;
+ return false;
+ }
+
+ jpeg_create_compress(&cinfo);
+
+ // Step 2: specify data destination
+ // We allocate a buffer of reasonable size. If we have a small image, just
+ // estimate the size of the output using the number of bytes of the input.
+ // If this is getting too big, we will append to the string by chunks of 1MB.
+ // This seems like a reasonable compromise between performance and memory.
+ int bufsize = std::min(width * height * components, 1 << 20);
+ buffer = new JOCTET[bufsize];
+ SetDest(&cinfo, buffer, bufsize, output);
+
+ // Step 3: set parameters for compression
+ cinfo.image_width = width;
+ cinfo.image_height = height;
+ switch (components) {
+ case 1:
+ cinfo.input_components = 1;
+ cinfo.in_color_space = JCS_GRAYSCALE;
+ break;
+ case 3:
+ case 4:
+ cinfo.input_components = 3;
+ cinfo.in_color_space = JCS_RGB;
+ break;
+ default:
+ LOG(ERROR) << " Invalid components value " << components << std::endl;
+ output->clear();
+ delete[] buffer;
+ return false;
+ }
+ jpeg_set_defaults(&cinfo);
+ if (flags.optimize_jpeg_size) cinfo.optimize_coding = TRUE;
+
+ cinfo.density_unit = flags.density_unit; // JFIF code for pixel size units:
+ // 1 = in, 2 = cm
+ cinfo.X_density = flags.x_density; // Horizontal pixel density
+ cinfo.Y_density = flags.y_density; // Vertical pixel density
+ jpeg_set_quality(&cinfo, flags.quality, TRUE);
+
+ if (flags.progressive) {
+ jpeg_simple_progression(&cinfo);
+ }
+
+ if (!flags.chroma_downsampling) {
+ // Turn off chroma subsampling (it is on by default). For more details on
+ // chroma subsampling, see http://en.wikipedia.org/wiki/Chroma_subsampling.
+ for (int i = 0; i < cinfo.num_components; ++i) {
+ cinfo.comp_info[i].h_samp_factor = 1;
+ cinfo.comp_info[i].v_samp_factor = 1;
+ }
+ }
+
+ jpeg_start_compress(&cinfo, TRUE);
+
+ // Embed XMP metadata if any
+ if (!flags.xmp_metadata.empty()) {
+ // XMP metadata is embedded in the APP1 tag of JPEG and requires this
+ // namespace header string (null-terminated)
+ const string name_space = "http://ns.adobe.com/xap/1.0/";
+ const int name_space_length = name_space.size();
+ const int metadata_length = flags.xmp_metadata.size();
+ const int packet_length = metadata_length + name_space_length + 1;
+ std::unique_ptr<JOCTET[]> joctet_packet(new JOCTET[packet_length]);
+
+ for (int i = 0; i < name_space_length; i++) {
+ // Conversion char --> JOCTET
+ joctet_packet[i] = name_space[i];
+ }
+ joctet_packet[name_space_length] = 0; // null-terminate namespace string
+
+ for (int i = 0; i < metadata_length; i++) {
+ // Conversion char --> JOCTET
+ joctet_packet[i + name_space_length + 1] = flags.xmp_metadata[i];
+ }
+ jpeg_write_marker(&cinfo, JPEG_APP0 + 1, joctet_packet.get(),
+ packet_length);
+ }
+
+ // JSAMPLEs per row in image_buffer
+ std::unique_ptr<JSAMPLE[]> row_temp(
+ new JSAMPLE[width * cinfo.input_components]);
+ while (cinfo.next_scanline < cinfo.image_height) {
+ JSAMPROW row_pointer[1]; // pointer to JSAMPLE row[s]
+ const uint8* r = &srcdata[cinfo.next_scanline * in_stride];
+ uint8* p = static_cast<uint8*>(row_temp.get());
+ switch (flags.format) {
+ case FORMAT_RGBA: {
+ for (int i = 0; i < width; ++i, p += 3, r += 4) {
+ p[0] = r[0];
+ p[1] = r[1];
+ p[2] = r[2];
+ }
+ row_pointer[0] = row_temp.get();
+ break;
+ }
+ case FORMAT_ABGR: {
+ for (int i = 0; i < width; ++i, p += 3, r += 4) {
+ p[0] = r[3];
+ p[1] = r[2];
+ p[2] = r[1];
+ }
+ row_pointer[0] = row_temp.get();
+ break;
+ }
+ default: {
+ row_pointer[0] = reinterpret_cast<JSAMPLE*>(const_cast<JSAMPLE*>(r));
+ }
+ }
+ CHECK_EQ(jpeg_write_scanlines(&cinfo, row_pointer, 1), 1);
+ }
+ jpeg_finish_compress(&cinfo);
+
+ // release JPEG compression object
+ jpeg_destroy_compress(&cinfo);
+ delete[] buffer;
+ return true;
+}
+
+} // anonymous namespace
+
+// -----------------------------------------------------------------------------
+
+bool Compress(const void* srcdata, int width, int height,
+ const CompressFlags& flags, string* output) {
+ return CompressInternal(static_cast<const uint8*>(srcdata), width, height,
+ flags, output);
+}
+
+string Compress(const void* srcdata, int width, int height,
+ const CompressFlags& flags) {
+ string temp;
+ CompressInternal(static_cast<const uint8*>(srcdata), width, height, flags,
+ &temp);
+ // If CompressInternal fails, temp will be empty.
+ return temp;
+}
+
+} // namespace jpeg
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.h b/tensorflow/core/lib/jpeg/jpeg_mem.h
new file mode 100644
index 0000000000..19ba7d4acf
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.h
@@ -0,0 +1,130 @@
+// This file defines functions to compress and uncompress JPEG files
+// to and from memory. It provides interfaces for raw images
+// (data array and size fields).
+// Direct manipulation of JPEG strings are supplied: Flip, Rotate, Crop..
+
+#ifndef TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
+#define TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace jpeg {
+
+// Flags for Uncompress
+struct UncompressFlags {
+ // ratio can be 1, 2, 4, or 8 and represent the denominator for the scaling
+ // factor (eg ratio = 4 means that the resulting image will be at 1/4 original
+ // size in both directions).
+ int ratio = 1;
+
+ // The number of bytes per pixel (1, 3 or 4), or 0 for autodetect.
+ int components = 0;
+
+ // If true, decoder will use a slower but nicer upscaling of the chroma
+ // planes (yuv420/422 only).
+ bool fancy_upscaling = true;
+
+ // If true, will attempt to fill in missing lines of truncated files
+ bool try_recover_truncated_jpeg = false;
+
+ // The minimum required fraction of lines read before the image is accepted.
+ float min_acceptable_fraction = 1.0;
+
+ // The distance in bytes from one scanline to the other. Should be at least
+ // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride
+ // used will be this minimal value.
+ int stride = 0;
+};
+
+// Uncompress some raw JPEG data given by the pointer srcdata and the length
+// datasize.
+// - width and height are the address where to store the size of the
+// uncompressed image in pixels. May be nullptr.
+// - components is the address where the number of read components are
+// stored. This is *output only*: to request a specific number of
+// components use flags.components. May be nullptr.
+// - nwarn is the address in which to store the number of warnings.
+// May be nullptr.
+// The function returns a pointer to the raw uncompressed data or NULL if
+// there was an error. The caller of the function is responsible for
+// freeing the memory (using delete []).
+uint8* Uncompress(const void* srcdata, int datasize,
+ const UncompressFlags& flags, int* width, int* height,
+ int* components, // Output only: useful with autodetect
+ int* nwarn);
+
+// Version of Uncompress that allocates memory via a callback. The callback
+// arguments are (width, height, components). If the size is known ahead of
+// time this function can return an existing buffer; passing a callback allows
+// the buffer to be shaped based on the JPEG header. The caller is responsible
+// for freeing the memory *even along error paths*.
+uint8* Uncompress(const void* srcdata, int datasize,
+ const UncompressFlags& flags, int* nwarn,
+ std::function<uint8*(int, int, int)> allocate_output);
+
+// Read jpeg header and get image information. Returns true on success.
+// The width, height, and components points may be null.
+bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height,
+ int* components);
+
+// Note: (format & 0xff) = number of components (<=> bytes per pixels)
+enum Format {
+ FORMAT_GRAYSCALE = 0x001, // 1 byte/pixel
+ FORMAT_RGB = 0x003, // 3 bytes/pixel RGBRGBRGBRGB...
+ FORMAT_RGBA = 0x004, // 4 bytes/pixel RGBARGBARGBARGBA...
+ FORMAT_ABGR = 0x104 // 4 bytes/pixel ABGRABGRABGR...
+};
+
+// Flags for compression
+struct CompressFlags {
+ // Encoding of the input data for compression
+ Format format;
+
+ // Quality of the compression from 0-100
+ int quality = 95;
+
+ // If true, create a jpeg image that loads progressively
+ bool progressive = false;
+
+ // If true, reduce jpeg size without changing quality (at the cost of CPU/RAM)
+ bool optimize_jpeg_size = false;
+
+ // See http://en.wikipedia.org/wiki/Chroma_subsampling
+ bool chroma_downsampling = true;
+
+ // Resolution
+ int density_unit = 1; // 1 = in, 2 = cm
+ int x_density = 300;
+ int y_density = 300;
+
+ // If not empty, embed this XMP metadata in the image header
+ StringPiece xmp_metadata;
+
+ // The distance in bytes from one scanline to the other. Should be at least
+ // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride
+ // used will be this minimal value.
+ int stride = 0;
+};
+
+// Compress some raw image given in srcdata, the data is a 2D array of size
+// stride*height with one of the formats enumerated above.
+// The encoded data is returned as a string.
+// If not empty, XMP metadata can be embedded in the image header
+// On error, returns the empty string (which is never a valid jpeg).
+string Compress(const void* srcdata, int width, int height,
+ const CompressFlags& flags);
+
+// On error, returns false and sets output to empty.
+bool Compress(const void* srcdata, int width, int height,
+ const CompressFlags& flags, string* output);
+
+} // namespace jpeg
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc
new file mode 100644
index 0000000000..23e72f9d57
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc
@@ -0,0 +1,304 @@
+#include "tensorflow/core/lib/jpeg/jpeg_mem.h"
+
+#include <setjmp.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <memory>
+
+#include "tensorflow/core/lib/jpeg/jpeg_handle.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+#include <gtest/gtest.h>
+
+#include "tensorflow/core/lib/core/casts.h"
+
+namespace tensorflow {
+namespace jpeg {
+namespace {
+
+const char kTestData[] = "tensorflow/core/lib/jpeg/testdata/";
+
+int ComputeSumAbsoluteDifference(const uint8* a, const uint8* b, int width,
+ int height, int a_stride, int b_stride) {
+ int totalerr = 0;
+ for (int i = 0; i < height; i++) {
+ const uint8* const pa = a + i * a_stride;
+ const uint8* const pb = b + i * b_stride;
+ for (int j = 0; j < 3 * width; j++) {
+ totalerr += abs(static_cast<int>(pa[j]) - static_cast<int>(pb[j]));
+ }
+ }
+ return totalerr;
+}
+
+// Reads the contents of the file into output
+void ReadFileToStringOrDie(Env* env, const string& filename, string* output) {
+ TF_CHECK_OK(ReadFileToString(env, filename, output));
+}
+
+void TestJPEG(Env* env, const string& jpegfile) {
+ // Read the data from the jpeg file into memory
+ string jpeg;
+ ReadFileToStringOrDie(Env::Default(), jpegfile, &jpeg);
+ const int fsize = jpeg.size();
+ const uint8* const temp = bit_cast<const uint8*>(jpeg.data());
+
+ // try partial decoding (half of the data)
+ int w, h, c;
+ std::unique_ptr<uint8[]> imgdata;
+
+ UncompressFlags flags;
+ flags.components = 3;
+
+ // set min_acceptable_fraction to something insufficient
+ flags.min_acceptable_fraction = 0.8;
+ imgdata.reset(Uncompress(temp, fsize / 2, flags, &w, &h, &c, NULL));
+ CHECK(imgdata.get() == NULL);
+
+ // now, use a value that makes fsize/2 be enough for a black-filling
+ flags.min_acceptable_fraction = 0.01;
+ imgdata.reset(Uncompress(temp, fsize / 2, flags, &w, &h, &c, NULL));
+ CHECK(imgdata.get() != NULL);
+
+ // finally, uncompress the whole data
+ flags.min_acceptable_fraction = 1.0;
+ imgdata.reset(Uncompress(temp, fsize, flags, &w, &h, &c, NULL));
+ CHECK(imgdata.get() != NULL);
+
+ // Uncompress the data to RGBA, too
+ flags.min_acceptable_fraction = 1.0;
+ flags.components = 4;
+ imgdata.reset(Uncompress(temp, fsize, flags, &w, &h, &c, NULL));
+ CHECK(imgdata.get() != NULL);
+}
+
+TEST(JpegMemTest, Jpeg) {
+ Env* env = Env::Default();
+ const string data_path = kTestData;
+
+ // Name of a valid jpeg file on the disk
+ TestJPEG(env, data_path + "jpeg_merge_test1.jpg");
+
+ // Exercise CMYK machinery as well
+ TestJPEG(env, data_path + "jpeg_merge_test1_cmyk.jpg");
+}
+
+TEST(JpegMemTest, Jpeg2) {
+ // create known data, for size in_w x in_h
+ const int in_w = 256;
+ const int in_h = 256;
+ const int stride1 = 3 * in_w;
+ const std::unique_ptr<uint8[]> refdata1(new uint8[stride1 * in_h]);
+ for (int i = 0; i < in_h; i++) {
+ for (int j = 0; j < in_w; j++) {
+ const int offset = i * stride1 + 3 * j;
+ refdata1[offset + 0] = i;
+ refdata1[offset + 1] = j;
+ refdata1[offset + 2] = static_cast<uint8>((i + j) >> 1);
+ }
+ }
+
+ // duplicate with weird input stride
+ const int stride2 = 3 * 357;
+ const std::unique_ptr<uint8[]> refdata2(new uint8[stride2 * in_h]);
+ for (int i = 0; i < in_h; i++) {
+ memcpy(&refdata2[i * stride2], &refdata1[i * stride1], 3 * in_w);
+ }
+
+ // Test compression
+ string cpdata1, cpdata2;
+ {
+ const string kXMP = "XMP_TEST_123";
+
+ // Compress it to JPEG
+ CompressFlags flags;
+ flags.format = FORMAT_RGB;
+ flags.quality = 97;
+ flags.xmp_metadata = kXMP;
+ cpdata1 = Compress(refdata1.get(), in_w, in_h, flags);
+ flags.stride = stride2;
+ cpdata2 = Compress(refdata2.get(), in_w, in_h, flags);
+ // Different input stride shouldn't change the output
+ CHECK_EQ(cpdata1, cpdata2);
+
+ // Verify valid XMP.
+ CHECK_NE(string::npos, cpdata1.find(kXMP));
+
+ // Test the other API, where a storage string is supplied
+ string cptest;
+ flags.stride = 0;
+ Compress(refdata1.get(), in_w, in_h, flags, &cptest);
+ CHECK_EQ(cptest, cpdata1);
+ flags.stride = stride2;
+ Compress(refdata2.get(), in_w, in_h, flags, &cptest);
+ CHECK_EQ(cptest, cpdata2);
+ }
+
+ // Uncompress twice: once with 3 components and once with autodetect
+ std::unique_ptr<uint8[]> imgdata1;
+ for (const int components : {0, 3}) {
+ // Uncompress it
+ UncompressFlags flags;
+ flags.components = components;
+ int w, h, c;
+ imgdata1.reset(
+ Uncompress(cpdata1.c_str(), cpdata1.length(), flags, &w, &h, &c, NULL));
+
+ // Check obvious formatting stuff
+ CHECK_EQ(w, in_w);
+ CHECK_EQ(h, in_h);
+ CHECK_EQ(c, 3);
+ CHECK(imgdata1.get());
+
+ // Compare the two images
+ const int totalerr = ComputeSumAbsoluteDifference(
+ imgdata1.get(), refdata1.get(), in_w, in_h, stride1, stride1);
+ CHECK_LE(totalerr, 85000);
+ }
+
+ // check the second image too. Should be bitwise identical to the first.
+ // uncompress using a weird stride
+ {
+ UncompressFlags flags;
+ flags.stride = 3 * 411;
+ const std::unique_ptr<uint8[]> imgdata2(new uint8[flags.stride * in_h]);
+ CHECK(imgdata2.get() == Uncompress(cpdata2.c_str(), cpdata2.length(), flags,
+ NULL, [&imgdata2](int w, int h, int c) {
+ CHECK_EQ(w, in_w);
+ CHECK_EQ(h, in_h);
+ CHECK_EQ(c, 3);
+ return imgdata2.get();
+ }));
+ const int totalerr = ComputeSumAbsoluteDifference(
+ imgdata1.get(), imgdata2.get(), in_w, in_h, stride1, flags.stride);
+ CHECK_EQ(totalerr, 0);
+ }
+}
+
+// Takes JPEG data and reads its headers to determine whether or not the JPEG
+// was chroma downsampled.
+bool IsChromaDownsampled(const string& jpegdata) {
+ // Initialize libjpeg structures to have a memory source
+ // Modify the usual jpeg error manager to catch fatal errors.
+ struct jpeg_decompress_struct cinfo;
+ struct jpeg_error_mgr jerr;
+ jmp_buf jpeg_jmpbuf;
+ cinfo.err = jpeg_std_error(&jerr);
+ cinfo.client_data = &jpeg_jmpbuf;
+ jerr.error_exit = CatchError;
+ if (setjmp(jpeg_jmpbuf)) return false;
+
+ // set up, read header, set image parameters, save size
+ jpeg_create_decompress(&cinfo);
+ SetSrc(&cinfo, jpegdata.c_str(), jpegdata.size(), false);
+
+ jpeg_read_header(&cinfo, TRUE);
+ jpeg_start_decompress(&cinfo); // required to transfer image size to cinfo
+ const int components = cinfo.output_components;
+ if (components == 1) return false;
+
+ // Check validity
+ CHECK_EQ(3, components);
+ CHECK_EQ(cinfo.comp_info[1].h_samp_factor, cinfo.comp_info[2].h_samp_factor)
+ << "The h sampling factors should be the same.";
+ CHECK_EQ(cinfo.comp_info[1].v_samp_factor, cinfo.comp_info[2].v_samp_factor)
+ << "The v sampling factors should be the same.";
+ for (int i = 0; i < components; ++i) {
+ CHECK_GT(cinfo.comp_info[i].h_samp_factor, 0) << "Invalid sampling factor.";
+ CHECK_EQ(cinfo.comp_info[i].h_samp_factor, cinfo.comp_info[i].v_samp_factor)
+ << "The sampling factor should be the same in both directions.";
+ }
+
+ // We're downsampled if we use fewer samples for color than for brightness.
+ // Do this before deallocating cinfo.
+ const bool downsampled =
+ cinfo.comp_info[1].h_samp_factor < cinfo.comp_info[0].h_samp_factor;
+
+ jpeg_destroy_decompress(&cinfo);
+ return downsampled;
+}
+
+TEST(JpegMemTest, ChromaDownsampling) {
+ // Read the data from a test jpeg file into memory
+ const string jpegfile = string(kTestData) + "jpeg_merge_test1.jpg";
+ string jpeg;
+ ReadFileToStringOrDie(Env::Default(), jpegfile, &jpeg);
+
+ // Verify that compressing the JPEG with chroma downsampling works.
+ //
+ // First, uncompress the JPEG.
+ UncompressFlags unflags;
+ unflags.components = 3;
+ int w, h, c, num_warnings;
+ std::unique_ptr<uint8[]> uncompressed(Uncompress(
+ jpeg.c_str(), jpeg.size(), unflags, &w, &h, &c, &num_warnings));
+ CHECK(uncompressed.get() != NULL);
+ CHECK_EQ(num_warnings, 0);
+
+ // Recompress the JPEG with and without chroma downsampling
+ for (const bool downsample : {false, true}) {
+ CompressFlags flags;
+ flags.format = FORMAT_RGB;
+ flags.quality = 85;
+ flags.chroma_downsampling = downsample;
+ string recompressed;
+ Compress(uncompressed.get(), w, h, flags, &recompressed);
+ CHECK(!recompressed.empty());
+ CHECK_EQ(IsChromaDownsampled(recompressed), downsample);
+ }
+}
+
+void TestBadJPEG(Env* env, const string& bad_jpeg_file, int expected_width,
+ int expected_height, const string& reference_RGB_file,
+ const bool try_recover_truncated_jpeg) {
+ string jpeg;
+ ReadFileToStringOrDie(env, bad_jpeg_file, &jpeg);
+
+ UncompressFlags flags;
+ flags.components = 3;
+ flags.try_recover_truncated_jpeg = try_recover_truncated_jpeg;
+
+ int width, height, components;
+ std::unique_ptr<uint8[]> imgdata;
+ imgdata.reset(Uncompress(jpeg.c_str(), jpeg.size(), flags, &width, &height,
+ &components, NULL));
+ if (expected_width > 0) { // we expect the file to decode into 'something'
+ CHECK_EQ(width, expected_width);
+ CHECK_EQ(height, expected_height);
+ CHECK_EQ(components, 3);
+ CHECK(imgdata.get());
+ if (!reference_RGB_file.empty()) {
+ string ref;
+ ReadFileToStringOrDie(env, reference_RGB_file, &ref);
+ CHECK(!memcmp(ref.data(), imgdata.get(), ref.size()));
+ }
+ } else { // no decodable
+ CHECK(!imgdata.get()) << "file:" << bad_jpeg_file;
+ }
+}
+
+TEST(JpegMemTest, BadJpeg) {
+ Env* env = Env::Default();
+ const string data_path = kTestData;
+
+ // Test corrupt file
+ TestBadJPEG(env, data_path + "bad_huffman.jpg", 1024, 768, "", false);
+ TestBadJPEG(env, data_path + "corrupt.jpg", 0 /*120*/, 90, "", false);
+
+ // Truncated files, undecodable because of missing lines:
+ TestBadJPEG(env, data_path + "corrupt34_2.jpg", 0, 3300, "", false);
+ TestBadJPEG(env, data_path + "corrupt34_3.jpg", 0, 3300, "", false);
+ TestBadJPEG(env, data_path + "corrupt34_4.jpg", 0, 3300, "", false);
+
+ // Try in 'recover' mode now:
+ TestBadJPEG(env, data_path + "corrupt34_2.jpg", 2544, 3300, "", true);
+ TestBadJPEG(env, data_path + "corrupt34_3.jpg", 2544, 3300, "", true);
+ TestBadJPEG(env, data_path + "corrupt34_4.jpg", 2544, 3300, "", true);
+}
+
+} // namespace
+} // namespace jpeg
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg b/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg
new file mode 100644
index 0000000000..ef5b6f12c5
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt.jpg
new file mode 100644
index 0000000000..5e2fe6c56f
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/corrupt.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg
new file mode 100644
index 0000000000..4211155c45
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg
new file mode 100644
index 0000000000..c1c2a9d1e1
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg
new file mode 100644
index 0000000000..b8e7308ba0
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg
new file mode 100644
index 0000000000..5e348a12fd
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg
new file mode 100644
index 0000000000..15f895960d
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg
Binary files differ
diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc
new file mode 100644
index 0000000000..43b84e41e0
--- /dev/null
+++ b/tensorflow/core/lib/png/png_io.cc
@@ -0,0 +1,385 @@
+// Functions to read and write images in PNG format.
+
+#include <string.h>
+#include <sys/types.h>
+#include <string>
+#include <utility>
+#include <vector>
+// NOTE(skal): we don't '#include <setjmp.h>' before png/png.h as it otherwise
+// provokes a compile error. We instead let png.h include what is needed.
+
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/png/png_io.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h" // endian
+#include "external/png_archive/libpng-1.2.53/png.h"
+
+namespace tensorflow {
+namespace png {
+
+////////////////////////////////////////////////////////////////////////////////
+// Encode an 8- or 16-bit rgb/grayscale image to PNG string
+////////////////////////////////////////////////////////////////////////////////
+
+namespace {
+
+#define PTR_INC(type, ptr, del) (ptr = \
+ reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del)))
+#define CPTR_INC(type, ptr, del) (ptr = \
+ reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + (del)))
+
+// Convert from 8 bit components to 16. This works in-place.
+static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes,
+ int width, int height, uint16* p16,
+ int p16_row_bytes) {
+ // Adjust pointers to copy backwards
+ width *= num_comps;
+ CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes +
+ (width - 1) * sizeof(*p8));
+ PTR_INC(uint16, p16, (height - 1) * p16_row_bytes +
+ (width - 1) * sizeof(*p16));
+ int bump8 = width * sizeof(*p8) - p8_row_bytes;
+ int bump16 = width * sizeof(*p16) - p16_row_bytes;
+ for (; height-- != 0;
+ CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) {
+ for (int w = width; w-- != 0; --p8, --p16) {
+ uint pix = *p8;
+ pix |= pix << 8;
+ *p16 = static_cast<uint16>(pix);
+ }
+ }
+}
+
+#undef PTR_INC
+#undef CPTR_INC
+
+void ErrorHandler(png_structp png_ptr, png_const_charp msg) {
+ DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
+ ctx->error_condition = true;
+ // To prevent log spam, errors are logged as VLOG(1) instead of ERROR.
+ VLOG(1) << "PNG error: " << msg;
+ longjmp(png_jmpbuf(png_ptr), 1);
+}
+
+void WarningHandler(png_structp png_ptr, png_const_charp msg) {
+ LOG(WARNING) << "PNG warning: " << msg;
+}
+
+void StringReader(png_structp png_ptr,
+ png_bytep data, png_size_t length) {
+ DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
+ if (static_cast<png_size_t>(ctx->data_left) < length) {
+ if (!ctx->error_condition) {
+ VLOG(1) << "PNG read decoding error";
+ ctx->error_condition = true;
+ }
+ memset(data, 0, length);
+ } else {
+ memcpy(data, ctx->data, length);
+ ctx->data += length;
+ ctx->data_left -= length;
+ }
+}
+
+void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
+ string* const s = bit_cast<string*>(png_get_io_ptr(png_ptr));
+ s->append(bit_cast<const char*>(data), length);
+}
+
+void StringWriterFlush(png_structp png_ptr) {
+}
+
+char* check_metadata_string(const string& s) {
+ const char* const c_string = s.c_str();
+ const size_t length = s.size();
+ if (strlen(c_string) != length) {
+ LOG(WARNING) << "Warning! Metadata contains \\0 character(s).";
+ }
+ return const_cast<char*>(c_string);
+}
+
+} // namespace
+
+// We move CommonInitDecode() and CommonFinishDecode()
+// out of the CommonDecode() template to save code space.
+void CommonFreeDecode(DecodeContext* context) {
+ if (context->png_ptr) {
+ png_destroy_read_struct(&context->png_ptr,
+ context->info_ptr ? &context->info_ptr : NULL, 0);
+ context->png_ptr = nullptr;
+ context->info_ptr = nullptr;
+ }
+}
+
+bool DecodeHeader(StringPiece png_string, int* width, int* height,
+ int* components, int* channel_bit_depth,
+ std::vector<std::pair<string, string> >* metadata) {
+ DecodeContext context;
+ // Ask for 16 bits even if there may be fewer. This assures that sniffing
+ // the metadata will succeed in all cases.
+ //
+ // TODO(skal): CommonInitDecode() mixes the operation of sniffing the
+ // metadata with setting up the data conversions. These should be separated.
+ constexpr int kDesiredNumChannels = 1;
+ constexpr int kDesiredChannelBits = 16;
+ if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits,
+ &context)) {
+ return false;
+ }
+ CHECK_NOTNULL(width);
+ *width = static_cast<int>(context.width);
+ CHECK_NOTNULL(height);
+ *height = static_cast<int>(context.height);
+ if (components != NULL) {
+ switch (context.color_type) {
+ case PNG_COLOR_TYPE_PALETTE:
+ *components = (context.info_ptr->valid & PNG_INFO_tRNS) ? 4 : 3;
+ break;
+ case PNG_COLOR_TYPE_GRAY:
+ *components = 1;
+ break;
+ case PNG_COLOR_TYPE_GRAY_ALPHA:
+ *components = 2;
+ break;
+ case PNG_COLOR_TYPE_RGB:
+ *components = 3;
+ break;
+ case PNG_COLOR_TYPE_RGB_ALPHA:
+ *components = 4;
+ break;
+ default:
+ *components = 0;
+ break;
+ }
+ }
+ if (channel_bit_depth != NULL) {
+ *channel_bit_depth = context.bit_depth;
+ }
+ if (metadata != NULL) {
+ metadata->clear();
+ for (int i = 0; i < context.info_ptr->num_text; i++) {
+ const png_text& text = context.info_ptr->text[i];
+ metadata->push_back(std::make_pair(text.key, text.text));
+ }
+ }
+ CommonFreeDecode(&context);
+ return true;
+}
+
+bool CommonInitDecode(StringPiece png_string, int desired_channels,
+ int desired_channel_bits, DecodeContext* context) {
+ CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
+ << "desired_channel_bits = " << desired_channel_bits;
+ CHECK(0 <= desired_channels && desired_channels <= 4) << "desired_channels = "
+ << desired_channels;
+ context->error_condition = false;
+ context->channels = desired_channels;
+ context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,
+ ErrorHandler, WarningHandler);
+ if (!context->png_ptr) {
+ VLOG(1) << ": DecodePNG <- png_create_read_struct failed";
+ return false;
+ }
+ if (setjmp(png_jmpbuf(context->png_ptr))) {
+ VLOG(1) << ": DecodePNG error trapped.";
+ CommonFreeDecode(context);
+ return false;
+ }
+ context->info_ptr = png_create_info_struct(context->png_ptr);
+ if (!context->info_ptr || context->error_condition) {
+ VLOG(1) << ": DecodePNG <- png_create_info_struct failed";
+ CommonFreeDecode(context);
+ return false;
+ }
+ context->data = bit_cast<const uint8*>(png_string.data());
+ context->data_left = png_string.size();
+ png_set_read_fn(context->png_ptr, context, StringReader);
+ png_read_info(context->png_ptr, context->info_ptr);
+ png_get_IHDR(context->png_ptr, context->info_ptr,
+ &context->width, &context->height,
+ &context->bit_depth, &context->color_type,
+ 0, 0, 0);
+ if (context->error_condition) {
+ VLOG(1) << ": DecodePNG <- error during header parsing.";
+ CommonFreeDecode(context);
+ return false;
+ }
+ if (context->width <= 0 || context->height <= 0) {
+ VLOG(1) << ": DecodePNG <- invalid dimensions";
+ CommonFreeDecode(context);
+ return false;
+ }
+ if (context->channels == 0) { // Autodetect number of channels
+ context->channels = context->info_ptr->channels;
+ }
+ const bool has_tRNS = (context->info_ptr->valid & PNG_INFO_tRNS) != 0;
+ const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0;
+ if ((context->channels & 1) == 0) { // We desire alpha
+ if (has_alpha) { // There is alpha
+ } else if (has_tRNS) {
+ png_set_tRNS_to_alpha(context->png_ptr); // Convert transparency to alpha
+ } else {
+ png_set_add_alpha(
+ context->png_ptr, (1 << context->bit_depth) - 1, PNG_FILLER_AFTER);
+ }
+ } else { // We don't want alpha
+ if (has_alpha || has_tRNS) { // There is alpha
+ png_set_strip_alpha(context->png_ptr); // Strip alpha
+ }
+ }
+
+ // If we only want 8 bits, but are given 16, strip off the LS 8 bits
+ if (context->bit_depth > 8 && desired_channel_bits <= 8)
+ png_set_strip_16(context->png_ptr);
+
+ context->need_to_synthesize_16 =
+ (context->bit_depth <= 8 && desired_channel_bits == 16);
+
+ png_set_packing(context->png_ptr);
+ context->num_passes = png_set_interlace_handling(context->png_ptr);
+ png_read_update_info(context->png_ptr, context->info_ptr);
+
+#ifdef IS_LITTLE_ENDIAN
+ if (desired_channel_bits > 8)
+ png_set_swap(context->png_ptr);
+#endif // IS_LITTLE_ENDIAN
+
+ // convert palette to rgb(a) if needs be.
+ if (context->color_type == PNG_COLOR_TYPE_PALETTE)
+ png_set_palette_to_rgb(context->png_ptr);
+
+ // handle grayscale case for source or destination
+ const bool want_gray = (context->channels < 3);
+ const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR);
+ if (is_gray) { // upconvert gray to 8-bit if needed.
+ if (context->bit_depth < 8)
+ png_set_gray_1_2_4_to_8(context->png_ptr);
+ }
+ if (want_gray) { // output is grayscale
+ if (!is_gray)
+ png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587); // 601, JPG
+ } else { // output is rgb(a)
+ if (is_gray)
+ png_set_gray_to_rgb(context->png_ptr); // Enable gray -> RGB conversion
+ }
+ return true;
+}
+
+bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
+ CHECK_NOTNULL(data);
+
+ // we need to re-set the jump point so that we trap the errors
+ // within *this* function (and not CommonInitDecode())
+ if (setjmp(png_jmpbuf(context->png_ptr))) {
+ VLOG(1) << ": DecodePNG error trapped.";
+ CommonFreeDecode(context);
+ return false;
+ }
+ // png_read_row() takes care of offsetting the pointer based on interlacing
+ for (int p = 0; p < context->num_passes; ++p) {
+ png_bytep row = data;
+ for (int h = context->height; h-- != 0; row += row_bytes) {
+ png_read_row(context->png_ptr, row, NULL);
+ }
+ }
+
+ context->info_ptr->valid |= PNG_INFO_IDAT;
+ png_read_end(context->png_ptr, context->info_ptr);
+
+ // Clean up.
+ const bool ok = !context->error_condition;
+ CommonFreeDecode(context);
+
+ // Synthesize 16 bits from 8 if requested.
+ if (context->need_to_synthesize_16)
+ Convert8to16(bit_cast<uint8*>(data), context->channels, row_bytes,
+ context->width, context->height, bit_cast<uint16*>(data),
+ row_bytes);
+ return ok;
+}
+
+bool WriteImageToBuffer(
+ const void* image, int width, int height, int row_bytes, int num_channels,
+ int channel_bits, int compression, string* png_string,
+ const std::vector<std::pair<string, string> >* metadata) {
+ CHECK_NOTNULL(image);
+ CHECK_NOTNULL(png_string);
+ // Although this case is checked inside png.cc and issues an error message,
+ // that error causes memory corruption.
+ if (width == 0 || height == 0)
+ return false;
+
+ png_string->resize(0);
+ png_infop info_ptr = NULL;
+ png_structp png_ptr =
+ png_create_write_struct(PNG_LIBPNG_VER_STRING,
+ NULL, ErrorHandler, WarningHandler);
+ if (png_ptr == NULL) return false;
+ if (setjmp(png_jmpbuf(png_ptr))) {
+ png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : NULL);
+ return false;
+ }
+ info_ptr = png_create_info_struct(png_ptr);
+ if (info_ptr == NULL) {
+ png_destroy_write_struct(&png_ptr, NULL);
+ return false;
+ }
+
+ int color_type = -1;
+ switch (num_channels) {
+ case 1:
+ color_type = PNG_COLOR_TYPE_GRAY;
+ break;
+ case 2:
+ color_type = PNG_COLOR_TYPE_GRAY_ALPHA;
+ break;
+ case 3:
+ color_type = PNG_COLOR_TYPE_RGB;
+ break;
+ case 4:
+ color_type = PNG_COLOR_TYPE_RGB_ALPHA;
+ break;
+ default:
+ png_destroy_write_struct(&png_ptr, &info_ptr);
+ return false;
+ }
+
+ png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush);
+ if (compression < 0) compression = Z_DEFAULT_COMPRESSION;
+ png_set_compression_level(png_ptr, compression);
+ png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);
+ // There used to be a call to png_set_filter here turning off filtering
+ // entirely, but it produced pessimal compression ratios. I'm not sure
+ // why it was there.
+ png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type,
+ PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT,
+ PNG_FILTER_TYPE_DEFAULT);
+ // If we have metadata write to it.
+ if (metadata && !metadata->empty()) {
+ std::vector<png_text> text;
+ for (const auto& pair : *metadata) {
+ png_text txt;
+ txt.compression = PNG_TEXT_COMPRESSION_NONE;
+ txt.key = check_metadata_string(pair.first);
+ txt.text = check_metadata_string(pair.second);
+ text.push_back(txt);
+ }
+ png_set_text(png_ptr, info_ptr, &text[0], text.size());
+ }
+
+ png_write_info(png_ptr, info_ptr);
+#ifdef IS_LITTLE_ENDIAN
+ if (channel_bits > 8)
+ png_set_swap(png_ptr);
+#endif // IS_LITTLE_ENDIAN
+
+ png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image));
+ for (; height--; row += row_bytes) png_write_row(png_ptr, row);
+ png_write_end(png_ptr, NULL);
+
+ png_destroy_write_struct(&png_ptr, &info_ptr);
+ return true;
+}
+
+} // namespace png
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h
new file mode 100644
index 0000000000..df9bff7be8
--- /dev/null
+++ b/tensorflow/core/lib/png/png_io.h
@@ -0,0 +1,88 @@
+// Functions to read and write images in PNG format.
+//
+// The advantage over image/codec/png{enc,dec}ocder.h is that this library
+// supports both 8 and 16 bit images.
+//
+// The decoding routine accepts binary image data as a StringPiece. These are
+// implicitly constructed from strings or char* so they're completely
+// transparent to the caller. They're also very cheap to construct so this
+// doesn't introduce any additional overhead.
+//
+// The primary benefit of StringPieces being, in this case, that APIs already
+// returning StringPieces (e.g., Bigtable Scanner) or Cords (e.g., IOBuffer;
+// only when they're flat, though) or protocol buffer fields typed to either of
+// these can be decoded without copying the data into a C++ string.
+
+#ifndef TENSORFLOW_LIB_PNG_PNG_IO_H_
+#define TENSORFLOW_LIB_PNG_PNG_IO_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "external/png_archive/libpng-1.2.53/png.h"
+
+namespace tensorflow {
+namespace png {
+
+// Handy container for decoding informations and struct pointers
+struct DecodeContext {
+ const uint8* data;
+ int data_left;
+ png_structp png_ptr;
+ png_infop info_ptr;
+ png_uint_32 width, height;
+ int num_passes;
+ int color_type;
+ int bit_depth;
+ int channels;
+ bool need_to_synthesize_16;
+ bool error_condition;
+ DecodeContext() : png_ptr(NULL), info_ptr(NULL) {}
+};
+
+bool DecodeHeader(StringPiece png_string, int* width, int* height,
+ int* components, int* channel_bit_depth,
+ std::vector<std::pair<string, string> >* metadata);
+
+// Sample usage for reading PNG:
+//
+// string png_string; /* fill with input PNG format data */
+// DecodeContext context;
+// CHECK(CommonInitDecode(png_string, 3 /*RGB*/, 8 /*uint8*/, &context));
+// char* image_buffer = new char[3*context.width*context.height];
+// CHECK(CommonFinishDecode(bit_cast<png_byte*>(image_buffer),
+// 3*context.width /*stride*/, &context));
+//
+// desired_channels may be 0 to detected it from the input.
+
+bool CommonInitDecode(StringPiece png_string, int desired_channels,
+ int desired_channel_bits, DecodeContext* context);
+
+bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context);
+
+// Normally called automatically from CommonFinishDecode. If CommonInitDecode
+// is called but not CommonFinishDecode, call this to clean up. Safe to call
+// extra times.
+void CommonFreeDecode(DecodeContext* context);
+
+// Sample usage for writing PNG:
+//
+// uint16* image_buffer = new uint16[width*height]; /* fill with pixels */
+// string png_string;
+// CHECK(WriteImageToBuffer(image_buffer, width, height, 2*width /*stride*/,
+// 1 /*gray*/, 16 /*uint16*/, &png_string, NULL));
+//
+// compression is in [-1,9], where 0 is fast and weak compression, 9 is slow
+// and strong, and -1 is the zlib default.
+
+bool WriteImageToBuffer(
+ const void* image, int width, int height, int row_bytes, int num_channels,
+ int channel_bits, int compression, string* png_string,
+ const std::vector<std::pair<string, string> >* metadata);
+
+} // namespace png
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_PNG_PNG_IO_H_
diff --git a/tensorflow/core/lib/png/testdata/lena_gray.png b/tensorflow/core/lib/png/testdata/lena_gray.png
new file mode 100644
index 0000000000..8bc73159b0
--- /dev/null
+++ b/tensorflow/core/lib/png/testdata/lena_gray.png
Binary files differ
diff --git a/tensorflow/core/lib/png/testdata/lena_rgba.png b/tensorflow/core/lib/png/testdata/lena_rgba.png
new file mode 100644
index 0000000000..79f1f84a62
--- /dev/null
+++ b/tensorflow/core/lib/png/testdata/lena_rgba.png
Binary files differ
diff --git a/tensorflow/core/lib/random/distribution_sampler.cc b/tensorflow/core/lib/random/distribution_sampler.cc
new file mode 100644
index 0000000000..341f1bd595
--- /dev/null
+++ b/tensorflow/core/lib/random/distribution_sampler.cc
@@ -0,0 +1,80 @@
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+
+#include <memory>
+#include <vector>
+
+namespace tensorflow {
+namespace random {
+
+DistributionSampler::DistributionSampler(
+ const gtl::ArraySlice<float>& weights) {
+ DCHECK(!weights.empty());
+ int n = weights.size();
+ num_ = n;
+ data_.reset(new std::pair<float, int>[n]);
+
+ std::unique_ptr<double[]> pr(new double[n]);
+
+ double sum = 0.0;
+ for (int i = 0; i < n; i++) {
+ sum += weights[i];
+ set_alt(i, -1);
+ }
+
+ // These are long/short items - called high/low because of reserved keywords.
+ std::vector<int> high;
+ high.reserve(n);
+ std::vector<int> low;
+ low.reserve(n);
+
+ // compute propotional weights
+ for (int i = 0; i < n; i++) {
+ double p = (weights[i] * n) / sum;
+ pr[i] = p;
+ if (p < 1.0) {
+ low.push_back(i);
+ } else {
+ high.push_back(i);
+ }
+ }
+
+ // Now pair high with low.
+ while (!high.empty() && !low.empty()) {
+ int l = low.back();
+ low.pop_back();
+ int h = high.back();
+ high.pop_back();
+
+ set_alt(l, h);
+ DCHECK_GE(pr[h], 1.0);
+ double remaining = pr[h] - (1.0 - pr[l]);
+ pr[h] = remaining;
+
+ if (remaining < 1.0) {
+ low.push_back(h);
+ } else {
+ high.push_back(h);
+ }
+ }
+ // Transfer pr to prob with rounding errors.
+ for (int i = 0; i < n; i++) {
+ set_prob(i, pr[i]);
+ }
+ // Because of rounding errors, both high and low may have elements, that are
+ // close to 1.0 prob.
+ for (size_t i = 0; i < high.size(); i++) {
+ int idx = high[i];
+ set_prob(idx, 1.0);
+ // set alt to self to prevent rounding errors returning 0
+ set_alt(idx, idx);
+ }
+ for (size_t i = 0; i < low.size(); i++) {
+ int idx = low[i];
+ set_prob(idx, 1.0);
+ // set alt to self to prevent rounding errors returning 0
+ set_alt(idx, idx);
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h
new file mode 100644
index 0000000000..ab9598a205
--- /dev/null
+++ b/tensorflow/core/lib/random/distribution_sampler.h
@@ -0,0 +1,79 @@
+// DistributionSampler allows generating a discrete random variable with a given
+// distribution.
+// The values taken by the variable are [0, N) and relative weights for each
+// value are specified using a vector of size N.
+//
+// The Algorithm takes O(N) time to precompute data at construction time and
+// takes O(1) time (2 random number generation, 2 lookups) for each sample.
+// The data structure takes O(N) memory.
+//
+// In contrast, util/random/weighted-picker.h provides O(lg N) sampling.
+// The advantage of that implementation is that weights can be adjusted
+// dynamically, while DistributionSampler doesn't allow weight adjustment.
+//
+// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2.
+
+#ifndef TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+#define TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+class DistributionSampler {
+ public:
+ explicit DistributionSampler(const gtl::ArraySlice<float>& weights);
+
+ ~DistributionSampler() {}
+
+ int Sample(SimplePhilox* rand) const {
+ float r = rand->RandFloat();
+ // Since n is typically low, we don't bother with UnbiasedUniform.
+ int idx = rand->Uniform(num_);
+ if (r < prob(idx)) return idx;
+ // else pick alt from that bucket.
+ DCHECK_NE(-1, alt(idx));
+ return alt(idx);
+ }
+
+ int num() const { return num_; }
+
+ private:
+ float prob(int idx) const {
+ DCHECK_LT(idx, num_);
+ return data_[idx].first;
+ }
+
+ int alt(int idx) const {
+ DCHECK_LT(idx, num_);
+ return data_[idx].second;
+ }
+
+ void set_prob(int idx, float f) {
+ DCHECK_LT(idx, num_);
+ data_[idx].first = f;
+ }
+
+ void set_alt(int idx, int val) {
+ DCHECK_LT(idx, num_);
+ data_[idx].second = val;
+ }
+
+ int num_;
+ std::unique_ptr<std::pair<float, int>[]> data_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DistributionSampler);
+};
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
diff --git a/tensorflow/core/lib/random/distribution_sampler_test.cc b/tensorflow/core/lib/random/distribution_sampler_test.cc
new file mode 100644
index 0000000000..d61a8daa0f
--- /dev/null
+++ b/tensorflow/core/lib/random/distribution_sampler_test.cc
@@ -0,0 +1,90 @@
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+
+#include <string.h>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+
+class DistributionSamplerTest : public ::testing::Test {
+ protected:
+ // Returns the Chi-Squared statistic for the two distributions.
+ float TestWeights(const std::vector<float>& weights, int trials_per_bin) {
+ int iters = weights.size() * trials_per_bin;
+ std::unique_ptr<float[]> counts(new float[weights.size()]);
+ memset(counts.get(), 0, sizeof(float) * weights.size());
+ DistributionSampler sampler(weights);
+ PhiloxRandom philox(testing::RandomSeed(), 17);
+ SimplePhilox random(&philox);
+ for (int i = 0; i < iters; i++) {
+ int r = sampler.Sample(&random);
+ EXPECT_LT(r, weights.size());
+ EXPECT_GE(r, 0);
+ counts[r] += 1.0;
+ }
+ float chi2 = 0.0;
+ for (size_t i = 0; i < weights.size(); i++) {
+ counts[i] /= iters;
+ float err = (counts[i] - weights[i]);
+ chi2 += (err * err) / weights[i];
+ }
+ return chi2;
+ }
+
+ void TestDistribution(float* arr, int n) {
+ std::vector<float> w;
+ w.reserve(n);
+ for (int i = 0; i < n; i++) {
+ w.push_back(arr[i]);
+ }
+ float var = TestWeights(w, 1000);
+ if (var < 0.001) return;
+ // Maybe a statistical skew. Let's try more iterations.
+ var = TestWeights(w, 100000);
+ if (var < 0.001) return;
+ EXPECT_TRUE(false) << "Chi2 is " << var << " in " << n * 100000
+ << "iterations";
+ }
+};
+
+TEST_F(DistributionSamplerTest, KnownDistribution) {
+ float kEven2[] = {0.5, 0.5};
+ float kEven3[] = {0.33333333, 0.33333333, 0.33333333};
+ float kEven4[] = {0.25, 0.25, 0.25, 0.25};
+
+ float kDist1[] = {0.8, 0.15, 0.05};
+
+ TestDistribution(kEven2, TF_ARRAYSIZE(kEven2));
+ TestDistribution(kEven3, TF_ARRAYSIZE(kEven3));
+ TestDistribution(kEven4, TF_ARRAYSIZE(kEven4));
+ TestDistribution(kDist1, TF_ARRAYSIZE(kDist1));
+}
+
+static void BM_DistributionSampler(int iters, int n) {
+ testing::StopTiming();
+ PhiloxRandom philox(173, 371);
+ SimplePhilox rand(&philox);
+ std::vector<float> weights(n, 0);
+ for (int i = 0; i < n; i++) {
+ weights[i] = rand.Uniform(100);
+ }
+ DistributionSampler picker(weights);
+ testing::StartTiming();
+ int r = 0;
+ for (int i = 0; i < iters; i++) {
+ r |= picker.Sample(&rand);
+ }
+ CHECK_NE(r, kint32max);
+}
+
+BENCHMARK(BM_DistributionSampler)->Arg(10)->Arg(100)->Arg(1000);
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/exact_uniform_int.h b/tensorflow/core/lib/random/exact_uniform_int.h
new file mode 100644
index 0000000000..616354cc5c
--- /dev/null
+++ b/tensorflow/core/lib/random/exact_uniform_int.h
@@ -0,0 +1,68 @@
+// Exact uniform integers using rejection sampling
+
+#ifndef TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
+#define TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
+
+#include <type_traits>
+
+namespace tensorflow {
+namespace random {
+
+template <typename UintType, typename RandomBits>
+UintType ExactUniformInt(const UintType n, const RandomBits& random) {
+ static_assert(std::is_unsigned<UintType>::value,
+ "UintType must be an unsigned int");
+ static_assert(std::is_same<UintType, decltype(random())>::value,
+ "random() should return UintType");
+ if (n == 0) {
+ // Consume a value anyway
+ // TODO(irving): Assert n != 0, since this case makes no sense.
+ return random() * n;
+ } else if (0 == (n & (n - 1))) {
+ // N is a power of two, so just mask off the lower bits.
+ return random() & (n - 1);
+ } else {
+ // Reject all numbers that skew the distribution towards 0.
+
+ // random's output is uniform in the half-open interval [0, 2^{bits}).
+ // For any interval [m,n), the number of elements in it is n-m.
+
+ const UintType range = ~static_cast<UintType>(0);
+ const UintType rem = (range % n) + 1;
+ UintType rnd;
+
+ // rem = ((2^bits-1) \bmod n) + 1
+ // 1 <= rem <= n
+
+ // NB: rem == n is impossible, since n is not a power of 2 (from
+ // earlier check).
+
+ do {
+ rnd = random(); // rnd uniform over [0, 2^{bits})
+ } while (rnd < rem); // reject [0, rem)
+ // rnd is uniform over [rem, 2^{bits})
+ //
+ // The number of elements in the half-open interval is
+ //
+ // 2^{bits} - rem = 2^{bits} - ((2^{bits}-1) \bmod n) - 1
+ // = 2^{bits}-1 - ((2^{bits}-1) \bmod n)
+ // = n \cdot \lfloor (2^{bits}-1)/n \rfloor
+ //
+ // therefore n evenly divides the number of integers in the
+ // interval.
+ //
+ // The function v \rightarrow v % n takes values from [bias,
+ // 2^{bits}) to [0, n). Each integer in the range interval [0, n)
+ // will have exactly \lfloor (2^{bits}-1)/n \rfloor preimages from
+ // the domain interval.
+ //
+ // Therefore, v % n is uniform over [0, n). QED.
+
+ return rnd % n;
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h
new file mode 100644
index 0000000000..2c3cd0c4b9
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random.h
@@ -0,0 +1,232 @@
+// Implement the Philox algorithm to generate random numbers in parallel.
+// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
+// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
+
+#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+
+#include <stdlib.h>
+
+#include "tensorflow/core/platform/port.h"
+
+// Function qualifiers that need to work on both CPU and GPU.
+#ifdef __CUDA_ARCH__
+// For nvcc.
+#define PHILOX_DEVICE_FUNC __host__ __device__
+#define PHILOX_INLINE __inline__
+#else
+// For non-nvcc.
+#define PHILOX_DEVICE_FUNC
+#define PHILOX_INLINE inline
+#endif
+#define PHILOX_DEVICE_INLINE PHILOX_DEVICE_FUNC PHILOX_INLINE
+
+#include <math.h>
+
+namespace tensorflow {
+namespace random {
+
+// A class that represents an inline array. It can be used on both CPU and GPU,
+// and also trivially copyable between CPU and GPU.
+// Arguments:
+// T: the array element type;
+// ElementCount: the fixed size of the array;
+template <typename T, int ElementCount>
+class Array {
+ public:
+ PHILOX_DEVICE_INLINE Array() {
+ for (int i = 0; i < ElementCount; ++i) {
+ data_[i] = T();
+ }
+ }
+
+ PHILOX_DEVICE_INLINE const T& operator[](int index) const {
+ return data_[index];
+ }
+
+ PHILOX_DEVICE_INLINE T& operator[](int index) { return data_[index]; }
+
+ size_t size() const { return ElementCount; }
+
+ private:
+ T data_[ElementCount];
+};
+
+// A class that encapsulates all the states for a random number generator using
+// the philox_4x32_10 algorithm. Each invocation returns a 128-bit random bits
+// in the form of four uint32.
+// There are multiple variants of this algorithm, we picked the 4x32_10 version
+// that is most suited for our applications.
+// Since this class is meant to be copied between CPU to GPU, it maintains a
+// value semantics.
+//
+// For example: To use this class and populate an array of 1024 randoms on CPU
+// with two threads,
+//
+// void Fill(PhiloxRandom rnd, uint32* output, int start, int limit) {
+// assert(start % 4 == 0);
+// assert(limit % 4 == 0);
+// rnd.Skip(start / 4);
+// for (int i = start; i < limit; i += 4) {
+// auto sample = rnd();
+// ... copy sample[0..3] to output[i..i+3]
+// }
+// }
+//
+// PhiloxRandom rng(seed);
+// PhiloxRandom rng_copy = rng;
+// rng.Skip(1000/4);
+//
+// ... schedule Fill(rng_copy, output, 0, 512) in thread 1;
+// ... schedule Fill(rng_copy, output, 512, 1024) in thread 2;
+// ... wait for thread 1 & 2 to finish executing Fill().
+//
+// NOTE:
+// 1. PhiloxRandom is trivially copyable.
+// 2. PhiloxRandom is compilable by gcc and nvcc.
+class PhiloxRandom {
+ public:
+ typedef Array<uint32, 4> ResultType;
+ typedef uint32 ResultElementType;
+ // The number of elements that will be returned.
+ static const int kResultElementCount = 4;
+
+ PHILOX_DEVICE_INLINE
+ PhiloxRandom() {}
+
+ PHILOX_DEVICE_INLINE
+ explicit PhiloxRandom(uint64 seed) {
+ key_[0] = static_cast<uint32>(seed);
+ key_[1] = static_cast<uint32>(seed >> 32);
+ }
+
+ PHILOX_DEVICE_INLINE
+ explicit PhiloxRandom(uint64 seed_lo, uint64 seed_hi) {
+ key_[0] = static_cast<uint32>(seed_lo);
+ key_[1] = static_cast<uint32>(seed_lo >> 32);
+ counter_[2] = static_cast<uint32>(seed_hi);
+ counter_[3] = static_cast<uint32>(seed_hi >> 32);
+ }
+
+ // Skip the specified number of samples of 128-bits in the current stream.
+ PHILOX_DEVICE_INLINE
+ void Skip(uint64 count) {
+ const uint32 count_lo = static_cast<uint32>(count);
+ uint32 count_hi = static_cast<uint32>(count >> 32);
+
+ counter_[0] += count_lo;
+ if (counter_[0] < count_lo) {
+ ++count_hi;
+ }
+
+ counter_[1] += count_hi;
+ if (counter_[1] < count_hi) {
+ if (++counter_[2] == 0) {
+ ++counter_[3];
+ }
+ }
+ }
+
+ // Returns a group of four random numbers using the underlying Philox
+ // algorithm.
+ PHILOX_DEVICE_INLINE ResultType operator()() {
+ ResultType counter = counter_;
+ Key key = key_;
+
+ // Run the single rounds for ten times. Manually unrolling the loop
+ // for better performance.
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+
+ SkipOne();
+
+ return counter;
+ }
+
+ private:
+ // The type for the 64-bit key stored in the form of two 32-bit uint
+ // that are used in the diffusion process.
+ typedef Array<uint32, 2> Key;
+
+ // We use the same constants as recommended by the original paper.
+ static const uint32 kPhiloxW32A = 0x9E3779B9;
+ static const uint32 kPhiloxW32B = 0xBB67AE85;
+ static const uint32 kPhiloxM4x32A = 0xD2511F53;
+ static const uint32 kPhiloxM4x32B = 0xCD9E8D57;
+
+ // Helper function to skip the next sample of 128-bits in the current stream.
+ PHILOX_DEVICE_INLINE void SkipOne() {
+ if (++counter_[0] == 0) {
+ if (++counter_[1] == 0) {
+ if (++counter_[2] == 0) {
+ ++counter_[3];
+ }
+ }
+ }
+ }
+
+ // Helper function to return the lower and higher 32-bits from two 32-bit
+ // integer multiplications.
+ PHILOX_DEVICE_INLINE
+ static void MultiplyHighLow(uint32 a, uint32 b, uint32* result_low,
+ uint32* result_high) {
+#ifndef __GCUDACC__
+ const uint64 product = static_cast<uint64>(a) * b;
+ *result_low = static_cast<uint32>(product);
+ *result_high = static_cast<uint32>(product >> 32);
+#else
+ *result_low = a * b;
+ *result_high = __umulhi(a, b);
+#endif
+ }
+
+ // Helper function for a single round of the underlying Philox algorithm.
+ PHILOX_DEVICE_INLINE static ResultType ComputeSingleRound(
+ const ResultType& counter, const Key& key) {
+ uint32 lo0;
+ uint32 hi0;
+ MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0);
+
+ uint32 lo1;
+ uint32 hi1;
+ MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1);
+
+ ResultType result;
+ result[0] = hi1 ^ counter[1] ^ key[0];
+ result[1] = lo1;
+ result[2] = hi0 ^ counter[3] ^ key[1];
+ result[3] = lo0;
+ return result;
+ }
+
+ PHILOX_DEVICE_INLINE void RaiseKey(Key* key) {
+ (*key)[0] += kPhiloxW32A;
+ (*key)[1] += kPhiloxW32B;
+ }
+
+ private:
+ ResultType counter_;
+ Key key_;
+};
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
diff --git a/tensorflow/core/lib/random/philox_random_test.cc b/tensorflow/core/lib/random/philox_random_test.cc
new file mode 100644
index 0000000000..997c0263b7
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random_test.cc
@@ -0,0 +1,58 @@
+#include "tensorflow/core/lib/random/philox_random.h"
+
+#include <math.h>
+#include <algorithm>
+#include <functional>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/random/philox_random_test_utils.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+// A trivial distribution that just returns the PhiloxRandom as a distribution
+class TrivialPhiloxDistribution {
+ public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = PhiloxRandom::kResultElementCount;
+ typedef PhiloxRandom::ResultType ResultType;
+ typedef PhiloxRandom::ResultElementType ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(PhiloxRandom* gen) { return (*gen)(); }
+};
+
+// This test checks that skipping certain number of samples, is equivalent to
+// generate the same number of samples without skipping.
+TEST(PhiloxRandomTest, SkipMatchTest) {
+ constexpr int count = 1024;
+ constexpr int skip_count = 2048;
+
+ uint64 test_seed = GetTestSeed();
+ std::vector<uint32> v1(count);
+ {
+ PhiloxRandom gen(test_seed);
+ gen.Skip(skip_count / 4);
+ FillRandoms<TrivialPhiloxDistribution>(gen, &v1[0], v1.size());
+ }
+
+ std::vector<uint32> v2(count + skip_count);
+ {
+ PhiloxRandom gen(test_seed);
+ FillRandoms<TrivialPhiloxDistribution>(gen, &v2[0], v2.size());
+ }
+
+ for (int i = 0; i < count; ++i) {
+ ASSERT_EQ(v1[i], v2[i + skip_count]);
+ }
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/philox_random_test_utils.h b/tensorflow/core/lib/random/philox_random_test_utils.h
new file mode 100644
index 0000000000..d22f6b36e4
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random_test_utils.h
@@ -0,0 +1,36 @@
+#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
+#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
+
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace random {
+
+// Return a random seed.
+inline uint64 GetTestSeed() { return New64(); }
+
+// A utility function to fill the given array with samples from the given
+// distribution.
+template <class Distribution>
+void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p,
+ int64 size) {
+ const int granularity = Distribution::kResultElementCount;
+
+ CHECK(size % granularity == 0) << " size: " << size
+ << " granularity: " << granularity;
+
+ Distribution dist;
+ for (int i = 0; i < size; i += granularity) {
+ const auto sample = dist(&gen);
+ std::copy(&sample[0], &sample[0] + granularity, &p[i]);
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
diff --git a/tensorflow/core/lib/random/random.cc b/tensorflow/core/lib/random/random.cc
new file mode 100644
index 0000000000..2959b05382
--- /dev/null
+++ b/tensorflow/core/lib/random/random.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/lib/random/random.h"
+
+#include <random>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+std::mt19937_64* InitRng() {
+ std::random_device device("/dev/random");
+ return new std::mt19937_64(device());
+}
+
+uint64 New64() {
+ static std::mt19937_64* rng = InitRng();
+ static mutex mu;
+ mutex_lock l(mu);
+ return (*rng)();
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/random.h b/tensorflow/core/lib/random/random.h
new file mode 100644
index 0000000000..1a20436c4e
--- /dev/null
+++ b/tensorflow/core/lib/random/random.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_H_
+#define TENSORFLOW_LIB_RANDOM_RANDOM_H_
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+// Return a 64-bit random value. Different sequences are generated
+// in different processes.
+uint64 New64();
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_RANDOM_H_
diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h
new file mode 100644
index 0000000000..caafcde513
--- /dev/null
+++ b/tensorflow/core/lib/random/random_distributions.h
@@ -0,0 +1,361 @@
+#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+
+#include <math.h>
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+
+namespace tensorflow {
+namespace random {
+
+// Helper function to convert a 32-bit integer to a float between [0..1).
+PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x);
+// Helper function to convert two 32-bit integers to a double between [0..1).
+PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1);
+
+// A class that generates uniform distribution random numbers from the
+// underlying random integer generator.
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numberes that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class Generator, typename RealType>
+class UniformDistribution;
+
+template <class Generator>
+class UniformDistribution<Generator, float> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i) {
+ result[i] = Uint32ToFloat(sample[i]);
+ }
+ return result;
+ }
+};
+
+template <class Generator>
+class UniformDistribution<Generator, double> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount / 2;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i) {
+ result[i] = Uint64ToDouble(sample[2 * i], sample[2 * i + 1]);
+ }
+ return result;
+ }
+};
+
+// A class that adapts the underlying native multiple samples to return a single
+// sample at a time.
+template <class Generator>
+class SingleSampleAdapter {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = 1;
+ // The number of elements that will be returned by the underlying generator.
+ static const int kNativeElementCount = Generator::kResultElementCount;
+ typedef typename Generator::ResultElementType ResultType;
+ typedef typename Generator::ResultElementType ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ explicit SingleSampleAdapter(Generator* gen)
+ : generator_(gen), used_result_index_(Generator::kResultElementCount) {}
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()() {
+ if (used_result_index_ == Generator::kResultElementCount) {
+ unused_results_ = (*generator_)();
+ used_result_index_ = 0;
+ }
+
+ return unused_results_[used_result_index_++];
+ }
+
+ private:
+ Generator* generator_;
+ typename Generator::ResultType unused_results_;
+ int used_result_index_;
+};
+
+// A class that generates unit normal distribution random numbers from the
+// underlying random integer generator.
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numberes that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class Generator, typename RealType>
+class NormalDistribution;
+
+PHILOX_DEVICE_INLINE
+void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1);
+
+PHILOX_DEVICE_INLINE
+void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0,
+ double* d1);
+
+template <class Generator>
+class NormalDistribution<Generator, float> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; i += 2) {
+ BoxMullerFloat(sample[i], sample[i + 1], &result[i], &result[i + 1]);
+ }
+ return result;
+ }
+};
+
+template <class Generator>
+class NormalDistribution<Generator, double> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount / 2;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; i += 2) {
+ const int i2 = 2 * i;
+ BoxMullerDouble(sample[i2], sample[i2 + 1], sample[i2 + 2],
+ sample[i2 + 3], &result[i], &result[i + 1]);
+ }
+ return result;
+ }
+};
+
+// A class that returns standard normal distribution between
+// [-kTruncateValue, kTruncateValue].
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numberes that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class SingleSampleGenerator, typename RealType>
+class TruncatedNormalDistribution;
+
+// Partial specialization for float.
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, float> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount =
+ SingleSampleGenerator::kNativeElementCount;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = true;
+ // The threshold where the normal distribution is truncated.
+ const float kTruncateValue = 2.0f;
+
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator* gen) {
+ ResultType results;
+ int index = 0;
+ while (true) {
+ // Repeatedly take samples from the normal distribution, until we have
+ // the desired number of elements that fall within the pre-defined cutoff
+ // threshold.
+ const uint32 x0 = (*gen)();
+ const uint32 x1 = (*gen)();
+ float f[2];
+ BoxMullerFloat(x0, x1, &f[0], &f[1]);
+
+ for (int i = 0; i < 2; ++i) {
+ if (fabs(f[i]) < kTruncateValue) {
+ results[index++] = f[i];
+ if (index >= kResultElementCount) {
+ return results;
+ }
+ }
+ }
+ }
+ }
+};
+
+// Partial specialization for double.
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, double> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount =
+ (SingleSampleGenerator::kNativeElementCount > 1)
+ ? SingleSampleGenerator::kNativeElementCount / 2
+ : 1;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = true;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+ const double kTruncateValue = 2.0;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator* gen) {
+ ResultType results;
+ int index = 0;
+ while (1) {
+ const uint32 x0 = (*gen)();
+ const uint32 x1 = (*gen)();
+ const uint32 x2 = (*gen)();
+ const uint32 x3 = (*gen)();
+ double d[2];
+ BoxMullerDouble(x0, x1, x2, x3, &d[0], &d[1]);
+
+ for (int i = 0; i < 2; ++i) {
+ if (fabs(d[i]) < kTruncateValue) {
+ results[index++] = d[i];
+ if (index >= kResultElementCount) {
+ return results;
+ }
+ }
+ }
+ }
+ }
+};
+
+// Helper function to convert two 32-bit uniform integers to two floats
+// under the unit normal distribution.
+PHILOX_DEVICE_INLINE
+void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1) {
+ // This function implements the Box-Muller transform:
+ // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
+ // Do not send a really small number to log().
+ // We cannot mark "epsilon" as "static const" because NVCC would complain
+ const float epsilon = 1.0e-7f;
+ float u1 = Uint32ToFloat(x0);
+ if (u1 < epsilon) {
+ u1 = epsilon;
+ }
+ const float v1 = 2.0f * M_PI * Uint32ToFloat(x1);
+ const float u2 = sqrt(-2.0f * log(u1));
+#if defined(__linux)
+ sincosf(v1, f0, f1);
+#else
+ *f0 = sinf(v1);
+ *f1 = cosf(v1);
+#endif
+ *f0 *= u2;
+ *f1 *= u2;
+}
+
+// Helper function to convert four 32-bit uniform integers to two doubles
+// under the unit normal distribution.
+PHILOX_DEVICE_INLINE
+void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0,
+ double* d1) {
+ // This function implements the Box-Muller transform:
+ // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
+ // Do not send a really small number to log().
+ // We cannot mark "epsilon" as "static const" because NVCC would complain
+ const double epsilon = 1.0e-7;
+ double u1 = Uint64ToDouble(x0, x1);
+ if (u1 < epsilon) {
+ u1 = epsilon;
+ }
+ const double v1 = 2 * M_PI * Uint64ToDouble(x2, x3);
+ const double u2 = sqrt(-2.0 * log(u1));
+#if defined(__linux)
+ sincos(v1, d0, d1);
+#else
+ *d0 = sin(v1);
+ *d1 = cos(v1);
+#endif
+ *d0 *= u2;
+ *d1 *= u2;
+}
+
+// Helper function to convert an 32-bit integer to a float between [0..1).
+PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x) {
+ // IEEE754 floats are formatted as follows (MSB first):
+ // sign(1) exponent(8) mantissa(23)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 127 -- an excess 127 representation of a zero exponent
+ // mantissa == 23 random bits
+ const uint32 man = x & 0x7fffffu; // 23 bit mantissa
+ const uint32 exp = static_cast<uint32>(127);
+ const uint32 val = (exp << 23) | man;
+
+ // Assumes that endian-ness is same for float and uint32.
+ float result;
+ memcpy(&result, &val, sizeof(val));
+ return result - 1.0f;
+}
+
+// Helper function to convert two 32-bit integers to a double between [0..1).
+PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1) {
+ // IEEE754 doubles are formatted as follows (MSB first):
+ // sign(1) exponent(11) mantissa(52)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 1023 -- an excess 1023 representation of a zero exponent
+ // mantissa == 52 random bits
+ const uint32 mhi = x0 & 0xfffffu; // upper 20 bits of mantissa
+ const uint32 mlo = x1; // lower 32 bits of mantissa
+ const uint64 man = (static_cast<uint64>(mhi) << 32) | mlo; // mantissa
+ const uint64 exp = static_cast<uint64>(1023);
+ const uint64 val = (exp << 52) | man;
+ // Assumes that endian-ness is same for double and uint64.
+ double result;
+ memcpy(&result, &val, sizeof(val));
+ return result - 1.0;
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc
new file mode 100644
index 0000000000..3ce86a907a
--- /dev/null
+++ b/tensorflow/core/lib/random/random_distributions_test.cc
@@ -0,0 +1,270 @@
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+#include <math.h>
+#include <algorithm>
+#include <functional>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/philox_random_test_utils.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+// The largest z-value we want to tolerate. Since the z-test approximates a
+// unit normal distribution, it should almost definitely never exceed 6.
+static constexpr float kZLimit = 6.0;
+
+// A utility function to fill the given array with samples from the given
+// distribution, using the single adatper of the underlying generator
+template <class Distribution>
+void FillRandomsWithSingles(PhiloxRandom gen,
+ typename Distribution::ResultElementType* p,
+ int64 size) {
+ int granularity = Distribution::kResultElementCount;
+
+ CHECK(size % granularity == 0) << " size: " << size
+ << " granularity: " << granularity;
+
+ SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
+
+ Distribution dist;
+ for (int i = 0; i < size; i += granularity) {
+ auto sample = dist(&single_samples);
+ std::copy(&sample[0], &sample[0] + granularity, &p[i]);
+ }
+}
+
+// Check the given array of samples matches the given theoretical moment
+// function at different orders. The test is considered passing if the z-tests
+// of all statistical moments are all below z_limit.
+// typename T in the template argument could be either float or double.
+// Arguments:
+// samples: an array of samples to be tested for their statistical properties;
+// theoretical_moments: a functor that can calculate arbitrary order of
+// of the given distribution;
+// max_moments: the largest moments of the uniform distribution to be tested;
+// stride: the distance between samples to check for statistical properties
+// 0 means the n-th moment of each sample
+// any other strides tests for spatial correlation between samples;
+// z_limit: the maximum z-test we would consider the test to pass;
+template <typename T>
+bool CheckSamplesMoments(const std::vector<T>& samples,
+ std::function<double(int)> theoretical_moments,
+ int max_moments, int stride, T z_limit) {
+ const T* const samples_data = &samples[0];
+ const int samples_size = samples.size();
+ std::vector<double> moments(max_moments + 1);
+ double* const moments_data = &moments[0];
+ std::vector<int> moments_sample_count(max_moments + 1);
+ int* const moments_sample_count_data = &moments_sample_count[0];
+
+ for (int k = 0; k < samples_size; ++k) {
+ double moment = 1.;
+ for (int i = 0; i <= max_moments; ++i) {
+ int index = k + i * stride;
+ if (index >= samples_size) {
+ break;
+ }
+ // moments[i] store the i-th order measured moments.
+ // bypass std::vector::opeartor[] because they are too slow in the debug
+ // mode, given the large number of samples.
+ moments_data[i] += moment;
+ ++moments_sample_count_data[i];
+ moment *= samples_data[index];
+ }
+ }
+
+ // normalize the moments
+ for (int i = 0; i <= max_moments; ++i) {
+ moments[i] /= moments_sample_count[i];
+ }
+
+ bool status = true;
+
+ for (int i = 1; i <= max_moments; ++i) {
+ // Calculate the theoretical mean and variance
+ const double moments_i_mean = (stride == 0)
+ ? theoretical_moments(i)
+ : std::pow(theoretical_moments(1), i);
+ const double moments_i_squared = (stride == 0)
+ ? theoretical_moments(2 * i)
+ : std::pow(theoretical_moments(2), i);
+ const double moments_i_var =
+ moments_i_squared - moments_i_mean * moments_i_mean;
+
+ // assume every operation has a small numerical error.
+ static const double kNumericalError = 1e-6;
+ // it takes i multiplications to calculate one i-th moment.
+ const double error_per_moment = i * kNumericalError;
+ const double total_variance =
+ moments_i_var / moments_sample_count[i] + error_per_moment;
+ // z_test is approximately a unit normal distribution.
+ const double z_test =
+ fabs((moments[i] - moments_i_mean) / sqrt(total_variance));
+
+ if (z_test > z_limit) {
+ LOG(ERROR) << "failing z_test:"
+ << " moment: " << i << " stride: " << stride
+ << " z_test: " << z_test << " z_limit: " << z_limit
+ << " measured moments: " << moments[i]
+ << " theoretical mean of the moments: " << moments_i_mean
+ << " theoretical var of the moments: " << moments_i_var
+ << " sample count: " << moments_sample_count[i];
+ status = false;
+ }
+ }
+
+ return status;
+}
+
+// This tests checks that the generated samples match the theoretical moments
+// of the uniform distribution.
+template <typename T>
+void UniformMomentsTest(int count, int max_moments,
+ const std::vector<int>& strides, T z_limit) {
+ auto uniform_moments = [](int n) -> double { return 1. / (n + 1); };
+
+ std::vector<T> v1(count);
+ uint64 seed = GetTestSeed();
+ PhiloxRandom gen(seed);
+ FillRandoms<UniformDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size());
+ for (int stride : strides) {
+ bool status = CheckSamplesMoments<T>(v1, uniform_moments, max_moments,
+ stride, z_limit);
+ ASSERT_TRUE(status) << " UniformMomentsTest failing. seed: " << seed;
+ }
+}
+
+// This test checks that the generated samples match the theoretical moments
+// of the unit normal distribution.
+template <typename T>
+void NormalMomentsTest(int count, int max_moments,
+ const std::vector<int>& strides, T z_limit) {
+ auto normal_moments = [](int n) -> double {
+ if (n % 2 == 1) {
+ // For an odd order, the moment of a unit normal distribution is zero.
+ return 0.;
+ } else {
+ // For an even order, the moment of a unit normal distribution is.
+ // (n-1)!!
+ double v = 1.;
+ for (int i = n - 1; i >= 1; i -= 2) {
+ v *= i;
+ }
+ return v;
+ }
+ };
+
+ std::vector<T> v1(count);
+ uint64 seed = GetTestSeed();
+ PhiloxRandom gen(seed);
+ FillRandoms<NormalDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size());
+
+ for (int stride : strides) {
+ bool status = CheckSamplesMoments<T>(v1, normal_moments, max_moments,
+ stride, z_limit);
+ ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed;
+ }
+}
+
+// A functor to calculate the moments for the truncated normal distribution.
+// For any odd order, the moment is zero. But for any other n, it can be proven
+// that the following recursive relationship for the moments of the truncated
+// standard normal:
+// m(n) = (n - 1) * m(n - 2) - 2 * v ^ (n - 1) * f(v) / (2 * Phi(v) - 1)
+// where v is the cut-off value, f(v) is the p.d.f of the standard
+// normal, and Phi(v) is the c.d.f of the standard normal.
+class TruncatedNormalMoments {
+ public:
+ double operator()(int n) {
+ if (n == 0) {
+ return 1;
+ }
+ if (n % 2 == 1) {
+ // For an odd order, the moment is always zero
+ return 0.;
+ }
+
+ // Memoization and check the cached results.
+ auto iter = cached_results_.find(n);
+ if (iter != cached_results_.end()) {
+ return iter->second;
+ }
+
+ // The real computation of the moment.
+ double bias = 2.0 * std::pow(kV, n - 1) * kFV / (2.0 * kPhiV - 1.0);
+ double moment_n_minus_2 = (*this)(n - 2);
+ double moment_n = (n - 1) * moment_n_minus_2 - bias;
+
+ cached_results_[n] = moment_n;
+ return moment_n;
+ }
+
+ private:
+ const double kV = 2.0;
+ // f(v), where f is the p.d.f of the normal distribution and v=2.
+ const double kFV = 1.0 / sqrt(2.0 * M_PI) * exp(-kV * kV / 2.0);
+ // The numerical evaluation of Phi(v), where v is the truncate value.
+ // v = 2 in the current implementation.
+ const double kPhiV = 0.977249868051821;
+ std::unordered_map<int, double> cached_results_;
+};
+
+// This test checks that the generated samples matche the theoretical moments
+// of the truncated normal distribution.
+template <typename T>
+void RandomParametersMomentsTest(int count, int max_moments,
+ const std::vector<int>& strides, T z_limit) {
+ std::vector<T> v1(count);
+ uint64 seed = GetTestSeed();
+ PhiloxRandom gen(seed);
+ FillRandomsWithSingles<
+ TruncatedNormalDistribution<SingleSampleAdapter<PhiloxRandom>, T> >(
+ gen, &v1[0], v1.size());
+
+ for (int stride : strides) {
+ bool status = CheckSamplesMoments<T>(v1, TruncatedNormalMoments(),
+ max_moments, stride, z_limit);
+ ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed;
+ }
+}
+
+TEST(PhiloxRandomTest, UniformFloatMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ UniformMomentsTest<float>(1 << 20, 40, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, NormalFloatMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ NormalMomentsTest<float>(8 << 20, 25, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, RandomParametersFloatMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ RandomParametersMomentsTest<float>(1 << 20, 40, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, UniformDoubleMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ UniformMomentsTest<double>(1 << 20, 40, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, NormalDoubleMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ NormalMomentsTest<double>(8 << 20, 25, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, RandomParametersDoubleMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ RandomParametersMomentsTest<double>(1 << 20, 40, strides, kZLimit);
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/random_test.cc b/tensorflow/core/lib/random/random_test.cc
new file mode 100644
index 0000000000..7ed37c8b5e
--- /dev/null
+++ b/tensorflow/core/lib/random/random_test.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/lib/random/random.h"
+
+#include <set>
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+TEST(New64Test, SanityCheck) {
+ std::set<uint64> values;
+ for (int i = 0; i < 1000000; i++) {
+ uint64 x = New64();
+ EXPECT_TRUE(values.insert(x).second) << "duplicate " << x;
+ }
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/simple_philox.cc b/tensorflow/core/lib/random/simple_philox.cc
new file mode 100644
index 0000000000..1035e1f017
--- /dev/null
+++ b/tensorflow/core/lib/random/simple_philox.cc
@@ -0,0 +1,24 @@
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/random/exact_uniform_int.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace random {
+
+uint32 SimplePhilox::Uniform(uint32 n) {
+ return ExactUniformInt<uint32>(n, [this]() { return Rand32(); });
+}
+
+uint64 SimplePhilox::Uniform64(uint64 n) {
+ return ExactUniformInt<uint64>(n, [this]() { return Rand64(); });
+}
+
+uint32 SimplePhilox::Skewed(int max_log) {
+ CHECK(0 <= max_log && max_log <= 32);
+ const int shift = Rand32() % (max_log + 1);
+ const uint32 mask = shift == 32 ? ~static_cast<uint32>(0) : (1 << shift) - 1;
+ return Rand32() & mask;
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/simple_philox.h b/tensorflow/core/lib/random/simple_philox.h
new file mode 100644
index 0000000000..12b15d7616
--- /dev/null
+++ b/tensorflow/core/lib/random/simple_philox.h
@@ -0,0 +1,61 @@
+#ifndef TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+#define TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+
+#include <math.h>
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+namespace random {
+
+// A simple imperative interface to Philox
+class SimplePhilox {
+ public:
+ PHILOX_DEVICE_INLINE
+ explicit SimplePhilox(PhiloxRandom* gen) : single_(gen) {}
+
+ // 32 random bits
+ PHILOX_DEVICE_INLINE uint32 Rand32() { return single_(); }
+
+ // 64 random bits
+ PHILOX_DEVICE_INLINE uint64 Rand64() {
+ const uint32 lo = single_(), hi = single_();
+ return lo | static_cast<uint64>(hi) << 32;
+ }
+
+ // Uniform float in [0, 1)
+ PHILOX_DEVICE_INLINE float RandFloat() { return Uint32ToFloat(single_()); }
+
+ // Uniform double in [0, 1)
+ PHILOX_DEVICE_INLINE double RandDouble() {
+ const uint32 x0 = single_(), x1 = single_();
+ return Uint64ToDouble(x0, x1);
+ }
+
+ // Uniform integer in [0, n).
+ // Uses rejection sampling, so may need more than one 32-bit sample.
+ uint32 Uniform(uint32 n);
+
+ // Approximately uniform integer in [0, n).
+ // Uses rejection sampling, so may need more than one 64-bit sample.
+ uint64 Uniform64(uint64 n);
+
+ // True with probability 1/n.
+ bool OneIn(uint32 n) { return Uniform(n) == 0; }
+
+ // Skewed: pick "base" uniformly from range [0,max_log] and then
+ // return "base" random bits. The effect is to pick a number in the
+ // range [0,2^max_log-1] with bias towards smaller numbers.
+ uint32 Skewed(int max_log);
+
+ private:
+ SingleSampleAdapter<PhiloxRandom> single_;
+};
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
diff --git a/tensorflow/core/lib/random/simple_philox_test.cc b/tensorflow/core/lib/random/simple_philox_test.cc
new file mode 100644
index 0000000000..4246b8b4dd
--- /dev/null
+++ b/tensorflow/core/lib/random/simple_philox_test.cc
@@ -0,0 +1,120 @@
+#include "tensorflow/core/lib/random/simple_philox.h"
+
+#include <set>
+#include <string>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+TEST(SimplePhiloxTest, FloatTest) {
+ PhiloxRandom philox(7, 7);
+ SimplePhilox gen(&philox);
+ static const int kIters = 1000000;
+ for (int i = 0; i < kIters; ++i) {
+ float f = gen.RandFloat();
+ EXPECT_LE(0.0f, f);
+ EXPECT_GT(1.0f, f);
+ }
+ for (int i = 0; i < kIters; ++i) {
+ double d = gen.RandDouble();
+ EXPECT_LE(0.0, d);
+ EXPECT_GT(1.0, d);
+ }
+}
+
+static void DifferenceTest(const char *names, SimplePhilox *gen1,
+ SimplePhilox *gen2) {
+ static const int kIters = 100;
+ bool different = false;
+ for (int i = 0; i < kIters; ++i) {
+ if (gen1->Rand32() != gen2->Rand32()) {
+ different = true;
+ break;
+ }
+ }
+ CHECK(different) << "different seeds but same output!";
+}
+
+TEST(SimplePhiloxTest, DifferenceTest) {
+ PhiloxRandom philox1(1, 1), philox2(17, 17);
+ SimplePhilox gen1(&philox1), gen2(&philox2);
+
+ DifferenceTest("SimplePhilox: different seeds", &gen1, &gen2);
+}
+
+TEST(SimplePhiloxTest, DifferenceTestCloseSeeds) {
+ PhiloxRandom philox1(1, 1), philox2(2, 1);
+ SimplePhilox gen1(&philox1), gen2(&philox2);
+
+ DifferenceTest("SimplePhilox: close seeds", &gen1, &gen2);
+}
+
+TEST(SimplePhiloxTest, Regression_CloseSeedsAreDifferent) {
+ const int kCount = 1000;
+
+ // Two seeds differ only by the last bit.
+ PhiloxRandom philox1(0, 1), philox2(1, 1);
+ SimplePhilox gen1(&philox1), gen2(&philox2);
+
+ std::set<uint32> first;
+ std::set<uint32> all;
+ for (int i = 0; i < kCount; ++i) {
+ uint32 v = gen1.Rand32();
+ first.insert(v);
+ all.insert(v);
+ all.insert(gen2.Rand32());
+ }
+
+ // Broken array initialization implementation (before 2009-08-18) using the
+ // above seeds return <1000, 1007>, generating output that is >99% similar.
+ // The fix returns <1000, 2000> for completely disjoint sets.
+ EXPECT_EQ(kCount, first.size());
+ EXPECT_EQ(2 * kCount, all.size());
+}
+
+TEST(SimplePhiloxTest, TestUniform) {
+ PhiloxRandom philox(17, 17);
+ SimplePhilox gen(&philox);
+
+ uint32 range = 3 * (1L << 29);
+ uint32 threshold = 1L << 30;
+
+ size_t count = 0;
+ static const int kTrials = 100000;
+ for (int i = 0; i < kTrials; ++i) {
+ uint32 rnd = gen.Uniform(range);
+ if (rnd < threshold) {
+ ++count;
+ }
+ }
+
+ EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005);
+}
+
+TEST(SimplePhiloxTest, TestUniform64) {
+ PhiloxRandom philox(17, 17);
+ SimplePhilox gen(&philox);
+
+ uint64 range = 3 * (1LL << 59);
+ uint64 threshold = 1LL << 60;
+
+ size_t count = 0;
+ static const int kTrials = 100000;
+ for (int i = 0; i < kTrials; ++i) {
+ uint64 rnd = gen.Uniform64(range);
+ if (rnd < threshold) {
+ ++count;
+ }
+ }
+
+ EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005);
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/weighted_picker.cc b/tensorflow/core/lib/random/weighted_picker.cc
new file mode 100644
index 0000000000..f96da578ec
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker.cc
@@ -0,0 +1,203 @@
+#include "tensorflow/core/lib/random/weighted_picker.h"
+
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/simple_philox.h"
+
+namespace tensorflow {
+namespace random {
+
+WeightedPicker::WeightedPicker(int N) {
+ CHECK_GE(N, 0);
+ N_ = N;
+
+ // Find the number of levels
+ num_levels_ = 1;
+ while (LevelSize(num_levels_ - 1) < N) {
+ num_levels_++;
+ }
+
+ // Initialize the levels
+ level_ = new int32*[num_levels_];
+ for (int l = 0; l < num_levels_; l++) {
+ level_[l] = new int32[LevelSize(l)];
+ }
+
+ SetAllWeights(1);
+}
+
+WeightedPicker::~WeightedPicker() {
+ for (int l = 0; l < num_levels_; l++) {
+ delete[] level_[l];
+ }
+ delete[] level_;
+}
+
+static int32 UnbiasedUniform(SimplePhilox* r, int32 n) {
+ CHECK_LE(0, n);
+ const uint32 range = ~static_cast<uint32>(0);
+ if (n == 0) {
+ return r->Rand32() * n;
+ } else if (0 == (n & (n - 1))) {
+ // N is a power of two, so just mask off the lower bits.
+ return r->Rand32() & (n - 1);
+ } else {
+ // Reject all numbers that skew the distribution towards 0.
+
+ // Rand32's output is uniform in the half-open interval [0, 2^{32}).
+ // For any interval [m,n), the number of elements in it is n-m.
+
+ uint32 rem = (range % n) + 1;
+ uint32 rnd;
+
+ // rem = ((2^{32}-1) \bmod n) + 1
+ // 1 <= rem <= n
+
+ // NB: rem == n is impossible, since n is not a power of 2 (from
+ // earlier check).
+
+ do {
+ rnd = r->Rand32(); // rnd uniform over [0, 2^{32})
+ } while (rnd < rem); // reject [0, rem)
+ // rnd is uniform over [rem, 2^{32})
+ //
+ // The number of elements in the half-open interval is
+ //
+ // 2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1
+ // = 2^{32}-1 - ((2^{32}-1) \bmod n)
+ // = n \cdot \lfloor (2^{32}-1)/n \rfloor
+ //
+ // therefore n evenly divides the number of integers in the
+ // interval.
+ //
+ // The function v \rightarrow v % n takes values from [bias,
+ // 2^{32}) to [0, n). Each integer in the range interval [0, n)
+ // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from
+ // the domain interval.
+ //
+ // Therefore, v % n is uniform over [0, n). QED.
+
+ return rnd % n;
+ }
+}
+
+int WeightedPicker::Pick(SimplePhilox* rnd) const {
+ if (total_weight() == 0) return -1;
+
+ // using unbiased uniform distribution to avoid bias
+ // toward low elements resulting from a possible use
+ // of big weights.
+ return PickAt(UnbiasedUniform(rnd, total_weight()));
+}
+
+int WeightedPicker::PickAt(int32 weight_index) const {
+ if (weight_index < 0 || weight_index >= total_weight()) return -1;
+
+ int32 position = weight_index;
+ int index = 0;
+
+ for (int l = 1; l < num_levels_; l++) {
+ // Pick left or right child of "level_[l-1][index]"
+ const int32 left_weight = level_[l][2 * index];
+ if (position < left_weight) {
+ // Descend to left child
+ index = 2 * index;
+ } else {
+ // Descend to right child
+ index = 2 * index + 1;
+ position -= left_weight;
+ }
+ }
+ CHECK_GE(index, 0);
+ CHECK_LT(index, N_);
+ CHECK_LE(position, level_[num_levels_ - 1][index]);
+ return index;
+}
+
+void WeightedPicker::set_weight(int index, int32 weight) {
+ assert(index >= 0);
+ assert(index < N_);
+
+ // Adjust the sums all the way up to the root
+ const int32 delta = weight - get_weight(index);
+ for (int l = num_levels_ - 1; l >= 0; l--) {
+ level_[l][index] += delta;
+ index >>= 1;
+ }
+}
+
+void WeightedPicker::SetAllWeights(int32 weight) {
+ // Initialize leaves
+ int32* leaves = level_[num_levels_ - 1];
+ for (int i = 0; i < N_; i++) leaves[i] = weight;
+ for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
+
+ // Now sum up towards the root
+ RebuildTreeWeights();
+}
+
+void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) {
+ Resize(N);
+
+ // Initialize leaves
+ int32* leaves = level_[num_levels_ - 1];
+ for (int i = 0; i < N_; i++) leaves[i] = weights[i];
+ for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
+
+ // Now sum up towards the root
+ RebuildTreeWeights();
+}
+
+void WeightedPicker::RebuildTreeWeights() {
+ for (int l = num_levels_ - 2; l >= 0; l--) {
+ int32* level = level_[l];
+ int32* children = level_[l + 1];
+ for (int i = 0; i < LevelSize(l); i++) {
+ level[i] = children[2 * i] + children[2 * i + 1];
+ }
+ }
+}
+
+void WeightedPicker::Append(int32 weight) {
+ Resize(num_elements() + 1);
+ set_weight(num_elements() - 1, weight);
+}
+
+void WeightedPicker::Resize(int new_size) {
+ CHECK_GE(new_size, 0);
+ if (new_size <= LevelSize(num_levels_ - 1)) {
+ // The new picker fits in the existing levels.
+
+ // First zero out any of the weights that are being dropped so
+ // that the levels are correct (only needed when shrinking)
+ for (int i = new_size; i < N_; i++) {
+ set_weight(i, 0);
+ }
+
+ // We do not need to set any new weights when enlarging because
+ // the unneeded entries always have weight zero.
+ N_ = new_size;
+ return;
+ }
+
+ // We follow the simple strategy of just copying the old
+ // WeightedPicker into a new WeightedPicker. The cost is
+ // O(N) regardless.
+ assert(new_size > N_);
+ WeightedPicker new_picker(new_size);
+ int32* dst = new_picker.level_[new_picker.num_levels_ - 1];
+ int32* src = this->level_[this->num_levels_ - 1];
+ memcpy(dst, src, sizeof(dst[0]) * N_);
+ memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_));
+ new_picker.RebuildTreeWeights();
+
+ // Now swap the two pickers
+ std::swap(new_picker.N_, this->N_);
+ std::swap(new_picker.num_levels_, this->num_levels_);
+ std::swap(new_picker.level_, this->level_);
+ assert(this->N_ == new_size);
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/weighted_picker.h b/tensorflow/core/lib/random/weighted_picker.h
new file mode 100644
index 0000000000..3d2c2dbb39
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker.h
@@ -0,0 +1,118 @@
+
+// An abstraction to pick from one of N elements with a specified
+// weight per element.
+//
+// The weight for a given element can be changed in O(lg N) time
+// An element can be picked in O(lg N) time.
+//
+// Uses O(N) bytes of memory.
+//
+// Alternative: distribution-sampler.h allows O(1) time picking, but no weight
+// adjustment after construction.
+
+#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
+#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
+
+#include <assert.h>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+class SimplePhilox;
+
+class WeightedPicker {
+ public:
+ // REQUIRES N >= 0
+ // Initializes the elements with a weight of one per element
+ explicit WeightedPicker(int N);
+
+ // Releases all resources
+ ~WeightedPicker();
+
+ // Pick a random element with probability proportional to its weight.
+ // If total weight is zero, returns -1.
+ int Pick(SimplePhilox* rnd) const;
+
+ // Deterministically pick element x whose weight covers the
+ // specified weight_index.
+ // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ]
+ int PickAt(int32 weight_index) const;
+
+ // Get the weight associated with an element
+ // REQUIRES 0 <= index < N
+ int32 get_weight(int index) const;
+
+ // Set the weight associated with an element
+ // REQUIRES weight >= 0.0f
+ // REQUIRES 0 <= index < N
+ void set_weight(int index, int32 weight);
+
+ // Get the total combined weight of all elements
+ int32 total_weight() const;
+
+ // Get the number of elements in the picker
+ int num_elements() const;
+
+ // Set weight of each element to "weight"
+ void SetAllWeights(int32 weight);
+
+ // Resizes the picker to N and
+ // sets the weight of each element i to weight[i].
+ // The sum of the weights should not exceed 2^31 - 2
+ // Complexity O(N).
+ void SetWeightsFromArray(int N, const int32* weights);
+
+ // REQUIRES N >= 0
+ //
+ // Resize the weighted picker so that it has "N" elements.
+ // Any newly added entries have zero weight.
+ //
+ // Note: Resizing to a smaller size than num_elements() will
+ // not reclaim any memory. If you wish to reduce memory usage,
+ // allocate a new WeightedPicker of the appropriate size.
+ //
+ // It is efficient to use repeated calls to Resize(num_elements() + 1)
+ // to grow the picker to size X (takes total time O(X)).
+ void Resize(int N);
+
+ // Grow the picker by one and set the weight of the new entry to "weight".
+ //
+ // Repeated calls to Append() in order to grow the
+ // picker to size X takes a total time of O(X lg(X)).
+ // Consider using SetWeightsFromArray instead.
+ void Append(int32 weight);
+
+ private:
+ // We keep a binary tree with N leaves. The "i"th leaf contains
+ // the weight of the "i"th element. An internal node contains
+ // the sum of the weights of its children.
+ int N_; // Number of elements
+ int num_levels_; // Number of levels in tree (level-0 is root)
+ int32** level_; // Array that holds nodes per level
+
+ // Size of each level
+ static int LevelSize(int level) { return 1 << level; }
+
+ // Rebuild the tree weights using the leaf weights
+ void RebuildTreeWeights();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker);
+};
+
+inline int32 WeightedPicker::get_weight(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, N_);
+ return level_[num_levels_ - 1][index];
+}
+
+inline int32 WeightedPicker::total_weight() const { return level_[0][0]; }
+
+inline int WeightedPicker::num_elements() const { return N_; }
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
diff --git a/tensorflow/core/lib/random/weighted_picker_test.cc b/tensorflow/core/lib/random/weighted_picker_test.cc
new file mode 100644
index 0000000000..0b27d437d5
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker_test.cc
@@ -0,0 +1,254 @@
+#include "tensorflow/core/lib/random/weighted_picker.h"
+
+#include <string.h>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+
+static void TestPicker(SimplePhilox* rnd, int size);
+static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials);
+static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials);
+static void TestPickAt(int items, const int32* weights);
+
+TEST(WeightedPicker, Simple) {
+ PhiloxRandom philox(testing::RandomSeed(), 17);
+ SimplePhilox rnd(&philox);
+
+ {
+ VLOG(0) << "======= Zero-length picker";
+ WeightedPicker picker(0);
+ EXPECT_EQ(picker.Pick(&rnd), -1);
+ }
+
+ {
+ VLOG(0) << "======= Singleton picker";
+ WeightedPicker picker(1);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ }
+
+ {
+ VLOG(0) << "======= Grown picker";
+ WeightedPicker picker(0);
+ for (int i = 0; i < 10; i++) {
+ picker.Append(1);
+ }
+ CheckUniform(&rnd, &picker, 100000);
+ }
+
+ {
+ VLOG(0) << "======= Grown picker with zero weights";
+ WeightedPicker picker(1);
+ picker.Resize(10);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ }
+
+ {
+ VLOG(0) << "======= Shrink picker and check weights";
+ WeightedPicker picker(1);
+ picker.Resize(10);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ for (int i = 0; i < 10; i++) {
+ picker.set_weight(i, i);
+ }
+ EXPECT_EQ(picker.total_weight(), 45);
+ picker.Resize(5);
+ EXPECT_EQ(picker.total_weight(), 10);
+ picker.Resize(2);
+ EXPECT_EQ(picker.total_weight(), 1);
+ picker.Resize(1);
+ EXPECT_EQ(picker.total_weight(), 0);
+ }
+}
+
+TEST(WeightedPicker, BigWeights) {
+ PhiloxRandom philox(testing::RandomSeed() + 1, 17);
+ SimplePhilox rnd(&philox);
+ VLOG(0) << "======= Check uniform with big weights";
+ WeightedPicker picker(2);
+ picker.SetAllWeights(2147483646L / 3); // (2^31 - 2) / 3
+ CheckUniform(&rnd, &picker, 100000);
+}
+
+TEST(WeightedPicker, Deterministic) {
+ VLOG(0) << "======= Testing deterministic pick";
+ static const int32 weights[] = {1, 0, 200, 5, 42};
+ TestPickAt(TF_ARRAYSIZE(weights), weights);
+}
+
+TEST(WeightedPicker, Randomized) {
+ PhiloxRandom philox(testing::RandomSeed() + 10, 17);
+ SimplePhilox rnd(&philox);
+ TestPicker(&rnd, 1);
+ TestPicker(&rnd, 2);
+ TestPicker(&rnd, 3);
+ TestPicker(&rnd, 4);
+ TestPicker(&rnd, 7);
+ TestPicker(&rnd, 8);
+ TestPicker(&rnd, 9);
+ TestPicker(&rnd, 10);
+ TestPicker(&rnd, 100);
+}
+
+static void TestPicker(SimplePhilox* rnd, int size) {
+ VLOG(0) << "======= Testing size " << size;
+
+ // Check that empty picker returns -1
+ {
+ WeightedPicker picker(size);
+ picker.SetAllWeights(0);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1);
+ }
+
+ // Create zero weights array
+ std::vector<int32> weights(size);
+ for (int elem = 0; elem < size; elem++) {
+ weights[elem] = 0;
+ }
+
+ // Check that singleton picker always returns the same element
+ for (int elem = 0; elem < size; elem++) {
+ WeightedPicker picker(size);
+ picker.SetAllWeights(0);
+ picker.set_weight(elem, elem + 1);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
+ weights[elem] = 10;
+ picker.SetWeightsFromArray(size, &weights[0]);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
+ weights[elem] = 0;
+ }
+
+ // Check that uniform picker generates elements roughly uniformly
+ {
+ WeightedPicker picker(size);
+ CheckUniform(rnd, &picker, 100000);
+ }
+
+ // Check uniform picker that was grown piecemeal
+ if (size / 3 > 0) {
+ WeightedPicker picker(size / 3);
+ while (picker.num_elements() != size) {
+ picker.Append(1);
+ }
+ CheckUniform(rnd, &picker, 100000);
+ }
+
+ // Check that skewed distribution works
+ if (size <= 10) {
+ // When picker grows one element at a time
+ WeightedPicker picker(size);
+ int32 weight = 1;
+ for (int elem = 0; elem < size; elem++) {
+ picker.set_weight(elem, weight);
+ weights[elem] = weight;
+ weight *= 2;
+ }
+ CheckSkewed(rnd, &picker, 1000000);
+
+ // When picker is created from an array
+ WeightedPicker array_picker(0);
+ array_picker.SetWeightsFromArray(size, &weights[0]);
+ CheckSkewed(rnd, &array_picker, 1000000);
+ }
+}
+
+static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker,
+ int trials) {
+ const int size = picker->num_elements();
+ int* count = new int[size];
+ memset(count, 0, sizeof(count[0]) * size);
+ for (int i = 0; i < size * trials; i++) {
+ const int elem = picker->Pick(rnd);
+ EXPECT_GE(elem, 0);
+ EXPECT_LT(elem, size);
+ count[elem]++;
+ }
+ const int expected_min = int(0.9 * trials);
+ const int expected_max = int(1.1 * trials);
+ for (int i = 0; i < size; i++) {
+ EXPECT_GE(count[i], expected_min);
+ EXPECT_LE(count[i], expected_max);
+ }
+ delete[] count;
+}
+
+static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) {
+ const int size = picker->num_elements();
+ int* count = new int[size];
+ memset(count, 0, sizeof(count[0]) * size);
+ for (int i = 0; i < size * trials; i++) {
+ const int elem = picker->Pick(rnd);
+ EXPECT_GE(elem, 0);
+ EXPECT_LT(elem, size);
+ count[elem]++;
+ }
+
+ for (int i = 0; i < size - 1; i++) {
+ LOG(INFO) << i << ": " << count[i];
+ const float ratio = float(count[i + 1]) / float(count[i]);
+ EXPECT_GE(ratio, 1.6f);
+ EXPECT_LE(ratio, 2.4f);
+ }
+ delete[] count;
+}
+
+static void TestPickAt(int items, const int32* weights) {
+ WeightedPicker picker(items);
+ picker.SetWeightsFromArray(items, weights);
+ int weight_index = 0;
+ for (int i = 0; i < items; ++i) {
+ for (int j = 0; j < weights[i]; ++j) {
+ int pick = picker.PickAt(weight_index);
+ EXPECT_EQ(pick, i);
+ ++weight_index;
+ }
+ }
+ EXPECT_EQ(weight_index, picker.total_weight());
+}
+
+static void BM_Create(int iters, int arg) {
+ while (--iters > 0) {
+ WeightedPicker p(arg);
+ }
+}
+BENCHMARK(BM_Create)->Range(1, 1024);
+
+static void BM_CreateAndSetWeights(int iters, int arg) {
+ std::vector<int32> weights(arg);
+ for (int i = 0; i < arg; i++) {
+ weights[i] = i * 10;
+ }
+ while (--iters > 0) {
+ WeightedPicker p(arg);
+ p.SetWeightsFromArray(arg, &weights[0]);
+ }
+}
+BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024);
+
+static void BM_Pick(int iters, int arg) {
+ PhiloxRandom philox(301, 17);
+ SimplePhilox rnd(&philox);
+ WeightedPicker p(arg);
+ int result = 0;
+ while (--iters > 0) {
+ result += p.Pick(&rnd);
+ }
+ VLOG(4) << result; // Dummy use
+}
+BENCHMARK(BM_Pick)->Range(1, 1024);
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
new file mode 100644
index 0000000000..d61129fb3f
--- /dev/null
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -0,0 +1,260 @@
+#include "tensorflow/core/lib/strings/numbers.h"
+
+#include <float.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <algorithm>
+#include <cmath>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace strings {
+
+char* FastInt32ToBufferLeft(int32 i, char* buffer) {
+ uint32 u = i;
+ if (i < 0) {
+ *buffer++ = '-';
+ // We need to do the negation in modular (i.e., "unsigned")
+ // arithmetic; MSVC++ apprently warns for plain "-u", so
+ // we write the equivalent expression "0 - u" instead.
+ u = 0 - u;
+ }
+ return FastUInt32ToBufferLeft(u, buffer);
+}
+
+char* FastUInt32ToBufferLeft(uint32 i, char* buffer) {
+ char* start = buffer;
+ do {
+ *buffer++ = ((i % 10) + '0');
+ i /= 10;
+ } while (i > 0);
+ *buffer = 0;
+ std::reverse(start, buffer);
+ return buffer;
+}
+
+char* FastInt64ToBufferLeft(int64 i, char* buffer) {
+ uint64 u = i;
+ if (i < 0) {
+ *buffer++ = '-';
+ u = 0 - u;
+ }
+ return FastUInt64ToBufferLeft(u, buffer);
+}
+
+char* FastUInt64ToBufferLeft(uint64 i, char* buffer) {
+ char* start = buffer;
+ do {
+ *buffer++ = ((i % 10) + '0');
+ i /= 10;
+ } while (i > 0);
+ *buffer = 0;
+ std::reverse(start, buffer);
+ return buffer;
+}
+
+static const double kDoublePrecisionCheckMax = DBL_MAX / 1.000000000000001;
+
+char* DoubleToBuffer(double value, char* buffer) {
+ // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
+ // platforms these days. Just in case some system exists where DBL_DIG
+ // is significantly larger -- and risks overflowing our buffer -- we have
+ // this assert.
+ static_assert(DBL_DIG < 20, "DBL_DIG is too big");
+
+ bool full_precision_needed = true;
+ if (std::abs(value) <= kDoublePrecisionCheckMax) {
+ int snprintf_result =
+ snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG, value);
+
+ // The snprintf should never overflow because the buffer is significantly
+ // larger than the precision we asked for.
+ DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
+
+ full_precision_needed = strtod(buffer, NULL) != value;
+ }
+
+ if (full_precision_needed) {
+ int snprintf_result =
+ snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG + 2, value);
+
+ // Should never overflow; see above.
+ DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
+ }
+ return buffer;
+}
+
+bool safe_strto64(const char* str, int64* value) {
+ if (!str) return false;
+
+ // Skip leading space.
+ while (isspace(*str)) ++str;
+
+ int64 vlimit = kint64max;
+ int sign = 1;
+ if (*str == '-') {
+ sign = -1;
+ ++str;
+ // Different limit for positive and negative integers.
+ vlimit = kint64min;
+ }
+
+ if (!isdigit(*str)) return false;
+
+ int64 result = 0;
+ if (sign == 1) {
+ do {
+ int digit = *str - '0';
+ if ((vlimit - digit) / 10 < result) {
+ return false;
+ }
+ result = result * 10 + digit;
+ ++str;
+ } while (isdigit(*str));
+ } else {
+ do {
+ int digit = *str - '0';
+ if ((vlimit + digit) / 10 > result) {
+ return false;
+ }
+ result = result * 10 - digit;
+ ++str;
+ } while (isdigit(*str));
+ }
+
+ // Skip trailing space.
+ while (isspace(*str)) ++str;
+
+ if (*str) return false;
+
+ *value = result;
+ return true;
+}
+
+bool safe_strto32(const char* str, int32* value) {
+ if (!str) return false;
+
+ // Skip leading space.
+ while (isspace(*str)) ++str;
+
+ int64 vmax = kint32max;
+ int sign = 1;
+ if (*str == '-') {
+ sign = -1;
+ ++str;
+ // Different max for positive and negative integers.
+ ++vmax;
+ }
+
+ if (!isdigit(*str)) return false;
+
+ int64 result = 0;
+ do {
+ result = result * 10 + *str - '0';
+ if (result > vmax) {
+ return false;
+ }
+ ++str;
+ } while (isdigit(*str));
+
+ // Skip trailing space.
+ while (isspace(*str)) ++str;
+
+ if (*str) return false;
+
+ *value = result * sign;
+ return true;
+}
+
+bool safe_strtof(const char* str, float* value) {
+ char* endptr;
+ *value = strtof(str, &endptr);
+ while (isspace(*endptr)) ++endptr;
+ // Ignore range errors from strtod/strtof.
+ // The values it returns on underflow and
+ // overflow are the right fallback in a
+ // robust setting.
+ return *str != '\0' && *endptr == '\0';
+}
+
+char* FloatToBuffer(float value, char* buffer) {
+ // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
+ // platforms these days. Just in case some system exists where FLT_DIG
+ // is significantly larger -- and risks overflowing our buffer -- we have
+ // this assert.
+ static_assert(FLT_DIG < 10, "FLT_DIG is too big");
+
+ int snprintf_result =
+ snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG, value);
+
+ // The snprintf should never overflow because the buffer is significantly
+ // larger than the precision we asked for.
+ DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
+
+ float parsed_value;
+ if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) {
+ snprintf_result =
+ snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 2, value);
+
+ // Should never overflow; see above.
+ DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
+ }
+ return buffer;
+}
+
+string FpToString(Fprint fp) {
+ char buf[17];
+ snprintf(buf, sizeof(buf), "%016llx", static_cast<uint64>(fp));
+ return string(buf);
+}
+
+bool StringToFp(const string& s, Fprint* fp) {
+ char junk;
+ uint64 result;
+ if (sscanf(s.c_str(), "%llx%c", &result, &junk) == 1) {
+ *fp = result;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+string HumanReadableNumBytes(int64 num_bytes) {
+ if (num_bytes == kint64min) {
+ // Special case for number with not representable negation.
+ return "-8E";
+ }
+
+ const char* neg_str = (num_bytes < 0) ? "-" : "";
+ if (num_bytes < 0) {
+ num_bytes = -num_bytes;
+ }
+
+ // Special case for bytes.
+ if (num_bytes < 1024) {
+ // No fractions for bytes.
+ char buf[8]; // Longest possible string is '-XXXXB'
+ snprintf(buf, sizeof(buf), "%s%lldB", neg_str,
+ static_cast<int64>(num_bytes));
+ return string(buf);
+ }
+
+ static const char units[] = "KMGTPE"; // int64 only goes up to E.
+ const char* unit = units;
+ while (num_bytes >= static_cast<int64>(1024) * 1024) {
+ num_bytes /= 1024;
+ ++unit;
+ CHECK(unit < units + TF_ARRAYSIZE(units));
+ }
+
+ // We use SI prefixes.
+ char buf[16];
+ snprintf(buf, sizeof(buf), ((*unit == 'K') ? "%s%.1f%ciB" : "%s%.2f%ciB"),
+ neg_str, num_bytes / 1024.0, *unit);
+ return string(buf);
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
new file mode 100644
index 0000000000..a30a862279
--- /dev/null
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -0,0 +1,92 @@
+#ifndef TENSORFLOW_LIB_STRINGS_NUMBERS_H_
+#define TENSORFLOW_LIB_STRINGS_NUMBERS_H_
+
+#include <string>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace strings {
+
+// ----------------------------------------------------------------------
+// FastIntToBufferLeft()
+// These are intended for speed.
+//
+// All functions take the output buffer as an arg. FastInt() uses
+// at most 22 bytes, FastTime() uses exactly 30 bytes. They all
+// return a pointer to the beginning of the output, which is the same as
+// the beginning of the input buffer.
+//
+// NOTE: In 64-bit land, sizeof(time_t) is 8, so it is possible
+// to pass to FastTimeToBuffer() a time whose year cannot be
+// represented in 4 digits. In this case, the output buffer
+// will contain the string "Invalid:<value>"
+// ----------------------------------------------------------------------
+
+// Previously documented minimums -- the buffers provided must be at least this
+// long, though these numbers are subject to change:
+// Int32, UInt32: 12 bytes
+// Int64, UInt64, Int, Uint: 22 bytes
+// Time: 30 bytes
+// Use kFastToBufferSize rather than hardcoding constants.
+static const int kFastToBufferSize = 32;
+
+// ----------------------------------------------------------------------
+// FastInt32ToBufferLeft()
+// FastUInt32ToBufferLeft()
+// FastInt64ToBufferLeft()
+// FastUInt64ToBufferLeft()
+//
+// These functions convert their numeric argument to an ASCII
+// representation of the numeric value in base 10, with the
+// representation being left-aligned in the buffer. The caller is
+// responsible for ensuring that the buffer has enough space to hold
+// the output. The buffer should typically be at least kFastToBufferSize
+// bytes.
+//
+// Returns a pointer to the end of the string (i.e. the null character
+// terminating the string).
+// ----------------------------------------------------------------------
+
+char* FastInt32ToBufferLeft(int32 i, char* buffer); // at least 12 bytes
+char* FastUInt32ToBufferLeft(uint32 i, char* buffer); // at least 12 bytes
+char* FastInt64ToBufferLeft(int64 i, char* buffer); // at least 22 bytes
+char* FastUInt64ToBufferLeft(uint64 i, char* buffer); // at least 22 bytes
+
+// Required buffer size for DoubleToBuffer is kFastToBufferSize.
+// Required buffer size for FloatToBuffer is kFastToBufferSize.
+char* DoubleToBuffer(double i, char* buffer);
+char* FloatToBuffer(float i, char* buffer);
+
+// Convert a 64-bit fingerprint value to an ASCII representation.
+string FpToString(Fprint fp);
+
+// Attempt to parse a fingerprint in the form encoded by FpToString. If
+// successsful, stores the fingerprint in *fp and returns true. Otherwise,
+// returns false.
+bool StringToFp(const string& s, Fprint* fp);
+
+// Convert strings to 32bit integer values.
+// Leading and trailing spaces are allowed.
+// Return false with overflow or invalid input.
+bool safe_strto32(const char* str, int32* value);
+
+// Convert strings to 64bit integer values.
+// Leading and trailing spaces are allowed.
+// Return false with overflow or invalid input.
+bool safe_strto64(const char* str, int64* value);
+
+// Convert strings to floating point values.
+// Leading and trailing spaces are allowed.
+// Values may be rounded on over- and underflow.
+bool safe_strtof(const char* str, float* value);
+
+// Converts from an int64 representing a number of bytes to a
+// human readable string representing the same number.
+// e.g. 12345678 -> "11.77MiB".
+string HumanReadableNumBytes(int64 num_bytes);
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_NUMBERS_H_
diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc
new file mode 100644
index 0000000000..b178e6af53
--- /dev/null
+++ b/tensorflow/core/lib/strings/numbers_test.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/core/lib/strings/numbers.h"
+
+#include <string>
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace strings {
+
+// NOTE: most of the routines in numbers.h are tested indirectly through
+// strcat_test.cc in this directory.
+
+// Test StrCat of ints and longs of various sizes and signdedness.
+TEST(FpToString, Ints) {
+ for (int s = 0; s < 64; s++) {
+ for (int delta = -1; delta <= 1; delta++) {
+ uint64 fp = (1ull << s) + delta;
+ string s = FpToString(fp);
+ uint64 fp2;
+ EXPECT_TRUE(StringToFp(s, &fp2));
+ EXPECT_EQ(fp, fp2);
+ }
+ }
+ Fprint dummy;
+ EXPECT_FALSE(StringToFp("", &dummy));
+ EXPECT_FALSE(StringToFp("xyz", &dummy));
+ EXPECT_FALSE(StringToFp("0000000000000000xyz", &dummy));
+}
+
+TEST(HumanReadableNumBytes, Bytes) {
+ EXPECT_EQ("0B", HumanReadableNumBytes(0));
+ EXPECT_EQ("4B", HumanReadableNumBytes(4));
+ EXPECT_EQ("1023B", HumanReadableNumBytes(1023));
+
+ EXPECT_EQ("1.0KiB", HumanReadableNumBytes(1024));
+ EXPECT_EQ("1.0KiB", HumanReadableNumBytes(1025));
+ EXPECT_EQ("1.5KiB", HumanReadableNumBytes(1500));
+ EXPECT_EQ("1.9KiB", HumanReadableNumBytes(1927));
+
+ EXPECT_EQ("2.0KiB", HumanReadableNumBytes(2048));
+ EXPECT_EQ("1.00MiB", HumanReadableNumBytes(1 << 20));
+ EXPECT_EQ("11.77MiB", HumanReadableNumBytes(12345678));
+ EXPECT_EQ("1.00GiB", HumanReadableNumBytes(1 << 30));
+
+ EXPECT_EQ("1.00TiB", HumanReadableNumBytes(1LL << 40));
+ EXPECT_EQ("1.00PiB", HumanReadableNumBytes(1LL << 50));
+ EXPECT_EQ("1.00EiB", HumanReadableNumBytes(1LL << 60));
+
+ // Try a few negative numbers
+ EXPECT_EQ("-1B", HumanReadableNumBytes(-1));
+ EXPECT_EQ("-4B", HumanReadableNumBytes(-4));
+ EXPECT_EQ("-1000B", HumanReadableNumBytes(-1000));
+ EXPECT_EQ("-11.77MiB", HumanReadableNumBytes(-12345678));
+ EXPECT_EQ("-8E", HumanReadableNumBytes(kint64min));
+}
+
+TEST(safe_strto32, Int32s) {
+ int32 result;
+
+ EXPECT_EQ(true, safe_strto32("1", &result));
+ EXPECT_EQ(1, result);
+ EXPECT_EQ(true, safe_strto32("123", &result));
+ EXPECT_EQ(123, result);
+ EXPECT_EQ(true, safe_strto32(" -123 ", &result));
+ EXPECT_EQ(-123, result);
+ EXPECT_EQ(true, safe_strto32("2147483647", &result));
+ EXPECT_EQ(2147483647, result);
+ EXPECT_EQ(true, safe_strto32("-2147483648", &result));
+ EXPECT_EQ(-2147483648, result);
+
+ // Invalid argument
+ EXPECT_EQ(false, safe_strto32(" 132as ", &result));
+ EXPECT_EQ(false, safe_strto32(" 132.2 ", &result));
+ EXPECT_EQ(false, safe_strto32(" -", &result));
+ EXPECT_EQ(false, safe_strto32("", &result));
+ EXPECT_EQ(false, safe_strto32(" ", &result));
+ EXPECT_EQ(false, safe_strto32("123 a", &result));
+
+ // Overflow
+ EXPECT_EQ(false, safe_strto32("2147483648", &result));
+ EXPECT_EQ(false, safe_strto32("-2147483649", &result));
+}
+
+TEST(safe_strto64, Int64s) {
+ int64 result;
+
+ EXPECT_EQ(true, safe_strto64("1", &result));
+ EXPECT_EQ(1, result);
+ EXPECT_EQ(true, safe_strto64("123", &result));
+ EXPECT_EQ(123, result);
+ EXPECT_EQ(true, safe_strto64(" -123 ", &result));
+ EXPECT_EQ(-123, result);
+ EXPECT_EQ(true, safe_strto64("9223372036854775807", &result));
+ EXPECT_EQ(9223372036854775807, result);
+ EXPECT_EQ(true, safe_strto64("-9223372036854775808", &result));
+ // kint64min == -9223372036854775808
+ // Use -9223372036854775808 directly results in out of range error
+ EXPECT_EQ(kint64min, result);
+
+ // Invalid argument
+ EXPECT_EQ(false, safe_strto64(" 132as ", &result));
+ EXPECT_EQ(false, safe_strto64(" 132.2 ", &result));
+ EXPECT_EQ(false, safe_strto64(" -", &result));
+ EXPECT_EQ(false, safe_strto64("", &result));
+ EXPECT_EQ(false, safe_strto64(" ", &result));
+ EXPECT_EQ(false, safe_strto64("123 a", &result));
+
+ // Overflow
+ EXPECT_EQ(false, safe_strto64("9223372036854775808", &result));
+ EXPECT_EQ(false, safe_strto64("-9223372036854775809", &result));
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/ordered_code.cc b/tensorflow/core/lib/strings/ordered_code.cc
new file mode 100644
index 0000000000..ec67595ebb
--- /dev/null
+++ b/tensorflow/core/lib/strings/ordered_code.cc
@@ -0,0 +1,515 @@
+#include "tensorflow/core/lib/strings/ordered_code.h"
+
+#include <assert.h>
+#include <stddef.h>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace strings {
+
+// We encode a string in different ways depending on whether the item
+// should be in lexicographically increasing or decreasing order.
+//
+//
+// Lexicographically increasing order
+//
+// We want a string-to-string mapping F(x) such that for any two strings
+//
+// x < y => F(x) < F(y)
+//
+// In addition to the normal characters '\x00' through '\xff', we want to
+// encode a few extra symbols in strings:
+//
+// <sep> Separator between items
+// <infinity> Infinite string
+//
+// Therefore we need an alphabet with at least 258 symbols. Each
+// character '\1' through '\xfe' is mapped to itself. The other four are
+// encoded into two-letter sequences starting with '\0' and '\xff':
+//
+// <sep> encoded as => \0\1
+// \0 encoded as => \0\xff
+// \xff encoded as => \xff\x00
+// <infinity> encoded as => \xff\xff
+//
+// The remaining two-letter sequences starting with '\0' and '\xff' are
+// currently unused.
+//
+// F(<infinity>) is defined above. For any finite string x, F(x) is the
+// the encodings of x's characters followed by the encoding for <sep>. The
+// ordering of two finite strings is the same as the ordering of the
+// respective characters at the first position where they differ, which in
+// turn is the same as the ordering of the encodings of those two
+// characters. Moreover, for every finite string x, F(x) < F(<infinity>).
+//
+//
+// Lexicographically decreasing order
+//
+// We want a string-to-string mapping G(x) such that for any two strings,
+// whether finite or not,
+//
+// x < y => G(x) > G(y)
+//
+// To achieve this, define G(x) to be the inversion of F(x): I(F(x)). In
+// other words, invert every bit in F(x) to get G(x). For example,
+//
+// x = \x00\x13\xff
+// F(x) = \x00\xff\x13\xff\x00\x00\x01 escape \0, \xff, append F(<sep>)
+// G(x) = \xff\x00\xec\x00\xff\xff\xfe invert every bit in F(x)
+//
+// x = <infinity>
+// F(x) = \xff\xff
+// G(x) = \x00\x00
+//
+// Another example is
+//
+// x F(x) G(x) = I(F(x))
+// - ---- --------------
+// <infinity> \xff\xff \x00\x00
+// "foo" foo\0\1 \x99\x90\x90\xff\xfe
+// "aaa" aaa\0\1 \x9e\x9e\x9e\xff\xfe
+// "aa" aa\0\1 \x9e\x9e\xff\xfe
+// "" \0\1 \xff\xfe
+//
+// More generally and rigorously, if for any two strings x and y
+//
+// F(x) < F(y) => I(F(x)) > I(F(y)) (1)
+//
+// it would follow that x < y => G(x) > G(y) because
+//
+// x < y => F(x) < F(y) => G(x) = I(F(x)) > I(F(y)) = G(y)
+//
+// We now show why (1) is true, in two parts. Notice that for any two
+// strings x < y, F(x) is *not* a proper prefix of F(y). Suppose x is a
+// proper prefix of y (say, x="abc" < y="abcd"). F(x) and F(y) diverge at
+// the F(<sep>) in F(x) (v. F('d') in the example). Suppose x is not a
+// proper prefix of y (say, x="abce" < y="abd"), F(x) and F(y) diverge at
+// their respective encodings of the characters where x and y diverge
+// (F('c') v. F('d')). Finally, if y=<infinity>, we can see that
+// F(y)=\xff\xff is not the prefix of F(x) for any finite string x, simply
+// by considering all the possible first characters of F(x).
+//
+// Given that F(x) is not a proper prefix F(y), the order of F(x) and F(y)
+// is determined by the byte where F(x) and F(y) diverge. For example, the
+// order of F(x)="eefh" and F(y)="eeg" is determined by their third
+// characters. I(p) inverts each byte in p, which effectively subtracts
+// each byte from 0xff. So, in this example, I('f') > I('g'), and thus
+// I(F(x)) > I(F(y)).
+//
+//
+// Implementation
+//
+// To implement G(x) efficiently, we use C++ template to instantiate two
+// versions of the code to produce F(x), one for normal encoding (giving us
+// F(x)) and one for inverted encoding (giving us G(x) = I(F(x))).
+
+static const char kEscape1 = '\000';
+static const char kNullCharacter = '\xff'; // Combined with kEscape1
+static const char kSeparator = '\001'; // Combined with kEscape1
+
+static const char kEscape2 = '\xff';
+static const char kInfinity = '\xff'; // Combined with kEscape2
+static const char kFFCharacter = '\000'; // Combined with kEscape2
+
+static const char kEscape1_Separator[2] = {kEscape1, kSeparator};
+
+// Append to "*dest" the "len" bytes starting from "*src".
+inline static void AppendBytes(string* dest, const char* src, int len) {
+ dest->append(src, len);
+}
+
+inline bool IsSpecialByte(char c) { return ((unsigned char)(c + 1)) < 2; }
+
+// Return a pointer to the first byte in the range "[start..limit)"
+// whose value is 0 or 255 (kEscape1 or kEscape2). If no such byte
+// exists in the range, returns "limit".
+inline const char* SkipToNextSpecialByte(const char* start, const char* limit) {
+ // If these constants were ever changed, this routine needs to change
+ DCHECK_EQ(kEscape1, 0);
+ DCHECK_EQ(kEscape2 & 0xffu, 255u);
+ const char* p = start;
+ while (p < limit && !IsSpecialByte(*p)) {
+ p++;
+ }
+ return p;
+}
+
+// Expose SkipToNextSpecialByte for testing purposes
+const char* OrderedCode::TEST_SkipToNextSpecialByte(const char* start,
+ const char* limit) {
+ return SkipToNextSpecialByte(start, limit);
+}
+
+// Helper routine to encode "s" and append to "*dest", escaping special
+// characters.
+inline static void EncodeStringFragment(string* dest, StringPiece s) {
+ const char* p = s.data();
+ const char* limit = p + s.size();
+ const char* copy_start = p;
+ while (true) {
+ p = SkipToNextSpecialByte(p, limit);
+ if (p >= limit) break; // No more special characters that need escaping
+ char c = *(p++);
+ DCHECK(IsSpecialByte(c));
+ if (c == kEscape1) {
+ AppendBytes(dest, copy_start, p - copy_start - 1);
+ dest->push_back(kEscape1);
+ dest->push_back(kNullCharacter);
+ copy_start = p;
+ } else {
+ assert(c == kEscape2);
+ AppendBytes(dest, copy_start, p - copy_start - 1);
+ dest->push_back(kEscape2);
+ dest->push_back(kFFCharacter);
+ copy_start = p;
+ }
+ }
+ if (p > copy_start) {
+ AppendBytes(dest, copy_start, p - copy_start);
+ }
+}
+
+void OrderedCode::WriteString(string* dest, StringPiece s) {
+ EncodeStringFragment(dest, s);
+ AppendBytes(dest, kEscape1_Separator, 2);
+}
+
+void OrderedCode::WriteNumIncreasing(string* dest, uint64 val) {
+ // Values are encoded with a single byte length prefix, followed
+ // by the actual value in big-endian format with leading 0 bytes
+ // dropped.
+ unsigned char buf[9]; // 8 bytes for value plus one byte for length
+ int len = 0;
+ while (val > 0) {
+ len++;
+ buf[9 - len] = (val & 0xff);
+ val >>= 8;
+ }
+ buf[9 - len - 1] = (unsigned char)len;
+ len++;
+ AppendBytes(dest, reinterpret_cast<const char*>(buf + 9 - len), len);
+}
+
+// Parse the encoding of a previously encoded string.
+// If parse succeeds, return true, consume encoding from
+// "*src", and if result != NULL append the decoded string to "*result".
+// Otherwise, return false and leave both undefined.
+inline static bool ReadStringInternal(StringPiece* src, string* result) {
+ const char* start = src->data();
+ const char* string_limit = src->data() + src->size();
+
+ // We only scan up to "limit-2" since a valid string must end with
+ // a two character terminator: 'kEscape1 kSeparator'
+ const char* limit = string_limit - 1;
+ const char* copy_start = start;
+ while (true) {
+ start = SkipToNextSpecialByte(start, limit);
+ if (start >= limit) break; // No terminator sequence found
+ const char c = *(start++);
+ // If inversion is required, instead of inverting 'c', we invert the
+ // character constants to which 'c' is compared. We get the same
+ // behavior but save the runtime cost of inverting 'c'.
+ DCHECK(IsSpecialByte(c));
+ if (c == kEscape1) {
+ if (result) {
+ AppendBytes(result, copy_start, start - copy_start - 1);
+ }
+ // kEscape1 kSeparator ends component
+ // kEscape1 kNullCharacter represents '\0'
+ const char next = *(start++);
+ if (next == kSeparator) {
+ src->remove_prefix(start - src->data());
+ return true;
+ } else if (next == kNullCharacter) {
+ if (result) {
+ *result += '\0';
+ }
+ } else {
+ return false;
+ }
+ copy_start = start;
+ } else {
+ assert(c == kEscape2);
+ if (result) {
+ AppendBytes(result, copy_start, start - copy_start - 1);
+ }
+ // kEscape2 kFFCharacter represents '\xff'
+ // kEscape2 kInfinity is an error
+ const char next = *(start++);
+ if (next == kFFCharacter) {
+ if (result) {
+ *result += '\xff';
+ }
+ } else {
+ return false;
+ }
+ copy_start = start;
+ }
+ }
+ return false;
+}
+
+bool OrderedCode::ReadString(StringPiece* src, string* result) {
+ return ReadStringInternal(src, result);
+}
+
+bool OrderedCode::ReadNumIncreasing(StringPiece* src, uint64* result) {
+ if (src->empty()) {
+ return false; // Not enough bytes
+ }
+
+ // Decode length byte
+ const size_t len = static_cast<unsigned char>((*src)[0]);
+
+ // If len > 0 and src is longer than 1, the first byte of "payload"
+ // must be non-zero (otherwise the encoding is not minimal).
+ // In opt mode, we don't enforce that encodings must be minimal.
+ DCHECK(0 == len || src->size() == 1 || (*src)[1] != '\0')
+ << "invalid encoding";
+
+ if (len + 1 > src->size() || len > 8) {
+ return false; // Not enough bytes or too many bytes
+ }
+
+ if (result) {
+ uint64 tmp = 0;
+ for (size_t i = 0; i < len; i++) {
+ tmp <<= 8;
+ tmp |= static_cast<unsigned char>((*src)[1 + i]);
+ }
+ *result = tmp;
+ }
+ src->remove_prefix(len + 1);
+ return true;
+}
+
+void OrderedCode::TEST_Corrupt(string* str, int k) {
+ int seen_seps = 0;
+ for (size_t i = 0; i + 1 < str->size(); i++) {
+ if ((*str)[i] == kEscape1 && (*str)[i + 1] == kSeparator) {
+ seen_seps++;
+ if (seen_seps == k) {
+ (*str)[i + 1] = kSeparator + 1;
+ return;
+ }
+ }
+ }
+}
+
+// Signed number encoding/decoding /////////////////////////////////////
+//
+// The format is as follows:
+//
+// The first bit (the most significant bit of the first byte)
+// represents the sign, 0 if the number is negative and
+// 1 if the number is >= 0.
+//
+// Any unbroken sequence of successive bits with the same value as the sign
+// bit, up to 9 (the 8th and 9th are the most significant bits of the next
+// byte), are size bits that count the number of bytes after the first byte.
+// That is, the total length is between 1 and 10 bytes.
+//
+// The value occupies the bits after the sign bit and the "size bits"
+// till the end of the string, in network byte order. If the number
+// is negative, the bits are in 2-complement.
+//
+//
+// Example 1: number 0x424242 -> 4 byte big-endian hex string 0xf0424242:
+//
+// +---------------+---------------+---------------+---------------+
+// 1 1 1 1 0 0 0 0 0 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 1 0
+// +---------------+---------------+---------------+---------------+
+// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
+// | | | | payload: the remaining bits after the sign and size bits
+// | | | | and the delimiter bit, the value is 0x424242
+// | | | |
+// | size bits: 3 successive bits with the same value as the sign bit
+// | (followed by a delimiter bit with the opposite value)
+// | mean that there are 3 bytes after the first byte, 4 total
+// |
+// sign bit: 1 means that the number is non-negative
+//
+// Example 2: negative number -0x800 -> 2 byte big-endian hex string 0x3800:
+//
+// +---------------+---------------+
+// 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0
+// +---------------+---------------+
+// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+// | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
+// | | payload: the remaining bits after the sign and size bits and the
+// | | delimiter bit, 2-complement because of the negative sign,
+// | | value is ~0x7ff, represents the value -0x800
+// | |
+// | size bits: 1 bit with the same value as the sign bit
+// | (followed by a delimiter bit with the opposite value)
+// | means that there is 1 byte after the first byte, 2 total
+// |
+// sign bit: 0 means that the number is negative
+//
+//
+// Compared with the simpler unsigned format used for uint64 numbers,
+// this format is more compact for small numbers, namely one byte encodes
+// numbers in the range [-64,64), two bytes cover the range [-2^13,2^13), etc.
+// In general, n bytes encode numbers in the range [-2^(n*7-1),2^(n*7-1)).
+// (The cross-over point for compactness of representation is 8 bytes,
+// where this format only covers the range [-2^55,2^55),
+// whereas an encoding with sign bit and length in the first byte and
+// payload in all following bytes would cover [-2^56,2^56).)
+
+static const int kMaxSigned64Length = 10;
+
+// This array maps encoding length to header bits in the first two bytes.
+static const char kLengthToHeaderBits[1 + kMaxSigned64Length][2] = {
+ {0, 0}, {'\x80', 0}, {'\xc0', 0}, {'\xe0', 0},
+ {'\xf0', 0}, {'\xf8', 0}, {'\xfc', 0}, {'\xfe', 0},
+ {'\xff', 0}, {'\xff', '\x80'}, {'\xff', '\xc0'}};
+
+// This array maps encoding lengths to the header bits that overlap with
+// the payload and need fixing when reading.
+static const uint64 kLengthToMask[1 + kMaxSigned64Length] = {
+ 0ULL,
+ 0x80ULL,
+ 0xc000ULL,
+ 0xe00000ULL,
+ 0xf0000000ULL,
+ 0xf800000000ULL,
+ 0xfc0000000000ULL,
+ 0xfe000000000000ULL,
+ 0xff00000000000000ULL,
+ 0x8000000000000000ULL,
+ 0ULL};
+
+// This array maps the number of bits in a number to the encoding
+// length produced by WriteSignedNumIncreasing.
+// For positive numbers, the number of bits is 1 plus the most significant
+// bit position (the highest bit position in a positive int64 is 63).
+// For a negative number n, we count the bits in ~n.
+// That is, length = kBitsToLength[Bits::Log2Floor64(n < 0 ? ~n : n) + 1].
+static const int8 kBitsToLength[1 + 63] = {
+ 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4,
+ 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 7,
+ 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 10};
+
+#if defined(__GNUC__)
+// Returns floor(lg(n)). Returns -1 if n == 0.
+static int Log2Floor64(uint64 n) {
+ return n == 0 ? -1 : 63 ^ __builtin_clzll(n);
+}
+#else
+// Portable slow version
+static int Log2Floor32_Portable(uint32 n) {
+ if (n == 0) return -1;
+ int log = 0;
+ uint32 value = n;
+ for (int i = 4; i >= 0; --i) {
+ int shift = (1 << i);
+ uint32 x = value >> shift;
+ if (x != 0) {
+ value = x;
+ log += shift;
+ }
+ }
+ assert(value == 1);
+ return log;
+}
+// Returns floor(lg(n)). Returns -1 if n == 0.
+static int Log2Floor64(uint64 n) {
+ const uint32 topbits = static_cast<uint32>(n >> 32);
+ if (topbits == 0) {
+ // Top bits are zero, so scan in bottom bits
+ return Log2Floor32_Portable(static_cast<uint32>(n));
+ } else {
+ return 32 + Log2Floor32_Portable(topbits);
+ }
+}
+#endif
+
+// Calculates the encoding length in bytes of the signed number n.
+static inline int SignedEncodingLength(int64 n) {
+ return kBitsToLength[Log2Floor64(n < 0 ? ~n : n) + 1];
+}
+
+static void StoreBigEndian64(char* dst, uint64 v) {
+ for (int i = 0; i < 8; i++) {
+ dst[i] = (v >> (56 - 8 * i)) & 0xff;
+ }
+}
+
+static uint64 LoadBigEndian64(const char* src) {
+ uint64 result = 0;
+ for (int i = 0; i < 8; i++) {
+ unsigned char c = static_cast<unsigned char>(src[i]);
+ result |= static_cast<uint64>(c) << (56 - 8 * i);
+ }
+ return result;
+}
+
+void OrderedCode::WriteSignedNumIncreasing(string* dest, int64 val) {
+ const uint64 x = val < 0 ? ~val : val;
+ if (x < 64) { // fast path for encoding length == 1
+ *dest += kLengthToHeaderBits[1][0] ^ val;
+ return;
+ }
+ // buf = val in network byte order, sign extended to 10 bytes
+ const char sign_byte = val < 0 ? '\xff' : '\0';
+ char buf[10] = {
+ sign_byte, sign_byte,
+ };
+ StoreBigEndian64(buf + 2, val);
+ static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch");
+ const int len = SignedEncodingLength(x);
+ DCHECK_GE(len, 2);
+ char* const begin = buf + sizeof(buf) - len;
+ begin[0] ^= kLengthToHeaderBits[len][0];
+ begin[1] ^= kLengthToHeaderBits[len][1]; // ok because len >= 2
+ dest->append(begin, len);
+}
+
+bool OrderedCode::ReadSignedNumIncreasing(StringPiece* src, int64* result) {
+ if (src->empty()) return false;
+ const uint64 xor_mask = (!((*src)[0] & 0x80)) ? ~0ULL : 0ULL;
+ const unsigned char first_byte = (*src)[0] ^ (xor_mask & 0xff);
+
+ // now calculate and test length, and set x to raw (unmasked) result
+ int len;
+ uint64 x;
+ if (first_byte != 0xff) {
+ len = 7 - Log2Floor64(first_byte ^ 0xff);
+ if (src->size() < static_cast<size_t>(len)) return false;
+ x = xor_mask; // sign extend using xor_mask
+ for (int i = 0; i < len; ++i)
+ x = (x << 8) | static_cast<unsigned char>((*src)[i]);
+ } else {
+ len = 8;
+ if (src->size() < static_cast<size_t>(len)) return false;
+ const unsigned char second_byte = (*src)[1] ^ (xor_mask & 0xff);
+ if (second_byte >= 0x80) {
+ if (second_byte < 0xc0) {
+ len = 9;
+ } else {
+ const unsigned char third_byte = (*src)[2] ^ (xor_mask & 0xff);
+ if (second_byte == 0xc0 && third_byte < 0x80) {
+ len = 10;
+ } else {
+ return false; // either len > 10 or len == 10 and #bits > 63
+ }
+ }
+ if (src->size() < static_cast<size_t>(len)) return false;
+ }
+ x = LoadBigEndian64(src->data() + len - 8);
+ }
+
+ x ^= kLengthToMask[len]; // remove spurious header bits
+
+ DCHECK_EQ(len, SignedEncodingLength(x)) << "invalid encoding";
+
+ if (result) *result = x;
+ src->remove_prefix(len);
+ return true;
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/ordered_code.h b/tensorflow/core/lib/strings/ordered_code.h
new file mode 100644
index 0000000000..39f1df9a94
--- /dev/null
+++ b/tensorflow/core/lib/strings/ordered_code.h
@@ -0,0 +1,77 @@
+// This module provides routines for encoding a sequence of typed
+// entities into a string. The resulting strings can be
+// lexicographically compared to yield the same comparison value that
+// would have been generated if the encoded items had been compared
+// one by one according to their type.
+//
+// More precisely, suppose:
+// 1. string A is generated by encoding the sequence of items [A_1..A_n]
+// 2. string B is generated by encoding the sequence of items [B_1..B_n]
+// 3. The types match; i.e., for all i: A_i was encoded using
+// the same routine as B_i
+// Then:
+// Comparing A vs. B lexicographically is the same as comparing
+// the vectors [A_1..A_n] and [B_1..B_n] lexicographically.
+//
+// Furthermore, if n < m, the encoding of [A_1..A_n] is a strict prefix of
+// [A_1..A_m] (unless m = n+1 and A_m is the empty string encoded with
+// WriteTrailingString, in which case the encodings are equal).
+//
+// This module is often useful when generating multi-part sstable
+// keys that have to be ordered in a particular fashion.
+
+#ifndef TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__
+#define TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__
+
+#include <string>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+class StringPiece;
+
+namespace strings {
+
+class OrderedCode {
+ public:
+ // -------------------------------------------------------------------
+ // Encoding routines: each one of the following routines append
+ // one item to "*dest" in an encoding where larger values are
+ // ordered lexicographically after smaller values.
+ static void WriteString(string* dest, StringPiece str);
+ static void WriteNumIncreasing(string* dest, uint64 num);
+ static void WriteSignedNumIncreasing(string* dest, int64 num);
+
+ // -------------------------------------------------------------------
+ // Decoding routines: these extract an item earlier encoded using
+ // the corresponding WriteXXX() routines above. The item is read
+ // from "*src"; "*src" is modified to point past the decoded item;
+ // and if "result" is non-NULL, "*result" is modified to contain the
+ // result. In case of string result, the decoded string is appended to
+ // "*result". Returns true if the next item was read successfully, false
+ // otherwise.
+ static bool ReadString(StringPiece* src, string* result);
+ static bool ReadNumIncreasing(StringPiece* src, uint64* result);
+ static bool ReadSignedNumIncreasing(StringPiece* src, int64* result);
+
+ // Helper for testing: corrupt "*str" by changing the kth item separator
+ // in the string.
+ static void TEST_Corrupt(string* str, int k);
+
+ // Helper for testing.
+ // SkipToNextSpecialByte is an internal routine defined in the .cc file
+ // with the following semantics. Return a pointer to the first byte
+ // in the range "[start..limit)" whose value is 0 or 255. If no such
+ // byte exists in the range, returns "limit".
+ static const char* TEST_SkipToNextSpecialByte(const char* start,
+ const char* limit);
+
+ private:
+ // This has only static methods, so disallow construction entirely
+ OrderedCode();
+ TF_DISALLOW_COPY_AND_ASSIGN(OrderedCode);
+};
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__
diff --git a/tensorflow/core/lib/strings/ordered_code_test.cc b/tensorflow/core/lib/strings/ordered_code_test.cc
new file mode 100644
index 0000000000..d517d14f4a
--- /dev/null
+++ b/tensorflow/core/lib/strings/ordered_code_test.cc
@@ -0,0 +1,1183 @@
+#include "tensorflow/core/lib/strings/ordered_code.h"
+
+#include <float.h>
+#include <stddef.h>
+#include <limits>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace strings {
+
+static string RandomString(random::SimplePhilox* rnd, int len) {
+ string x;
+ for (int i = 0; i < len; i++) {
+ x += rnd->Uniform(256);
+ }
+ return x;
+}
+
+// ---------------------------------------------------------------------
+// Utility template functions (they help templatize the tests below)
+
+// Read/WriteIncreasing are defined for string, uint64, int64 below.
+template <typename T>
+static void OCWriteIncreasing(string* dest, const T& val);
+template <typename T>
+static bool OCReadIncreasing(StringPiece* src, T* result);
+
+// Read/WriteIncreasing<string>
+template <>
+void OCWriteIncreasing<string>(string* dest, const string& val) {
+ OrderedCode::WriteString(dest, val);
+}
+template <>
+bool OCReadIncreasing<string>(StringPiece* src, string* result) {
+ return OrderedCode::ReadString(src, result);
+}
+
+// Read/WriteIncreasing<uint64>
+template <>
+void OCWriteIncreasing<uint64>(string* dest, const uint64& val) {
+ OrderedCode::WriteNumIncreasing(dest, val);
+}
+template <>
+bool OCReadIncreasing<uint64>(StringPiece* src, uint64* result) {
+ return OrderedCode::ReadNumIncreasing(src, result);
+}
+
+// Read/WriteIncreasing<int64>
+template <>
+void OCWriteIncreasing<int64>(string* dest, const int64& val) {
+ OrderedCode::WriteSignedNumIncreasing(dest, val);
+}
+template <>
+bool OCReadIncreasing<int64>(StringPiece* src, int64* result) {
+ return OrderedCode::ReadSignedNumIncreasing(src, result);
+}
+
+template <typename T>
+string OCWrite(T val) {
+ string result;
+ OCWriteIncreasing<T>(&result, val);
+ return result;
+}
+
+template <typename T>
+void OCWriteToString(string* result, T val) {
+ OCWriteIncreasing<T>(result, val);
+}
+
+template <typename T>
+bool OCRead(StringPiece* s, T* val) {
+ return OCReadIncreasing<T>(s, val);
+}
+
+// ---------------------------------------------------------------------
+// Numbers
+
+template <typename T>
+static T TestRead(const string& a) {
+ // gracefully reject any proper prefix of an encoding
+ for (int i = 0; i < a.size() - 1; ++i) {
+ StringPiece s(a.data(), i);
+ CHECK(!OCRead<T>(&s, NULL));
+ CHECK_EQ(s, a.substr(0, i));
+ }
+
+ StringPiece s(a);
+ T v;
+ CHECK(OCRead<T>(&s, &v));
+ CHECK(s.empty());
+ return v;
+}
+
+template <typename T>
+static void TestWriteRead(T expected) {
+ EXPECT_EQ(expected, TestRead<T>(OCWrite<T>(expected)));
+}
+
+// Verifies that the second Write* call appends a non-empty string to its
+// output.
+template <typename T, typename U>
+static void TestWriteAppends(T first, U second) {
+ string encoded;
+ OCWriteToString<T>(&encoded, first);
+ string encoded_first_only = encoded;
+ OCWriteToString<U>(&encoded, second);
+ EXPECT_NE(encoded, encoded_first_only);
+ EXPECT_TRUE(StringPiece(encoded).starts_with(encoded_first_only));
+}
+
+template <typename T>
+static void TestNumbers(T multiplier) {
+ // first test powers of 2 (and nearby numbers)
+ for (T x = std::numeric_limits<T>().max(); x != 0; x /= 2) {
+ TestWriteRead(multiplier * (x - 1));
+ TestWriteRead(multiplier * x);
+ if (x != std::numeric_limits<T>::max()) {
+ TestWriteRead(multiplier * (x + 1));
+ } else if (multiplier < 0 && multiplier == -1) {
+ TestWriteRead(-x - 1);
+ }
+ }
+
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int bits = 1; bits <= std::numeric_limits<T>().digits; ++bits) {
+ // test random non-negative numbers with given number of significant bits
+ const uint64 mask = (~0ULL) >> (64 - bits);
+ for (int i = 0; i < 1000; i++) {
+ T x = rnd.Rand64() & mask;
+ TestWriteRead(multiplier * x);
+ T y = rnd.Rand64() & mask;
+ TestWriteAppends(multiplier * x, multiplier * y);
+ }
+ }
+}
+
+// Return true iff 'a' is "before" 'b'
+static bool CompareStrings(const string& a, const string& b) { return (a < b); }
+
+template <typename T>
+static void TestNumberOrdering() {
+ // first the negative numbers (if T is signed, otherwise no-op)
+ string laststr = OCWrite<T>(std::numeric_limits<T>().min());
+ for (T num = std::numeric_limits<T>().min() / 2; num != 0; num /= 2) {
+ string strminus1 = OCWrite<T>(num - 1);
+ string str = OCWrite<T>(num);
+ string strplus1 = OCWrite<T>(num + 1);
+
+ CHECK(CompareStrings(strminus1, str));
+ CHECK(CompareStrings(str, strplus1));
+
+ // Compare 'str' with 'laststr'. When we approach 0, 'laststr' is
+ // not necessarily before 'strminus1'.
+ CHECK(CompareStrings(laststr, str));
+ laststr = str;
+ }
+
+ // then the positive numbers
+ laststr = OCWrite<T>(0);
+ T num = 1;
+ while (num < std::numeric_limits<T>().max() / 2) {
+ num *= 2;
+ string strminus1 = OCWrite<T>(num - 1);
+ string str = OCWrite<T>(num);
+ string strplus1 = OCWrite<T>(num + 1);
+
+ CHECK(CompareStrings(strminus1, str));
+ CHECK(CompareStrings(str, strplus1));
+
+ // Compare 'str' with 'laststr'.
+ CHECK(CompareStrings(laststr, str));
+ laststr = str;
+ }
+}
+
+// Helper routine for testing TEST_SkipToNextSpecialByte
+static int FindSpecial(const string& x) {
+ const char* p = x.data();
+ const char* limit = p + x.size();
+ const char* result = OrderedCode::TEST_SkipToNextSpecialByte(p, limit);
+ return result - p;
+}
+
+TEST(OrderedCode, SkipToNextSpecialByte) {
+ for (size_t len = 0; len < 256; len++) {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ string x;
+ while (x.size() < len) {
+ char c = 1 + rnd.Uniform(254);
+ ASSERT_NE(c, 0);
+ ASSERT_NE(c, 255);
+ x += c; // No 0 bytes, no 255 bytes
+ }
+ EXPECT_EQ(FindSpecial(x), x.size());
+ for (size_t special_pos = 0; special_pos < len; special_pos++) {
+ for (size_t special_test = 0; special_test < 2; special_test++) {
+ const char special_byte = (special_test == 0) ? 0 : 255;
+ string y = x;
+ y[special_pos] = special_byte;
+ EXPECT_EQ(FindSpecial(y), special_pos);
+ if (special_pos < 16) {
+ // Add some special bytes after the one at special_pos to make sure
+ // we still return the earliest special byte in the string
+ for (size_t rest = special_pos + 1; rest < len; rest++) {
+ if (rnd.OneIn(3)) {
+ y[rest] = rnd.OneIn(2) ? 0 : 255;
+ EXPECT_EQ(FindSpecial(y), special_pos);
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(OrderedCode, ExhaustiveFindSpecial) {
+ char buf[16];
+ char* limit = buf + sizeof(buf);
+ int count = 0;
+ for (int start_offset = 0; start_offset <= 5; start_offset += 5) {
+ // We test exhaustively with all combinations of 3 bytes starting
+ // at offset 0 and offset 5 (so as to test with the bytes at both
+ // ends of a 64-bit word).
+ for (size_t i = 0; i < sizeof(buf); i++) {
+ buf[i] = 'a'; // Not a special byte
+ }
+ for (int b0 = 0; b0 < 256; b0++) {
+ for (int b1 = 0; b1 < 256; b1++) {
+ for (int b2 = 0; b2 < 256; b2++) {
+ buf[start_offset + 0] = b0;
+ buf[start_offset + 1] = b1;
+ buf[start_offset + 2] = b2;
+ char* expected;
+ if (b0 == 0 || b0 == 255) {
+ expected = &buf[start_offset];
+ } else if (b1 == 0 || b1 == 255) {
+ expected = &buf[start_offset + 1];
+ } else if (b2 == 0 || b2 == 255) {
+ expected = &buf[start_offset + 2];
+ } else {
+ expected = limit;
+ }
+ count++;
+ EXPECT_EQ(expected,
+ OrderedCode::TEST_SkipToNextSpecialByte(buf, limit));
+ }
+ }
+ }
+ }
+ EXPECT_EQ(count, 256 * 256 * 256 * 2);
+}
+
+TEST(Uint64, EncodeDecode) { TestNumbers<uint64>(1); }
+
+TEST(Uint64, Ordering) { TestNumberOrdering<uint64>(); }
+
+TEST(Int64, EncodeDecode) {
+ TestNumbers<int64>(1);
+ TestNumbers<int64>(-1);
+}
+
+TEST(Int64, Ordering) { TestNumberOrdering<int64>(); }
+
+// Returns the bitwise complement of s.
+static inline string StrNot(const string& s) {
+ string result;
+ for (string::const_iterator it = s.begin(); it != s.end(); ++it)
+ result.push_back(~*it);
+ return result;
+}
+
+template <typename T>
+static void TestInvalidEncoding(const string& s) {
+ StringPiece p(s);
+ EXPECT_FALSE(OCRead<T>(&p, static_cast<T*>(NULL)));
+ EXPECT_EQ(s, p);
+}
+
+TEST(OrderedCodeInvalidEncodingsTest, Overflow) {
+ // 1U << 64, increasing and decreasing
+ const string k2xx64U = "\x09\x01" + string(8, 0);
+ TestInvalidEncoding<uint64>(k2xx64U);
+
+ // 1 << 63 and ~(1 << 63), increasing and decreasing
+ const string k2xx63 = "\xff\xc0\x80" + string(7, 0);
+ TestInvalidEncoding<int64>(k2xx63);
+ TestInvalidEncoding<int64>(StrNot(k2xx63));
+}
+
+TEST(OrderedCodeInvalidEncodingsDeathTest, NonCanonical) {
+ // Test "ambiguous"/"non-canonical" encodings.
+ // These are non-minimal (but otherwise "valid") encodings that
+ // differ from the minimal encoding chosen by OrderedCode::WriteXXX
+ // and thus should be avoided to not mess up the string ordering of
+ // encodings.
+
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+
+ for (int n = 2; n <= 9; ++n) {
+ // The zero in non_minimal[1] is "redundant".
+ string non_minimal =
+ string(1, n - 1) + string(1, 0) + RandomString(&rnd, n - 2);
+ EXPECT_EQ(n, non_minimal.length());
+
+ EXPECT_NE(OCWrite<uint64>(0), non_minimal);
+#ifndef NDEBUG
+ StringPiece s(non_minimal);
+ EXPECT_DEATH(OrderedCode::ReadNumIncreasing(&s, NULL), "invalid encoding");
+#else
+ TestRead<uint64>(non_minimal);
+#endif
+ }
+
+ for (int n = 2; n <= 10; ++n) {
+ // Header with 1 sign bit and n-1 size bits.
+ string header = string(n / 8, 0xff) + string(1, 0xff << (8 - (n % 8)));
+ // There are more than 7 zero bits between header bits and "payload".
+ string non_minimal = header +
+ string(1, rnd.Uniform(256) & ~*header.rbegin()) +
+ RandomString(&rnd, n - header.length() - 1);
+ EXPECT_EQ(n, non_minimal.length());
+
+ EXPECT_NE(OCWrite<int64>(0), non_minimal);
+#ifndef NDEBUG
+ StringPiece s(non_minimal);
+ EXPECT_DEATH(OrderedCode::ReadSignedNumIncreasing(&s, NULL),
+ "invalid encoding")
+ << n;
+#else
+ TestRead<int64>(non_minimal);
+#endif
+ }
+}
+
+// Returns random number with specified number of bits,
+// i.e., in the range [2^(bits-1),2^bits).
+static uint64 NextBits(random::SimplePhilox* rnd, int bits) {
+ return (bits != 0)
+ ? (rnd->Rand64() % (1LL << (bits - 1))) + (1LL << (bits - 1))
+ : 0;
+}
+
+template <typename T>
+static void BM_WriteNum(int n, T multiplier) {
+ static const int kValues = 64;
+ T values[kValues];
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ // Use enough distinct values to confuse the branch predictor
+ for (int i = 0; i < kValues; i++) {
+ values[i] = NextBits(&rnd, n % 64) * multiplier;
+ }
+ string result;
+ int index = 0;
+ while (n-- > 0) {
+ result.clear();
+ OCWriteToString<T>(&result, values[index % kValues]);
+ index++;
+ }
+}
+
+template <typename T>
+static void BM_ReadNum(int n, T multiplier) {
+ string x;
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ // Use enough distinct values to confuse the branch predictor
+ static const int kValues = 64;
+ string values[kValues];
+ for (int i = 0; i < kValues; i++) {
+ T val = NextBits(&rnd, i % 64) * multiplier;
+ values[i] = OCWrite<T>(val);
+ }
+ uint32 index = 0;
+ while (n-- > 0) {
+ T val;
+ StringPiece s = values[index++ % kValues];
+ OCRead<T>(&s, &val);
+ }
+}
+
+#define BENCHMARK_NUM(name, T, multiplier) \
+ static void BM_Write##name(int n) { BM_WriteNum<T>(n, multiplier); } \
+ BENCHMARK(BM_Write##name); \
+ static void BM_Read##name(int n) { BM_ReadNum<T>(n, multiplier); } \
+ BENCHMARK(BM_Read##name)
+
+BENCHMARK_NUM(NumIncreasing, uint64, 1);
+BENCHMARK_NUM(SignedNum, int64, 1);
+BENCHMARK_NUM(SignedNumNegative, int64, -1);
+
+#undef BENCHMARK_NUM
+
+// ---------------------------------------------------------------------
+// Strings
+
+TEST(String, EncodeDecode) {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+
+ for (int len = 0; len < 256; len++) {
+ const string a = RandomString(&rnd, len);
+ TestWriteRead(a);
+ for (int len2 = 0; len2 < 64; len2++) {
+ const string b = RandomString(&rnd, len2);
+
+ TestWriteAppends(a, b);
+
+ string out;
+ OCWriteToString<string>(&out, a);
+ OCWriteToString<string>(&out, b);
+
+ string a2, b2, dummy;
+ StringPiece s = out;
+ StringPiece s2 = out;
+ CHECK(OCRead<string>(&s, &a2));
+ CHECK(OCRead<string>(&s2, NULL));
+ CHECK_EQ(s, s2);
+
+ CHECK(OCRead<string>(&s, &b2));
+ CHECK(OCRead<string>(&s2, NULL));
+ CHECK_EQ(s, s2);
+
+ CHECK(!OCRead<string>(&s, &dummy));
+ CHECK(!OCRead<string>(&s2, NULL));
+ CHECK_EQ(a, a2);
+ CHECK_EQ(b, b2);
+ CHECK(s.empty());
+ CHECK(s2.empty());
+ }
+ }
+}
+
+// 'str' is a static C-style string that may contain '\0'
+#define STATIC_STR(str) StringPiece((str), sizeof(str) - 1)
+
+static string EncodeStringIncreasing(StringPiece value) {
+ string encoded;
+ OrderedCode::WriteString(&encoded, value);
+ return encoded;
+}
+
+TEST(String, Increasing) {
+ // Here are a series of strings in non-decreasing order, including
+ // consecutive strings such that the second one is equal to, a proper
+ // prefix of, or has the same length as the first one. Most also contain
+ // the special escaping characters '\x00' and '\xff'.
+ ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("")),
+ EncodeStringIncreasing(STATIC_STR("")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("")),
+ EncodeStringIncreasing(STATIC_STR("\x00")));
+
+ ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("\x00")),
+ EncodeStringIncreasing(STATIC_STR("\x00")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\x00")),
+ EncodeStringIncreasing(STATIC_STR("\x01")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\x01")),
+ EncodeStringIncreasing(STATIC_STR("a")));
+
+ ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("a")),
+ EncodeStringIncreasing(STATIC_STR("a")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("a")),
+ EncodeStringIncreasing(STATIC_STR("aa")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("aa")),
+ EncodeStringIncreasing(STATIC_STR("\xff")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\xff")),
+ EncodeStringIncreasing(STATIC_STR("\xff\x00")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\xff\x00")),
+ EncodeStringIncreasing(STATIC_STR("\xff\x01")));
+}
+
+TEST(EncodingIsExpected, String) {
+ std::vector<std::pair<string, string>> data = {
+ {"", string("\x00\x01", 2)},
+ {"foo", string("foo\x00\x01", 5)},
+ {"hello", string("hello\x00\x01", 7)},
+ {string("\x00\x01\xff", 3), string("\x00\xff\x01\xff\x00\x00\x01", 7)},
+ };
+ for (const auto& t : data) {
+ string result;
+ OrderedCode::WriteString(&result, t.first);
+ EXPECT_EQ(t.second, result);
+
+ StringPiece in = result;
+ string decoded;
+ EXPECT_TRUE(OrderedCode::ReadString(&in, &decoded));
+ EXPECT_EQ(t.first, decoded);
+ EXPECT_EQ("", in);
+ }
+}
+
+TEST(EncodingIsExpected, Unsigned) {
+ std::vector<std::pair<uint64, string>> data = {
+ {0x0ull, string("\000", 1)},
+ {0x1ull, string("\001\001", 2)},
+ {0x2ull, string("\001\002", 2)},
+ {0x1ull, string("\001\001", 2)},
+ {0x2ull, string("\001\002", 2)},
+ {0x3ull, string("\001\003", 2)},
+ {0x3ull, string("\001\003", 2)},
+ {0x4ull, string("\001\004", 2)},
+ {0x5ull, string("\001\005", 2)},
+ {0x7ull, string("\001\007", 2)},
+ {0x8ull, string("\001\010", 2)},
+ {0x9ull, string("\001\t", 2)},
+ {0xfull, string("\001\017", 2)},
+ {0x10ull, string("\001\020", 2)},
+ {0x11ull, string("\001\021", 2)},
+ {0x1full, string("\001\037", 2)},
+ {0x20ull, string("\001 ", 2)},
+ {0x21ull, string("\001!", 2)},
+ {0x3full, string("\001?", 2)},
+ {0x40ull, string("\001@", 2)},
+ {0x41ull, string("\001A", 2)},
+ {0x7full, string("\001\177", 2)},
+ {0x80ull, string("\001\200", 2)},
+ {0x81ull, string("\001\201", 2)},
+ {0xffull, string("\001\377", 2)},
+ {0x100ull, string("\002\001\000", 3)},
+ {0x101ull, string("\002\001\001", 3)},
+ {0x1ffull, string("\002\001\377", 3)},
+ {0x200ull, string("\002\002\000", 3)},
+ {0x201ull, string("\002\002\001", 3)},
+ {0x3ffull, string("\002\003\377", 3)},
+ {0x400ull, string("\002\004\000", 3)},
+ {0x401ull, string("\002\004\001", 3)},
+ {0x7ffull, string("\002\007\377", 3)},
+ {0x800ull, string("\002\010\000", 3)},
+ {0x801ull, string("\002\010\001", 3)},
+ {0xfffull, string("\002\017\377", 3)},
+ {0x1000ull, string("\002\020\000", 3)},
+ {0x1001ull, string("\002\020\001", 3)},
+ {0x1fffull, string("\002\037\377", 3)},
+ {0x2000ull, string("\002 \000", 3)},
+ {0x2001ull, string("\002 \001", 3)},
+ {0x3fffull, string("\002?\377", 3)},
+ {0x4000ull, string("\002@\000", 3)},
+ {0x4001ull, string("\002@\001", 3)},
+ {0x7fffull, string("\002\177\377", 3)},
+ {0x8000ull, string("\002\200\000", 3)},
+ {0x8001ull, string("\002\200\001", 3)},
+ {0xffffull, string("\002\377\377", 3)},
+ {0x10000ull, string("\003\001\000\000", 4)},
+ {0x10001ull, string("\003\001\000\001", 4)},
+ {0x1ffffull, string("\003\001\377\377", 4)},
+ {0x20000ull, string("\003\002\000\000", 4)},
+ {0x20001ull, string("\003\002\000\001", 4)},
+ {0x3ffffull, string("\003\003\377\377", 4)},
+ {0x40000ull, string("\003\004\000\000", 4)},
+ {0x40001ull, string("\003\004\000\001", 4)},
+ {0x7ffffull, string("\003\007\377\377", 4)},
+ {0x80000ull, string("\003\010\000\000", 4)},
+ {0x80001ull, string("\003\010\000\001", 4)},
+ {0xfffffull, string("\003\017\377\377", 4)},
+ {0x100000ull, string("\003\020\000\000", 4)},
+ {0x100001ull, string("\003\020\000\001", 4)},
+ {0x1fffffull, string("\003\037\377\377", 4)},
+ {0x200000ull, string("\003 \000\000", 4)},
+ {0x200001ull, string("\003 \000\001", 4)},
+ {0x3fffffull, string("\003?\377\377", 4)},
+ {0x400000ull, string("\003@\000\000", 4)},
+ {0x400001ull, string("\003@\000\001", 4)},
+ {0x7fffffull, string("\003\177\377\377", 4)},
+ {0x800000ull, string("\003\200\000\000", 4)},
+ {0x800001ull, string("\003\200\000\001", 4)},
+ {0xffffffull, string("\003\377\377\377", 4)},
+ {0x1000000ull, string("\004\001\000\000\000", 5)},
+ {0x1000001ull, string("\004\001\000\000\001", 5)},
+ {0x1ffffffull, string("\004\001\377\377\377", 5)},
+ {0x2000000ull, string("\004\002\000\000\000", 5)},
+ {0x2000001ull, string("\004\002\000\000\001", 5)},
+ {0x3ffffffull, string("\004\003\377\377\377", 5)},
+ {0x4000000ull, string("\004\004\000\000\000", 5)},
+ {0x4000001ull, string("\004\004\000\000\001", 5)},
+ {0x7ffffffull, string("\004\007\377\377\377", 5)},
+ {0x8000000ull, string("\004\010\000\000\000", 5)},
+ {0x8000001ull, string("\004\010\000\000\001", 5)},
+ {0xfffffffull, string("\004\017\377\377\377", 5)},
+ {0x10000000ull, string("\004\020\000\000\000", 5)},
+ {0x10000001ull, string("\004\020\000\000\001", 5)},
+ {0x1fffffffull, string("\004\037\377\377\377", 5)},
+ {0x20000000ull, string("\004 \000\000\000", 5)},
+ {0x20000001ull, string("\004 \000\000\001", 5)},
+ {0x3fffffffull, string("\004?\377\377\377", 5)},
+ {0x40000000ull, string("\004@\000\000\000", 5)},
+ {0x40000001ull, string("\004@\000\000\001", 5)},
+ {0x7fffffffull, string("\004\177\377\377\377", 5)},
+ {0x80000000ull, string("\004\200\000\000\000", 5)},
+ {0x80000001ull, string("\004\200\000\000\001", 5)},
+ {0xffffffffull, string("\004\377\377\377\377", 5)},
+ {0x100000000ull, string("\005\001\000\000\000\000", 6)},
+ {0x100000001ull, string("\005\001\000\000\000\001", 6)},
+ {0x1ffffffffull, string("\005\001\377\377\377\377", 6)},
+ {0x200000000ull, string("\005\002\000\000\000\000", 6)},
+ {0x200000001ull, string("\005\002\000\000\000\001", 6)},
+ {0x3ffffffffull, string("\005\003\377\377\377\377", 6)},
+ {0x400000000ull, string("\005\004\000\000\000\000", 6)},
+ {0x400000001ull, string("\005\004\000\000\000\001", 6)},
+ {0x7ffffffffull, string("\005\007\377\377\377\377", 6)},
+ {0x800000000ull, string("\005\010\000\000\000\000", 6)},
+ {0x800000001ull, string("\005\010\000\000\000\001", 6)},
+ {0xfffffffffull, string("\005\017\377\377\377\377", 6)},
+ {0x1000000000ull, string("\005\020\000\000\000\000", 6)},
+ {0x1000000001ull, string("\005\020\000\000\000\001", 6)},
+ {0x1fffffffffull, string("\005\037\377\377\377\377", 6)},
+ {0x2000000000ull, string("\005 \000\000\000\000", 6)},
+ {0x2000000001ull, string("\005 \000\000\000\001", 6)},
+ {0x3fffffffffull, string("\005?\377\377\377\377", 6)},
+ {0x4000000000ull, string("\005@\000\000\000\000", 6)},
+ {0x4000000001ull, string("\005@\000\000\000\001", 6)},
+ {0x7fffffffffull, string("\005\177\377\377\377\377", 6)},
+ {0x8000000000ull, string("\005\200\000\000\000\000", 6)},
+ {0x8000000001ull, string("\005\200\000\000\000\001", 6)},
+ {0xffffffffffull, string("\005\377\377\377\377\377", 6)},
+ {0x10000000000ull, string("\006\001\000\000\000\000\000", 7)},
+ {0x10000000001ull, string("\006\001\000\000\000\000\001", 7)},
+ {0x1ffffffffffull, string("\006\001\377\377\377\377\377", 7)},
+ {0x20000000000ull, string("\006\002\000\000\000\000\000", 7)},
+ {0x20000000001ull, string("\006\002\000\000\000\000\001", 7)},
+ {0x3ffffffffffull, string("\006\003\377\377\377\377\377", 7)},
+ {0x40000000000ull, string("\006\004\000\000\000\000\000", 7)},
+ {0x40000000001ull, string("\006\004\000\000\000\000\001", 7)},
+ {0x7ffffffffffull, string("\006\007\377\377\377\377\377", 7)},
+ {0x80000000000ull, string("\006\010\000\000\000\000\000", 7)},
+ {0x80000000001ull, string("\006\010\000\000\000\000\001", 7)},
+ {0xfffffffffffull, string("\006\017\377\377\377\377\377", 7)},
+ {0x100000000000ull, string("\006\020\000\000\000\000\000", 7)},
+ {0x100000000001ull, string("\006\020\000\000\000\000\001", 7)},
+ {0x1fffffffffffull, string("\006\037\377\377\377\377\377", 7)},
+ {0x200000000000ull, string("\006 \000\000\000\000\000", 7)},
+ {0x200000000001ull, string("\006 \000\000\000\000\001", 7)},
+ {0x3fffffffffffull, string("\006?\377\377\377\377\377", 7)},
+ {0x400000000000ull, string("\006@\000\000\000\000\000", 7)},
+ {0x400000000001ull, string("\006@\000\000\000\000\001", 7)},
+ {0x7fffffffffffull, string("\006\177\377\377\377\377\377", 7)},
+ {0x800000000000ull, string("\006\200\000\000\000\000\000", 7)},
+ {0x800000000001ull, string("\006\200\000\000\000\000\001", 7)},
+ {0xffffffffffffull, string("\006\377\377\377\377\377\377", 7)},
+ {0x1000000000000ull, string("\007\001\000\000\000\000\000\000", 8)},
+ {0x1000000000001ull, string("\007\001\000\000\000\000\000\001", 8)},
+ {0x1ffffffffffffull, string("\007\001\377\377\377\377\377\377", 8)},
+ {0x2000000000000ull, string("\007\002\000\000\000\000\000\000", 8)},
+ {0x2000000000001ull, string("\007\002\000\000\000\000\000\001", 8)},
+ {0x3ffffffffffffull, string("\007\003\377\377\377\377\377\377", 8)},
+ {0x4000000000000ull, string("\007\004\000\000\000\000\000\000", 8)},
+ {0x4000000000001ull, string("\007\004\000\000\000\000\000\001", 8)},
+ {0x7ffffffffffffull, string("\007\007\377\377\377\377\377\377", 8)},
+ {0x8000000000000ull, string("\007\010\000\000\000\000\000\000", 8)},
+ {0x8000000000001ull, string("\007\010\000\000\000\000\000\001", 8)},
+ {0xfffffffffffffull, string("\007\017\377\377\377\377\377\377", 8)},
+ {0x10000000000000ull, string("\007\020\000\000\000\000\000\000", 8)},
+ {0x10000000000001ull, string("\007\020\000\000\000\000\000\001", 8)},
+ {0x1fffffffffffffull, string("\007\037\377\377\377\377\377\377", 8)},
+ {0x20000000000000ull, string("\007 \000\000\000\000\000\000", 8)},
+ {0x20000000000001ull, string("\007 \000\000\000\000\000\001", 8)},
+ {0x3fffffffffffffull, string("\007?\377\377\377\377\377\377", 8)},
+ {0x40000000000000ull, string("\007@\000\000\000\000\000\000", 8)},
+ {0x40000000000001ull, string("\007@\000\000\000\000\000\001", 8)},
+ {0x7fffffffffffffull, string("\007\177\377\377\377\377\377\377", 8)},
+ {0x80000000000000ull, string("\007\200\000\000\000\000\000\000", 8)},
+ {0x80000000000001ull, string("\007\200\000\000\000\000\000\001", 8)},
+ {0xffffffffffffffull, string("\007\377\377\377\377\377\377\377", 8)},
+ {0x100000000000000ull, string("\010\001\000\000\000\000\000\000\000", 9)},
+ {0x100000000000001ull, string("\010\001\000\000\000\000\000\000\001", 9)},
+ {0x1ffffffffffffffull, string("\010\001\377\377\377\377\377\377\377", 9)},
+ {0x200000000000000ull, string("\010\002\000\000\000\000\000\000\000", 9)},
+ {0x200000000000001ull, string("\010\002\000\000\000\000\000\000\001", 9)},
+ {0x3ffffffffffffffull, string("\010\003\377\377\377\377\377\377\377", 9)},
+ {0x400000000000000ull, string("\010\004\000\000\000\000\000\000\000", 9)},
+ {0x400000000000001ull, string("\010\004\000\000\000\000\000\000\001", 9)},
+ {0x7ffffffffffffffull, string("\010\007\377\377\377\377\377\377\377", 9)},
+ {0x800000000000000ull, string("\010\010\000\000\000\000\000\000\000", 9)},
+ {0x800000000000001ull, string("\010\010\000\000\000\000\000\000\001", 9)},
+ {0xfffffffffffffffull, string("\010\017\377\377\377\377\377\377\377", 9)},
+ {0x1000000000000000ull,
+ string("\010\020\000\000\000\000\000\000\000", 9)},
+ {0x1000000000000001ull,
+ string("\010\020\000\000\000\000\000\000\001", 9)},
+ {0x1fffffffffffffffull,
+ string("\010\037\377\377\377\377\377\377\377", 9)},
+ {0x2000000000000000ull, string("\010 \000\000\000\000\000\000\000", 9)},
+ {0x2000000000000001ull, string("\010 \000\000\000\000\000\000\001", 9)},
+ {0x3fffffffffffffffull, string("\010?\377\377\377\377\377\377\377", 9)},
+ {0x4000000000000000ull, string("\010@\000\000\000\000\000\000\000", 9)},
+ {0x4000000000000001ull, string("\010@\000\000\000\000\000\000\001", 9)},
+ {0x7fffffffffffffffull,
+ string("\010\177\377\377\377\377\377\377\377", 9)},
+ {0x8000000000000000ull,
+ string("\010\200\000\000\000\000\000\000\000", 9)},
+ {0x8000000000000001ull,
+ string("\010\200\000\000\000\000\000\000\001", 9)},
+ };
+ for (const auto& t : data) {
+ uint64 num = t.first;
+ string result;
+ OrderedCode::WriteNumIncreasing(&result, num);
+ EXPECT_EQ(t.second, result) << std::hex << num;
+
+ StringPiece in = result;
+ uint64 decoded;
+ EXPECT_TRUE(OrderedCode::ReadNumIncreasing(&in, &decoded));
+ EXPECT_EQ(num, decoded);
+ EXPECT_EQ("", in);
+ }
+}
+
+TEST(EncodingIsExpected, Signed) {
+ std::vector<std::pair<int64, string>> data = {
+ {0ll, string("\200", 1)},
+ {1ll, string("\201", 1)},
+ {2ll, string("\202", 1)},
+ {1ll, string("\201", 1)},
+ {2ll, string("\202", 1)},
+ {3ll, string("\203", 1)},
+ {3ll, string("\203", 1)},
+ {4ll, string("\204", 1)},
+ {5ll, string("\205", 1)},
+ {7ll, string("\207", 1)},
+ {8ll, string("\210", 1)},
+ {9ll, string("\211", 1)},
+ {15ll, string("\217", 1)},
+ {16ll, string("\220", 1)},
+ {17ll, string("\221", 1)},
+ {31ll, string("\237", 1)},
+ {32ll, string("\240", 1)},
+ {33ll, string("\241", 1)},
+ {63ll, string("\277", 1)},
+ {64ll, string("\300@", 2)},
+ {65ll, string("\300A", 2)},
+ {127ll, string("\300\177", 2)},
+ {128ll, string("\300\200", 2)},
+ {129ll, string("\300\201", 2)},
+ {255ll, string("\300\377", 2)},
+ {256ll, string("\301\000", 2)},
+ {257ll, string("\301\001", 2)},
+ {511ll, string("\301\377", 2)},
+ {512ll, string("\302\000", 2)},
+ {513ll, string("\302\001", 2)},
+ {1023ll, string("\303\377", 2)},
+ {1024ll, string("\304\000", 2)},
+ {1025ll, string("\304\001", 2)},
+ {2047ll, string("\307\377", 2)},
+ {2048ll, string("\310\000", 2)},
+ {2049ll, string("\310\001", 2)},
+ {4095ll, string("\317\377", 2)},
+ {4096ll, string("\320\000", 2)},
+ {4097ll, string("\320\001", 2)},
+ {8191ll, string("\337\377", 2)},
+ {8192ll, string("\340 \000", 3)},
+ {8193ll, string("\340 \001", 3)},
+ {16383ll, string("\340?\377", 3)},
+ {16384ll, string("\340@\000", 3)},
+ {16385ll, string("\340@\001", 3)},
+ {32767ll, string("\340\177\377", 3)},
+ {32768ll, string("\340\200\000", 3)},
+ {32769ll, string("\340\200\001", 3)},
+ {65535ll, string("\340\377\377", 3)},
+ {65536ll, string("\341\000\000", 3)},
+ {65537ll, string("\341\000\001", 3)},
+ {131071ll, string("\341\377\377", 3)},
+ {131072ll, string("\342\000\000", 3)},
+ {131073ll, string("\342\000\001", 3)},
+ {262143ll, string("\343\377\377", 3)},
+ {262144ll, string("\344\000\000", 3)},
+ {262145ll, string("\344\000\001", 3)},
+ {524287ll, string("\347\377\377", 3)},
+ {524288ll, string("\350\000\000", 3)},
+ {524289ll, string("\350\000\001", 3)},
+ {1048575ll, string("\357\377\377", 3)},
+ {1048576ll, string("\360\020\000\000", 4)},
+ {1048577ll, string("\360\020\000\001", 4)},
+ {2097151ll, string("\360\037\377\377", 4)},
+ {2097152ll, string("\360 \000\000", 4)},
+ {2097153ll, string("\360 \000\001", 4)},
+ {4194303ll, string("\360?\377\377", 4)},
+ {4194304ll, string("\360@\000\000", 4)},
+ {4194305ll, string("\360@\000\001", 4)},
+ {8388607ll, string("\360\177\377\377", 4)},
+ {8388608ll, string("\360\200\000\000", 4)},
+ {8388609ll, string("\360\200\000\001", 4)},
+ {16777215ll, string("\360\377\377\377", 4)},
+ {16777216ll, string("\361\000\000\000", 4)},
+ {16777217ll, string("\361\000\000\001", 4)},
+ {33554431ll, string("\361\377\377\377", 4)},
+ {33554432ll, string("\362\000\000\000", 4)},
+ {33554433ll, string("\362\000\000\001", 4)},
+ {67108863ll, string("\363\377\377\377", 4)},
+ {67108864ll, string("\364\000\000\000", 4)},
+ {67108865ll, string("\364\000\000\001", 4)},
+ {134217727ll, string("\367\377\377\377", 4)},
+ {134217728ll, string("\370\010\000\000\000", 5)},
+ {134217729ll, string("\370\010\000\000\001", 5)},
+ {268435455ll, string("\370\017\377\377\377", 5)},
+ {268435456ll, string("\370\020\000\000\000", 5)},
+ {268435457ll, string("\370\020\000\000\001", 5)},
+ {536870911ll, string("\370\037\377\377\377", 5)},
+ {536870912ll, string("\370 \000\000\000", 5)},
+ {536870913ll, string("\370 \000\000\001", 5)},
+ {1073741823ll, string("\370?\377\377\377", 5)},
+ {1073741824ll, string("\370@\000\000\000", 5)},
+ {1073741825ll, string("\370@\000\000\001", 5)},
+ {2147483647ll, string("\370\177\377\377\377", 5)},
+ {2147483648ll, string("\370\200\000\000\000", 5)},
+ {2147483649ll, string("\370\200\000\000\001", 5)},
+ {4294967295ll, string("\370\377\377\377\377", 5)},
+ {4294967296ll, string("\371\000\000\000\000", 5)},
+ {4294967297ll, string("\371\000\000\000\001", 5)},
+ {8589934591ll, string("\371\377\377\377\377", 5)},
+ {8589934592ll, string("\372\000\000\000\000", 5)},
+ {8589934593ll, string("\372\000\000\000\001", 5)},
+ {17179869183ll, string("\373\377\377\377\377", 5)},
+ {17179869184ll, string("\374\004\000\000\000\000", 6)},
+ {17179869185ll, string("\374\004\000\000\000\001", 6)},
+ {34359738367ll, string("\374\007\377\377\377\377", 6)},
+ {34359738368ll, string("\374\010\000\000\000\000", 6)},
+ {34359738369ll, string("\374\010\000\000\000\001", 6)},
+ {68719476735ll, string("\374\017\377\377\377\377", 6)},
+ {68719476736ll, string("\374\020\000\000\000\000", 6)},
+ {68719476737ll, string("\374\020\000\000\000\001", 6)},
+ {137438953471ll, string("\374\037\377\377\377\377", 6)},
+ {137438953472ll, string("\374 \000\000\000\000", 6)},
+ {137438953473ll, string("\374 \000\000\000\001", 6)},
+ {274877906943ll, string("\374?\377\377\377\377", 6)},
+ {274877906944ll, string("\374@\000\000\000\000", 6)},
+ {274877906945ll, string("\374@\000\000\000\001", 6)},
+ {549755813887ll, string("\374\177\377\377\377\377", 6)},
+ {549755813888ll, string("\374\200\000\000\000\000", 6)},
+ {549755813889ll, string("\374\200\000\000\000\001", 6)},
+ {1099511627775ll, string("\374\377\377\377\377\377", 6)},
+ {1099511627776ll, string("\375\000\000\000\000\000", 6)},
+ {1099511627777ll, string("\375\000\000\000\000\001", 6)},
+ {2199023255551ll, string("\375\377\377\377\377\377", 6)},
+ {2199023255552ll, string("\376\002\000\000\000\000\000", 7)},
+ {2199023255553ll, string("\376\002\000\000\000\000\001", 7)},
+ {4398046511103ll, string("\376\003\377\377\377\377\377", 7)},
+ {4398046511104ll, string("\376\004\000\000\000\000\000", 7)},
+ {4398046511105ll, string("\376\004\000\000\000\000\001", 7)},
+ {8796093022207ll, string("\376\007\377\377\377\377\377", 7)},
+ {8796093022208ll, string("\376\010\000\000\000\000\000", 7)},
+ {8796093022209ll, string("\376\010\000\000\000\000\001", 7)},
+ {17592186044415ll, string("\376\017\377\377\377\377\377", 7)},
+ {17592186044416ll, string("\376\020\000\000\000\000\000", 7)},
+ {17592186044417ll, string("\376\020\000\000\000\000\001", 7)},
+ {35184372088831ll, string("\376\037\377\377\377\377\377", 7)},
+ {35184372088832ll, string("\376 \000\000\000\000\000", 7)},
+ {35184372088833ll, string("\376 \000\000\000\000\001", 7)},
+ {70368744177663ll, string("\376?\377\377\377\377\377", 7)},
+ {70368744177664ll, string("\376@\000\000\000\000\000", 7)},
+ {70368744177665ll, string("\376@\000\000\000\000\001", 7)},
+ {140737488355327ll, string("\376\177\377\377\377\377\377", 7)},
+ {140737488355328ll, string("\376\200\000\000\000\000\000", 7)},
+ {140737488355329ll, string("\376\200\000\000\000\000\001", 7)},
+ {281474976710655ll, string("\376\377\377\377\377\377\377", 7)},
+ {281474976710656ll, string("\377\001\000\000\000\000\000\000", 8)},
+ {281474976710657ll, string("\377\001\000\000\000\000\000\001", 8)},
+ {562949953421311ll, string("\377\001\377\377\377\377\377\377", 8)},
+ {562949953421312ll, string("\377\002\000\000\000\000\000\000", 8)},
+ {562949953421313ll, string("\377\002\000\000\000\000\000\001", 8)},
+ {1125899906842623ll, string("\377\003\377\377\377\377\377\377", 8)},
+ {1125899906842624ll, string("\377\004\000\000\000\000\000\000", 8)},
+ {1125899906842625ll, string("\377\004\000\000\000\000\000\001", 8)},
+ {2251799813685247ll, string("\377\007\377\377\377\377\377\377", 8)},
+ {2251799813685248ll, string("\377\010\000\000\000\000\000\000", 8)},
+ {2251799813685249ll, string("\377\010\000\000\000\000\000\001", 8)},
+ {4503599627370495ll, string("\377\017\377\377\377\377\377\377", 8)},
+ {4503599627370496ll, string("\377\020\000\000\000\000\000\000", 8)},
+ {4503599627370497ll, string("\377\020\000\000\000\000\000\001", 8)},
+ {9007199254740991ll, string("\377\037\377\377\377\377\377\377", 8)},
+ {9007199254740992ll, string("\377 \000\000\000\000\000\000", 8)},
+ {9007199254740993ll, string("\377 \000\000\000\000\000\001", 8)},
+ {18014398509481983ll, string("\377?\377\377\377\377\377\377", 8)},
+ {18014398509481984ll, string("\377@\000\000\000\000\000\000", 8)},
+ {18014398509481985ll, string("\377@\000\000\000\000\000\001", 8)},
+ {36028797018963967ll, string("\377\177\377\377\377\377\377\377", 8)},
+ {36028797018963968ll, string("\377\200\200\000\000\000\000\000\000", 9)},
+ {36028797018963969ll, string("\377\200\200\000\000\000\000\000\001", 9)},
+ {72057594037927935ll, string("\377\200\377\377\377\377\377\377\377", 9)},
+ {72057594037927936ll, string("\377\201\000\000\000\000\000\000\000", 9)},
+ {72057594037927937ll, string("\377\201\000\000\000\000\000\000\001", 9)},
+ {144115188075855871ll, string("\377\201\377\377\377\377\377\377\377", 9)},
+ {144115188075855872ll, string("\377\202\000\000\000\000\000\000\000", 9)},
+ {144115188075855873ll, string("\377\202\000\000\000\000\000\000\001", 9)},
+ {288230376151711743ll, string("\377\203\377\377\377\377\377\377\377", 9)},
+ {288230376151711744ll, string("\377\204\000\000\000\000\000\000\000", 9)},
+ {288230376151711745ll, string("\377\204\000\000\000\000\000\000\001", 9)},
+ {576460752303423487ll, string("\377\207\377\377\377\377\377\377\377", 9)},
+ {576460752303423488ll, string("\377\210\000\000\000\000\000\000\000", 9)},
+ {576460752303423489ll, string("\377\210\000\000\000\000\000\000\001", 9)},
+ {1152921504606846975ll,
+ string("\377\217\377\377\377\377\377\377\377", 9)},
+ {1152921504606846976ll,
+ string("\377\220\000\000\000\000\000\000\000", 9)},
+ {1152921504606846977ll,
+ string("\377\220\000\000\000\000\000\000\001", 9)},
+ {2305843009213693951ll,
+ string("\377\237\377\377\377\377\377\377\377", 9)},
+ {2305843009213693952ll,
+ string("\377\240\000\000\000\000\000\000\000", 9)},
+ {2305843009213693953ll,
+ string("\377\240\000\000\000\000\000\000\001", 9)},
+ {4611686018427387903ll,
+ string("\377\277\377\377\377\377\377\377\377", 9)},
+ {4611686018427387904ll,
+ string("\377\300@\000\000\000\000\000\000\000", 10)},
+ {4611686018427387905ll,
+ string("\377\300@\000\000\000\000\000\000\001", 10)},
+ {9223372036854775807ll,
+ string("\377\300\177\377\377\377\377\377\377\377", 10)},
+ {-9223372036854775807ll,
+ string("\000?\200\000\000\000\000\000\000\001", 10)},
+ {0ll, string("\200", 1)},
+ {-1ll, string("\177", 1)},
+ {-2ll, string("~", 1)},
+ {-1ll, string("\177", 1)},
+ {-2ll, string("~", 1)},
+ {-3ll, string("}", 1)},
+ {-3ll, string("}", 1)},
+ {-4ll, string("|", 1)},
+ {-5ll, string("{", 1)},
+ {-7ll, string("y", 1)},
+ {-8ll, string("x", 1)},
+ {-9ll, string("w", 1)},
+ {-15ll, string("q", 1)},
+ {-16ll, string("p", 1)},
+ {-17ll, string("o", 1)},
+ {-31ll, string("a", 1)},
+ {-32ll, string("`", 1)},
+ {-33ll, string("_", 1)},
+ {-63ll, string("A", 1)},
+ {-64ll, string("@", 1)},
+ {-65ll, string("?\277", 2)},
+ {-127ll, string("?\201", 2)},
+ {-128ll, string("?\200", 2)},
+ {-129ll, string("?\177", 2)},
+ {-255ll, string("?\001", 2)},
+ {-256ll, string("?\000", 2)},
+ {-257ll, string(">\377", 2)},
+ {-511ll, string(">\001", 2)},
+ {-512ll, string(">\000", 2)},
+ {-513ll, string("=\377", 2)},
+ {-1023ll, string("<\001", 2)},
+ {-1024ll, string("<\000", 2)},
+ {-1025ll, string(";\377", 2)},
+ {-2047ll, string("8\001", 2)},
+ {-2048ll, string("8\000", 2)},
+ {-2049ll, string("7\377", 2)},
+ {-4095ll, string("0\001", 2)},
+ {-4096ll, string("0\000", 2)},
+ {-4097ll, string("/\377", 2)},
+ {-8191ll, string(" \001", 2)},
+ {-8192ll, string(" \000", 2)},
+ {-8193ll, string("\037\337\377", 3)},
+ {-16383ll, string("\037\300\001", 3)},
+ {-16384ll, string("\037\300\000", 3)},
+ {-16385ll, string("\037\277\377", 3)},
+ {-32767ll, string("\037\200\001", 3)},
+ {-32768ll, string("\037\200\000", 3)},
+ {-32769ll, string("\037\177\377", 3)},
+ {-65535ll, string("\037\000\001", 3)},
+ {-65536ll, string("\037\000\000", 3)},
+ {-65537ll, string("\036\377\377", 3)},
+ {-131071ll, string("\036\000\001", 3)},
+ {-131072ll, string("\036\000\000", 3)},
+ {-131073ll, string("\035\377\377", 3)},
+ {-262143ll, string("\034\000\001", 3)},
+ {-262144ll, string("\034\000\000", 3)},
+ {-262145ll, string("\033\377\377", 3)},
+ {-524287ll, string("\030\000\001", 3)},
+ {-524288ll, string("\030\000\000", 3)},
+ {-524289ll, string("\027\377\377", 3)},
+ {-1048575ll, string("\020\000\001", 3)},
+ {-1048576ll, string("\020\000\000", 3)},
+ {-1048577ll, string("\017\357\377\377", 4)},
+ {-2097151ll, string("\017\340\000\001", 4)},
+ {-2097152ll, string("\017\340\000\000", 4)},
+ {-2097153ll, string("\017\337\377\377", 4)},
+ {-4194303ll, string("\017\300\000\001", 4)},
+ {-4194304ll, string("\017\300\000\000", 4)},
+ {-4194305ll, string("\017\277\377\377", 4)},
+ {-8388607ll, string("\017\200\000\001", 4)},
+ {-8388608ll, string("\017\200\000\000", 4)},
+ {-8388609ll, string("\017\177\377\377", 4)},
+ {-16777215ll, string("\017\000\000\001", 4)},
+ {-16777216ll, string("\017\000\000\000", 4)},
+ {-16777217ll, string("\016\377\377\377", 4)},
+ {-33554431ll, string("\016\000\000\001", 4)},
+ {-33554432ll, string("\016\000\000\000", 4)},
+ {-33554433ll, string("\r\377\377\377", 4)},
+ {-67108863ll, string("\014\000\000\001", 4)},
+ {-67108864ll, string("\014\000\000\000", 4)},
+ {-67108865ll, string("\013\377\377\377", 4)},
+ {-134217727ll, string("\010\000\000\001", 4)},
+ {-134217728ll, string("\010\000\000\000", 4)},
+ {-134217729ll, string("\007\367\377\377\377", 5)},
+ {-268435455ll, string("\007\360\000\000\001", 5)},
+ {-268435456ll, string("\007\360\000\000\000", 5)},
+ {-268435457ll, string("\007\357\377\377\377", 5)},
+ {-536870911ll, string("\007\340\000\000\001", 5)},
+ {-536870912ll, string("\007\340\000\000\000", 5)},
+ {-536870913ll, string("\007\337\377\377\377", 5)},
+ {-1073741823ll, string("\007\300\000\000\001", 5)},
+ {-1073741824ll, string("\007\300\000\000\000", 5)},
+ {-1073741825ll, string("\007\277\377\377\377", 5)},
+ {-2147483647ll, string("\007\200\000\000\001", 5)},
+ {-2147483648ll, string("\007\200\000\000\000", 5)},
+ {-2147483649ll, string("\007\177\377\377\377", 5)},
+ {-4294967295ll, string("\007\000\000\000\001", 5)},
+ {-4294967296ll, string("\007\000\000\000\000", 5)},
+ {-4294967297ll, string("\006\377\377\377\377", 5)},
+ {-8589934591ll, string("\006\000\000\000\001", 5)},
+ {-8589934592ll, string("\006\000\000\000\000", 5)},
+ {-8589934593ll, string("\005\377\377\377\377", 5)},
+ {-17179869183ll, string("\004\000\000\000\001", 5)},
+ {-17179869184ll, string("\004\000\000\000\000", 5)},
+ {-17179869185ll, string("\003\373\377\377\377\377", 6)},
+ {-34359738367ll, string("\003\370\000\000\000\001", 6)},
+ {-34359738368ll, string("\003\370\000\000\000\000", 6)},
+ {-34359738369ll, string("\003\367\377\377\377\377", 6)},
+ {-68719476735ll, string("\003\360\000\000\000\001", 6)},
+ {-68719476736ll, string("\003\360\000\000\000\000", 6)},
+ {-68719476737ll, string("\003\357\377\377\377\377", 6)},
+ {-137438953471ll, string("\003\340\000\000\000\001", 6)},
+ {-137438953472ll, string("\003\340\000\000\000\000", 6)},
+ {-137438953473ll, string("\003\337\377\377\377\377", 6)},
+ {-274877906943ll, string("\003\300\000\000\000\001", 6)},
+ {-274877906944ll, string("\003\300\000\000\000\000", 6)},
+ {-274877906945ll, string("\003\277\377\377\377\377", 6)},
+ {-549755813887ll, string("\003\200\000\000\000\001", 6)},
+ {-549755813888ll, string("\003\200\000\000\000\000", 6)},
+ {-549755813889ll, string("\003\177\377\377\377\377", 6)},
+ {-1099511627775ll, string("\003\000\000\000\000\001", 6)},
+ {-1099511627776ll, string("\003\000\000\000\000\000", 6)},
+ {-1099511627777ll, string("\002\377\377\377\377\377", 6)},
+ {-2199023255551ll, string("\002\000\000\000\000\001", 6)},
+ {-2199023255552ll, string("\002\000\000\000\000\000", 6)},
+ {-2199023255553ll, string("\001\375\377\377\377\377\377", 7)},
+ {-4398046511103ll, string("\001\374\000\000\000\000\001", 7)},
+ {-4398046511104ll, string("\001\374\000\000\000\000\000", 7)},
+ {-4398046511105ll, string("\001\373\377\377\377\377\377", 7)},
+ {-8796093022207ll, string("\001\370\000\000\000\000\001", 7)},
+ {-8796093022208ll, string("\001\370\000\000\000\000\000", 7)},
+ {-8796093022209ll, string("\001\367\377\377\377\377\377", 7)},
+ {-17592186044415ll, string("\001\360\000\000\000\000\001", 7)},
+ {-17592186044416ll, string("\001\360\000\000\000\000\000", 7)},
+ {-17592186044417ll, string("\001\357\377\377\377\377\377", 7)},
+ {-35184372088831ll, string("\001\340\000\000\000\000\001", 7)},
+ {-35184372088832ll, string("\001\340\000\000\000\000\000", 7)},
+ {-35184372088833ll, string("\001\337\377\377\377\377\377", 7)},
+ {-70368744177663ll, string("\001\300\000\000\000\000\001", 7)},
+ {-70368744177664ll, string("\001\300\000\000\000\000\000", 7)},
+ {-70368744177665ll, string("\001\277\377\377\377\377\377", 7)},
+ {-140737488355327ll, string("\001\200\000\000\000\000\001", 7)},
+ {-140737488355328ll, string("\001\200\000\000\000\000\000", 7)},
+ {-140737488355329ll, string("\001\177\377\377\377\377\377", 7)},
+ {-281474976710655ll, string("\001\000\000\000\000\000\001", 7)},
+ {-281474976710656ll, string("\001\000\000\000\000\000\000", 7)},
+ {-281474976710657ll, string("\000\376\377\377\377\377\377\377", 8)},
+ {-562949953421311ll, string("\000\376\000\000\000\000\000\001", 8)},
+ {-562949953421312ll, string("\000\376\000\000\000\000\000\000", 8)},
+ {-562949953421313ll, string("\000\375\377\377\377\377\377\377", 8)},
+ {-1125899906842623ll, string("\000\374\000\000\000\000\000\001", 8)},
+ {-1125899906842624ll, string("\000\374\000\000\000\000\000\000", 8)},
+ {-1125899906842625ll, string("\000\373\377\377\377\377\377\377", 8)},
+ {-2251799813685247ll, string("\000\370\000\000\000\000\000\001", 8)},
+ {-2251799813685248ll, string("\000\370\000\000\000\000\000\000", 8)},
+ {-2251799813685249ll, string("\000\367\377\377\377\377\377\377", 8)},
+ {-4503599627370495ll, string("\000\360\000\000\000\000\000\001", 8)},
+ {-4503599627370496ll, string("\000\360\000\000\000\000\000\000", 8)},
+ {-4503599627370497ll, string("\000\357\377\377\377\377\377\377", 8)},
+ {-9007199254740991ll, string("\000\340\000\000\000\000\000\001", 8)},
+ {-9007199254740992ll, string("\000\340\000\000\000\000\000\000", 8)},
+ {-9007199254740993ll, string("\000\337\377\377\377\377\377\377", 8)},
+ {-18014398509481983ll, string("\000\300\000\000\000\000\000\001", 8)},
+ {-18014398509481984ll, string("\000\300\000\000\000\000\000\000", 8)},
+ {-18014398509481985ll, string("\000\277\377\377\377\377\377\377", 8)},
+ {-36028797018963967ll, string("\000\200\000\000\000\000\000\001", 8)},
+ {-36028797018963968ll, string("\000\200\000\000\000\000\000\000", 8)},
+ {-36028797018963969ll, string("\000\177\177\377\377\377\377\377\377", 9)},
+ {-72057594037927935ll, string("\000\177\000\000\000\000\000\000\001", 9)},
+ {-72057594037927936ll, string("\000\177\000\000\000\000\000\000\000", 9)},
+ {-72057594037927937ll, string("\000~\377\377\377\377\377\377\377", 9)},
+ {-144115188075855871ll, string("\000~\000\000\000\000\000\000\001", 9)},
+ {-144115188075855872ll, string("\000~\000\000\000\000\000\000\000", 9)},
+ {-144115188075855873ll, string("\000}\377\377\377\377\377\377\377", 9)},
+ {-288230376151711743ll, string("\000|\000\000\000\000\000\000\001", 9)},
+ {-288230376151711744ll, string("\000|\000\000\000\000\000\000\000", 9)},
+ {-288230376151711745ll, string("\000{\377\377\377\377\377\377\377", 9)},
+ {-576460752303423487ll, string("\000x\000\000\000\000\000\000\001", 9)},
+ {-576460752303423488ll, string("\000x\000\000\000\000\000\000\000", 9)},
+ {-576460752303423489ll, string("\000w\377\377\377\377\377\377\377", 9)},
+ {-1152921504606846975ll, string("\000p\000\000\000\000\000\000\001", 9)},
+ {-1152921504606846976ll, string("\000p\000\000\000\000\000\000\000", 9)},
+ {-1152921504606846977ll, string("\000o\377\377\377\377\377\377\377", 9)},
+ {-2305843009213693951ll, string("\000`\000\000\000\000\000\000\001", 9)},
+ {-2305843009213693952ll, string("\000`\000\000\000\000\000\000\000", 9)},
+ {-2305843009213693953ll, string("\000_\377\377\377\377\377\377\377", 9)},
+ {-4611686018427387903ll, string("\000@\000\000\000\000\000\000\001", 9)},
+ {-4611686018427387904ll, string("\000@\000\000\000\000\000\000\000", 9)},
+ {-4611686018427387905ll,
+ string("\000?\277\377\377\377\377\377\377\377", 10)},
+ {-9223372036854775807ll,
+ string("\000?\200\000\000\000\000\000\000\001", 10)},
+ {9223372036854775807ll,
+ string("\377\300\177\377\377\377\377\377\377\377", 10)},
+ };
+ for (const auto& t : data) {
+ int64 num = t.first;
+ string result;
+ OrderedCode::WriteSignedNumIncreasing(&result, num);
+ EXPECT_EQ(t.second, result) << std::hex << num;
+
+ StringPiece in = result;
+ int64 decoded;
+ EXPECT_TRUE(OrderedCode::ReadSignedNumIncreasing(&in, &decoded));
+ EXPECT_EQ(num, decoded);
+ EXPECT_EQ("", in);
+ }
+}
+
+static void BM_WriteString(int n, int len) {
+ testing::StopTiming();
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ string x;
+ for (int i = 0; i < len; i++) {
+ x += rnd.Uniform(256);
+ }
+ string y;
+
+ testing::BytesProcessed(n * len);
+ testing::StartTiming();
+ while (n-- > 0) {
+ y.clear();
+ OCWriteToString<string>(&y, x);
+ }
+}
+
+static void BM_ReadString(int n, int len) {
+ testing::StopTiming();
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ string x;
+ for (int i = 0; i < len; i++) {
+ x += rnd.Uniform(256);
+ }
+ string data;
+ OCWriteToString<string>(&data, x);
+ string result;
+
+ testing::BytesProcessed(n * len);
+ testing::StartTiming();
+ while (n-- > 0) {
+ result.clear();
+ StringPiece s = data;
+ OCRead<string>(&s, &result);
+ }
+}
+
+static void BM_WriteStringIncreasing(int n, int len) { BM_WriteString(n, len); }
+static void BM_ReadStringIncreasing(int n, int len) { BM_ReadString(n, len); }
+
+BENCHMARK(BM_WriteStringIncreasing)->Range(0, 1024);
+BENCHMARK(BM_ReadStringIncreasing)->Range(0, 1024);
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc
new file mode 100644
index 0000000000..cccd50c7ff
--- /dev/null
+++ b/tensorflow/core/lib/strings/str_util.cc
@@ -0,0 +1,312 @@
+#include "tensorflow/core/lib/strings/str_util.h"
+#include <ctype.h>
+
+namespace tensorflow {
+namespace str_util {
+
+static char hex_char[] = "0123456789abcdef";
+
+string CEscape(const string& src) {
+ string dest;
+
+ for (unsigned char c : src) {
+ switch (c) {
+ case '\n':
+ dest.append("\\n");
+ break;
+ case '\r':
+ dest.append("\\r");
+ break;
+ case '\t':
+ dest.append("\\t");
+ break;
+ case '\"':
+ dest.append("\\\"");
+ break;
+ case '\'':
+ dest.append("\\'");
+ break;
+ case '\\':
+ dest.append("\\\\");
+ break;
+ default:
+ // Note that if we emit \xNN and the src character after that is a hex
+ // digit then that digit must be escaped too to prevent it being
+ // interpreted as part of the character code by C.
+ if ((c >= 0x80) || !isprint(c)) {
+ dest.append("\\");
+ dest.push_back(hex_char[c / 64]);
+ dest.push_back(hex_char[(c % 64) / 8]);
+ dest.push_back(hex_char[c % 8]);
+ } else {
+ dest.push_back(c);
+ break;
+ }
+ }
+ }
+
+ return dest;
+}
+
+namespace { // Private helpers for CUnescape().
+
+inline bool is_octal_digit(unsigned char c) { return c >= '0' && c <= '7'; }
+
+inline bool ascii_isxdigit(unsigned char c) {
+ return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') ||
+ (c >= 'A' && c <= 'F');
+}
+
+inline int hex_digit_to_int(char c) {
+ int x = static_cast<unsigned char>(c);
+ if (x > '9') {
+ x += 9;
+ }
+ return x & 0xf;
+}
+
+bool CUnescapeInternal(StringPiece source, char* dest, int* dest_len,
+ string* error) {
+ char* d = dest;
+ const char* p = source.data();
+ const char* end = source.end();
+ const char* last_byte = end - 1;
+
+ // Small optimization for case where source = dest and there's no escaping
+ while (p == d && p < end && *p != '\\') p++, d++;
+
+ while (p < end) {
+ if (*p != '\\') {
+ *d++ = *p++;
+ } else {
+ if (++p > last_byte) { // skip past the '\\'
+ if (error) *error = "String cannot end with \\";
+ return false;
+ }
+ switch (*p) {
+ case 'a':
+ *d++ = '\a';
+ break;
+ case 'b':
+ *d++ = '\b';
+ break;
+ case 'f':
+ *d++ = '\f';
+ break;
+ case 'n':
+ *d++ = '\n';
+ break;
+ case 'r':
+ *d++ = '\r';
+ break;
+ case 't':
+ *d++ = '\t';
+ break;
+ case 'v':
+ *d++ = '\v';
+ break;
+ case '\\':
+ *d++ = '\\';
+ break;
+ case '?':
+ *d++ = '\?';
+ break; // \? Who knew?
+ case '\'':
+ *d++ = '\'';
+ break;
+ case '"':
+ *d++ = '\"';
+ break;
+ case '0':
+ case '1':
+ case '2':
+ case '3': // octal digit: 1 to 3 digits
+ case '4':
+ case '5':
+ case '6':
+ case '7': {
+ const char* octal_start = p;
+ unsigned int ch = *p - '0';
+ if (p < last_byte && is_octal_digit(p[1])) ch = ch * 8 + *++p - '0';
+ if (p < last_byte && is_octal_digit(p[1]))
+ ch = ch * 8 + *++p - '0'; // now points at last digit
+ if (ch > 0xff) {
+ if (error) {
+ *error = "Value of \\" +
+ string(octal_start, p + 1 - octal_start) +
+ " exceeds 0xff";
+ }
+ return false;
+ }
+ *d++ = ch;
+ break;
+ }
+ case 'x':
+ case 'X': {
+ if (p >= last_byte) {
+ if (error) *error = "String cannot end with \\x";
+ return false;
+ } else if (!ascii_isxdigit(p[1])) {
+ if (error) *error = "\\x cannot be followed by a non-hex digit";
+ return false;
+ }
+ unsigned int ch = 0;
+ const char* hex_start = p;
+ while (p < last_byte && ascii_isxdigit(p[1]))
+ // Arbitrarily many hex digits
+ ch = (ch << 4) + hex_digit_to_int(*++p);
+ if (ch > 0xFF) {
+ if (error) {
+ *error = "Value of \\" + string(hex_start, p + 1 - hex_start) +
+ " exceeds 0xff";
+ }
+ return false;
+ }
+ *d++ = ch;
+ break;
+ }
+ default: {
+ if (error) *error = string("Unknown escape sequence: \\") + *p;
+ return false;
+ }
+ }
+ p++; // read past letter we escaped
+ }
+ }
+ *dest_len = d - dest;
+ return true;
+}
+
+} // namespace
+
+bool CUnescape(StringPiece source, string* dest, string* error) {
+ dest->resize(source.size());
+ int dest_size;
+ if (!CUnescapeInternal(source, const_cast<char*>(dest->data()), &dest_size,
+ error)) {
+ return false;
+ }
+ dest->erase(dest_size);
+ return true;
+}
+
+bool NumericParse32(const string& text, int32* val) {
+ // Slow, but this code is not performance critical, and this
+ // doesn't bring in any new dependencies
+ char junk;
+ if (sscanf(text.c_str(), "%d%c", val, &junk) == 1) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+void StripTrailingWhitespace(string* s) {
+ string::size_type i;
+ for (i = s->size(); i > 0 && isspace((*s)[i - 1]); --i) {
+ }
+ s->resize(i);
+}
+
+// Return lower-cased version of s.
+string Lowercase(StringPiece s) {
+ string result(s.data(), s.size());
+ for (char& c : result) {
+ c = tolower(c);
+ }
+ return result;
+}
+
+// Return upper-cased version of s.
+string Uppercase(StringPiece s) {
+ string result(s.data(), s.size());
+ for (char& c : result) {
+ c = toupper(c);
+ }
+ return result;
+}
+
+void TitlecaseString(string* s, StringPiece delimiters) {
+ bool upper = true;
+ for (string::iterator ss = s->begin(); ss != s->end(); ++ss) {
+ if (upper) {
+ *ss = toupper(*ss);
+ }
+ upper = (delimiters.find(*ss) != StringPiece::npos);
+ }
+}
+
+size_t RemoveLeadingWhitespace(StringPiece* text) {
+ size_t count = 0;
+ const char* ptr = text->data();
+ while (count < text->size() && isspace(*ptr)) {
+ count++;
+ ptr++;
+ }
+ text->remove_prefix(count);
+ return count;
+}
+
+size_t RemoveTrailingWhitespace(StringPiece* text) {
+ size_t count = 0;
+ const char* ptr = text->data() + text->size() - 1;
+ while (count < text->size() && isspace(*ptr)) {
+ ++count;
+ --ptr;
+ }
+ text->remove_suffix(count);
+ return count;
+}
+
+size_t RemoveWhitespaceContext(StringPiece* text) {
+ // use RemoveLeadingWhitespace() and RemoveTrailingWhitespace() to do the job
+ return (RemoveLeadingWhitespace(text) + RemoveTrailingWhitespace(text));
+}
+
+bool ConsumePrefix(StringPiece* s, StringPiece expected) {
+ if (s->starts_with(expected)) {
+ s->remove_prefix(expected.size());
+ return true;
+ }
+ return false;
+}
+
+bool ConsumeLeadingDigits(StringPiece* s, uint64* val) {
+ const char* p = s->data();
+ const char* limit = p + s->size();
+ uint64 v = 0;
+ while (p < limit) {
+ const char c = *p;
+ if (c < '0' || c > '9') break;
+ uint64 new_v = (v * 10) + (c - '0');
+ if (new_v < v) {
+ // Overflow occurred
+ return false;
+ }
+ v = new_v;
+ p++;
+ }
+ if (p > s->data()) {
+ // Consume some digits
+ s->remove_prefix(p - s->data());
+ *val = v;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool SplitAndParseAsInts(StringPiece text, char delim,
+ std::vector<int32>* result) {
+ result->clear();
+ std::vector<string> num_strings = Split(text, delim);
+ for (const auto& s : num_strings) {
+ int32 num;
+ if (!NumericParse32(s, &num)) return false;
+ result->push_back(num);
+ }
+ return true;
+}
+
+} // namespace str_util
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h
new file mode 100644
index 0000000000..34ea462b2d
--- /dev/null
+++ b/tensorflow/core/lib/strings/str_util.h
@@ -0,0 +1,149 @@
+#ifndef TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
+#define TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
+
+#include <string>
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+// Basic string utility routines
+namespace tensorflow {
+namespace str_util {
+
+// Returns a version of 'src' where unprintable characters have been
+// escaped using C-style escape sequences.
+string CEscape(const string& src);
+
+// Copies "source" to "dest", rewriting C-style escape sequences --
+// '\n', '\r', '\\', '\ooo', etc -- to their ASCII equivalents.
+//
+// Errors: Sets the description of the first encountered error in
+// 'error'. To disable error reporting, set 'error' to NULL.
+//
+// NOTE: Does not support \u or \U!
+bool CUnescape(StringPiece source, string* dest, string* error);
+
+// If "text" can be successfully parsed as the ASCII representation of
+// an integer, sets "*val" to the value and returns true. Otherwise,
+// returns false.
+bool NumericParse32(const string& text, int32* val);
+
+// Removes any trailing whitespace from "*s".
+void StripTrailingWhitespace(string* s);
+
+// Removes leading ascii_isspace() characters.
+// Returns number of characters removed.
+size_t RemoveLeadingWhitespace(StringPiece* text);
+
+// Removes trailing ascii_isspace() characters.
+// Returns number of characters removed.
+size_t RemoveTrailingWhitespace(StringPiece* text);
+
+// Removes leading and trailing ascii_isspace() chars.
+// Returns number of chars removed.
+size_t RemoveWhitespaceContext(StringPiece* text);
+
+// Consume a leading positive integer value. If any digits were
+// found, store the value of the leading unsigned number in "*val",
+// advance "*s" past the consumed number, and return true. If
+// overflow occurred, returns false. Otherwise, returns false.
+bool ConsumeLeadingDigits(StringPiece* s, uint64* val);
+
+// If "*s" starts with "expected", consume it and return true.
+// Otherwise, return false.
+bool ConsumePrefix(StringPiece* s, StringPiece expected);
+
+// Return lower-cased version of s.
+string Lowercase(StringPiece s);
+
+// Return upper-cased version of s.
+string Uppercase(StringPiece s);
+
+// Capitalize first character of each word in "*s". "delimiters" is a
+// set of characters that can be used as word boundaries.
+void TitlecaseString(string* s, StringPiece delimiters);
+
+// Join functionality
+template <typename T>
+string Join(const std::vector<T>& s, const char* sep);
+template <typename T>
+string Join(const gtl::ArraySlice<T>& s, const char* sep);
+
+struct AllowEmpty {
+ bool operator()(StringPiece sp) const { return true; }
+};
+struct SkipEmpty {
+ bool operator()(StringPiece sp) const { return !sp.empty(); }
+};
+struct SkipWhitespace {
+ bool operator()(StringPiece sp) const {
+ RemoveTrailingWhitespace(&sp);
+ return !sp.empty();
+ }
+};
+
+std::vector<string> Split(StringPiece text, char delim);
+template <typename Predicate>
+std::vector<string> Split(StringPiece text, char delim, Predicate p);
+
+// Split "text" at "delim" characters, and parse each component as
+// an integer. If successful, adds the individual numbers in order
+// to "*result" and returns true. Otherwise returns false.
+bool SplitAndParseAsInts(StringPiece text, char delim,
+ std::vector<int32>* result);
+
+// ------------------------------------------------------------------
+// Implementation details below
+namespace internal {
+template <typename T>
+string JoinHelper(typename gtl::ArraySlice<T>::const_iterator begin,
+ typename gtl::ArraySlice<T>::const_iterator end,
+ const char* sep) {
+ string result;
+ bool first = true;
+ for (typename gtl::ArraySlice<T>::const_iterator it = begin; it != end;
+ ++it) {
+ tensorflow::strings::StrAppend(&result, (first ? "" : sep), *it);
+ first = false;
+ }
+ return result;
+}
+} // namespace internal
+
+template <typename T>
+string Join(const std::vector<T>& s, const char* sep) {
+ return Join<T>(gtl::ArraySlice<T>(s), sep);
+}
+
+template <typename T>
+string Join(const gtl::ArraySlice<T>& s, const char* sep) {
+ return internal::JoinHelper<T>(s.begin(), s.end(), sep);
+}
+
+inline std::vector<string> Split(StringPiece text, char delim) {
+ return Split(text, delim, AllowEmpty());
+}
+
+template <typename Predicate>
+std::vector<string> Split(StringPiece text, char delim, Predicate p) {
+ std::vector<string> result;
+ int token_start = 0;
+ if (!text.empty()) {
+ for (int i = 0; i < text.size() + 1; i++) {
+ if ((i == text.size()) || (text[i] == delim)) {
+ StringPiece token(text.data() + token_start, i - token_start);
+ if (p(token)) {
+ result.push_back(token.ToString());
+ }
+ token_start = i + 1;
+ }
+ }
+ }
+ return result;
+}
+
+} // namespace str_util
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc
new file mode 100644
index 0000000000..f71cc6c609
--- /dev/null
+++ b/tensorflow/core/lib/strings/str_util_test.cc
@@ -0,0 +1,258 @@
+#include "tensorflow/core/lib/strings/str_util.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(CEscape, Basic) {
+ EXPECT_EQ(str_util::CEscape("hello"), "hello");
+ EXPECT_EQ(str_util::CEscape("hello\n"), "hello\\n");
+ EXPECT_EQ(str_util::CEscape("hello\r"), "hello\\r");
+ EXPECT_EQ(str_util::CEscape("\t\r\"'"), "\\t\\r\\\"\\'");
+ EXPECT_EQ(str_util::CEscape("\320hi\200"), "\\320hi\\200");
+}
+
+string ExpectCUnescapeSuccess(StringPiece source) {
+ string dest;
+ string error;
+ EXPECT_TRUE(str_util::CUnescape(source, &dest, &error)) << error;
+ return dest;
+}
+
+TEST(CUnescape, Basic) {
+ EXPECT_EQ("hello", ExpectCUnescapeSuccess("hello"));
+ EXPECT_EQ("hello\n", ExpectCUnescapeSuccess("hello\\n"));
+ EXPECT_EQ("hello\r", ExpectCUnescapeSuccess("hello\\r"));
+ EXPECT_EQ("\t\r\"'", ExpectCUnescapeSuccess("\\t\\r\\\"\\'"));
+ EXPECT_EQ("\320hi\200", ExpectCUnescapeSuccess("\\320hi\\200"));
+}
+
+TEST(NumericParse32, Basic) {
+ int32 val = -1234;
+ EXPECT_TRUE(str_util::NumericParse32("0", &val) && val == 0);
+ EXPECT_TRUE(str_util::NumericParse32("123", &val) && val == 123);
+ EXPECT_TRUE(str_util::NumericParse32("-375", &val) && val == -375);
+ EXPECT_FALSE(str_util::NumericParse32("123hello", &val));
+ EXPECT_FALSE(str_util::NumericParse32("hello123", &val));
+}
+
+TEST(StripTrailingWhitespace, Basic) {
+ string test;
+ test = "hello";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, "hello");
+
+ test = "foo ";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, "foo");
+
+ test = " ";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, "");
+
+ test = "";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, "");
+
+ test = " abc\t";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, " abc");
+}
+
+TEST(RemoveLeadingWhitespace, Basic) {
+ string text = " \t \n \r Quick\t";
+ StringPiece data(text);
+ // check that all whitespace is removed
+ EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 11);
+ EXPECT_EQ(data, StringPiece("Quick\t"));
+ // check that non-whitespace is not removed
+ EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0);
+ EXPECT_EQ(data, StringPiece("Quick\t"));
+}
+
+TEST(RemoveLeadingWhitespace, TerminationHandling) {
+ // check termination handling
+ string text = "\t";
+ StringPiece data(text);
+ EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 1);
+ EXPECT_EQ(data, StringPiece(""));
+
+ // check termination handling again
+ EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0);
+ EXPECT_EQ(data, StringPiece(""));
+}
+
+TEST(RemoveTrailingWhitespace, Basic) {
+ string text = " \t \n \r Quick \t";
+ StringPiece data(text);
+ // check that all whitespace is removed
+ EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 2);
+ EXPECT_EQ(data, StringPiece(" \t \n \r Quick"));
+ // check that non-whitespace is not removed
+ EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0);
+ EXPECT_EQ(data, StringPiece(" \t \n \r Quick"));
+}
+
+TEST(RemoveTrailingWhitespace, TerminationHandling) {
+ // check termination handling
+ string text = "\t";
+ StringPiece data(text);
+ EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 1);
+ EXPECT_EQ(data, StringPiece(""));
+
+ // check termination handling again
+ EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0);
+ EXPECT_EQ(data, StringPiece(""));
+}
+
+TEST(RemoveWhitespaceContext, Basic) {
+ string text = " \t \n \r Quick \t";
+ StringPiece data(text);
+ // check that all whitespace is removed
+ EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 13);
+ EXPECT_EQ(data, StringPiece("Quick"));
+ // check that non-whitespace is not removed
+ EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0);
+ EXPECT_EQ(data, StringPiece("Quick"));
+
+ // Test empty string
+ text = "";
+ data = text;
+ EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0);
+ EXPECT_EQ(data, StringPiece(""));
+}
+
+void TestConsumeLeadingDigits(StringPiece s, int64 expected,
+ StringPiece remaining) {
+ uint64 v;
+ StringPiece input(s);
+ if (str_util::ConsumeLeadingDigits(&input, &v)) {
+ EXPECT_EQ(v, static_cast<uint64>(expected));
+ EXPECT_EQ(input, remaining);
+ } else {
+ EXPECT_LT(expected, 0);
+ EXPECT_EQ(input, remaining);
+ }
+}
+
+TEST(ConsumeLeadingDigits, Basic) {
+ TestConsumeLeadingDigits("123", 123, "");
+ TestConsumeLeadingDigits("a123", -1, "a123");
+ TestConsumeLeadingDigits("9_", 9, "_");
+ TestConsumeLeadingDigits("11111111111xyz", 11111111111ll, "xyz");
+
+ // Overflow case
+ TestConsumeLeadingDigits("1111111111111111111111111111111xyz", -1,
+ "1111111111111111111111111111111xyz");
+
+ // 2^64
+ TestConsumeLeadingDigits("18446744073709551616xyz", -1,
+ "18446744073709551616xyz");
+ // 2^64-1
+ TestConsumeLeadingDigits("18446744073709551615xyz", 18446744073709551615ull,
+ "xyz");
+}
+
+TEST(ConsumePrefix, Basic) {
+ string s("abcdef");
+ StringPiece input(s);
+ EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdefg"));
+ EXPECT_EQ(input, "abcdef");
+
+ EXPECT_FALSE(str_util::ConsumePrefix(&input, "abce"));
+ EXPECT_EQ(input, "abcdef");
+
+ EXPECT_TRUE(str_util::ConsumePrefix(&input, ""));
+ EXPECT_EQ(input, "abcdef");
+
+ EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdeg"));
+ EXPECT_EQ(input, "abcdef");
+
+ EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcdef"));
+ EXPECT_EQ(input, "");
+
+ input = s;
+ EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcde"));
+ EXPECT_EQ(input, "f");
+}
+
+TEST(JoinStrings, Basic) {
+ std::vector<string> s;
+ s = {"hi"};
+ EXPECT_EQ(str_util::Join(s, " "), "hi");
+ s = {"hi", "there", "strings"};
+ EXPECT_EQ(str_util::Join(s, " "), "hi there strings");
+
+ std::vector<StringPiece> sp;
+ sp = {"hi"};
+ EXPECT_EQ(str_util::Join(sp, ",,"), "hi");
+ sp = {"hi", "there", "strings"};
+ EXPECT_EQ(str_util::Join(sp, "--"), "hi--there--strings");
+}
+
+TEST(Split, Basic) {
+ EXPECT_TRUE(str_util::Split("", ',').empty());
+ EXPECT_EQ(str_util::Join(str_util::Split("a", ','), "|"), "a");
+ EXPECT_EQ(str_util::Join(str_util::Split(",", ','), "|"), "|");
+ EXPECT_EQ(str_util::Join(str_util::Split("a,b,c", ','), "|"), "a|b|c");
+ EXPECT_EQ(str_util::Join(str_util::Split("a,,,b,,c,", ','), "|"),
+ "a|||b||c|");
+ EXPECT_EQ(str_util::Join(
+ str_util::Split("a,,,b,,c,", ',', str_util::SkipEmpty()), "|"),
+ "a|b|c");
+ EXPECT_EQ(
+ str_util::Join(
+ str_util::Split("a, ,b,,c,", ',', str_util::SkipWhitespace()), "|"),
+ "a|b|c");
+}
+
+TEST(SplitAndParseAsInts, Basic) {
+ std::vector<int32> nums;
+ EXPECT_TRUE(str_util::SplitAndParseAsInts("", ',', &nums));
+ EXPECT_EQ(nums.size(), 0);
+
+ EXPECT_TRUE(str_util::SplitAndParseAsInts("134", ',', &nums));
+ EXPECT_EQ(nums.size(), 1);
+ EXPECT_EQ(nums[0], 134);
+
+ EXPECT_TRUE(str_util::SplitAndParseAsInts("134,2,13,-5", ',', &nums));
+ EXPECT_EQ(nums.size(), 4);
+ EXPECT_EQ(nums[0], 134);
+ EXPECT_EQ(nums[1], 2);
+ EXPECT_EQ(nums[2], 13);
+ EXPECT_EQ(nums[3], -5);
+
+ EXPECT_FALSE(str_util::SplitAndParseAsInts("abc", ',', &nums));
+
+ EXPECT_FALSE(str_util::SplitAndParseAsInts("-13,abc", ',', &nums));
+
+ EXPECT_FALSE(str_util::SplitAndParseAsInts("13,abc,5", ',', &nums));
+}
+
+TEST(Lowercase, Basic) {
+ EXPECT_EQ("", str_util::Lowercase(""));
+ EXPECT_EQ("hello", str_util::Lowercase("hello"));
+ EXPECT_EQ("hello world", str_util::Lowercase("Hello World"));
+}
+
+TEST(Uppercase, Basic) {
+ EXPECT_EQ("", str_util::Uppercase(""));
+ EXPECT_EQ("HELLO", str_util::Uppercase("hello"));
+ EXPECT_EQ("HELLO WORLD", str_util::Uppercase("Hello World"));
+}
+
+TEST(TitlecaseString, Basic) {
+ string s = "sparse_lookup";
+ str_util::TitlecaseString(&s, "_");
+ ASSERT_EQ(s, "Sparse_Lookup");
+
+ s = "sparse_lookup";
+ str_util::TitlecaseString(&s, " ");
+ ASSERT_EQ(s, "Sparse_lookup");
+
+ s = "dense";
+ str_util::TitlecaseString(&s, " ");
+ ASSERT_EQ(s, "Dense");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc
new file mode 100644
index 0000000000..e564b9eb73
--- /dev/null
+++ b/tensorflow/core/lib/strings/strcat.cc
@@ -0,0 +1,194 @@
+#include "tensorflow/core/lib/strings/strcat.h"
+
+#include <stdarg.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+
+namespace tensorflow {
+namespace strings {
+
+AlphaNum gEmptyAlphaNum("");
+
+AlphaNum::AlphaNum(Hex hex) {
+ char *const end = &digits_[kFastToBufferSize];
+ char *writer = end;
+ uint64 value = hex.value;
+ uint64 width = hex.spec;
+ // We accomplish minimum width by OR'ing in 0x10000 to the user's value,
+ // where 0x10000 is the smallest hex number that is as wide as the user
+ // asked for.
+ uint64 mask = ((static_cast<uint64>(1) << (width - 1) * 4)) | value;
+ static const char hexdigits[] = "0123456789abcdef";
+ do {
+ *--writer = hexdigits[value & 0xF];
+ value >>= 4;
+ mask >>= 4;
+ } while (mask != 0);
+ piece_.set(writer, end - writer);
+}
+
+// ----------------------------------------------------------------------
+// StrCat()
+// This merges the given strings or integers, with no delimiter. This
+// is designed to be the fastest possible way to construct a string out
+// of a mix of raw C strings, StringPieces, strings, and integer values.
+// ----------------------------------------------------------------------
+
+// Append is merely a version of memcpy that returns the address of the byte
+// after the area just overwritten. It comes in multiple flavors to minimize
+// call overhead.
+static char *Append1(char *out, const AlphaNum &x) {
+ memcpy(out, x.data(), x.size());
+ return out + x.size();
+}
+
+static char *Append2(char *out, const AlphaNum &x1, const AlphaNum &x2) {
+ memcpy(out, x1.data(), x1.size());
+ out += x1.size();
+
+ memcpy(out, x2.data(), x2.size());
+ return out + x2.size();
+}
+
+static char *Append4(char *out, const AlphaNum &x1, const AlphaNum &x2,
+ const AlphaNum &x3, const AlphaNum &x4) {
+ memcpy(out, x1.data(), x1.size());
+ out += x1.size();
+
+ memcpy(out, x2.data(), x2.size());
+ out += x2.size();
+
+ memcpy(out, x3.data(), x3.size());
+ out += x3.size();
+
+ memcpy(out, x4.data(), x4.size());
+ return out + x4.size();
+}
+
+string StrCat(const AlphaNum &a, const AlphaNum &b) {
+ string result;
+ gtl::STLStringResizeUninitialized(&result, a.size() + b.size());
+ char *const begin = &*result.begin();
+ char *out = Append2(begin, a, b);
+ DCHECK_EQ(out, begin + result.size());
+ return result;
+}
+
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c) {
+ string result;
+ gtl::STLStringResizeUninitialized(&result, a.size() + b.size() + c.size());
+ char *const begin = &*result.begin();
+ char *out = Append2(begin, a, b);
+ out = Append1(out, c);
+ DCHECK_EQ(out, begin + result.size());
+ return result;
+}
+
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d) {
+ string result;
+ gtl::STLStringResizeUninitialized(&result,
+ a.size() + b.size() + c.size() + d.size());
+ char *const begin = &*result.begin();
+ char *out = Append4(begin, a, b, c, d);
+ DCHECK_EQ(out, begin + result.size());
+ return result;
+}
+
+namespace internal {
+
+// Do not call directly - these are not part of the public API.
+string CatPieces(std::initializer_list<StringPiece> pieces) {
+ string result;
+ size_t total_size = 0;
+ for (const StringPiece piece : pieces) total_size += piece.size();
+ gtl::STLStringResizeUninitialized(&result, total_size);
+
+ char *const begin = &*result.begin();
+ char *out = begin;
+ for (const StringPiece piece : pieces) {
+ const size_t this_size = piece.size();
+ memcpy(out, piece.data(), this_size);
+ out += this_size;
+ }
+ DCHECK_EQ(out, begin + result.size());
+ return result;
+}
+
+// It's possible to call StrAppend with a StringPiece that is itself a fragment
+// of the string we're appending to. However the results of this are random.
+// Therefore, check for this in debug mode. Use unsigned math so we only have
+// to do one comparison.
+#define DCHECK_NO_OVERLAP(dest, src) \
+ DCHECK_GE(uintptr_t((src).data() - (dest).data()), uintptr_t((dest).size()))
+
+void AppendPieces(string *result, std::initializer_list<StringPiece> pieces) {
+ size_t old_size = result->size();
+ size_t total_size = old_size;
+ for (const StringPiece piece : pieces) {
+ DCHECK_NO_OVERLAP(*result, piece);
+ total_size += piece.size();
+ }
+ gtl::STLStringResizeUninitialized(result, total_size);
+
+ char *const begin = &*result->begin();
+ char *out = begin + old_size;
+ for (const StringPiece piece : pieces) {
+ const size_t this_size = piece.size();
+ memcpy(out, piece.data(), this_size);
+ out += this_size;
+ }
+ DCHECK_EQ(out, begin + result->size());
+}
+
+} // namespace internal
+
+void StrAppend(string *result, const AlphaNum &a) {
+ DCHECK_NO_OVERLAP(*result, a);
+ result->append(a.data(), a.size());
+}
+
+void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b) {
+ DCHECK_NO_OVERLAP(*result, a);
+ DCHECK_NO_OVERLAP(*result, b);
+ string::size_type old_size = result->size();
+ gtl::STLStringResizeUninitialized(result, old_size + a.size() + b.size());
+ char *const begin = &*result->begin();
+ char *out = Append2(begin + old_size, a, b);
+ DCHECK_EQ(out, begin + result->size());
+}
+
+void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c) {
+ DCHECK_NO_OVERLAP(*result, a);
+ DCHECK_NO_OVERLAP(*result, b);
+ DCHECK_NO_OVERLAP(*result, c);
+ string::size_type old_size = result->size();
+ gtl::STLStringResizeUninitialized(result,
+ old_size + a.size() + b.size() + c.size());
+ char *const begin = &*result->begin();
+ char *out = Append2(begin + old_size, a, b);
+ out = Append1(out, c);
+ DCHECK_EQ(out, begin + result->size());
+}
+
+void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c, const AlphaNum &d) {
+ DCHECK_NO_OVERLAP(*result, a);
+ DCHECK_NO_OVERLAP(*result, b);
+ DCHECK_NO_OVERLAP(*result, c);
+ DCHECK_NO_OVERLAP(*result, d);
+ string::size_type old_size = result->size();
+ gtl::STLStringResizeUninitialized(
+ result, old_size + a.size() + b.size() + c.size() + d.size());
+ char *const begin = &*result->begin();
+ char *out = Append4(begin + old_size, a, b, c, d);
+ DCHECK_EQ(out, begin + result->size());
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
new file mode 100644
index 0000000000..763ad8368a
--- /dev/null
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -0,0 +1,229 @@
+// #status: RECOMMENDED
+// #category: operations on strings
+// #summary: Merges strings or numbers with no delimiter.
+//
+#ifndef TENSORFLOW_LIB_STRINGS_STRCAT_H_
+#define TENSORFLOW_LIB_STRINGS_STRCAT_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/platform/port.h"
+
+// The AlphaNum type was designed to be used as the parameter type for StrCat().
+// Any routine accepting either a string or a number may accept it.
+// The basic idea is that by accepting a "const AlphaNum &" as an argument
+// to your function, your callers will automagically convert bools, integers,
+// and floating point values to strings for you.
+//
+// NOTE: Use of AlphaNum outside of the //strings package is unsupported except
+// for the specific case of function parameters of type "AlphaNum" or "const
+// AlphaNum &". In particular, instantiating AlphaNum directly as a stack
+// variable is not supported.
+//
+// Conversion from 8-bit values is not accepted because if it were, then an
+// attempt to pass ':' instead of ":" might result in a 58 ending up in your
+// result.
+//
+// Bools convert to "0" or "1".
+//
+// Floating point values are converted to a string which, if passed to strtod(),
+// would produce the exact same original double (except in case of NaN; all NaNs
+// are considered the same value). We try to keep the string short but it's not
+// guaranteed to be as short as possible.
+//
+// You can convert to Hexadecimal output rather than Decimal output using Hex.
+// To do this, pass strings::Hex(my_int) as a parameter to StrCat. You may
+// specify a minimum field width using a separate parameter, so the equivalent
+// of Printf("%04x", my_int) is StrCat(Hex(my_int, strings::ZERO_PAD_4))
+//
+// This class has implicit constructors.
+namespace tensorflow {
+namespace strings {
+
+enum PadSpec {
+ NO_PAD = 1,
+ ZERO_PAD_2,
+ ZERO_PAD_3,
+ ZERO_PAD_4,
+ ZERO_PAD_5,
+ ZERO_PAD_6,
+ ZERO_PAD_7,
+ ZERO_PAD_8,
+ ZERO_PAD_9,
+ ZERO_PAD_10,
+ ZERO_PAD_11,
+ ZERO_PAD_12,
+ ZERO_PAD_13,
+ ZERO_PAD_14,
+ ZERO_PAD_15,
+ ZERO_PAD_16,
+};
+
+struct Hex {
+ uint64 value;
+ enum PadSpec spec;
+ template <class Int>
+ explicit Hex(Int v, PadSpec s = NO_PAD)
+ : spec(s) {
+ // Prevent sign-extension by casting integers to
+ // their unsigned counterparts.
+ static_assert(
+ sizeof(v) == 1 || sizeof(v) == 2 || sizeof(v) == 4 || sizeof(v) == 8,
+ "Unknown integer type");
+ value = sizeof(v) == 1
+ ? static_cast<uint8>(v)
+ : sizeof(v) == 2 ? static_cast<uint16>(v)
+ : sizeof(v) == 4 ? static_cast<uint32>(v)
+ : static_cast<uint64>(v);
+ }
+};
+
+class AlphaNum {
+ public:
+ // No bool ctor -- bools convert to an integral type.
+ // A bool ctor would also convert incoming pointers (bletch).
+
+ AlphaNum(int i32) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastInt32ToBufferLeft(i32, digits_) - &digits_[0]) {}
+ AlphaNum(unsigned int u32) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastUInt32ToBufferLeft(u32, digits_) - &digits_[0]) {}
+ AlphaNum(long x) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastInt64ToBufferLeft(x, digits_) - &digits_[0]) {}
+ AlphaNum(unsigned long x) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastUInt64ToBufferLeft(x, digits_) - &digits_[0]) {}
+ AlphaNum(long long int i64) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastInt64ToBufferLeft(i64, digits_) - &digits_[0]) {}
+ AlphaNum(unsigned long long int u64) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastUInt64ToBufferLeft(u64, digits_) - &digits_[0]) {}
+
+ AlphaNum(float f) // NOLINT(runtime/explicit)
+ : piece_(digits_, strlen(FloatToBuffer(f, digits_))) {}
+ AlphaNum(double f) // NOLINT(runtime/explicit)
+ : piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {}
+
+ AlphaNum(Hex hex); // NOLINT(runtime/explicit)
+
+ AlphaNum(const char *c_str) : piece_(c_str) {} // NOLINT(runtime/explicit)
+ AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit)
+ AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit)
+ : piece_(str) {}
+
+ StringPiece::size_type size() const { return piece_.size(); }
+ const char *data() const { return piece_.data(); }
+ StringPiece Piece() const { return piece_; }
+
+ private:
+ StringPiece piece_;
+ char digits_[kFastToBufferSize];
+
+ // Use ":" not ':'
+ AlphaNum(char c); // NOLINT(runtime/explicit)
+
+ TF_DISALLOW_COPY_AND_ASSIGN(AlphaNum);
+};
+
+extern AlphaNum gEmptyAlphaNum;
+
+using strings::AlphaNum;
+using strings::gEmptyAlphaNum;
+
+// ----------------------------------------------------------------------
+// StrCat()
+// This merges the given strings or numbers, with no delimiter. This
+// is designed to be the fastest possible way to construct a string out
+// of a mix of raw C strings, StringPieces, strings, bool values,
+// and numeric values.
+//
+// Don't use this for user-visible strings. The localization process
+// works poorly on strings built up out of fragments.
+//
+// For clarity and performance, don't use StrCat when appending to a
+// string. In particular, avoid using any of these (anti-)patterns:
+// str.append(StrCat(...))
+// str += StrCat(...)
+// str = StrCat(str, ...)
+// where the last is the worse, with the potential to change a loop
+// from a linear time operation with O(1) dynamic allocations into a
+// quadratic time operation with O(n) dynamic allocations. StrAppend
+// is a better choice than any of the above, subject to the restriction
+// of StrAppend(&str, a, b, c, ...) that none of the a, b, c, ... may
+// be a reference into str.
+// ----------------------------------------------------------------------
+
+// For performance reasons, we have specializations for <= 4 args.
+string StrCat(const AlphaNum &a) TF_MUST_USE_RESULT;
+string StrCat(const AlphaNum &a, const AlphaNum &b) TF_MUST_USE_RESULT;
+string StrCat(const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c) TF_MUST_USE_RESULT;
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d) TF_MUST_USE_RESULT;
+
+// inline definitions must be duplicated due to TF_MUST_USE_RESULT
+inline string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); }
+
+namespace internal {
+
+// Do not call directly - this is not part of the public API.
+string CatPieces(std::initializer_list<StringPiece> pieces);
+void AppendPieces(string *dest, std::initializer_list<StringPiece> pieces);
+
+} // namespace internal
+
+// Support 5 or more arguments
+template <typename... AV>
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d, const AlphaNum &e,
+ const AV &... args) TF_MUST_USE_RESULT;
+
+template <typename... AV>
+inline string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d, const AlphaNum &e, const AV &... args) {
+ return internal::CatPieces({a.Piece(), b.Piece(), c.Piece(), d.Piece(),
+ e.Piece(),
+ static_cast<const AlphaNum &>(args).Piece()...});
+}
+
+// ----------------------------------------------------------------------
+// StrAppend()
+// Same as above, but adds the output to the given string.
+// WARNING: For speed, StrAppend does not try to check each of its input
+// arguments to be sure that they are not a subset of the string being
+// appended to. That is, while this will work:
+//
+// string s = "foo";
+// s += s;
+//
+// This will not (necessarily) work:
+//
+// string s = "foo";
+// StrAppend(&s, s);
+//
+// Note: while StrCat supports appending up to 26 arguments, StrAppend
+// is currently limited to 9. That's rarely an issue except when
+// automatically transforming StrCat to StrAppend, and can easily be
+// worked around as consecutive calls to StrAppend are quite efficient.
+// ----------------------------------------------------------------------
+
+void StrAppend(string *dest, const AlphaNum &a);
+void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b);
+void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c);
+void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c, const AlphaNum &d);
+
+// Support 5 or more arguments
+template <typename... AV>
+inline void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c, const AlphaNum &d, const AlphaNum &e,
+ const AV &... args) {
+ internal::AppendPieces(dest,
+ {a.Piece(), b.Piece(), c.Piece(), d.Piece(), e.Piece(),
+ static_cast<const AlphaNum &>(args).Piece()...});
+}
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_STRCAT_H_
diff --git a/tensorflow/core/lib/strings/strcat_test.cc b/tensorflow/core/lib/strings/strcat_test.cc
new file mode 100644
index 0000000000..9ff7d81af9
--- /dev/null
+++ b/tensorflow/core/lib/strings/strcat_test.cc
@@ -0,0 +1,324 @@
+#include "tensorflow/core/lib/strings/strcat.h"
+
+#include <string>
+
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace strings {
+
+// Test StrCat of ints and longs of various sizes and signdedness.
+TEST(StrCat, Ints) {
+ const int16 s = -1;
+ const uint16 us = 2;
+ const int i = -3;
+ const unsigned int ui = 4;
+ const int32 l = -5;
+ const uint32 ul = 6;
+ const int64 ll = -7;
+ const uint64 ull = 8;
+ const ptrdiff_t ptrdiff = -9;
+ const size_t size = 10;
+ const ssize_t ssize = -11;
+ const intptr_t intptr = -12;
+ const uintptr_t uintptr = 13;
+ string answer;
+ answer = StrCat(s, us);
+ EXPECT_EQ(answer, "-12");
+ answer = StrCat(i, ui);
+ EXPECT_EQ(answer, "-34");
+ answer = StrCat(l, ul);
+ EXPECT_EQ(answer, "-56");
+ answer = StrCat(ll, ull);
+ EXPECT_EQ(answer, "-78");
+ answer = StrCat(ptrdiff, size);
+ EXPECT_EQ(answer, "-910");
+ answer = StrCat(ssize, intptr);
+ EXPECT_EQ(answer, "-11-12");
+ answer = StrCat(uintptr, 0);
+ EXPECT_EQ(answer, "130");
+}
+
+TEST(StrCat, Basics) {
+ string result;
+
+ string strs[] = {"Hello", "Cruel", "World"};
+
+ StringPiece pieces[] = {"Hello", "Cruel", "World"};
+
+ const char *c_strs[] = {"Hello", "Cruel", "World"};
+
+ int32 i32s[] = {'H', 'C', 'W'};
+ uint64 ui64s[] = {12345678910LL, 10987654321LL};
+
+ result = StrCat(false, true, 2, 3);
+ EXPECT_EQ(result, "0123");
+
+ result = StrCat(-1);
+ EXPECT_EQ(result, "-1");
+
+ result = StrCat(0.5);
+ EXPECT_EQ(result, "0.5");
+
+ result = StrCat(strs[1], pieces[2]);
+ EXPECT_EQ(result, "CruelWorld");
+
+ result = StrCat(strs[0], ", ", pieces[2]);
+ EXPECT_EQ(result, "Hello, World");
+
+ result = StrCat(strs[0], ", ", strs[1], " ", strs[2], "!");
+ EXPECT_EQ(result, "Hello, Cruel World!");
+
+ result = StrCat(pieces[0], ", ", pieces[1], " ", pieces[2]);
+ EXPECT_EQ(result, "Hello, Cruel World");
+
+ result = StrCat(c_strs[0], ", ", c_strs[1], " ", c_strs[2]);
+ EXPECT_EQ(result, "Hello, Cruel World");
+
+ result = StrCat("ASCII ", i32s[0], ", ", i32s[1], " ", i32s[2], "!");
+ EXPECT_EQ(result, "ASCII 72, 67 87!");
+
+ result = StrCat(ui64s[0], ", ", ui64s[1], "!");
+ EXPECT_EQ(result, "12345678910, 10987654321!");
+
+ string one = "1"; // Actually, it's the size of this string that we want; a
+ // 64-bit build distinguishes between size_t and uint64,
+ // even though they're both unsigned 64-bit values.
+ result = StrCat("And a ", one.size(), " and a ", &result[2] - &result[0],
+ " and a ", one, " 2 3 4", "!");
+ EXPECT_EQ(result, "And a 1 and a 2 and a 1 2 3 4!");
+
+ // result = StrCat("Single chars won't compile", '!');
+ // result = StrCat("Neither will NULLs", NULL);
+ result = StrCat("To output a char by ASCII/numeric value, use +: ", '!' + 0);
+ EXPECT_EQ(result, "To output a char by ASCII/numeric value, use +: 33");
+
+ float f = 100000.5;
+ result = StrCat("A hundred K and a half is ", f);
+ EXPECT_EQ(result, "A hundred K and a half is 100000.5");
+
+ double d = f;
+ d *= d;
+ result = StrCat("A hundred K and a half squared is ", d);
+ EXPECT_EQ(result, "A hundred K and a half squared is 10000100000.25");
+
+ result = StrCat(1, 2, 333, 4444, 55555, 666666, 7777777, 88888888, 999999999);
+ EXPECT_EQ(result, "12333444455555666666777777788888888999999999");
+}
+
+TEST(StrCat, MaxArgs) {
+ string result;
+ // Test 10 up to 26 arguments, the current maximum
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a");
+ EXPECT_EQ(result, "123456789a");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b");
+ EXPECT_EQ(result, "123456789ab");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c");
+ EXPECT_EQ(result, "123456789abc");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d");
+ EXPECT_EQ(result, "123456789abcd");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e");
+ EXPECT_EQ(result, "123456789abcde");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f");
+ EXPECT_EQ(result, "123456789abcdef");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g");
+ EXPECT_EQ(result, "123456789abcdefg");
+ result =
+ StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", "h");
+ EXPECT_EQ(result, "123456789abcdefgh");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i");
+ EXPECT_EQ(result, "123456789abcdefghi");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j");
+ EXPECT_EQ(result, "123456789abcdefghij");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k");
+ EXPECT_EQ(result, "123456789abcdefghijk");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l");
+ EXPECT_EQ(result, "123456789abcdefghijkl");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m");
+ EXPECT_EQ(result, "123456789abcdefghijklm");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m", "n");
+ EXPECT_EQ(result, "123456789abcdefghijklmn");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m", "n", "o");
+ EXPECT_EQ(result, "123456789abcdefghijklmno");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m", "n", "o", "p");
+ EXPECT_EQ(result, "123456789abcdefghijklmnop");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m", "n", "o", "p", "q");
+ EXPECT_EQ(result, "123456789abcdefghijklmnopq");
+ // No limit thanks to C++11's variadic templates
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "a", "b", "c", "d", "e", "f",
+ "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r",
+ "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D",
+ "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P",
+ "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z");
+ EXPECT_EQ(result,
+ "12345678910abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ");
+}
+
+TEST(StrAppend, Basics) {
+ string result = "existing text";
+
+ string strs[] = {"Hello", "Cruel", "World"};
+
+ StringPiece pieces[] = {"Hello", "Cruel", "World"};
+
+ const char *c_strs[] = {"Hello", "Cruel", "World"};
+
+ int32 i32s[] = {'H', 'C', 'W'};
+ uint64 ui64s[] = {12345678910LL, 10987654321LL};
+
+ string::size_type old_size = result.size();
+ StrAppend(&result, strs[0]);
+ EXPECT_EQ(result.substr(old_size), "Hello");
+
+ old_size = result.size();
+ StrAppend(&result, strs[1], pieces[2]);
+ EXPECT_EQ(result.substr(old_size), "CruelWorld");
+
+ old_size = result.size();
+ StrAppend(&result, strs[0], ", ", pieces[2]);
+ EXPECT_EQ(result.substr(old_size), "Hello, World");
+
+ old_size = result.size();
+ StrAppend(&result, strs[0], ", ", strs[1], " ", strs[2], "!");
+ EXPECT_EQ(result.substr(old_size), "Hello, Cruel World!");
+
+ old_size = result.size();
+ StrAppend(&result, pieces[0], ", ", pieces[1], " ", pieces[2]);
+ EXPECT_EQ(result.substr(old_size), "Hello, Cruel World");
+
+ old_size = result.size();
+ StrAppend(&result, c_strs[0], ", ", c_strs[1], " ", c_strs[2]);
+ EXPECT_EQ(result.substr(old_size), "Hello, Cruel World");
+
+ old_size = result.size();
+ StrAppend(&result, "ASCII ", i32s[0], ", ", i32s[1], " ", i32s[2], "!");
+ EXPECT_EQ(result.substr(old_size), "ASCII 72, 67 87!");
+
+ old_size = result.size();
+ StrAppend(&result, ui64s[0], ", ", ui64s[1], "!");
+ EXPECT_EQ(result.substr(old_size), "12345678910, 10987654321!");
+
+ string one = "1"; // Actually, it's the size of this string that we want; a
+ // 64-bit build distinguishes between size_t and uint64,
+ // even though they're both unsigned 64-bit values.
+ old_size = result.size();
+ StrAppend(&result, "And a ", one.size(), " and a ", &result[2] - &result[0],
+ " and a ", one, " 2 3 4", "!");
+ EXPECT_EQ(result.substr(old_size), "And a 1 and a 2 and a 1 2 3 4!");
+
+ // result = StrCat("Single chars won't compile", '!');
+ // result = StrCat("Neither will NULLs", NULL);
+ old_size = result.size();
+ StrAppend(&result, "To output a char by ASCII/numeric value, use +: ",
+ '!' + 0);
+ EXPECT_EQ(result.substr(old_size),
+ "To output a char by ASCII/numeric value, use +: 33");
+
+ float f = 100000.5;
+ old_size = result.size();
+ StrAppend(&result, "A hundred K and a half is ", f);
+ EXPECT_EQ(result.substr(old_size), "A hundred K and a half is 100000.5");
+
+ double d = f;
+ d *= d;
+ old_size = result.size();
+ StrAppend(&result, "A hundred K and a half squared is ", d);
+ EXPECT_EQ(result.substr(old_size),
+ "A hundred K and a half squared is 10000100000.25");
+
+ // Test 9 arguments, the old maximum
+ old_size = result.size();
+ StrAppend(&result, 1, 22, 333, 4444, 55555, 666666, 7777777, 88888888, 9);
+ EXPECT_EQ(result.substr(old_size), "1223334444555556666667777777888888889");
+
+ // No limit thanks to C++11's variadic templates
+ old_size = result.size();
+ StrAppend(&result, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "a", "b", "c", "d", "e",
+ "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r",
+ "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D", "E",
+ "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R",
+ "S", "T", "U", "V", "W", "X", "Y", "Z",
+ "No limit thanks to C++11's variadic templates");
+ EXPECT_EQ(result.substr(old_size),
+ "12345678910abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "No limit thanks to C++11's variadic templates");
+}
+
+TEST(StrAppend, Death) {
+ string s = "self";
+ EXPECT_DEBUG_DEATH(StrAppend(&s, s.c_str() + 1), "Check failed:");
+ EXPECT_DEBUG_DEATH(StrAppend(&s, s), "Check failed:");
+}
+
+static void CheckHex64(uint64 v) {
+ using tensorflow::strings::Hex;
+ string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_16));
+ string expected = Printf("%016llx", static_cast<unsigned long long>(v));
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+
+ actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ expected = Printf("%08llx", static_cast<unsigned long long>(v));
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+
+ actual = StrCat(Hex(v));
+ expected = Printf("%llx", static_cast<unsigned long long>(v));
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+}
+
+static void CheckHex32(uint32 v) {
+ using tensorflow::strings::Hex;
+ string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string expected = Printf("%08x", v);
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+
+ actual = StrCat(Hex(v));
+ expected = Printf("%x", v);
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+}
+
+static void CheckHexSigned32(int32 v) {
+ using tensorflow::strings::Hex;
+ string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string expected = Printf("%08x", v);
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+
+ actual = StrCat(Hex(v));
+ expected = Printf("%x", v);
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+}
+
+static void TestFastPrints() {
+ using tensorflow::strings::Hex;
+
+ // Test min int to make sure that works
+ for (int i = 0; i < 10000; i++) {
+ CheckHex64(i);
+ CheckHex32(i);
+ CheckHexSigned32(i);
+ CheckHexSigned32(-i);
+ }
+ CheckHex64(0x123456789abcdef0ull);
+ CheckHex32(0x12345678);
+
+ int8 minus_one_8bit = -1;
+ EXPECT_EQ("ff", StrCat(Hex(minus_one_8bit)));
+
+ int16 minus_one_16bit = -1;
+ EXPECT_EQ("ffff", StrCat(Hex(minus_one_16bit)));
+}
+
+TEST(Numbers, TestFunctionsMovedOverFromNumbersMain) { TestFastPrints(); }
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/stringprintf.cc b/tensorflow/core/lib/strings/stringprintf.cc
new file mode 100644
index 0000000000..b354706cbd
--- /dev/null
+++ b/tensorflow/core/lib/strings/stringprintf.cc
@@ -0,0 +1,85 @@
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+#include <errno.h>
+#include <stdarg.h> // For va_list and related operations
+#include <stdio.h> // MSVC requires this for _vsnprintf
+#include <vector>
+
+namespace tensorflow {
+namespace strings {
+
+#ifdef COMPILER_MSVC
+enum { IS_COMPILER_MSVC = 1 };
+#else
+enum { IS_COMPILER_MSVC = 0 };
+#endif
+
+void Appendv(string* dst, const char* format, va_list ap) {
+ // First try with a small fixed size buffer
+ static const int kSpaceLength = 1024;
+ char space[kSpaceLength];
+
+ // It's possible for methods that use a va_list to invalidate
+ // the data in it upon use. The fix is to make a copy
+ // of the structure before using it and use that copy instead.
+ va_list backup_ap;
+ va_copy(backup_ap, ap);
+ int result = vsnprintf(space, kSpaceLength, format, backup_ap);
+ va_end(backup_ap);
+
+ if (result < kSpaceLength) {
+ if (result >= 0) {
+ // Normal case -- everything fit.
+ dst->append(space, result);
+ return;
+ }
+
+ if (IS_COMPILER_MSVC) {
+ // Error or MSVC running out of space. MSVC 8.0 and higher
+ // can be asked about space needed with the special idiom below:
+ va_copy(backup_ap, ap);
+ result = vsnprintf(NULL, 0, format, backup_ap);
+ va_end(backup_ap);
+ }
+
+ if (result < 0) {
+ // Just an error.
+ return;
+ }
+ }
+
+ // Increase the buffer size to the size requested by vsnprintf,
+ // plus one for the closing \0.
+ int length = result + 1;
+ char* buf = new char[length];
+
+ // Restore the va_list before we use it again
+ va_copy(backup_ap, ap);
+ result = vsnprintf(buf, length, format, backup_ap);
+ va_end(backup_ap);
+
+ if (result >= 0 && result < length) {
+ // It fit
+ dst->append(buf, result);
+ }
+ delete[] buf;
+}
+
+string Printf(const char* format, ...) {
+ va_list ap;
+ va_start(ap, format);
+ string result;
+ Appendv(&result, format, ap);
+ va_end(ap);
+ return result;
+}
+
+void Appendf(string* dst, const char* format, ...) {
+ va_list ap;
+ va_start(ap, format);
+ Appendv(dst, format, ap);
+ va_end(ap);
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/stringprintf.h b/tensorflow/core/lib/strings/stringprintf.h
new file mode 100644
index 0000000000..23ca2583ca
--- /dev/null
+++ b/tensorflow/core/lib/strings/stringprintf.h
@@ -0,0 +1,37 @@
+// Printf variants that place their output in a C++ string.
+//
+// Usage:
+// string result = strings::Printf("%d %s\n", 10, "hello");
+// strings::SPrintf(&result, "%d %s\n", 10, "hello");
+// strings::Appendf(&result, "%d %s\n", 20, "there");
+
+#ifndef TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
+#define TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
+
+#include <stdarg.h>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace strings {
+
+// Return a C++ string
+extern string Printf(const char* format, ...)
+ // Tell the compiler to do printf format string checking.
+ TF_PRINTF_ATTRIBUTE(1, 2);
+
+// Append result to a supplied string
+extern void Appendf(string* dst, const char* format, ...)
+ // Tell the compiler to do printf format string checking.
+ TF_PRINTF_ATTRIBUTE(2, 3);
+
+// Lower-level routine that takes a va_list and appends to a specified
+// string. All other routines are just convenience wrappers around it.
+extern void Appendv(string* dst, const char* format, va_list ap);
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
diff --git a/tensorflow/core/lib/strings/stringprintf_test.cc b/tensorflow/core/lib/strings/stringprintf_test.cc
new file mode 100644
index 0000000000..737ed5c0e0
--- /dev/null
+++ b/tensorflow/core/lib/strings/stringprintf_test.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+#include <string>
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace strings {
+namespace {
+
+TEST(PrintfTest, Empty) {
+ EXPECT_EQ("", Printf("%s", string().c_str()));
+ EXPECT_EQ("", Printf("%s", ""));
+}
+
+TEST(PrintfTest, Misc) {
+// MSVC does not support $ format specifier.
+#if !defined(COMPILER_MSVC)
+ EXPECT_EQ("123hello w", Printf("%3$d%2$s %1$c", 'w', "hello", 123));
+#endif // !COMPILER_MSVC
+}
+
+TEST(AppendfTest, Empty) {
+ string value("Hello");
+ const char* empty = "";
+ Appendf(&value, "%s", empty);
+ EXPECT_EQ("Hello", value);
+}
+
+TEST(AppendfTest, EmptyString) {
+ string value("Hello");
+ Appendf(&value, "%s", "");
+ EXPECT_EQ("Hello", value);
+}
+
+TEST(AppendfTest, String) {
+ string value("Hello");
+ Appendf(&value, " %s", "World");
+ EXPECT_EQ("Hello World", value);
+}
+
+TEST(AppendfTest, Int) {
+ string value("Hello");
+ Appendf(&value, " %d", 123);
+ EXPECT_EQ("Hello 123", value);
+}
+
+TEST(PrintfTest, Multibyte) {
+ // If we are in multibyte mode and feed invalid multibyte sequence,
+ // Printf should return an empty string instead of running
+ // out of memory while trying to determine destination buffer size.
+ // see b/4194543.
+
+ char* old_locale = setlocale(LC_CTYPE, NULL);
+ // Push locale with multibyte mode
+ setlocale(LC_CTYPE, "en_US.utf8");
+
+ const char kInvalidCodePoint[] = "\375\067s";
+ string value = Printf("%.*s", 3, kInvalidCodePoint);
+
+ // In some versions of glibc (e.g. eglibc-2.11.1, aka GRTEv2), snprintf
+ // returns error given an invalid codepoint. Other versions
+ // (e.g. eglibc-2.15, aka pre-GRTEv3) emit the codepoint verbatim.
+ // We test that the output is one of the above.
+ EXPECT_TRUE(value.empty() || value == kInvalidCodePoint);
+
+ // Repeat with longer string, to make sure that the dynamically
+ // allocated path in StringAppendV is handled correctly.
+ int n = 2048;
+ char* buf = new char[n + 1];
+ memset(buf, ' ', n - 3);
+ memcpy(buf + n - 3, kInvalidCodePoint, 4);
+ value = Printf("%.*s", n, buf);
+ // See GRTEv2 vs. GRTEv3 comment above.
+ EXPECT_TRUE(value.empty() || value == buf);
+ delete[] buf;
+
+ setlocale(LC_CTYPE, old_locale);
+}
+
+TEST(PrintfTest, NoMultibyte) {
+ // No multibyte handling, but the string contains funny chars.
+ char* old_locale = setlocale(LC_CTYPE, NULL);
+ setlocale(LC_CTYPE, "POSIX");
+ string value = Printf("%.*s", 3, "\375\067s");
+ setlocale(LC_CTYPE, old_locale);
+ EXPECT_EQ("\375\067s", value);
+}
+
+TEST(PrintfTest, DontOverwriteErrno) {
+ // Check that errno isn't overwritten unless we're printing
+ // something significantly larger than what people are normally
+ // printing in their badly written PLOG() statements.
+ errno = ECHILD;
+ string value = Printf("Hello, %s!", "World");
+ EXPECT_EQ(ECHILD, errno);
+}
+
+TEST(PrintfTest, LargeBuf) {
+ // Check that the large buffer is handled correctly.
+ int n = 2048;
+ char* buf = new char[n + 1];
+ memset(buf, ' ', n);
+ buf[n] = 0;
+ string value = Printf("%s", buf);
+ EXPECT_EQ(buf, value);
+ delete[] buf;
+}
+
+} // namespace
+
+} // namespace strings
+} // namespace tensorflow