diff options
Diffstat (limited to 'tensorflow/core/lib')
136 files changed, 19846 insertions, 0 deletions
diff --git a/tensorflow/core/lib/core/arena.cc b/tensorflow/core/lib/core/arena.cc new file mode 100644 index 0000000000..ceb1001af0 --- /dev/null +++ b/tensorflow/core/lib/core/arena.cc @@ -0,0 +1,246 @@ +// This approach to arenas overcomes many of the limitations described +// in the "Specialized allocators" section of +// http://www.pdos.lcs.mit.edu/~dm/c++-new.html +// +// A somewhat similar approach to Gladiator, but for heap-detection, was +// suggested by Ron van der Wal and Scott Meyers at +// http://www.aristeia.com/BookErrata/M27Comments_frames.html + +#include "tensorflow/core/lib/core/arena.h" + +#include <assert.h> +#include <unistd.h> + +#include <vector> + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { +namespace core { + +static const int kPageSize = getpagesize(); + +// ---------------------------------------------------------------------- +// Arena::Arena() +// Arena::~Arena() +// Destroying the arena automatically calls Reset() +// ---------------------------------------------------------------------- + +Arena::Arena(const size_t block_size) + : remaining_(0), + block_size_(block_size), + freestart_(NULL), // set for real in Reset() + blocks_alloced_(1), + overflow_blocks_(NULL) { + assert(block_size > kDefaultAlignment); + + first_blocks_[0].mem = reinterpret_cast<char*>(malloc(block_size_)); + + first_blocks_[0].size = block_size_; + + Reset(); +} + +Arena::~Arena() { + FreeBlocks(); + assert(overflow_blocks_ == NULL); // FreeBlocks() should do that + // The first X blocks stay allocated always by default. Delete them now. + for (size_t i = 0; i < blocks_alloced_; ++i) free(first_blocks_[i].mem); +} + +// Returns true iff it advances freestart_ to the first position +// satisfying alignment without exhausting the current block. +bool Arena::SatisfyAlignment(size_t alignment) { + const size_t overage = reinterpret_cast<size_t>(freestart_) & (alignment - 1); + if (overage > 0) { + const size_t waste = alignment - overage; + if (waste >= remaining_) { + return false; + } + freestart_ += waste; + remaining_ -= waste; + } + DCHECK_EQ(0, reinterpret_cast<size_t>(freestart_) & (alignment - 1)); + return true; +} + +// ---------------------------------------------------------------------- +// Arena::Reset() +// Clears all the memory an arena is using. +// ---------------------------------------------------------------------- + +void Arena::Reset() { + FreeBlocks(); + freestart_ = first_blocks_[0].mem; + remaining_ = first_blocks_[0].size; + + // There is no guarantee the first block is properly aligned, so + // enforce that now. + CHECK(SatisfyAlignment(kDefaultAlignment)); + + freestart_when_empty_ = freestart_; +} + +// ---------------------------------------------------------------------- +// Arena::MakeNewBlock() +// Our sbrk() equivalent. We always make blocks of the same size +// (though GetMemory() can also make a new block for really big +// data. +// ---------------------------------------------------------------------- + +void Arena::MakeNewBlock(const uint32 alignment) { + AllocatedBlock* block = AllocNewBlock(block_size_, alignment); + freestart_ = block->mem; + remaining_ = block->size; + CHECK(SatisfyAlignment(alignment)); +} + +// The following simple numeric routines also exist in util/math/mathutil.h +// but we don't want to depend on that library. + +// Euclid's algorithm for Greatest Common Denominator. +static uint32 GCD(uint32 x, uint32 y) { + while (y != 0) { + uint32 r = x % y; + x = y; + y = r; + } + return x; +} + +static uint32 LeastCommonMultiple(uint32 a, uint32 b) { + if (a > b) { + return (a / GCD(a, b)) * b; + } else if (a < b) { + return (b / GCD(b, a)) * a; + } else { + return a; + } +} + +// ------------------------------------------------------------- +// Arena::AllocNewBlock() +// Adds and returns an AllocatedBlock. +// The returned AllocatedBlock* is valid until the next call +// to AllocNewBlock or Reset. (i.e. anything that might +// affect overflow_blocks_). +// ------------------------------------------------------------- + +Arena::AllocatedBlock* Arena::AllocNewBlock(const size_t block_size, + const uint32 alignment) { + AllocatedBlock* block; + // Find the next block. + if (blocks_alloced_ < TF_ARRAYSIZE(first_blocks_)) { + // Use one of the pre-allocated blocks + block = &first_blocks_[blocks_alloced_++]; + } else { // oops, out of space, move to the vector + if (overflow_blocks_ == NULL) + overflow_blocks_ = new std::vector<AllocatedBlock>; + // Adds another block to the vector. + overflow_blocks_->resize(overflow_blocks_->size() + 1); + // block points to the last block of the vector. + block = &overflow_blocks_->back(); + } + + // NOTE(tucker): this utility is made slightly more complex by + // not disallowing the case where alignment > block_size. + // Can we, without breaking existing code? + + // Must be a multiple of kDefaultAlignment, unless requested + // alignment is 1, in which case we don't care at all. + const uint32 adjusted_alignment = + (alignment > 1 ? LeastCommonMultiple(alignment, kDefaultAlignment) : 1); + + CHECK_LE(adjusted_alignment, 1 << 20) + << "Alignment on boundaries greater than 1MB not supported."; + + // If block_size > alignment we force block_size to be a multiple + // of alignment; if block_size < alignment we make no adjustment. + size_t adjusted_block_size = block_size; + if (adjusted_alignment > 1) { + if (adjusted_block_size > adjusted_alignment) { + const uint32 excess = adjusted_block_size % adjusted_alignment; + adjusted_block_size += (excess > 0 ? adjusted_alignment - excess : 0); + } + block->mem = reinterpret_cast<char*>( + port::aligned_malloc(adjusted_block_size, adjusted_alignment)); + } else { + block->mem = reinterpret_cast<char*>(malloc(adjusted_block_size)); + } + block->size = adjusted_block_size; + CHECK(NULL != block->mem) << "block_size=" << block_size + << " adjusted_block_size=" << adjusted_block_size + << " alignment=" << alignment + << " adjusted_alignment=" << adjusted_alignment; + + return block; +} + +// ---------------------------------------------------------------------- +// Arena::GetMemoryFallback() +// We take memory out of our pool, aligned on the byte boundary +// requested. If we don't have space in our current pool, we +// allocate a new block (wasting the remaining space in the +// current block) and give you that. If your memory needs are +// too big for a single block, we make a special your-memory-only +// allocation -- this is equivalent to not using the arena at all. +// ---------------------------------------------------------------------- + +void* Arena::GetMemoryFallback(const size_t size, const int alignment) { + if (0 == size) { + return NULL; // stl/stl_alloc.h says this is okay + } + + // alignment must be a positive power of 2. + CHECK(alignment > 0 && 0 == (alignment & (alignment - 1))); + + // If the object is more than a quarter of the block size, allocate + // it separately to avoid wasting too much space in leftover bytes. + if (block_size_ == 0 || size > block_size_ / 4) { + return AllocNewBlock(size, alignment)->mem; + } + + // Enforce alignment on freestart_ then check for adequate space, + // which may require starting a new block. + if (!SatisfyAlignment(alignment) || size > remaining_) { + MakeNewBlock(alignment); + } + CHECK_LE(size, remaining_); + + remaining_ -= size; + void* result = freestart_; + freestart_ += size; + + return result; +} + +// ---------------------------------------------------------------------- +// Arena::ReturnMemoryFallback() +// Arena::FreeBlocks() +// Unlike GetMemory(), which does actual work, ReturnMemory() is a +// no-op: we don't "free" memory until Reset() is called. We do +// update some stats, though. Note we do no checking that the +// pointer you pass in was actually allocated by us, or that it +// was allocated for the size you say, so be careful here! +// FreeBlocks() does the work for Reset(), actually freeing all +// memory allocated in one fell swoop. +// ---------------------------------------------------------------------- + +void Arena::FreeBlocks() { + for (size_t i = 1; i < blocks_alloced_; ++i) { // keep first block alloced + free(first_blocks_[i].mem); + first_blocks_[i].mem = NULL; + first_blocks_[i].size = 0; + } + blocks_alloced_ = 1; + if (overflow_blocks_ != NULL) { + std::vector<AllocatedBlock>::iterator it; + for (it = overflow_blocks_->begin(); it != overflow_blocks_->end(); ++it) { + free(it->mem); + } + delete overflow_blocks_; // These should be used very rarely + overflow_blocks_ = NULL; + } +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/arena.h b/tensorflow/core/lib/core/arena.h new file mode 100644 index 0000000000..59896803bb --- /dev/null +++ b/tensorflow/core/lib/core/arena.h @@ -0,0 +1,90 @@ +// TODO(vrv): Switch this to an open-sourced version of Arena. + +#ifndef TENSORFLOW_LIB_CORE_ARENA_H_ +#define TENSORFLOW_LIB_CORE_ARENA_H_ + +#include <assert.h> + +#include <vector> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace core { + +// This class is "thread-compatible": different threads can access the +// arena at the same time without locking, as long as they use only +// const methods. +class Arena { + public: + // Allocates a thread-compatible arena with the specified block size. + explicit Arena(const size_t block_size); + ~Arena(); + + char* Alloc(const size_t size) { + return reinterpret_cast<char*>(GetMemory(size, 1)); + } + + void Reset(); + +// This should be the worst-case alignment for any type. This is +// good for IA-32, SPARC version 7 (the last one I know), and +// supposedly Alpha. i386 would be more time-efficient with a +// default alignment of 8, but ::operator new() uses alignment of 4, +// and an assertion will fail below after the call to MakeNewBlock() +// if you try to use a larger alignment. +#ifdef __i386__ + static const int kDefaultAlignment = 4; +#else + static const int kDefaultAlignment = 8; +#endif + + protected: + bool SatisfyAlignment(const size_t alignment); + void MakeNewBlock(const uint32 alignment); + void* GetMemoryFallback(const size_t size, const int align); + void* GetMemory(const size_t size, const int align) { + assert(remaining_ <= block_size_); // an invariant + if (size > 0 && size < remaining_ && align == 1) { // common case + void* result = freestart_; + freestart_ += size; + remaining_ -= size; + return result; + } + return GetMemoryFallback(size, align); + } + + size_t remaining_; + + private: + struct AllocatedBlock { + char* mem; + size_t size; + }; + + // Allocate new new block of at least block_size, with the specified + // alignment. + // The returned AllocatedBlock* is valid until the next call to AllocNewBlock + // or Reset (i.e. anything that might affect overflow_blocks_). + AllocatedBlock* AllocNewBlock(const size_t block_size, + const uint32 alignment); + + const size_t block_size_; + char* freestart_; // beginning of the free space in most recent block + char* freestart_when_empty_; // beginning of the free space when we're empty + // STL vector isn't as efficient as it could be, so we use an array at first + size_t blocks_alloced_; // how many of the first_blocks_ have been alloced + AllocatedBlock first_blocks_[16]; // the length of this array is arbitrary + // if the first_blocks_ aren't enough, expand into overflow_blocks_. + std::vector<AllocatedBlock>* overflow_blocks_; + + void FreeBlocks(); // Frees all except first block + + TF_DISALLOW_COPY_AND_ASSIGN(Arena); +}; + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_ARENA_H_ diff --git a/tensorflow/core/lib/core/arena_test.cc b/tensorflow/core/lib/core/arena_test.cc new file mode 100644 index 0000000000..fa147c3014 --- /dev/null +++ b/tensorflow/core/lib/core/arena_test.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/lib/core/arena.h" + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace core { +namespace { + +// Write random data to allocated memory +static void TestMemory(void* mem, int size) { + // Check that we can memset the entire memory + memset(mem, 0xaa, size); + + // Do some memory allocation to check that the arena doesn't mess up + // the internal memory allocator + char* tmp[100]; + for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) { + tmp[i] = new char[i * i + 1]; + } + + memset(mem, 0xcc, size); + + // Free up the allocated memory; + for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) { + delete[] tmp[i]; + } + + // Check that we can memset the entire memory + memset(mem, 0xee, size); +} + +TEST(ArenaTest, TestBasicArena) { + Arena a(1024); + char* memory = a.Alloc(100); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 100); + + // Allocate again + memory = a.Alloc(100); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 100); +} + +TEST(ArenaTest, TestVariousArenaSizes) { + { + Arena a(1024); + + // Allocate blocksize + char* memory = a.Alloc(1024); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 1024); + + // Allocate another blocksize + char* memory2 = a.Alloc(1024); + ASSERT_NE(memory2, nullptr); + TestMemory(memory2, 1024); + } + + // Allocate an arena and allocate two blocks + // that together exceed a block size + { + Arena a(1024); + + // + char* memory = a.Alloc(768); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 768); + + // Allocate another blocksize + char* memory2 = a.Alloc(768); + ASSERT_NE(memory2, nullptr); + TestMemory(memory2, 768); + } + + // Allocate larger than a blocksize + { + Arena a(1024); + + char* memory = a.Alloc(10240); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 10240); + + // Allocate another blocksize + char* memory2 = a.Alloc(1234); + ASSERT_NE(memory2, nullptr); + TestMemory(memory2, 1234); + } +} + +} // namespace +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/bit_cast_test.cc b/tensorflow/core/lib/core/bit_cast_test.cc new file mode 100644 index 0000000000..0ea583e96f --- /dev/null +++ b/tensorflow/core/lib/core/bit_cast_test.cc @@ -0,0 +1,95 @@ +// Unit test for bit_cast template. + +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +// Marshall and unmarshall. +// ISO spec C++ section 3.9 promises this will work. + +template <int N> +struct marshall { + char buf[N]; +}; + +template <class T> +void TestMarshall(const T values[], int num_values) { + for (int i = 0; i < num_values; ++i) { + T t0 = values[i]; + marshall<sizeof(T)> m0 = bit_cast<marshall<sizeof(T)> >(t0); + T t1 = bit_cast<T>(m0); + marshall<sizeof(T)> m1 = bit_cast<marshall<sizeof(T)> >(t1); + ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T))); + ASSERT_EQ(0, memcmp(&m0, &m1, sizeof(T))); + } +} + +// Convert back and forth to an integral type. The C++ standard does +// not guarantee this will work. +// +// There are implicit assumptions about sizeof(float) and +// sizeof(double). These assumptions are quite extant everywhere. + +template <class T, class I> +void TestIntegral(const T values[], int num_values) { + for (int i = 0; i < num_values; ++i) { + T t0 = values[i]; + I i0 = bit_cast<I>(t0); + T t1 = bit_cast<T>(i0); + I i1 = bit_cast<I>(t1); + ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T))); + ASSERT_EQ(i0, i1); + } +} + +TEST(BitCast, Bool) { + LOG(INFO) << "Test bool"; + static const bool bool_list[] = {false, true}; + TestMarshall<bool>(bool_list, TF_ARRAYSIZE(bool_list)); +} + +TEST(BitCast, Int32) { + static const int32 int_list[] = {0, 1, 100, 2147483647, + -1, -100, -2147483647, -2147483647 - 1}; + TestMarshall<int32>(int_list, TF_ARRAYSIZE(int_list)); +} + +TEST(BitCast, Int64) { + static const int64 int64_list[] = {0, 1, 1LL << 40, -1, -(1LL << 40)}; + TestMarshall<int64>(int64_list, TF_ARRAYSIZE(int64_list)); +} + +TEST(BitCast, Uint64) { + static const uint64 uint64_list[] = {0, 1, 1LLU << 40, 1LLU << 63}; + TestMarshall<uint64>(uint64_list, TF_ARRAYSIZE(uint64_list)); +} + +TEST(BitCast, Float) { + static const float float_list[] = {0.0, 1.0, -1.0, 10.0, -10.0, 1e10, + 1e20, 1e-10, 1e-20, 2.71828, 3.14159}; + TestMarshall<float>(float_list, TF_ARRAYSIZE(float_list)); + TestIntegral<float, int32>(float_list, TF_ARRAYSIZE(float_list)); + TestIntegral<float, uint32>(float_list, TF_ARRAYSIZE(float_list)); +} + +TEST(BitCast, Double) { + static const double double_list[] = { + 0.0, + 1.0, + -1.0, + 10.0, + -10.0, + 1e10, + 1e100, + 1e-10, + 1e-100, + 2.718281828459045, + 3.141592653589793238462643383279502884197169399375105820974944}; + TestMarshall<double>(double_list, TF_ARRAYSIZE(double_list)); + TestIntegral<double, int64>(double_list, TF_ARRAYSIZE(double_list)); + TestIntegral<double, uint64>(double_list, TF_ARRAYSIZE(double_list)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/bits.h b/tensorflow/core/lib/core/bits.h new file mode 100644 index 0000000000..5456a63168 --- /dev/null +++ b/tensorflow/core/lib/core/bits.h @@ -0,0 +1,84 @@ +#ifndef TENSORFLOW_LIB_CORE_BITS_H_ +#define TENSORFLOW_LIB_CORE_BITS_H_ + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +int Log2Floor(uint32 n); +int Log2Floor64(uint64 n); + +// Return ceiling(log2(n)) for positive integer n. Returns -1 iff n == 0. +int Log2Ceiling(uint32 n); +int Log2Ceiling64(uint64 n); + +// ------------------------------------------------------------------------ +// Implementation details follow +// ------------------------------------------------------------------------ + +#if defined(__GNUC__) + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +inline int Log2Floor(uint32 n) { + return n == 0 ? -1 : 31 ^ __builtin_clz(n); +} + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +inline int Log2Floor64(uint64 n) { + return n == 0 ? -1 : 63 ^ __builtin_clzll(n); +} + +#else + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +inline int Log2Floor(uint32 n) { + if (n == 0) + return -1; + int log = 0; + uint32 value = n; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32 x = value >> shift; + if (x != 0) { + value = x; + log += shift; + } + } + assert(value == 1); + return log; +} + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +// Log2Floor64() is defined in terms of Log2Floor32() +inline int Log2Floor64(uint64 n) { + const uint32 topbits = static_cast<uint32>(n >> 32); + if (topbits == 0) { + // Top bits are zero, so scan in bottom bits + return Log2Floor(static_cast<uint32>(n)); + } else { + return 32 + Log2Floor(topbits); + } +} + +#endif + +inline int Log2Ceiling(uint32 n) { + int floor = Log2Floor(n); + if (n == (n & ~(n - 1))) // zero or a power of two + return floor; + else + return floor + 1; +} + +inline int Log2Ceiling64(uint64 n) { + int floor = Log2Floor64(n); + if (n == (n & ~(n - 1))) // zero or a power of two + return floor; + else + return floor + 1; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_BITS_H_ diff --git a/tensorflow/core/lib/core/blocking_counter.h b/tensorflow/core/lib/core/blocking_counter.h new file mode 100644 index 0000000000..f141be2c76 --- /dev/null +++ b/tensorflow/core/lib/core/blocking_counter.h @@ -0,0 +1,41 @@ +#ifndef TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ +#define TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class BlockingCounter { + public: + BlockingCounter(int initial_count) : count_(initial_count) { + CHECK_GE(count_, 0); + } + + ~BlockingCounter() {} + + inline void DecrementCount() { + mutex_lock l(mu_); + --count_; + CHECK(count_ >= 0); + if (count_ == 0) { + cond_var_.notify_all(); + } + } + + inline void Wait() { + mutex_lock l(mu_); + while (count_ > 0) { + cond_var_.wait(l); + } + } + + private: + int count_; + mutex mu_; + condition_variable cond_var_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ diff --git a/tensorflow/core/lib/core/blocking_counter_test.cc b/tensorflow/core/lib/core/blocking_counter_test.cc new file mode 100644 index 0000000000..feb0342086 --- /dev/null +++ b/tensorflow/core/lib/core/blocking_counter_test.cc @@ -0,0 +1,36 @@ +#include <gtest/gtest.h> + +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { +namespace { + +TEST(BlockingCounterTest, TestZero) { + BlockingCounter bc(0); + bc.Wait(); +} + +TEST(BlockingCounterTest, TestSingleThread) { + BlockingCounter bc(2); + bc.DecrementCount(); + bc.DecrementCount(); + bc.Wait(); +} + +TEST(BlockingCounterTest, TestMultipleThread) { + int N = 3; + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "test", N); + + BlockingCounter bc(N); + for (int i = 0; i < N; ++i) { + thread_pool->Schedule([&bc] { bc.DecrementCount(); }); + } + + bc.Wait(); + delete thread_pool; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/casts.h b/tensorflow/core/lib/core/casts.h new file mode 100644 index 0000000000..5b72048ac5 --- /dev/null +++ b/tensorflow/core/lib/core/casts.h @@ -0,0 +1,85 @@ +// Various Google-specific casting templates. +// +// This code is compiled directly on many platforms, including client +// platforms like Windows, Mac, and embedded systems. Before making +// any changes here, make sure that you're not breaking any platforms. +// + +#ifndef TENSORFLOW_LIB_CORE_CASTS_H_ +#define TENSORFLOW_LIB_CORE_CASTS_H_ + +#include <string.h> // for memcpy + +namespace tensorflow { + +// bit_cast<Dest,Source> is a template function that implements the +// equivalent of "*reinterpret_cast<Dest*>(&source)". We need this in +// very low-level functions like the protobuf library and fast math +// support. +// +// float f = 3.14159265358979; +// int i = bit_cast<int32>(f); +// // i = 0x40490fdb +// +// The classical address-casting method is: +// +// // WRONG +// float f = 3.14159265358979; // WRONG +// int i = * reinterpret_cast<int*>(&f); // WRONG +// +// The address-casting method actually produces undefined behavior +// according to ISO C++ specification section 3.10 -15 -. Roughly, this +// section says: if an object in memory has one type, and a program +// accesses it with a different type, then the result is undefined +// behavior for most values of "different type". +// +// This is true for any cast syntax, either *(int*)&f or +// *reinterpret_cast<int*>(&f). And it is particularly true for +// conversions between integral lvalues and floating-point lvalues. +// +// The purpose of 3.10 -15- is to allow optimizing compilers to assume +// that expressions with different types refer to different memory. gcc +// 4.0.1 has an optimizer that takes advantage of this. So a +// non-conforming program quietly produces wildly incorrect output. +// +// The problem is not the use of reinterpret_cast. The problem is type +// punning: holding an object in memory of one type and reading its bits +// back using a different type. +// +// The C++ standard is more subtle and complex than this, but that +// is the basic idea. +// +// Anyways ... +// +// bit_cast<> calls memcpy() which is blessed by the standard, +// especially by the example in section 3.9 . Also, of course, +// bit_cast<> wraps up the nasty logic in one place. +// +// Fortunately memcpy() is very fast. In optimized mode, with a +// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline +// code with the minimal amount of data movement. On a 32-bit system, +// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8) +// compiles to two loads and two stores. +// +// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1. +// +// WARNING: if Dest or Source is a non-POD type, the result of the memcpy +// is likely to surprise you. +// +// Props to Bill Gibbons for the compile time assertion technique and +// Art Komninos and Igor Tandetnik for the msvc experiments. +// +// -- mec 2005-10-17 + +template <class Dest, class Source> +inline Dest bit_cast(const Source& source) { + static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match"); + + Dest dest; + memcpy(&dest, &source, sizeof(dest)); + return dest; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_CASTS_H_ diff --git a/tensorflow/core/lib/core/coding.cc b/tensorflow/core/lib/core/coding.cc new file mode 100644 index 0000000000..efff554742 --- /dev/null +++ b/tensorflow/core/lib/core/coding.cc @@ -0,0 +1,164 @@ +#include "tensorflow/core/lib/core/coding.h" + +namespace tensorflow { +namespace core { + +void EncodeFixed32(char* buf, uint32 value) { + if (port::kLittleEndian) { + memcpy(buf, &value, sizeof(value)); + } else { + buf[0] = value & 0xff; + buf[1] = (value >> 8) & 0xff; + buf[2] = (value >> 16) & 0xff; + buf[3] = (value >> 24) & 0xff; + } +} + +void EncodeFixed64(char* buf, uint64 value) { + if (port::kLittleEndian) { + memcpy(buf, &value, sizeof(value)); + } else { + buf[0] = value & 0xff; + buf[1] = (value >> 8) & 0xff; + buf[2] = (value >> 16) & 0xff; + buf[3] = (value >> 24) & 0xff; + buf[4] = (value >> 32) & 0xff; + buf[5] = (value >> 40) & 0xff; + buf[6] = (value >> 48) & 0xff; + buf[7] = (value >> 56) & 0xff; + } +} + +void PutFixed32(string* dst, uint32 value) { + char buf[sizeof(value)]; + EncodeFixed32(buf, value); + dst->append(buf, sizeof(buf)); +} + +void PutFixed64(string* dst, uint64 value) { + char buf[sizeof(value)]; + EncodeFixed64(buf, value); + dst->append(buf, sizeof(buf)); +} + +char* EncodeVarint32(char* dst, uint32 v) { + // Operate on characters as unsigneds + unsigned char* ptr = reinterpret_cast<unsigned char*>(dst); + static const int B = 128; + if (v < (1 << 7)) { + *(ptr++) = v; + } else if (v < (1 << 14)) { + *(ptr++) = v | B; + *(ptr++) = v >> 7; + } else if (v < (1 << 21)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = v >> 14; + } else if (v < (1 << 28)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = v >> 21; + } else { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = (v >> 21) | B; + *(ptr++) = v >> 28; + } + return reinterpret_cast<char*>(ptr); +} + +void PutVarint32(string* dst, uint32 v) { + char buf[5]; + char* ptr = EncodeVarint32(buf, v); + dst->append(buf, ptr - buf); +} + +char* EncodeVarint64(char* dst, uint64 v) { + static const int B = 128; + unsigned char* ptr = reinterpret_cast<unsigned char*>(dst); + while (v >= B) { + *(ptr++) = (v & (B - 1)) | B; + v >>= 7; + } + *(ptr++) = static_cast<unsigned char>(v); + return reinterpret_cast<char*>(ptr); +} + +void PutVarint64(string* dst, uint64 v) { + char buf[10]; + char* ptr = EncodeVarint64(buf, v); + dst->append(buf, ptr - buf); +} + +int VarintLength(uint64_t v) { + int len = 1; + while (v >= 128) { + v >>= 7; + len++; + } + return len; +} + +const char* GetVarint32PtrFallback(const char* p, const char* limit, + uint32* value) { + uint32 result = 0; + for (uint32 shift = 0; shift <= 28 && p < limit; shift += 7) { + uint32 byte = *(reinterpret_cast<const unsigned char*>(p)); + p++; + if (byte & 128) { + // More bytes are present + result |= ((byte & 127) << shift); + } else { + result |= (byte << shift); + *value = result; + return reinterpret_cast<const char*>(p); + } + } + return NULL; +} + +bool GetVarint32(StringPiece* input, uint32* value) { + const char* p = input->data(); + const char* limit = p + input->size(); + const char* q = GetVarint32Ptr(p, limit, value); + if (q == NULL) { + return false; + } else { + *input = StringPiece(q, limit - q); + return true; + } +} + +const char* GetVarint64Ptr(const char* p, const char* limit, uint64* value) { + uint64 result = 0; + for (uint32 shift = 0; shift <= 63 && p < limit; shift += 7) { + uint64 byte = *(reinterpret_cast<const unsigned char*>(p)); + p++; + if (byte & 128) { + // More bytes are present + result |= ((byte & 127) << shift); + } else { + result |= (byte << shift); + *value = result; + return reinterpret_cast<const char*>(p); + } + } + return NULL; +} + +bool GetVarint64(StringPiece* input, uint64* value) { + const char* p = input->data(); + const char* limit = p + input->size(); + const char* q = GetVarint64Ptr(p, limit, value); + if (q == NULL) { + return false; + } else { + *input = StringPiece(q, limit - q); + return true; + } +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/coding.h b/tensorflow/core/lib/core/coding.h new file mode 100644 index 0000000000..0c14bf1bbf --- /dev/null +++ b/tensorflow/core/lib/core/coding.h @@ -0,0 +1,55 @@ +// Endian-neutral encoding: +// * Fixed-length numbers are encoded with least-significant byte first +// * In addition we support variable length "varint" encoding +// * Strings are encoded prefixed by their length in varint format + +#ifndef TENSORFLOW_LIB_CORE_CODING_H_ +#define TENSORFLOW_LIB_CORE_CODING_H_ + +#include "tensorflow/core/lib/core/raw_coding.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace core { + +// Lower-level versions of Put... that write directly into a character buffer +// REQUIRES: dst has enough space for the value being written +extern void EncodeFixed32(char* dst, uint32 value); +extern void EncodeFixed64(char* dst, uint64 value); +extern void PutFixed32(string* dst, uint32 value); +extern void PutFixed64(string* dst, uint64 value); + +extern void PutVarint32(string* dst, uint32 value); +extern void PutVarint64(string* dst, uint64 value); + +extern bool GetVarint32(StringPiece* input, uint32* value); +extern bool GetVarint64(StringPiece* input, uint64* value); + +extern const char* GetVarint32Ptr(const char* p, const char* limit, uint32* v); +extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v); + +// Internal routine for use by fallback path of GetVarint32Ptr +extern const char* GetVarint32PtrFallback(const char* p, const char* limit, + uint32* value); +inline const char* GetVarint32Ptr(const char* p, const char* limit, + uint32* value) { + if (p < limit) { + uint32 result = *(reinterpret_cast<const unsigned char*>(p)); + if ((result & 128) == 0) { + *value = result; + return p + 1; + } + } + return GetVarint32PtrFallback(p, limit, value); +} + +extern char* EncodeVarint64(char* dst, uint64 v); + +// Returns the length of the varint32 or varint64 encoding of "v" +extern int VarintLength(uint64_t v); + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_CODING_H_ diff --git a/tensorflow/core/lib/core/coding_test.cc b/tensorflow/core/lib/core/coding_test.cc new file mode 100644 index 0000000000..5e9e2c5e96 --- /dev/null +++ b/tensorflow/core/lib/core/coding_test.cc @@ -0,0 +1,168 @@ +#include "tensorflow/core/lib/core/coding.h" + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace core { + +TEST(Coding, Fixed32) { + static const int N = 100000; + + string s; + for (uint32 v = 0; v < N; v++) { + char buf[sizeof(uint32)]; + EncodeFixed32(buf, v); + s.append(buf, sizeof(buf)); + } + + const char* p = s.data(); + for (uint32 v = 0; v < N; v++) { + uint32 actual = DecodeFixed32(p); + ASSERT_EQ(v, actual); + p += sizeof(uint32); + } +} + +TEST(Coding, Fixed64) { + string s; + for (int power = 0; power <= 63; power++) { + uint64 v = static_cast<uint64>(1) << power; + char buf[sizeof(uint64)]; + EncodeFixed64(buf, v - 1); + s.append(buf, sizeof(buf)); + EncodeFixed64(buf, v + 0); + s.append(buf, sizeof(buf)); + EncodeFixed64(buf, v + 1); + s.append(buf, sizeof(buf)); + } + + const char* p = s.data(); + for (int power = 0; power <= 63; power++) { + uint64 v = static_cast<uint64>(1) << power; + uint64 actual; + actual = DecodeFixed64(p); + ASSERT_EQ(v - 1, actual); + p += sizeof(uint64); + + actual = DecodeFixed64(p); + ASSERT_EQ(v + 0, actual); + p += sizeof(uint64); + + actual = DecodeFixed64(p); + ASSERT_EQ(v + 1, actual); + p += sizeof(uint64); + } +} + +// Test that encoding routines generate little-endian encodings +TEST(Coding, EncodingOutput) { + char dst[8]; + EncodeFixed32(dst, 0x04030201); + ASSERT_EQ(0x01, static_cast<int>(dst[0])); + ASSERT_EQ(0x02, static_cast<int>(dst[1])); + ASSERT_EQ(0x03, static_cast<int>(dst[2])); + ASSERT_EQ(0x04, static_cast<int>(dst[3])); + + EncodeFixed64(dst, 0x0807060504030201ull); + ASSERT_EQ(0x01, static_cast<int>(dst[0])); + ASSERT_EQ(0x02, static_cast<int>(dst[1])); + ASSERT_EQ(0x03, static_cast<int>(dst[2])); + ASSERT_EQ(0x04, static_cast<int>(dst[3])); + ASSERT_EQ(0x05, static_cast<int>(dst[4])); + ASSERT_EQ(0x06, static_cast<int>(dst[5])); + ASSERT_EQ(0x07, static_cast<int>(dst[6])); + ASSERT_EQ(0x08, static_cast<int>(dst[7])); +} + +TEST(Coding, Varint32) { + string s; + for (uint32 i = 0; i < (32 * 32); i++) { + uint32 v = (i / 32) << (i % 32); + PutVarint32(&s, v); + } + + const char* p = s.data(); + const char* limit = p + s.size(); + for (uint32 i = 0; i < (32 * 32); i++) { + uint32 expected = (i / 32) << (i % 32); + uint32 actual; + p = GetVarint32Ptr(p, limit, &actual); + ASSERT_TRUE(p != NULL); + ASSERT_EQ(expected, actual); + } + ASSERT_EQ(p, s.data() + s.size()); +} + +TEST(Coding, Varint64) { + // Construct the list of values to check + std::vector<uint64> values; + // Some special values + values.push_back(0); + values.push_back(100); + values.push_back(~static_cast<uint64>(0)); + values.push_back(~static_cast<uint64>(0) - 1); + for (uint32 k = 0; k < 64; k++) { + // Test values near powers of two + const uint64 power = 1ull << k; + values.push_back(power); + values.push_back(power - 1); + values.push_back(power + 1); + } + + string s; + for (size_t i = 0; i < values.size(); i++) { + PutVarint64(&s, values[i]); + } + + const char* p = s.data(); + const char* limit = p + s.size(); + for (size_t i = 0; i < values.size(); i++) { + ASSERT_TRUE(p < limit); + uint64 actual; + p = GetVarint64Ptr(p, limit, &actual); + ASSERT_TRUE(p != NULL); + ASSERT_EQ(values[i], actual); + } + ASSERT_EQ(p, limit); +} + +TEST(Coding, Varint32Overflow) { + uint32 result; + string input("\x81\x82\x83\x84\x85\x11"); + ASSERT_TRUE(GetVarint32Ptr(input.data(), input.data() + input.size(), + &result) == NULL); +} + +TEST(Coding, Varint32Truncation) { + uint32 large_value = (1u << 31) + 100; + string s; + PutVarint32(&s, large_value); + uint32 result; + for (size_t len = 0; len < s.size() - 1; len++) { + ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + len, &result) == NULL); + } + ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + s.size(), &result) != NULL); + ASSERT_EQ(large_value, result); +} + +TEST(Coding, Varint64Overflow) { + uint64 result; + string input("\x81\x82\x83\x84\x85\x81\x82\x83\x84\x85\x11"); + ASSERT_TRUE(GetVarint64Ptr(input.data(), input.data() + input.size(), + &result) == NULL); +} + +TEST(Coding, Varint64Truncation) { + uint64 large_value = (1ull << 63) + 100ull; + string s; + PutVarint64(&s, large_value); + uint64 result; + for (size_t len = 0; len < s.size() - 1; len++) { + ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + len, &result) == NULL); + } + ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + s.size(), &result) != NULL); + ASSERT_EQ(large_value, result); +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/command_line_flags.cc b/tensorflow/core/lib/core/command_line_flags.cc new file mode 100644 index 0000000000..0f1072ffaa --- /dev/null +++ b/tensorflow/core/lib/core/command_line_flags.cc @@ -0,0 +1,94 @@ +#include "tensorflow/core/lib/core/command_line_flags.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { +namespace { + +// Templated function to convert a string to target values. +// Return true if the conversion is successful. Otherwise, return false. +template <typename T> +bool StringToValue(const string& content, T* value); + +template <> +bool StringToValue<int32>(const string& content, int* value) { + return str_util::NumericParse32(content, value); +} + +// Parse a single argument by linearly searching through the command table. +// The input format is: --argument=value. +// Return OK if the argument is used. It store the extracted value into the +// matching flag. +// Return NOT_FOUND if the argument is not recognized. +// Retrun INVALID_ARGUMENT if the command is recognized, but fails to extract +// its value. +template <typename T> +Status ParseArgument(const string& argument) { + for (auto& command : + internal::CommandLineFlagRegistry<int>::Instance()->commands) { + string prefix = strings::StrCat("--", command.name, "="); + if (tensorflow::StringPiece(argument).starts_with(prefix)) { + string content = argument.substr(prefix.length()); + if (StringToValue<T>(content, command.value)) { + return Status::OK(); + } + return Status(error::INVALID_ARGUMENT, + strings::StrCat("Cannot parse integer in: ", argument)); + } + } + return Status(error::NOT_FOUND, + strings::StrCat("Unknown command: ", argument)); +} + +// A specialization for booleans. The input format is: +// "--argument" or "--noargument". +// Parse a single argument by linearly searching through the command table. +// Return OK if the argument is used. The value is stored in the matching flag. +// Return NOT_FOUND if the argument is not recognized. +template <> +Status ParseArgument<bool>(const string& argument) { + for (auto& command : + internal::CommandLineFlagRegistry<bool>::Instance()->commands) { + if (argument == strings::StrCat("--", command.name)) { + *command.value = true; + return Status::OK(); + } else if (argument == strings::StrCat("--no", command.name)) { + *command.value = false; + return Status::OK(); + } + } + return Status(error::NOT_FOUND, + strings::StrCat("Unknown command: ", argument)); +} +} // namespace + +Status ParseCommandLineFlags(int* argc, char* argv[]) { + int unused_argc = 1; + for (int index = 1; index < *argc; ++index) { + Status s; + // Search bool commands. + s = ParseArgument<bool>(argv[index]); + if (s.ok()) { + continue; + } + if (s.code() != error::NOT_FOUND) { + return s; + } + // Search int32 commands. + s = ParseArgument<int32>(argv[index]); + if (s.ok()) { + continue; + } + if (s.code() != error::NOT_FOUND) { + return s; + } + // Pointer swap the unused argument to the front. + std::swap(argv[unused_argc++], argv[index]); + } + *argc = unused_argc; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/command_line_flags.h b/tensorflow/core/lib/core/command_line_flags.h new file mode 100644 index 0000000000..f1a94c11f9 --- /dev/null +++ b/tensorflow/core/lib/core/command_line_flags.h @@ -0,0 +1,60 @@ +#ifndef TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace internal { + +template <typename T> +struct CommandLineFlagRegistry { + static CommandLineFlagRegistry* Instance() { + static CommandLineFlagRegistry instance_; + return &instance_; + } + struct Command { + string name; + T* value; + string text; + }; + std::vector<Command> commands; + + private: + CommandLineFlagRegistry() {} + TF_DISALLOW_COPY_AND_ASSIGN(CommandLineFlagRegistry); +}; + +template <typename T> +struct CommandLineFlagRegister { + CommandLineFlagRegister(const string& name, T* val, const string& text) { + CommandLineFlagRegistry<T>::Instance()->commands.push_back( + {name, val, text}); + } +}; + +#define TF_DEFINE_variable(type, name, default_value, text) \ + type FLAGS_##name = default_value; \ + namespace TF_flags_internal { \ + tensorflow::internal::CommandLineFlagRegister<type> \ + TF_flags_internal_var_##name(#name, &FLAGS_##name, text); \ + } // namespace TF_flags_internal + +} // namespace internal + +#define TF_DEFINE_int32(name, default_value, text) \ + TF_DEFINE_variable(int32, name, default_value, text); + +#define TF_DEFINE_bool(name, default_value, text) \ + TF_DEFINE_variable(bool, name, default_value, text); + +// Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv. +// Returned the number of unused arguments in *argc. +// Return error Status if the parsing encounters errors. +// TODO(opensource): switch to a command line argument parser that can be +// shared with other tests. +Status ParseCommandLineFlags(int* argc, char* argv[]); + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/core/lib/core/error_codes.proto b/tensorflow/core/lib/core/error_codes.proto new file mode 100644 index 0000000000..6735fd8f88 --- /dev/null +++ b/tensorflow/core/lib/core/error_codes.proto @@ -0,0 +1,145 @@ +syntax = "proto3"; + +package tensorflow.error; +// option cc_enable_arenas = true; + +// The canonical error codes for TensorFlow APIs. +// +// Warnings: +// +// - Do not change any numeric assignments. +// - Changes to this list should only be made if there is a compelling +// need that can't be satisfied in another way. Such changes +// must be approved by at least two OWNERS. +// +// Sometimes multiple error codes may apply. Services should return +// the most specific error code that applies. For example, prefer +// OUT_OF_RANGE over FAILED_PRECONDITION if both codes apply. +// Similarly prefer NOT_FOUND or ALREADY_EXISTS over FAILED_PRECONDITION. +enum Code { + // Not an error; returned on success + OK = 0; + + // The operation was cancelled (typically by the caller). + CANCELLED = 1; + + // Unknown error. An example of where this error may be returned is + // if a Status value received from another address space belongs to + // an error-space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + UNKNOWN = 2; + + // Client specified an invalid argument. Note that this differs + // from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + INVALID_ARGUMENT = 3; + + // Deadline expired before operation could complete. For operations + // that change the state of the system, this error may be returned + // even if the operation has completed successfully. For example, a + // successful response from a server could have been delayed long + // enough for the deadline to expire. + DEADLINE_EXCEEDED = 4; + + // Some requested entity (e.g., file or directory) was not found. + // For privacy reasons, this code *may* be returned when the client + // does not have the access right to the entity. + NOT_FOUND = 5; + + // Some entity that we attempted to create (e.g., file or directory) + // already exists. + ALREADY_EXISTS = 6; + + // The caller does not have permission to execute the specified + // operation. PERMISSION_DENIED must not be used for rejections + // caused by exhausting some resource (use RESOURCE_EXHAUSTED + // instead for those errors). PERMISSION_DENIED must not be + // used if the caller can not be identified (use UNAUTHENTICATED + // instead for those errors). + PERMISSION_DENIED = 7; + + // The request does not have valid authentication credentials for the + // operation. + UNAUTHENTICATED = 16; + + // Some resource has been exhausted, perhaps a per-user quota, or + // perhaps the entire file system is out of space. + RESOURCE_EXHAUSTED = 8; + + // Operation was rejected because the system is not in a state + // required for the operation's execution. For example, directory + // to be deleted may be non-empty, an rmdir operation is applied to + // a non-directory, etc. + // + // A litmus test that may help a service implementor in deciding + // between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE: + // (a) Use UNAVAILABLE if the client can retry just the failing call. + // (b) Use ABORTED if the client should retry at a higher-level + // (e.g., restarting a read-modify-write sequence). + // (c) Use FAILED_PRECONDITION if the client should not retry until + // the system state has been explicitly fixed. E.g., if an "rmdir" + // fails because the directory is non-empty, FAILED_PRECONDITION + // should be returned since the client should not retry unless + // they have first fixed up the directory by deleting files from it. + // (d) Use FAILED_PRECONDITION if the client performs conditional + // REST Get/Update/Delete on a resource and the resource on the + // server does not match the condition. E.g., conflicting + // read-modify-write on the same resource. + FAILED_PRECONDITION = 9; + + // The operation was aborted, typically due to a concurrency issue + // like sequencer check failures, transaction aborts, etc. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + ABORTED = 10; + + // Operation was attempted past the valid range. E.g., seeking or + // reading past end of file. + // + // Unlike INVALID_ARGUMENT, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate INVALID_ARGUMENT if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // OUT_OF_RANGE if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between FAILED_PRECONDITION and + // OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an OUT_OF_RANGE error to detect when + // they are done. + OUT_OF_RANGE = 11; + + // Operation is not implemented or not supported/enabled in this service. + UNIMPLEMENTED = 12; + + // Internal errors. Means some invariants expected by underlying + // system has been broken. If you see one of these errors, + // something is very broken. + INTERNAL = 13; + + // The service is currently unavailable. This is a most likely a + // transient condition and may be corrected by retrying with + // a backoff. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + UNAVAILABLE = 14; + + // Unrecoverable data loss or corruption. + DATA_LOSS = 15; + + // An extra enum entry to prevent people from writing code that + // fails to compile when a new code is added. + // + // Nobody should ever reference this enumeration entry. In particular, + // if you write C++ code that switches on this enumeration, add a default: + // case instead of a case that mentions this enumeration entry. + // + // Nobody should rely on the value (currently 20) listed here. It + // may change in the future. + DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ = 20; +} diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h new file mode 100644 index 0000000000..b0badd8c4d --- /dev/null +++ b/tensorflow/core/lib/core/errors.h @@ -0,0 +1,131 @@ +#ifndef TENSORFLOW_LIB_CORE_ERRORS_H_ +#define TENSORFLOW_LIB_CORE_ERRORS_H_ + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace errors { + +typedef ::tensorflow::error::Code Code; + +// Append some context to an error message. Each time we append +// context put it on a new line, since it is possible for there +// to be several layers of additional context. +template <typename... Args> +void AppendToMessage(::tensorflow::Status* status, Args... args) { + *status = ::tensorflow::Status( + status->code(), + strings::StrCat(status->error_message(), "\n\t", args...)); +} + +// For propagating errors when calling a function. +#define TF_RETURN_IF_ERROR(expr) \ + do { \ + const ::tensorflow::Status _status = (expr); \ + if (TF_PREDICT_FALSE(!_status.ok())) return _status; \ + } while (0) + +#define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \ + do { \ + ::tensorflow::Status _status = (expr); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + ::tensorflow::errors::AppendToMessage(&_status, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// Convenience functions for generating and using error status. +// Example usage: +// status.Update(errors::InvalidArgument("The ", foo, " isn't right.")); +// if (errors::IsInvalidArgument(status)) { ... } +// switch (status.code()) { case error::INVALID_ARGUMENT: ... } + +#define DECLARE_ERROR(FUNC, CONST) \ + template <typename... Args> \ + inline ::tensorflow::Status FUNC(Args... args) { \ + return ::tensorflow::Status(::tensorflow::error::CONST, \ + strings::StrCat(args...)); \ + } \ + inline bool Is##FUNC(const ::tensorflow::Status& status) { \ + return status.code() == ::tensorflow::error::CONST; \ + } + +DECLARE_ERROR(Cancelled, CANCELLED) +DECLARE_ERROR(InvalidArgument, INVALID_ARGUMENT) +DECLARE_ERROR(NotFound, NOT_FOUND) +DECLARE_ERROR(AlreadyExists, ALREADY_EXISTS) +DECLARE_ERROR(ResourceExhausted, RESOURCE_EXHAUSTED) +DECLARE_ERROR(Unavailable, UNAVAILABLE) +DECLARE_ERROR(FailedPrecondition, FAILED_PRECONDITION) +DECLARE_ERROR(OutOfRange, OUT_OF_RANGE) +DECLARE_ERROR(Unimplemented, UNIMPLEMENTED) +DECLARE_ERROR(Internal, INTERNAL) +DECLARE_ERROR(Aborted, ABORTED) +DECLARE_ERROR(DeadlineExceeded, DEADLINE_EXCEEDED) +DECLARE_ERROR(DataLoss, DATA_LOSS) +DECLARE_ERROR(Unknown, UNKNOWN) +DECLARE_ERROR(PermissionDenied, PERMISSION_DENIED) +DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED) + +#undef DECLARE_ERROR + +// The CanonicalCode() for non-errors. +using ::tensorflow::error::OK; + +// Convenience macros for asserting and handling exceptional conditions. +// Analogous to the CHECK* macros provided by logging.h. +// +// Example use: +// void Compute(OperationContext* context) { +// OP_REQUIRES(context, context->num_inputs() == 2, +// errors::InvalidArgument("FooOp requires 2 arguments")); +// ... +// Status status = SomeUncertainMethod(); +// OP_REQUIRES_OK(context, status); +// ... +// } + +#define OP_REQUIRES(CTX, EXP, STATUS) \ + if (!(EXP)) { \ + ::tensorflow::Status _s(STATUS); \ + VLOG(1) << _s; \ + (CTX)->SetStatus(_s); \ + return; \ + } + +#define OP_REQUIRES_OK(CTX, STATUS) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!_s.ok()) { \ + LOG(WARNING) << _s; \ + (CTX)->SetStatus(_s); \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ + if (!(EXP)) { \ + ::tensorflow::Status _s(STATUS); \ + VLOG(1) << _s; \ + (CTX)->SetStatus(_s); \ + (CALLBACK)(); \ + return; \ + } + +#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!_s.ok()) { \ + LOG(WARNING) << _s; \ + (CTX)->SetStatus(_s); \ + (CALLBACK)(); \ + return; \ + } \ + } while (0) + +} // namespace errors +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_ERRORS_H_ diff --git a/tensorflow/core/lib/core/notification.h b/tensorflow/core/lib/core/notification.h new file mode 100644 index 0000000000..071e24285a --- /dev/null +++ b/tensorflow/core/lib/core/notification.h @@ -0,0 +1,42 @@ +#ifndef TENSORFLOW_UTIL_NOTIFICATION_H_ +#define TENSORFLOW_UTIL_NOTIFICATION_H_ + +#include <assert.h> + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class Notification { + public: + Notification() : notified_(false) {} + ~Notification() {} + + void Notify() { + mutex_lock l(mu_); + assert(!notified_); + notified_ = true; + cv_.notify_all(); + } + + bool HasBeenNotified() { + mutex_lock l(mu_); + return notified_; + } + + void WaitForNotification() { + mutex_lock l(mu_); + while (!notified_) { + cv_.wait(l); + } + } + + private: + mutex mu_; + condition_variable cv_; + bool notified_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_NOTIFICATION_H_ diff --git a/tensorflow/core/lib/core/notification_test.cc b/tensorflow/core/lib/core/notification_test.cc new file mode 100644 index 0000000000..a9e8942f05 --- /dev/null +++ b/tensorflow/core/lib/core/notification_test.cc @@ -0,0 +1,64 @@ +#include <gtest/gtest.h> + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace { + +TEST(NotificationTest, TestSingleNotification) { + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "test", 1); + + int counter = 0; + Notification start; + Notification proceed; + thread_pool->Schedule([&start, &proceed, &counter] { + start.Notify(); + proceed.WaitForNotification(); + ++counter; + }); + + // Wait for the thread to start + start.WaitForNotification(); + + // The thread should be waiting for the 'proceed' notification. + EXPECT_EQ(0, counter); + + // Unblock the thread + proceed.Notify(); + + delete thread_pool; // Wait for closure to finish. + + // Verify the counter has been incremented + EXPECT_EQ(1, counter); +} + +TEST(NotificationTest, TestMultipleThreadsWaitingOnNotification) { + const int num_closures = 4; + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "test", num_closures); + + mutex lock; + int counter = 0; + Notification n; + + for (int i = 0; i < num_closures; ++i) { + thread_pool->Schedule([&n, &lock, &counter] { + n.WaitForNotification(); + mutex_lock l(lock); + ++counter; + }); + } + sleep(1); + + EXPECT_EQ(0, counter); + + n.Notify(); + delete thread_pool; // Wait for all closures to finish. + EXPECT_EQ(4, counter); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/raw_coding.h b/tensorflow/core/lib/core/raw_coding.h new file mode 100644 index 0000000000..1fe49b75bb --- /dev/null +++ b/tensorflow/core/lib/core/raw_coding.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_LIB_CORE_RAW_CODING_H_ +#define TENSORFLOW_LIB_CORE_RAW_CODING_H_ + +#include <string.h> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace core { + +// Lower-level versions of Get... that read directly from a character buffer +// without any bounds checking. + +inline uint32 DecodeFixed32(const char* ptr) { + if (port::kLittleEndian) { + // Load the raw bytes + uint32 result; + memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load + return result; + } else { + return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) | + (static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) | + (static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) | + (static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24)); + } +} + +inline uint64 DecodeFixed64(const char* ptr) { + if (port::kLittleEndian) { + // Load the raw bytes + uint64 result; + memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load + return result; + } else { + uint64 lo = DecodeFixed32(ptr); + uint64 hi = DecodeFixed32(ptr + 4); + return (hi << 32) | lo; + } +} + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_RAW_CODING_H_ diff --git a/tensorflow/core/lib/core/refcount.cc b/tensorflow/core/lib/core/refcount.cc new file mode 100644 index 0000000000..3ed8c58eb8 --- /dev/null +++ b/tensorflow/core/lib/core/refcount.cc @@ -0,0 +1,35 @@ +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace core { + +RefCounted::RefCounted() : ref_(1) {} + +RefCounted::~RefCounted() { DCHECK_EQ(ref_.load(), 0); } + +void RefCounted::Ref() const { + DCHECK_GE(ref_.load(), 1); + ref_.fetch_add(1, std::memory_order_relaxed); +} + +bool RefCounted::Unref() const { + DCHECK_GT(ref_.load(), 0); + // If ref_==1, this object is owned only by the caller. Bypass a locked op + // in that case. + if (ref_.load(std::memory_order_acquire) == 1 || ref_.fetch_sub(1) == 1) { + // Make DCHECK in ~RefCounted happy + DCHECK((ref_.store(0), true)); + delete this; + return true; + } else { + return false; + } +} + +bool RefCounted::RefCountIsOne() const { + return (ref_.load(std::memory_order_acquire) == 1); +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/refcount.h b/tensorflow/core/lib/core/refcount.h new file mode 100644 index 0000000000..f727750f9e --- /dev/null +++ b/tensorflow/core/lib/core/refcount.h @@ -0,0 +1,63 @@ +#ifndef TENSORFLOW_LIB_CORE_REFCOUNT_H_ +#define TENSORFLOW_LIB_CORE_REFCOUNT_H_ + +#include <atomic> + +namespace tensorflow { +namespace core { + +class RefCounted { + public: + // Initial reference count is one. + RefCounted(); + + // Increments reference count by one. + void Ref() const; + + // Decrements reference count by one. If the count remains + // positive, returns false. When the count reaches zero, returns + // true and deletes this, in which case the caller must not access + // the object afterward. + bool Unref() const; + + // Return whether the reference count is one. + // If the reference count is used in the conventional way, a + // reference count of 1 implies that the current thread owns the + // reference and no other thread shares it. + // This call performs the test for a reference count of one, and + // performs the memory barrier needed for the owning thread + // to act on the object, knowing that it has exclusive access to the + // object. + bool RefCountIsOne() const; + + protected: + // Make destructor protected so that RefCounted objects cannot + // be instantiated directly. Only subclasses can be instantiated. + virtual ~RefCounted(); + + private: + mutable std::atomic_int_fast32_t ref_; + + RefCounted(const RefCounted&) = delete; + void operator=(const RefCounted&) = delete; +}; + +// Helper class to unref an object when out-of-scope. +class ScopedUnref { + public: + explicit ScopedUnref(RefCounted* o) : obj_(o) {} + ~ScopedUnref() { + if (obj_) obj_->Unref(); + } + + private: + RefCounted* obj_; + + ScopedUnref(const ScopedUnref&) = delete; + void operator=(const ScopedUnref&) = delete; +}; + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_REFCOUNT_H_ diff --git a/tensorflow/core/lib/core/refcount_test.cc b/tensorflow/core/lib/core/refcount_test.cc new file mode 100644 index 0000000000..c042be2d61 --- /dev/null +++ b/tensorflow/core/lib/core/refcount_test.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/lib/core/refcount.h" + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace core { +namespace { + +static int constructed = 0; +static int destroyed = 0; + +class MyRef : public RefCounted { + public: + MyRef() { constructed++; } + ~MyRef() override { destroyed++; } +}; + +class RefTest : public testing::Test { + public: + RefTest() { + constructed = 0; + destroyed = 0; + } +}; + +TEST_F(RefTest, New) { + MyRef* ref = new MyRef; + ASSERT_EQ(1, constructed); + ASSERT_EQ(0, destroyed); + ref->Unref(); + ASSERT_EQ(1, constructed); + ASSERT_EQ(1, destroyed); +} + +TEST_F(RefTest, RefUnref) { + MyRef* ref = new MyRef; + ASSERT_EQ(1, constructed); + ASSERT_EQ(0, destroyed); + ref->Ref(); + ASSERT_EQ(0, destroyed); + ref->Unref(); + ASSERT_EQ(0, destroyed); + ref->Unref(); + ASSERT_EQ(1, destroyed); +} + +TEST_F(RefTest, RefCountOne) { + MyRef* ref = new MyRef; + ASSERT_TRUE(ref->RefCountIsOne()); + ref->Unref(); +} + +TEST_F(RefTest, RefCountNotOne) { + MyRef* ref = new MyRef; + ref->Ref(); + ASSERT_FALSE(ref->RefCountIsOne()); + ref->Unref(); + ref->Unref(); +} + +TEST_F(RefTest, ConstRefUnref) { + const MyRef* cref = new MyRef; + ASSERT_EQ(1, constructed); + ASSERT_EQ(0, destroyed); + cref->Ref(); + ASSERT_EQ(0, destroyed); + cref->Unref(); + ASSERT_EQ(0, destroyed); + cref->Unref(); + ASSERT_EQ(1, destroyed); +} + +TEST_F(RefTest, ReturnOfUnref) { + MyRef* ref = new MyRef; + ref->Ref(); + EXPECT_FALSE(ref->Unref()); + EXPECT_TRUE(ref->Unref()); +} + +TEST_F(RefTest, ScopedUnref) { + { ScopedUnref unref(new MyRef); } + EXPECT_EQ(destroyed, 1); +} + +TEST_F(RefTest, ScopedUnref_Nullptr) { + { ScopedUnref unref(nullptr); } + EXPECT_EQ(destroyed, 0); +} + +} // namespace +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/status.cc b/tensorflow/core/lib/core/status.cc new file mode 100644 index 0000000000..24ce842560 --- /dev/null +++ b/tensorflow/core/lib/core/status.cc @@ -0,0 +1,107 @@ +#include "tensorflow/core/public/status.h" +#include <stdio.h> + +namespace tensorflow { + +Status::Status(tensorflow::error::Code code, StringPiece msg) { + assert(code != tensorflow::error::OK); + state_ = new State; + state_->code = code; + state_->msg = msg.ToString(); +} +Status::~Status() { delete state_; } + +void Status::Update(const Status& new_status) { + if (ok()) { + *this = new_status; + } +} + +void Status::SlowCopyFrom(const State* src) { + delete state_; + if (src == nullptr) { + state_ = nullptr; + } else { + state_ = new State(*src); + } +} + +const string& Status::empty_string() { + static string* empty = new string; + return *empty; +} + +string Status::ToString() const { + if (state_ == NULL) { + return "OK"; + } else { + char tmp[30]; + const char* type; + switch (code()) { + case tensorflow::error::CANCELLED: + type = "Cancelled"; + break; + case tensorflow::error::UNKNOWN: + type = "Unknown"; + break; + case tensorflow::error::INVALID_ARGUMENT: + type = "Invalid argument"; + break; + case tensorflow::error::DEADLINE_EXCEEDED: + type = "Deadline exceeded"; + break; + case tensorflow::error::NOT_FOUND: + type = "Not found"; + break; + case tensorflow::error::ALREADY_EXISTS: + type = "Already exists"; + break; + case tensorflow::error::PERMISSION_DENIED: + type = "Permission denied"; + break; + case tensorflow::error::UNAUTHENTICATED: + type = "Unauthenticated"; + break; + case tensorflow::error::RESOURCE_EXHAUSTED: + type = "Resource exhausted"; + break; + case tensorflow::error::FAILED_PRECONDITION: + type = "Failed precondition"; + break; + case tensorflow::error::ABORTED: + type = "Aborted"; + break; + case tensorflow::error::OUT_OF_RANGE: + type = "Out of range"; + break; + case tensorflow::error::UNIMPLEMENTED: + type = "Unimplemented"; + break; + case tensorflow::error::INTERNAL: + type = "Internal"; + break; + case tensorflow::error::UNAVAILABLE: + type = "Unavailable"; + break; + case tensorflow::error::DATA_LOSS: + type = "Data loss"; + break; + default: + snprintf(tmp, sizeof(tmp), "Unknown code(%d)", + static_cast<int>(code())); + type = tmp; + break; + } + string result(type); + result += ": "; + result += state_->msg; + return result; + } +} + +std::ostream& operator<<(std::ostream& os, const Status& x) { + os << x.ToString(); + return os; +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/status_test.cc b/tensorflow/core/lib/core/status_test.cc new file mode 100644 index 0000000000..3ef6b3302a --- /dev/null +++ b/tensorflow/core/lib/core/status_test.cc @@ -0,0 +1,84 @@ +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +TEST(Status, OK) { + EXPECT_EQ(Status::OK().code(), error::OK); + EXPECT_EQ(Status::OK().error_message(), ""); + EXPECT_OK(Status::OK()); + ASSERT_OK(Status::OK()); + EXPECT_EQ(Status::OK(), Status()); + Status s; + EXPECT_TRUE(s.ok()); +} + +TEST(DeathStatus, CheckOK) { + Status status(errors::InvalidArgument("Invalid")); + ASSERT_DEATH(TF_CHECK_OK(status), "Invalid"); +} + +TEST(Status, Set) { + Status status; + status = Status(error::CANCELLED, "Error message"); + EXPECT_EQ(status.code(), error::CANCELLED); + EXPECT_EQ(status.error_message(), "Error message"); +} + +TEST(Status, Copy) { + Status a(errors::InvalidArgument("Invalid")); + Status b(a); + ASSERT_EQ(a.ToString(), b.ToString()); +} + +TEST(Status, Assign) { + Status a(errors::InvalidArgument("Invalid")); + Status b; + b = a; + ASSERT_EQ(a.ToString(), b.ToString()); +} + +TEST(Status, Update) { + Status s; + s.Update(Status::OK()); + ASSERT_TRUE(s.ok()); + Status a(errors::InvalidArgument("Invalid")); + s.Update(a); + ASSERT_EQ(s.ToString(), a.ToString()); + Status b(errors::Internal("Internal")); + s.Update(b); + ASSERT_EQ(s.ToString(), a.ToString()); + s.Update(Status::OK()); + ASSERT_EQ(s.ToString(), a.ToString()); + ASSERT_FALSE(s.ok()); +} + +TEST(Status, EqualsOK) { ASSERT_EQ(Status::OK(), Status()); } + +TEST(Status, EqualsSame) { + Status a(errors::InvalidArgument("Invalid")); + Status b(errors::InvalidArgument("Invalid")); + ASSERT_EQ(a, b); +} + +TEST(Status, EqualsCopy) { + const Status a(errors::InvalidArgument("Invalid")); + const Status b = a; + ASSERT_EQ(a, b); +} + +TEST(Status, EqualsDifferentCode) { + const Status a(errors::InvalidArgument("message")); + const Status b(errors::Internal("message")); + ASSERT_NE(a, b); +} + +TEST(Status, EqualsDifferentMessage) { + const Status a(errors::InvalidArgument("message")); + const Status b(errors::InvalidArgument("another")); + ASSERT_NE(a, b); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h new file mode 100644 index 0000000000..b3b4db429f --- /dev/null +++ b/tensorflow/core/lib/core/status_test_util.h @@ -0,0 +1,20 @@ +#ifndef TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_ +#define TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_ + +#include <gtest/gtest.h> +#include "tensorflow/core/public/status.h" + +// Macros for testing the results of functions that return util::Status. + +#define EXPECT_OK(statement) EXPECT_EQ(::tensorflow::Status::OK(), (statement)) +#define ASSERT_OK(statement) ASSERT_EQ(::tensorflow::Status::OK(), (statement)) + +// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not +// provide much value (when they fail, they would just print the OK status +// which conveys no more information than EXPECT_FALSE(status.ok()); +// If you want to check for particular errors, better alternatives are: +// EXPECT_EQ(::util::Status(...expected error...), status.StripMessage()); +// EXPECT_THAT(status.ToString(), HasSubstr("expected error")); +// Also, see testing/lib/util/status_util.h. + +#endif // TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_ diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc new file mode 100644 index 0000000000..57c5139f47 --- /dev/null +++ b/tensorflow/core/lib/core/stringpiece.cc @@ -0,0 +1,57 @@ +#include "tensorflow/core/lib/core/stringpiece.h" + +#include <iostream> +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +size_t StringPiece::Hasher::operator()(StringPiece s) const { + return Hash64(s.data(), s.size()); +} + +std::ostream& operator<<(std::ostream& o, StringPiece piece) { + o.write(piece.data(), piece.size()); + return o; +} + +bool StringPiece::contains(StringPiece s) const { + return memmem(data_, size_, s.data_, s.size_) != nullptr; +} + +size_t StringPiece::find(char c, size_t pos) const { + if (pos >= size_) { + return npos; + } + const char* result = + reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos)); + return result != NULL ? result - data_ : npos; +} + +// Search range is [0..pos] inclusive. If pos == npos, search everything. +size_t StringPiece::rfind(char c, size_t pos) const { + if (size_ == 0) return npos; + for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) { + if (*p == c) { + return p - data_; + } + } + return npos; +} + +bool StringPiece::Consume(StringPiece x) { + if (starts_with(x)) { + remove_prefix(x.size_); + return true; + } + return false; +} + +StringPiece StringPiece::substr(size_t pos, size_t n) const { + if (pos > size_) pos = size_; + if (n > size_ - pos) n = size_ - pos; + return StringPiece(data_ + pos, n); +} + +const StringPiece::size_type StringPiece::npos = size_type(-1); + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h new file mode 100644 index 0000000000..17d4b294e9 --- /dev/null +++ b/tensorflow/core/lib/core/stringpiece.h @@ -0,0 +1,159 @@ +// StringPiece is a simple structure containing a pointer into some external +// storage and a size. The user of a StringPiece must ensure that the slice +// is not used after the corresponding external storage has been +// deallocated. +// +// Multiple threads can invoke const methods on a StringPiece without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same StringPiece must use +// external synchronization. + +#ifndef TENSORFLOW_LIB_CORE_STRINGPIECE_H_ +#define TENSORFLOW_LIB_CORE_STRINGPIECE_H_ + +#include <assert.h> +#include <stddef.h> +#include <string.h> +#include <iosfwd> +#include <string> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class StringPiece { + public: + typedef size_t size_type; + + // Create an empty slice. + StringPiece() : data_(""), size_(0) {} + + // Create a slice that refers to d[0,n-1]. + StringPiece(const char* d, size_t n) : data_(d), size_(n) {} + + // Create a slice that refers to the contents of "s" + StringPiece(const string& s) : data_(s.data()), size_(s.size()) {} + + // Create a slice that refers to s[0,strlen(s)-1] + StringPiece(const char* s) : data_(s), size_(strlen(s)) {} + + void set(const void* data, size_t len) { + data_ = reinterpret_cast<const char*>(data); + size_ = len; + } + + // Return a pointer to the beginning of the referenced data + const char* data() const { return data_; } + + // Return the length (in bytes) of the referenced data + size_t size() const { return size_; } + + // Return true iff the length of the referenced data is zero + bool empty() const { return size_ == 0; } + + typedef const char* const_iterator; + typedef const char* iterator; + iterator begin() const { return data_; } + iterator end() const { return data_ + size_; } + + static const size_t npos; + + // Return the ith byte in the referenced data. + // REQUIRES: n < size() + char operator[](size_t n) const { + assert(n < size()); + return data_[n]; + } + + // Change this slice to refer to an empty array + void clear() { + data_ = ""; + size_ = 0; + } + + // Drop the first "n" bytes from this slice. + void remove_prefix(size_t n) { + assert(n <= size()); + data_ += n; + size_ -= n; + } + + void remove_suffix(size_t n) { + assert(size_ >= n); + size_ -= n; + } + + size_t find(char c, size_t pos = 0) const; + size_t rfind(char c, size_t pos = npos) const; + bool contains(StringPiece s) const; + + // Checks whether StringPiece starts with x and if so advances the beginning + // of it to past the match. It's basically a shortcut for starts_with + // followed by remove_prefix. + bool Consume(StringPiece x); + + StringPiece substr(size_t pos, size_t n = npos) const; + + struct Hasher { + size_t operator()(StringPiece arg) const; + }; + + // Return a string that contains the copy of the referenced data. + std::string ToString() const { return std::string(data_, size_); } + + // Three-way comparison. Returns value: + // < 0 iff "*this" < "b", + // == 0 iff "*this" == "b", + // > 0 iff "*this" > "b" + int compare(StringPiece b) const; + + // Return true iff "x" is a prefix of "*this" + bool starts_with(StringPiece x) const { + return ((size_ >= x.size_) && (memcmp(data_, x.data_, x.size_) == 0)); + } + // Return true iff "x" is a suffix of "*this" + bool ends_with(StringPiece x) const { + return ((size_ >= x.size_) && + (memcmp(data_ + (size_ - x.size_), x.data_, x.size_) == 0)); + } + + private: + const char* data_; + size_t size_; + + // Intentionally copyable +}; + +inline bool operator==(StringPiece x, StringPiece y) { + return ((x.size() == y.size()) && + (memcmp(x.data(), y.data(), x.size()) == 0)); +} + +inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } + +inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; } +inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; } +inline bool operator<=(StringPiece x, StringPiece y) { + return x.compare(y) <= 0; +} +inline bool operator>=(StringPiece x, StringPiece y) { + return x.compare(y) >= 0; +} + +inline int StringPiece::compare(StringPiece b) const { + const size_t min_len = (size_ < b.size_) ? size_ : b.size_; + int r = memcmp(data_, b.data_, min_len); + if (r == 0) { + if (size_ < b.size_) + r = -1; + else if (size_ > b.size_) + r = +1; + } + return r; +} + +// allow StringPiece to be logged +extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece); + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_STRINGPIECE_H_ diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc new file mode 100644 index 0000000000..e9b84d3102 --- /dev/null +++ b/tensorflow/core/lib/core/threadpool.cc @@ -0,0 +1,108 @@ +#include "tensorflow/core/lib/core/threadpool.h" + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" + +namespace tensorflow { +namespace thread { + +struct ThreadPool::Waiter { + condition_variable cv; + bool ready; +}; + +ThreadPool::ThreadPool(Env* env, const string& name, int num_threads) + : ThreadPool(env, ThreadOptions(), name, num_threads) {} + +ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads) + : name_(name) { + CHECK_GE(num_threads, 1); + string name_prefix = "tf_" + name_; + for (int i = 0; i < num_threads; i++) { + threads_.push_back(env->StartThread(thread_options, name_prefix, + [this]() { WorkerLoop(); })); + } +} + +ThreadPool::~ThreadPool() { + { + // Wait for all work to get done. + mutex_lock l(mu_); + + // Inform every thread to exit. + for (size_t i = 0; i < threads_.size(); ++i) { + pending_.push_back({nullptr, 0}); + } + + // Wakeup all waiters. + for (auto w : waiters_) { + w->ready = true; + w->cv.notify_one(); + } + } + + // Wait for threads to finish. + for (auto t : threads_) { + delete t; + } +} + +bool ThreadPool::HasPendingClosures() const { + mutex_lock l(mu_); + return pending_.size() != 0; +} + +void ThreadPool::Schedule(std::function<void()> fn) { + CHECK(fn != nullptr); + uint64 id = 0; + if (port::Tracing::IsActive()) { + id = port::Tracing::UniqueId(); + port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure, + id); + } + + mutex_lock l(mu_); + pending_.push_back({fn, id}); + if (!waiters_.empty()) { + Waiter* w = waiters_.back(); + waiters_.pop_back(); + w->ready = true; + w->cv.notify_one(); + } +} + +void ThreadPool::WorkerLoop() { + port::Tracing::RegisterCurrentThread(name_.c_str()); + mutex_lock l(mu_); + Waiter w; + while (true) { + while (pending_.empty()) { + // Wait for work to be assigned to me + w.ready = false; + waiters_.push_back(&w); + while (!w.ready) { + w.cv.wait(l); + } + } + // Pick up pending work + Item item = pending_.front(); + pending_.pop_front(); + if (item.fn == nullptr) { + break; + } + mu_.unlock(); + if (item.id != 0) { + port::Tracing::ScopedActivity region( + port::Tracing::EventCategory::kRunClosure, item.id); + item.fn(); + } else { + item.fn(); + } + mu_.lock(); + } +} + +} // namespace thread +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h new file mode 100644 index 0000000000..5cf780fa86 --- /dev/null +++ b/tensorflow/core/lib/core/threadpool.h @@ -0,0 +1,59 @@ +#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_ +#define TENSORFLOW_LIB_CORE_THREADPOOL_H_ + +#include <deque> +#include <functional> +#include <thread> +#include <vector> +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace thread { + +class ThreadPool { + public: + // Construct a pool that contains "num_threads" threads with specified "name". + // env->StartThread() is used to create individual threads. + // + // REQUIRES: num_threads > 0 + ThreadPool(Env* env, const string& name, int num_threads); + + // Construct a pool that contains "num_threads" threads with specified "name". + // env->StartThread() is used to create individual threads. + // + // REQUIRES: num_threads > 0 + ThreadPool(Env* env, const ThreadOptions& thread_options, const string& name, + int num_threads); + + // Wait until all scheduled work has finished and then destroy the + // set of threads. + virtual ~ThreadPool(); + + // Schedule fn() for execution in the pool of threads. + virtual void Schedule(std::function<void()> fn); + + virtual bool HasPendingClosures() const; + + private: + struct Waiter; + struct Item { + std::function<void()> fn; + uint64 id; + }; + + void WorkerLoop(); + + const string name_; + mutable mutex mu_; + std::vector<Thread*> threads_; // All threads + std::vector<Waiter*> waiters_; // Stack of waiting threads. + std::deque<Item> pending_; // Queue of pending work + + TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool); +}; + +} // namespace thread +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_THREADPOOL_H_ diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc new file mode 100644 index 0000000000..f4909c445c --- /dev/null +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -0,0 +1,93 @@ +#include "tensorflow/core/lib/core/threadpool.h" + +#include <atomic> + +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/env.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace thread { + +static const int kNumThreads = 30; + +TEST(ThreadPool, Empty) { + for (int num_threads = 1; num_threads < kNumThreads; num_threads++) { + fprintf(stderr, "Testing with %d threads\n", num_threads); + ThreadPool pool(Env::Default(), "test", num_threads); + } +} + +TEST(ThreadPool, DoWork) { + for (int num_threads = 1; num_threads < kNumThreads; num_threads++) { + fprintf(stderr, "Testing with %d threads\n", num_threads); + const int kWorkItems = 15; + bool work[kWorkItems]; + for (int i = 0; i < kWorkItems; i++) { + work[i] = false; + } + { + ThreadPool pool(Env::Default(), "test", num_threads); + for (int i = 0; i < kWorkItems; i++) { + pool.Schedule([&work, i]() { + ASSERT_FALSE(work[i]); + work[i] = true; + }); + } + } + for (int i = 0; i < kWorkItems; i++) { + ASSERT_TRUE(work[i]); + } + } +} + +static void BM_Sequential(int iters) { + ThreadPool pool(Env::Default(), "test", kNumThreads); + // Decrement count sequentially until 0. + int count = iters; + mutex done_lock; + condition_variable done; + bool done_flag = false; + std::function<void()> work = [&pool, &count, &done_lock, &done, &done_flag, + &work]() { + if (count--) { + pool.Schedule(work); + } else { + mutex_lock l(done_lock); + done_flag = true; + done.notify_all(); + } + }; + work(); + mutex_lock l(done_lock); + if (!done_flag) { + done.wait(l); + } +} +BENCHMARK(BM_Sequential); + +static void BM_Parallel(int iters) { + ThreadPool pool(Env::Default(), "test", kNumThreads); + // Decrement count concurrently until 0. + std::atomic_int_fast32_t count(iters); + mutex done_lock; + condition_variable done; + bool done_flag = false; + for (int i = 0; i < iters; ++i) { + pool.Schedule([&count, &done_lock, &done, &done_flag]() { + if (count.fetch_sub(1) == 1) { + mutex_lock l(done_lock); + done_flag = true; + done.notify_all(); + } + }); + } + mutex_lock l(done_lock); + if (!done_flag) { + done.wait(l); + } +} +BENCHMARK(BM_Parallel); + +} // namespace thread +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/array_slice.h b/tensorflow/core/lib/gtl/array_slice.h new file mode 100644 index 0000000000..813fb126e3 --- /dev/null +++ b/tensorflow/core/lib/gtl/array_slice.h @@ -0,0 +1,299 @@ +// An ArraySlice<T> represents an immutable array of elements of type +// T. It has a length "length", and a base pointer "ptr", and the +// array it represents contains the elements "ptr[0] .. ptr[len-1]". +// The backing store for the array is *not* owned by the ArraySlice +// object, and clients must arrange for the backing store to remain +// live while the ArraySlice object is in use. +// +// An ArraySlice<T> is somewhat analogous to a StringPiece, but for +// array elements of type T. +// +// Implicit conversion operations are provided from types such as +// std::vector<T> and util::gtl::InlinedVector<T, N>. Note that ArraySlice +// objects constructed from types in this way may be invalidated by +// any operations that mutate the underlying vector. +// +// One common use for ArraySlice is when passing arguments to a +// routine where you want to be able to accept a variety of array +// types (e.g. a vector, a util::gtl::InlinedVector, a C-style array, +// etc.). The usual approach here is to have the client explicitly +// pass in a pointer and a length, as in: +// +// void MyRoutine(const int* elems, int N) { +// for (int i = 0; i < N; i++) { .. do something with elems[i] .. } +// } +// +// Unfortunately, this leads to ugly and error-prone code at the call site: +// +// std::vector<int> my_vector; +// MyRoutine(vector_as_array(&my_vector), my_vector.size()); +// +// util::gtl::InlinedVector<int, 4> my_inline_vector; +// MyRoutine(my_inline_vector.array(), my_inline_vector.size()); +// +// int my_array[10]; +// MyRoutine(my_array, 10); +// +// Instead, you can use an ArraySlice as the argument to the routine: +// +// void MyRoutine(ArraySlice<int> a) { +// for (int i = 0; i < a.size(); i++) { .. do something with a[i] .. } +// } +// +// This makes the call sites cleaner, for the most part: +// +// std::vector<int> my_vector; +// MyRoutine(my_vector); +// +// util::gtl::InlinedVector<int, 4> my_inline_vector; +// MyRoutine(my_inline_vector); +// +// int my_array[10]; +// MyRoutine(my_array); +// +// int* my_array = new int[10]; +// MyRoutine(gtl::ArraySlice<int>(my_array, 10)); +// +// MutableArraySlice<T> represents a mutable array of elements, and, like +// ArraySlice, does not own the backing store. The implicit constructors it +// provides allow functions not to worry about whether their mutable arguments +// refer to vectors, arrays, proto2::RepeatedFields, etc.: +// +// void MyMutatingRoutine(MutableArraySlice<int> a) { +// for (int i = 0; i < a.size(); i++) { .. mutate a[i] .. } +// } +// +// std::vector<int> my_vector; +// MyMutatingRoutine(&my_vector); +// +// int my_array[10]; +// MyMutatingRoutine(my_array); +// +// int* my_array = new int[10]; +// MyMutatingRoutine(gtl::MutableArraySlice<int>(my_array, 10)); +// +// MyProto my_proto; +// for (int i = 0; i < 10; ++i) { my_proto.add_value(i); } +// MyMutatingRoutine(my_proto.mutable_value()); + +#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_ +#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_ + +#include <initializer_list> +#include <type_traits> +#include <vector> + +#include "tensorflow/core/lib/gtl/array_slice_internal.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { +namespace gtl { + +template <typename T> +class ArraySlice { + private: + typedef array_slice_internal::ArraySliceImpl<T> Impl; + + public: + typedef T value_type; + typedef typename Impl::pointer pointer; + typedef typename Impl::const_pointer const_pointer; + typedef typename Impl::reference reference; + typedef typename Impl::const_reference const_reference; + typedef typename Impl::iterator iterator; + typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::reverse_iterator reverse_iterator; + typedef typename Impl::const_reverse_iterator const_reverse_iterator; + typedef typename Impl::size_type size_type; + typedef typename Impl::difference_type difference_type; + + static const size_type npos = Impl::npos; + + ArraySlice() : impl_(nullptr, 0) {} + ArraySlice(const_pointer array, size_type length) : impl_(array, length) {} + + // Implicit conversion constructors + ArraySlice(const std::vector<value_type>& v) // NOLINT(runtime/explicit) + : impl_(v.data(), v.size()) {} + + template <size_t N> + ArraySlice(const value_type (&a)[N]) // NOLINT(runtime/explicit) + : impl_(a, N) {} + + template <int N> + ArraySlice(const InlinedVector<value_type, N>& v) // NOLINT(runtime/explicit) + : impl_(v.array(), v.size()) {} + + // The constructor for any class supplying 'data() const' that returns either + // const T* or a less const-qualified version of it, and 'some_integral_type + // size() const'. proto2::RepeatedField<T>, string and (since C++11) + // std::vector<T,A> and std::array<T, N> are examples of this. See + // array_slice_internal.h for details. + template <typename V, + typename = typename Impl::template EnableIfConvertibleFrom<V>> + ArraySlice(const V& v) // NOLINT(runtime/explicit) + : impl_(v) {} + + // Implicitly constructs an ArraySlice from an initializer list. This makes it + // possible to pass a brace-enclosed initializer list to a function expecting + // an ArraySlice: + // void Process(ArraySlice<int> x); + // Process({1, 2, 3}); + // The data referenced by the initializer_list must outlive this + // ArraySlice. For example, "ArraySlice<int> s={1,2};" and "return + // ArraySlice<int>({3,4});" are errors, as the resulting ArraySlice may + // reference data that is no longer valid. + ArraySlice(std::initializer_list<value_type> v) // NOLINT(runtime/explicit) + : impl_(v.begin(), v.size()) {} + + // Substring of another ArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + // If len==npos, the substring continues till the end of x. + ArraySlice(const ArraySlice& x, size_type pos, size_type len) + : impl_(x.impl_, pos, len) {} + + const_pointer data() const { return impl_.data(); } + size_type size() const { return impl_.size(); } + size_type length() const { return size(); } + bool empty() const { return size() == 0; } + + void clear() { impl_.clear(); } + + const_reference operator[](size_type i) const { return impl_[i]; } + const_reference at(size_type i) const { return impl_.at(i); } + const_reference front() const { return impl_.front(); } + const_reference back() const { return impl_.back(); } + + const_iterator begin() const { return impl_.begin(); } + const_iterator end() const { return impl_.end(); } + const_reverse_iterator rbegin() const { return impl_.rbegin(); } + const_reverse_iterator rend() const { return impl_.rend(); } + + void remove_prefix(size_type n) { impl_.remove_prefix(n); } + void remove_suffix(size_type n) { impl_.remove_suffix(n); } + void pop_back() { remove_suffix(1); } + void pop_front() { remove_prefix(1); } + + // These relational operators have the same semantics as the + // std::vector<T> relational operators: they do deep (elementwise) + // comparisons. Array slices are equal iff their size is the same + // and all their elements are equal. + bool operator==(ArraySlice<T> other) const { return impl_ == other.impl_; } + bool operator!=(ArraySlice<T> other) const { return impl_ != other.impl_; } + + private: + Impl impl_; +}; + +// Mutable version of ArraySlice, which allows the clients to mutate the +// underlying data. It is implicitly convertible to ArraySlice since it provides +// the data() and size() methods with correct signatures. When a +// MutableArraySlice is created from a pointer to a container (as opposed to raw +// memory pointer), the pointer must not be null. +// +// A note on const-ness: "mutable" here refers to the mutability of the +// underlying data, not of the slice itself. It is perfectly reasonable to have +// a variable of type "const MutableArraySlice<T>"; this means that the bounds +// of the view on the array cannot be changed, but the underlying data in the +// array still may be modified. This is akin to a "T* const" pointer, as opposed +// to a "const T*" pointer (corresponding to a non-const ArraySlice<T>). +template <typename T> +class MutableArraySlice { + private: + typedef array_slice_internal::MutableArraySliceImpl<T> Impl; + + public: + typedef T value_type; + typedef typename Impl::pointer pointer; + typedef typename Impl::const_pointer const_pointer; + typedef typename Impl::reference reference; + typedef typename Impl::const_reference const_reference; + typedef typename Impl::iterator iterator; + typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::reverse_iterator reverse_iterator; + typedef typename Impl::const_reverse_iterator const_reverse_iterator; + typedef typename Impl::size_type size_type; + typedef typename Impl::difference_type difference_type; + + static const size_type npos = Impl::npos; + + MutableArraySlice() : impl_(nullptr, 0) {} + MutableArraySlice(pointer array, size_type length) : impl_(array, length) {} + + // Implicit conversion constructors + MutableArraySlice(std::vector<value_type>* v) // NOLINT(runtime/explicit) + : impl_(v->data(), v->size()) {} + + template <size_t N> + MutableArraySlice(value_type (&a)[N]) // NOLINT(runtime/explicit) + : impl_(a, N) {} + + template <int N> + MutableArraySlice( + InlinedVector<value_type, N>* v) // NOLINT(runtime/explicit) + : impl_(v->mutable_array(), v->size()) {} + + // The constructor for any class supplying 'T* data()' or 'T* mutable_data()' + // (the former is called if both exist), and 'some_integral_type size() + // const'. proto2::RepeatedField is an example of this. Also supports string + // arguments, when T==char. The appropriate ctor is selected using SFINAE. See + // array_slice_internal.h for details. + template <typename V, + typename = typename Impl::template EnableIfConvertibleFrom<V>> + MutableArraySlice(V* v) // NOLINT(runtime/explicit) + : impl_(v) {} + + // Substring of another MutableArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + // If len==npos, the substring continues till the end of x. + MutableArraySlice(const MutableArraySlice& x, size_type pos, size_type len) + : impl_(x.impl_, pos, len) {} + + // Accessors. + pointer data() const { return impl_.data(); } + size_type size() const { return impl_.size(); } + size_type length() const { return size(); } + bool empty() const { return size() == 0; } + + void clear() { impl_.clear(); } + + reference operator[](size_type i) const { return impl_[i]; } + reference at(size_type i) const { return impl_.at(i); } + reference front() const { return impl_.front(); } + reference back() const { return impl_.back(); } + + iterator begin() const { return impl_.begin(); } + iterator end() const { return impl_.end(); } + reverse_iterator rbegin() const { return impl_.rbegin(); } + reverse_iterator rend() const { return impl_.rend(); } + + void remove_prefix(size_type n) { impl_.remove_prefix(n); } + void remove_suffix(size_type n) { impl_.remove_suffix(n); } + void pop_back() { remove_suffix(1); } + void pop_front() { remove_prefix(1); } + + bool operator==(ArraySlice<T> other) const { + return ArraySlice<T>(*this) == other; + } + bool operator!=(ArraySlice<T> other) const { + return ArraySlice<T>(*this) != other; + } + + // DEPRECATED(jacobsa): Please use data() instead. + pointer mutable_data() const { return impl_.data(); } + + private: + Impl impl_; +}; + +template <typename T> +const typename ArraySlice<T>::size_type ArraySlice<T>::npos; +template <typename T> +const typename MutableArraySlice<T>::size_type MutableArraySlice<T>::npos; + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_ diff --git a/tensorflow/core/lib/gtl/array_slice_internal.h b/tensorflow/core/lib/gtl/array_slice_internal.h new file mode 100644 index 0000000000..080f0a38d8 --- /dev/null +++ b/tensorflow/core/lib/gtl/array_slice_internal.h @@ -0,0 +1,253 @@ +// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by +// array_slice.h. + +// Helper functions and templates for ArraySlice. + +#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_ +#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_ + +#include <stddef.h> +#include <algorithm> +#include <iterator> +#include <memory> +#include <string> +#include <type_traits> +#include <utility> +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace gtl { +namespace array_slice_internal { + +// Template logic for generic constructors. + +// Wrappers whose Get() delegates to the appropriate method of a container, and +// is defined when this method exists. Delegates to the const method if C is a +// const type. +struct Data { + template <typename C> + static decltype(std::declval<C>().data()) Get(C* v) { + return v->data(); + } +}; + +struct MutableData { + template <typename C> + static decltype(std::declval<C>().mutable_data()) Get(C* v) { + return v->mutable_data(); + } +}; + +struct Size { + template <typename C> + static decltype(std::declval<C>().size()) Get(C* v) { + return v->size(); + } +}; + +struct MutableStringData { + // Defined only for string. + static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); } +}; + +// Checks whether M::Get(C*) is defined and has a return type R such that +// Checker::valid<R>()==true. +template <typename M, typename Checker, typename C> +struct HasGetHelper : public M { + private: + struct None {}; + // M::Get is selected when it is viable. Get(...) is selected otherwise. + using M::Get; + static None Get(...); + + public: + static constexpr bool HasGet() { + using Result = decltype(Get(std::declval<C*>())); + return !std::is_same<Result, None>() && Checker::template valid<Result>(); + } +}; + +// Defines HasGet() for a particular method, container, and checker. If +// HasGet()==true, provides Get() that delegates to the method. +template <typename M, typename Checker, typename C, + bool /*has_get*/ = HasGetHelper<M, Checker, C>::HasGet()> +struct Wrapper { + static constexpr bool HasGet() { return false; } +}; + +template <typename M, typename Checker, typename C> +struct Wrapper<M, Checker, C, true> { + static constexpr bool HasGet() { return true; } + static decltype(M::Get(std::declval<C*>())) Get(C* v) { return M::Get(v); } +}; + +// Type checker for a method returning an integral value. +struct SizeChecker { + template <typename R> + static constexpr bool valid() { + return std::is_integral<R>::value; + } +}; + +// Type checker for a method returning either a pointer to T or a less const +// version of that. +template <typename T> +struct DataChecker { + // We want to enable conversion from std::vector<T*> to ArraySlice<const T*> + // but + // disable conversion from std::vector<Derived> to ArraySlice<Base>. Here we + // use + // the fact that U** is convertible to Q* const* if and only if Q is the same + // type or a more cv-qualified version of U. + template <typename R> + static constexpr bool valid() { + return std::is_convertible<R*, T* const*>::value; + } +}; + +// Aliases to A if A::HasGet()==true, or to B otherwise. +template <typename A, typename B> +using FirstWithGet = typename std::conditional<A::HasGet(), A, B>::type; + +// Wraps C::data() const, returning a pointer to const data. +template <typename T, typename C> +using ContainerData = Wrapper<Data, DataChecker<const T>, const C>; + +// Wraps a method returning a pointer to mutable data. Prefers data() over +// mutable_data(), and handles strings when T==char. If data() returns a pointer +// to mutable data, it is most likely overloaded, but may also be a single +// method 'T* C::data() const' in a non-STL-compliant container. +template <typename T, typename C> +using ContainerMutableData = + FirstWithGet<Wrapper<Data, DataChecker<T>, C>, + FirstWithGet<Wrapper<MutableData, DataChecker<T>, C>, + Wrapper<MutableStringData, DataChecker<T>, C>>>; + +// Wraps C::size() const. +template <typename C> +using ContainerSize = Wrapper<Size, SizeChecker, const C>; + +// Implementation class for ArraySlice and MutableArraySlice. In the case of +// ArraySlice, T will be a const type; for MutableArraySlice, T will be a +// mutable type. +template <typename T> +class ArraySliceImplBase { + public: + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + typedef pointer iterator; + typedef const_pointer const_iterator; + typedef std::reverse_iterator<iterator> reverse_iterator; + typedef std::reverse_iterator<const_iterator> const_reverse_iterator; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + static const size_type npos = -1; + + ArraySliceImplBase(pointer array, size_type length) + : ptr_(array), length_(length) {} + + // Substring of another ArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len) + : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {} + + // Some of the const methods below return pointers and references to mutable + // data. This is only the case in this internal class; ArraySlice and + // MutableArraySlice provide deep-constness. + + pointer data() const { return ptr_; } + size_type size() const { return length_; } + + void clear() { + ptr_ = nullptr; + length_ = 0; + } + + reference operator[](size_type i) const { return ptr_[i]; } + reference at(size_type i) const { + DCHECK_LT(i, length_); + return ptr_[i]; + } + reference front() const { + DCHECK_GT(length_, 0); + return ptr_[0]; + } + reference back() const { + DCHECK_GT(length_, 0); + return ptr_[length_ - 1]; + } + + void remove_prefix(size_type n) { + DCHECK_GE(length_, n); + ptr_ += n; + length_ -= n; + } + void remove_suffix(size_type n) { + DCHECK_GE(length_, n); + length_ -= n; + } + + iterator begin() const { return ptr_; } + iterator end() const { return ptr_ + length_; } + reverse_iterator rbegin() const { return reverse_iterator(end()); } + reverse_iterator rend() const { return reverse_iterator(begin()); } + + bool operator==(const ArraySliceImplBase& other) const { + if (size() != other.size()) return false; + if (data() == other.data()) return true; + return std::equal(data(), data() + size(), other.data()); + } + bool operator!=(const ArraySliceImplBase& other) const { + return !(*this == other); + } + + private: + pointer ptr_; + size_type length_; +}; + +template <typename T> +class ArraySliceImpl : public ArraySliceImplBase<const T> { + public: + using ArraySliceImplBase<const T>::ArraySliceImplBase; + + // Defined iff the data and size accessors for the container C have been + // defined. + template <typename C> + using EnableIfConvertibleFrom = + typename std::enable_if<ContainerData<T, C>::HasGet() && + ContainerSize<C>::HasGet()>::type; + + // Constructs from a container when EnableIfConvertibleFrom is + // defined. std::addressof handles types with overloaded operator&. + template <typename C> + explicit ArraySliceImpl(const C& v) + : ArraySliceImplBase<const T>(ContainerData<T, C>::Get(std::addressof(v)), + ContainerSize<C>::Get(std::addressof(v))) {} +}; + +template <typename T> +class MutableArraySliceImpl : public ArraySliceImplBase<T> { + public: + using ArraySliceImplBase<T>::ArraySliceImplBase; + + template <typename C> + using EnableIfConvertibleFrom = + typename std::enable_if<ContainerMutableData<T, C>::HasGet() && + ContainerSize<C>::HasGet()>::type; + + template <typename C> + explicit MutableArraySliceImpl(C* v) + : ArraySliceImplBase<T>(ContainerMutableData<T, C>::Get(v), + ContainerSize<C>::Get(v)) {} +}; + +} // namespace array_slice_internal +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_ diff --git a/tensorflow/core/lib/gtl/array_slice_test.cc b/tensorflow/core/lib/gtl/array_slice_test.cc new file mode 100644 index 0000000000..33ee8fc8dd --- /dev/null +++ b/tensorflow/core/lib/gtl/array_slice_test.cc @@ -0,0 +1,646 @@ +#include "tensorflow/core/lib/gtl/array_slice.h" + +#include <algorithm> +#include <array> +#include <string> +#include <vector> + +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/port.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace gtl { +namespace { + +typedef ArraySlice<int> IntSlice; +typedef ArraySlice<char> CharSlice; +typedef MutableArraySlice<int> MutableIntSlice; +typedef MutableArraySlice<char> MutableCharSlice; +typedef std::vector<int> IntVec; + +// Append 0..len-1 to *v +template <typename Vector> +static void Fill(Vector* v, int len, int offset = 0) { + for (int i = 0; i < len; i++) { + v->push_back(i + offset); + } +} + +static void TestHelper(const IntSlice& vorig, const IntVec& vec) { + IntSlice other; // To test the assignment return value. + IntSlice v = other = vorig; + const int len = vec.size(); + EXPECT_EQ(v.size(), vec.size()); + + for (int i = 0; i < len; i++) { + EXPECT_EQ(v[i], vec[i]); + EXPECT_EQ(v.at(i), vec[i]); + } + EXPECT_EQ(v.begin(), gtl::vector_as_array(&vec)); + + int counter = 0; + for (IntSlice::iterator it = v.begin(); it != v.end(); ++it) { + EXPECT_EQ(counter, *it); + counter++; + } + EXPECT_EQ(counter, len); + + counter = 0; + for (IntSlice::const_iterator it = v.begin(); it != v.end(); ++it) { + EXPECT_EQ(counter, *it); + counter++; + } + EXPECT_EQ(counter, len); + + if (len > 0) { + EXPECT_EQ(0, v.front()); + EXPECT_EQ(len - 1, v.back()); + v.pop_back(); + EXPECT_EQ(len - 1, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(i, v[i]); + } + if (len > 1) { + v.pop_front(); + EXPECT_EQ(len - 2, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(i + 1, v[i]); + } + } + } +} + +// The element access test that is applicable both when MutableArraySlice is +// const and when it's not. +template <class V> +void MutableTestHelperTemplated(V v, int* ptr, const int len) { + CHECK_EQ(v.size(), len); + + for (int i = 0; i < len; i++) { + EXPECT_EQ(ptr + i, &v[i]); + EXPECT_EQ(ptr + i, &v.at(i)); + } + EXPECT_EQ(ptr, v.begin()); + EXPECT_EQ(ptr + len, v.end()); + EXPECT_EQ(ptr, v.data()); + + int counter = 0; + for (MutableIntSlice::const_iterator it = v.begin(); it != v.end(); ++it) { + EXPECT_EQ(ptr + counter, &*it); + counter++; + } + EXPECT_EQ(counter, len); + + EXPECT_EQ(len, std::distance(v.rbegin(), v.rend())); + + if (len > 0) { + EXPECT_EQ(ptr, &v.front()); + EXPECT_EQ(ptr + len - 1, &v.back()); + EXPECT_EQ(ptr + len - 1, &*v.rbegin()); + EXPECT_EQ(ptr, &*(v.rend() - 1)); + } +} + +static void MutableTestHelper(const MutableIntSlice& vorig, int* ptr, + const int len) { + // Test the data accessors both when the MutableArraySlice is declared const, + // and when it is not. + MutableTestHelperTemplated<const MutableIntSlice&>(vorig, ptr, len); + MutableTestHelperTemplated<MutableIntSlice>(vorig, ptr, len); + + MutableIntSlice other; // To test the assignment return value. + MutableIntSlice v = other = vorig; + EXPECT_EQ(ptr, v.mutable_data()); + + int counter = 0; + for (MutableIntSlice::iterator it = v.begin(); it != v.end(); ++it) { + EXPECT_EQ(ptr + counter, &*it); + counter++; + } + EXPECT_EQ(counter, len); + + if (len > 0) { + // Test that elements are assignable. + v[0] = 1; + v.front() = 2; + v.back() = 5; + *v.mutable_data() = 4; + std::fill(v.begin(), v.end(), 5); + std::fill(v.rbegin(), v.rend(), 6); + // Test size-changing methods. + v.pop_back(); + EXPECT_EQ(len - 1, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(ptr + i, &v[i]); + } + if (len > 1) { + v.pop_front(); + EXPECT_EQ(len - 2, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(ptr + i + 1, &v[i]); + } + } + } +} + +template <typename Vector> +static void TestImplicitConversion(const IntSlice& v, const Vector& vec) { + EXPECT_EQ(v.size(), vec.size()); + for (size_t i = 0; i < v.size(); i++) { + EXPECT_EQ(v[i], vec[i]); + } +} + +template <typename Vector> +static void TestImplicitConversion(const CharSlice& v, const Vector& vec) { + TestImplicitConversion(IntVec(v.begin(), v.end()), vec); +} + +static void TestImplicitConversion(const MutableIntSlice& v, const int* data, + int size) { + EXPECT_EQ(size, v.size()); + for (size_t i = 0; i < v.size(); i++) { + EXPECT_EQ(data + i, &v[i]); + } +} + +static void TestImplicitConversion(const MutableCharSlice& v, const char* data, + int size) { + EXPECT_EQ(size, v.size()); + for (size_t i = 0; i < v.size(); i++) { + EXPECT_EQ(data + i, &v[i]); + } +} +// A struct supplying the data(), mutable_data() and size() methods, just like +// e.g. proto2::RepeatedField. +struct RepeatedField { + std::vector<int> storage; + const int* data() const { return storage.data(); } + int* mutable_data() { return storage.data(); } + int size() const { return storage.size(); } +}; + +// A struct supplying the data() (both mutable and const versions) and +// size(). It also supplies mutable_data() but we test that data() is selected +// instead. +struct ContainerWithOverloads { + std::vector<int> storage; + std::vector<int> wrong_storage; + const int* data() const { return storage.data(); } + int* data() { return storage.data(); } + // MutableArraySlice should not call mutable_data(), preferring data() + // instead. + int* mutable_data() { return wrong_storage.data(); } + int size() const { return storage.size(); } +}; + +// A struct supplying data() and size() methods. +struct ContainerWithShallowConstData { + std::vector<int> storage; + int* data() const { return const_cast<int*>(storage.data()); } + int size() const { return storage.size(); } +}; + +TEST(IntSlice, Simple) { + for (int len = 0; len < 20; len++) { + IntVec vec; + Fill(&vec, len); + TestHelper(IntSlice(vec), vec); + TestHelper(IntSlice(vec.data(), vec.size()), vec); + } +} + +TEST(IntSlice, WithPosAndLen) { + IntVec vec; + Fill(&vec, 20); + for (size_t len = 0; len < vec.size(); len++) { + IntVec subvec(vec.begin(), vec.begin() + len); + TestImplicitConversion(IntSlice(vec, 0, len), subvec); + TestImplicitConversion(IntSlice(IntSlice(vec), 0, len), subvec); + } + EXPECT_EQ(0, IntSlice(vec, 0, 0).size()); + EXPECT_EQ(0, IntSlice(IntSlice(vec), 0, 0).size()); + TestImplicitConversion(IntSlice(vec, 0, IntSlice::npos), vec); +} + +TEST(IntSlice, Clear) { + for (int len = 0; len < 20; len++) { + IntVec vec; + Fill(&vec, len); + IntSlice v(vec); + v.clear(); + EXPECT_EQ(0, v.size()); + EXPECT_EQ(v.begin(), v.end()); + } +} + +TEST(IntSlice, Swap) { + for (int l1 = 0; l1 < 20; l1++) { + for (int l2 = 0; l2 < 20; l2++) { + IntVec avec, bvec; + Fill(&avec, l1); + Fill(&bvec, l2, 100); + IntSlice a(avec), b(bvec); + using std::swap; + swap(a, b); + EXPECT_EQ(l1, b.size()); + EXPECT_EQ(l2, a.size()); + for (int i = 0; i < l1; i++) { + EXPECT_EQ(i, b[i]); + } + for (int i = 0; i < l2; i++) { + EXPECT_EQ(100 + i, a[i]); + } + } + } +} + +TEST(IntSlice, ImplicitConversion) { + for (int len = 0; len < 20; len++) { + IntVec vec; + Fill(&vec, len); + IntSlice slice; + slice = vec; + TestImplicitConversion(vec, vec); + TestImplicitConversion(slice, vec); + TestImplicitConversion(IntSlice(vec.data(), vec.size()), vec); + } +} + +TEST(IntSlice, InlinedVectorConversion) { + for (int len = 0; len < 20; len++) { + InlinedVector<int, 4> inline_vec; + for (int i = 0; i < len; i++) { + inline_vec.push_back(i); + } + IntVec vec; + Fill(&vec, len); + IntSlice v = inline_vec; // Test assignment + static_cast<void>(v); + TestImplicitConversion(inline_vec, vec); + } +} + +TEST(IntSlice, StaticArrayConversion) { + int array[20]; + IntVec vec; + Fill(&vec, TF_ARRAYSIZE(array)); + std::copy(vec.begin(), vec.end(), array); + IntSlice v = array; // Test assignment + static_cast<void>(v); + TestImplicitConversion(array, vec); +} + +TEST(IntSlice, StdArrayConversion) { + std::array<int, 20> array; + IntVec vec; + Fill(&vec, array.size()); + std::copy(vec.begin(), vec.end(), array.begin()); + + // Check assignment. + { + IntSlice v = array; + static_cast<void>(v); + } + + // Check sub-slice initialization. + { + IntSlice v = {array, 10, 15}; + static_cast<void>(v); + } + + TestImplicitConversion(array, vec); +} + +// Values according to the Fill function. +static const int test_const_array[] = {0, 1, 2}; + +TEST(IntSlice, ConstStaticArrayConversion) { + IntVec vec; + Fill(&vec, TF_ARRAYSIZE(test_const_array)); + IntSlice v = test_const_array; // Test assignment + static_cast<void>(v); + TestImplicitConversion(test_const_array, vec); +} + +TEST(IntSlice, RepeatedFieldConversion) { + RepeatedField repeated_field; + IntVec vec; + Fill(&vec, 20); + repeated_field.storage = vec; + IntSlice v = repeated_field; // Test assignment + static_cast<void>(v); + TestImplicitConversion(repeated_field, vec); +} + +TEST(IntSlice, ContainerWithOverloadsConversion) { + ContainerWithOverloads container; + Fill(&container.storage, 20); + container.wrong_storage.resize(container.size()); + IntSlice v = container; // Test assignment + static_cast<void>(v); + TestImplicitConversion(container, container.storage); +} + +TEST(IntSlice, ContainerWithShallowConstDataConversion) { + ContainerWithShallowConstData container; + Fill(&container.storage, 20); + IntSlice v = container; // Test assignment + static_cast<void>(v); + TestImplicitConversion(container, container.storage); +} + +TEST(IntSlice, MutableIntSliceConversion) { + IntVec vec(20); + IntSlice slice = MutableIntSlice(&vec); + EXPECT_EQ(vec.size(), slice.size()); + EXPECT_EQ(vec.data(), slice.data()); +} + +TEST(IntSlice, Equality) { + IntVec vec1(20); + IntVec vec2(20); + // These two slices are from different vectors, but have the same + // size and have the same elements (right now). They should + // compare equal. + const IntSlice from1(vec1); + const IntSlice from2(vec2); + EXPECT_EQ(from1, from1); + EXPECT_EQ(from1, from2); + + // This verifies that MutableArraySlices can be compared freely with + // ArraySlices. + const MutableIntSlice mutable_from1(&vec1); + const MutableIntSlice mutable_from2(&vec2); + EXPECT_EQ(from1, mutable_from1); + EXPECT_EQ(mutable_from1, from1); + EXPECT_EQ(mutable_from1, mutable_from2); + EXPECT_EQ(mutable_from2, mutable_from1); + + // With a different size, the array slices should not be equal. + EXPECT_NE(from1, IntSlice(from1, 0, from1.size() - 1)); + + // With different contents, the array slices should not be equal. + ++vec2.back(); + EXPECT_NE(from1, from2); +} + +// Compile-asserts that the argument has the expected type. +template <typename Expected, typename T> +void CheckType(const T& value) { + testing::StaticAssertTypeEq<Expected, T>(); +} + +TEST(IntSlice, ExposesContainerTypesAndConsts) { + IntSlice slice; + const IntSlice const_slice; + CheckType<IntSlice::iterator>(slice.begin()); + CheckType<IntSlice::const_iterator>(const_slice.end()); + CheckType<IntSlice::const_reverse_iterator>(const_slice.rbegin()); + CheckType<IntSlice::reverse_iterator>(slice.rend()); + testing::StaticAssertTypeEq<int, IntSlice::value_type>(); + testing::StaticAssertTypeEq<const int*, IntSlice::pointer>(); + testing::StaticAssertTypeEq<const int&, IntSlice::const_reference>(); + EXPECT_EQ(static_cast<IntSlice::size_type>(-1), IntSlice::npos); +} + +void TestEmpty(IntSlice slice) { ASSERT_TRUE(slice.empty()); } + +void TestRange(IntSlice slice, int from, int to) { + ASSERT_EQ(to - from + 1, slice.size()); + for (size_t i = 0; i < slice.size(); ++i) { + EXPECT_EQ(from + i, slice[i]); + } +} + +TEST(IntSlice, InitializerListConversion) { + TestEmpty({}); + TestRange({1}, 1, 1); + TestRange({10, 11, 12, 13}, 10, 13); +} + +TEST(CharSlice, StringConversion) { + IntVec vec; + Fill(&vec, 20); + string str(vec.begin(), vec.end()); + CharSlice v = str; // Test assignment + static_cast<void>(v); + TestImplicitConversion(str, vec); +} + +TEST(IntPtrSlice, ConstConversion) { + int one = 1; + int two = 2; + std::vector<int*> vec; + vec.push_back(&one); + vec.push_back(&two); + ArraySlice<const int*> v = vec; + ASSERT_EQ(2, v.size()); + EXPECT_EQ(&one, v[0]); + EXPECT_EQ(&two, v[1]); +} + +TEST(MutableIntSlice, Simple) { + for (int len = 0; len < 20; len++) { + IntVec vec(len); + MutableTestHelper(MutableIntSlice(&vec), vec.data(), len); + MutableTestHelper(MutableIntSlice(vec.data(), vec.size()), vec.data(), len); + } +} + +TEST(MutableIntSlice, WithPosAndLen) { + IntVec vec(20); + for (size_t len = 0; len < vec.size(); len++) { + TestImplicitConversion(MutableIntSlice(&vec, 0, len), vec.data(), len); + TestImplicitConversion(MutableIntSlice(MutableIntSlice(&vec), 0, len), + vec.data(), len); + } + EXPECT_EQ(0, MutableIntSlice(&vec, 0, 0).size()); + EXPECT_EQ(0, MutableIntSlice(MutableIntSlice(&vec), 0, 0).size()); + TestImplicitConversion(MutableIntSlice(&vec, 0, MutableIntSlice::npos), + vec.data(), vec.size()); +} + +TEST(MutableIntSlice, Clear) { + for (int len = 0; len < 20; len++) { + IntVec vec(len); + MutableIntSlice v(&vec); + v.clear(); + EXPECT_EQ(0, v.size()); + EXPECT_EQ(v.begin(), v.end()); + } +} + +TEST(MutableIntSlice, Swap) { + for (int l1 = 0; l1 < 20; l1++) { + for (int l2 = 0; l2 < 20; l2++) { + IntVec avec(l1), bvec(l2); + MutableIntSlice a(&avec), b(&bvec); + using std::swap; + swap(a, b); + EXPECT_EQ(l1, b.size()); + EXPECT_EQ(l2, a.size()); + for (int i = 0; i < l1; i++) { + EXPECT_EQ(&avec[i], &b[i]); + } + for (int i = 0; i < l2; i++) { + EXPECT_EQ(&bvec[i], &a[i]); + } + } + } +} + +TEST(MutableIntSlice, ImplicitConversion) { + for (int len = 0; len < 20; len++) { + IntVec vec(len); + MutableIntSlice slice; + slice = &vec; + TestImplicitConversion(&vec, vec.data(), len); + TestImplicitConversion(slice, vec.data(), len); + TestImplicitConversion(MutableIntSlice(vec.data(), vec.size()), vec.data(), + len); + } +} + +TEST(MutableIntSlice, InlinedVectorConversion) { + for (int len = 0; len < 20; len++) { + InlinedVector<int, 4> inline_vec; + for (int i = 0; i < len; i++) { + inline_vec.push_back(i); + } + MutableIntSlice v = &inline_vec; // Test assignment + static_cast<void>(v); + TestImplicitConversion(&inline_vec, inline_vec.array(), inline_vec.size()); + } +} + +TEST(MutableIntSlice, StaticArrayConversion) { + int array[20]; + MutableIntSlice v = array; // Test assignment + static_cast<void>(v); + TestImplicitConversion(array, array, TF_ARRAYSIZE(array)); +} + +TEST(MutableIntSlice, StdArrayConversion) { + std::array<int, 20> array; + + // Check assignment. + { + MutableIntSlice v = &array; + static_cast<void>(v); + } + + // Check sub-slice initialization. + { + MutableIntSlice v = {&array, 10, 15}; + static_cast<void>(v); + } + + TestImplicitConversion(&array, &array[0], array.size()); +} + +TEST(MutableIntSlice, RepeatedFieldConversion) { + RepeatedField repeated_field; + Fill(&repeated_field.storage, 20); + MutableIntSlice v = &repeated_field; // Test assignment + static_cast<void>(v); + TestImplicitConversion(&repeated_field, repeated_field.storage.data(), + repeated_field.storage.size()); +} + +TEST(MutableIntSlice, ContainerWithOverloadsConversion) { + ContainerWithOverloads container; + Fill(&container.storage, 20); + container.wrong_storage.resize(container.size()); + MutableIntSlice v = &container; // Test assignment + static_cast<void>(v); + TestImplicitConversion(&container, container.storage.data(), + container.storage.size()); +} + +TEST(MutableIntSlice, ContainerWithShallowConstDataConversion) { + ContainerWithShallowConstData container; + Fill(&container.storage, 20); + MutableIntSlice v = &container; // Test assignment + static_cast<void>(v); + TestImplicitConversion(&container, container.storage.data(), + container.storage.size()); +} + +TEST(MutableIntSlice, TypedefsAndConstants) { + testing::StaticAssertTypeEq<int, MutableIntSlice::value_type>(); + testing::StaticAssertTypeEq<int*, MutableIntSlice::pointer>(); + testing::StaticAssertTypeEq<const int*, MutableIntSlice::const_pointer>(); + testing::StaticAssertTypeEq<int&, MutableIntSlice::reference>(); + testing::StaticAssertTypeEq<const int&, MutableIntSlice::const_reference>(); + + EXPECT_EQ(static_cast<MutableIntSlice::size_type>(-1), MutableIntSlice::npos); +} + +TEST(MutableIntSlice, IteratorsAndReferences) { + auto accept_pointer = [](int* x) {}; + auto accept_reference = [](int& x) {}; + auto accept_iterator = [](MutableIntSlice::iterator x) {}; + auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {}; + + int a[1]; + MutableIntSlice s = a; + + accept_pointer(s.data()); + accept_pointer(s.mutable_data()); + accept_iterator(s.begin()); + accept_iterator(s.end()); + accept_reverse_iterator(s.rbegin()); + accept_reverse_iterator(s.rend()); + + accept_reference(s[0]); + accept_reference(s.at(0)); + accept_reference(s.front()); + accept_reference(s.back()); +} + +TEST(MutableIntSlice, IteratorsAndReferences_Const) { + auto accept_pointer = [](int* x) {}; + auto accept_reference = [](int& x) {}; + auto accept_iterator = [](MutableIntSlice::iterator x) {}; + auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {}; + + int a[1]; + const MutableIntSlice s = a; + + accept_pointer(s.data()); + accept_pointer(s.mutable_data()); + accept_iterator(s.begin()); + accept_iterator(s.end()); + accept_reverse_iterator(s.rbegin()); + accept_reverse_iterator(s.rend()); + + accept_reference(s[0]); + accept_reference(s.at(0)); + accept_reference(s.front()); + accept_reference(s.back()); +} + +bool TestMutableOverload(MutableIntSlice slice) { return false; } + +bool TestMutableOverload(MutableCharSlice slice) { return true; } + +TEST(MutableCharSlice, StringConversion) { + for (int len = 0; len < 20; len++) { + string str(len, '\0'); + MutableCharSlice v = &str; // Test assignment + static_cast<void>(v); + TestImplicitConversion(v, str.data(), str.size()); + } + // Verify that only the correct overload is feasible. Note that this would + // fail if the string ctor was declared simply as MutableArraySlice(string*), + // since in that case both overloads would be feasible. + string str; + EXPECT_TRUE(TestMutableOverload(&str)); +} + +} // namespace +} // namespace gtl +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/edit_distance.h b/tensorflow/core/lib/gtl/edit_distance.h new file mode 100644 index 0000000000..82b6c2299f --- /dev/null +++ b/tensorflow/core/lib/gtl/edit_distance.h @@ -0,0 +1,82 @@ +#ifndef TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_ +#define TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_ + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { +namespace gtl { + +// Calculate the Levenshtein Edit Distance between two contiguous +// sequences, s and t, of type T. +// +// The Levenshtein distance is a symmetric distance defined as the +// smallest number of insertions, deletions, and substitutions +// required to convert sequence s to t (and vice versa). +// Note, this distance does not consider transpositions. +// +// For more details and a reference implementation, see: +// https://en.wikipedia.org/wiki/Levenshtein_distance +// +// This implementation has time complexity O(|s|*|t|) +// and space complexity O(min(|s|, |t|)), where +// |x| := x.size() +// +// A simple call to LevenshteinDistance looks like: +// +// int64 dist = LevenshteinDistance("hi", "bye", std::equal_to<char>()); +// +template <typename T, typename Cmp> +inline int64 LevenshteinDistance(const gtl::ArraySlice<T>& s, + const gtl::ArraySlice<T>& t, const Cmp& cmp) { + const int64 s_size = s.size(); + const int64 t_size = t.size(); + + if (s_size == 0) return t_size; + if (t_size == 0) return s_size; + if (s == t) return 0; + if (t_size > s_size) return LevenshteinDistance(t, s, cmp); + + // Create work vectors + gtl::InlinedVector<int64, 32> scratch0(t_size + 1); + gtl::InlinedVector<int64, 32> scratch1(t_size + 1); + + int64* previous = scratch0.data(); + int64* current = scratch1.data(); + + // Initialize previous row of distances + std::iota(scratch0.begin(), scratch0.end(), 0); + + for (int64 i = 0; i < s_size; ++i) { + // Swap current and previous rows for next iteration + std::swap(previous, current); + + // Calculate current row distances from previous row + current[0] = i + 1; + + // Fill in the rest of the row + for (int64 j = 0; j < t_size; ++j) { + const int64 cost = cmp(s[i], t[j]) ? 0 : 1; + current[j + 1] = + std::min(current[j] + 1, // deletion cost + std::min(previous[j + 1] + 1, // insertion cost + previous[j] + cost)); // substitution cost + } + } + + return current[t_size]; +} + +template <typename Container1, typename Container2, typename Cmp> +inline int64 LevenshteinDistance(const Container1& s, const Container2& t, + const Cmp& cmp) { + return LevenshteinDistance( + gtl::ArraySlice<typename Container1::value_type>(s.data(), s.size()), + gtl::ArraySlice<typename Container1::value_type>(t.data(), t.size()), + cmp); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_ diff --git a/tensorflow/core/lib/gtl/edit_distance_test.cc b/tensorflow/core/lib/gtl/edit_distance_test.cc new file mode 100644 index 0000000000..0526ee0a05 --- /dev/null +++ b/tensorflow/core/lib/gtl/edit_distance_test.cc @@ -0,0 +1,125 @@ +#include "tensorflow/core/lib/gtl/edit_distance.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace gtl { +namespace { + +class LevenshteinDistanceTest : public ::testing::Test { + protected: + std::vector<char> empty_; + std::string s1_; + std::string s1234_; + std::string s567_; + std::string kilo_; + std::string kilogram_; + std::string mother_; + std::string grandmother_; + std::string lower_; + std::string upper_; + + void SetUp() override { + s1_ = "1"; + s1234_ = "1234"; + s567_ = "567"; + kilo_ = "kilo"; + kilogram_ = "kilogram"; + mother_ = "mother"; + grandmother_ = "grandmother"; + lower_ = "lower case"; + upper_ = "UPPER case"; + } +}; + +TEST_F(LevenshteinDistanceTest, BothEmpty) { + ASSERT_EQ(LevenshteinDistance(empty_, empty_, std::equal_to<char>()), 0); +} + +TEST_F(LevenshteinDistanceTest, OneEmpty) { + ASSERT_EQ(LevenshteinDistance(s1234_, empty_, std::equal_to<char>()), 4); + ASSERT_EQ(LevenshteinDistance(empty_, s567_, std::equal_to<char>()), 3); +} + +TEST_F(LevenshteinDistanceTest, SingleElement) { + ASSERT_EQ(LevenshteinDistance(s1234_, s1_, std::equal_to<char>()), 3); + ASSERT_EQ(LevenshteinDistance(s1_, s1234_, std::equal_to<char>()), 3); +} + +TEST_F(LevenshteinDistanceTest, Prefix) { + ASSERT_EQ(LevenshteinDistance(kilo_, kilogram_, std::equal_to<char>()), 4); + ASSERT_EQ(LevenshteinDistance(kilogram_, kilo_, std::equal_to<char>()), 4); +} + +TEST_F(LevenshteinDistanceTest, Suffix) { + ASSERT_EQ(LevenshteinDistance(mother_, grandmother_, std::equal_to<char>()), + 5); + ASSERT_EQ(LevenshteinDistance(grandmother_, mother_, std::equal_to<char>()), + 5); +} + +TEST_F(LevenshteinDistanceTest, DifferentComparisons) { + ASSERT_EQ(LevenshteinDistance(lower_, upper_, std::equal_to<char>()), 5); + ASSERT_EQ(LevenshteinDistance(upper_, lower_, std::equal_to<char>()), 5); + ASSERT_EQ( + LevenshteinDistance(gtl::ArraySlice<char>(lower_.data(), lower_.size()), + gtl::ArraySlice<char>(upper_.data(), upper_.size()), + std::equal_to<char>()), + 5); + auto no_case_cmp = [](char c1, char c2) { + return std::tolower(c1) == std::tolower(c2); + }; + ASSERT_EQ(LevenshteinDistance(lower_, upper_, no_case_cmp), 3); + ASSERT_EQ(LevenshteinDistance(upper_, lower_, no_case_cmp), 3); +} + +TEST_F(LevenshteinDistanceTest, Vectors) { + ASSERT_EQ( + LevenshteinDistance(std::string("algorithm"), std::string("altruistic"), + std::equal_to<char>()), + 6); +} + +static void BM_EditDistanceHelper(int n, int len, bool completely_different) { + string a = + "The quick brown fox jumped over the lazy dog and on and on and on" + " Every good boy deserves fudge. In fact, this is a very long sentence " + " w/many bytes.."; + while (a.size() < static_cast<size_t>(len)) { + a = a + a; + } + string b = a; + if (completely_different) { + for (size_t i = 0; i < b.size(); i++) { + b[i]++; + } + } + while (n-- > 0) { + LevenshteinDistance(gtl::ArraySlice<char>(a.data(), len), + gtl::ArraySlice<char>(b.data(), len), + std::equal_to<char>()); + } +} + +static void BM_EditDistanceSame(int n, int len) { + BM_EditDistanceHelper(n, len, false); +} +static void BM_EditDistanceDiff(int n, int len) { + BM_EditDistanceHelper(n, len, true); +} + +BENCHMARK(BM_EditDistanceSame)->Arg(5); +BENCHMARK(BM_EditDistanceSame)->Arg(50); +BENCHMARK(BM_EditDistanceSame)->Arg(200); +BENCHMARK(BM_EditDistanceSame)->Arg(1000); +BENCHMARK(BM_EditDistanceDiff)->Arg(5); +BENCHMARK(BM_EditDistanceDiff)->Arg(50); +BENCHMARK(BM_EditDistanceDiff)->Arg(200); +BENCHMARK(BM_EditDistanceDiff)->Arg(1000); + +} // namespace +} // namespace gtl +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h new file mode 100644 index 0000000000..c23075129c --- /dev/null +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -0,0 +1,839 @@ +// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage +// for sequences of length <= N are provided inline without requiring +// any heap allocation. Typically N is very small (e.g., 4) so that +// sequences that are expected to be short do not require allocations. +// +// Only some of the std::vector<> operations are currently implemented. +// Other operations may be added as needed to facilitate migrating +// code that uses std::vector<> to InlinedVector<>. +// +// NOTE: If you want an inlined version to replace use of a +// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS> +// in util/bitmap/inlined_bitvector.h +// +// TODO(billydonahue): change size_t to size_type where appropriate. + +#ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ +#define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ + +#include <stddef.h> +#include <stdlib.h> +#include <string.h> +#include <sys/types.h> +#include <algorithm> +#include <iterator> +#include <memory> +#include <type_traits> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/gtl/manual_constructor.h" + +#include <initializer_list> // NOLINT(build/include_order) + +namespace tensorflow { +namespace gtl { + +template <typename T, int N, typename A = std::allocator<T> > +class InlinedVector { + public: + typedef A allocator_type; + typedef typename allocator_type::value_type value_type; + typedef typename allocator_type::pointer pointer; + typedef typename allocator_type::const_pointer const_pointer; + typedef typename allocator_type::reference reference; + typedef typename allocator_type::const_reference const_reference; + typedef typename allocator_type::size_type size_type; + typedef typename allocator_type::difference_type difference_type; + typedef pointer iterator; + typedef const_pointer const_iterator; + + // Create an empty vector + InlinedVector(); + explicit InlinedVector(const allocator_type& alloc); + + // Create a vector with n copies of value_type(). + explicit InlinedVector(size_t n); + + // Create a vector with n copies of elem + InlinedVector(size_t n, const value_type& elem, + const allocator_type& alloc = allocator_type()); + + // Create and initialize with the elements [range_start .. range_end). + // The unused enable_if argument restricts this constructor so that it is + // elided when value_type is an integral type. This prevents ambiguous + // interpretation between a call to this constructor with two integral + // arguments and a call to the preceding (n, elem) constructor. + template <typename InputIterator> + InlinedVector( + InputIterator range_start, InputIterator range_end, + const allocator_type& alloc = allocator_type(), + typename std::enable_if<!std::is_integral<InputIterator>::value>::type* = + NULL) + : allocator_and_tag_(alloc) { + AppendRange(range_start, range_end); + } + + InlinedVector(std::initializer_list<value_type> init, + const allocator_type& alloc = allocator_type()) + : allocator_and_tag_(alloc) { + AppendRange(init.begin(), init.end()); + } + + InlinedVector(const InlinedVector& v); + + ~InlinedVector() { clear(); } + + InlinedVector& operator=(const InlinedVector& v) { + // Optimized to avoid reallocation. + // Prefer reassignment to copy construction for elements. + if (size() < v.size()) { // grow + reserve(v.size()); + std::copy(v.begin(), v.begin() + size(), begin()); + std::copy(v.begin() + size(), v.end(), std::back_inserter(*this)); + } else { // maybe shrink + erase(begin() + v.size(), end()); + std::copy(v.begin(), v.end(), begin()); + } + return *this; + } + + size_t size() const { + return allocated() ? allocation().size() : tag().size(); + } + + bool empty() const { return (size() == 0); } + + // Return number of elements that can be stored in vector + // without requiring a reallocation of underlying memory + size_t capacity() const { return allocated() ? allocation().capacity() : N; } + + // Return a pointer to the underlying array. + // Only result[0,size()-1] are defined. + const_pointer data() const { + return allocated() ? allocated_space() : inlined_space(); + } + pointer data() { return allocated() ? allocated_space() : inlined_space(); } + + // An older name for the more standard-friendly .data(). + const_pointer array() const { return data(); } + pointer mutable_array() { return data(); } + + // Remove all elements + void clear() { + size_t s = size(); + if (allocated()) { + DestroyAllocated(allocated_space(), allocated_space() + s); + allocation().Dealloc(allocator()); + } else { + DestroyInlined(inlined_space(), inlined_space() + s); + } + tag() = Tag(); + } + + // Return the ith element + // REQUIRES: 0 <= i < size() + const value_type& at(size_t i) const { + DCHECK_LT(i, size()); + return array()[i]; + } + const value_type& operator[](size_t i) const { + DCHECK_LT(i, size()); + return array()[i]; + } + + // Return a non-const reference to the ith element + // REQUIRES: 0 <= i < size() + value_type& at(size_t i) { + DCHECK_LT(i, size()); + return mutable_array()[i]; + } + value_type& operator[](size_t i) { + DCHECK_LT(i, size()); + return mutable_array()[i]; + } + + value_type& back() { + DCHECK(!empty()); + return at(size() - 1); + } + + const value_type& back() const { + DCHECK(!empty()); + return at(size() - 1); + } + + value_type& front() { + DCHECK(!empty()); + return at(0); + } + + const value_type& front() const { + DCHECK(!empty()); + return at(0); + } + + // Append t to the vector. + // Increases size() by one. + // Amortized complexity: O(1) + // Worst-case complexity: O(size()) + void push_back(const value_type& t) { + size_t s = size(); + DCHECK_LE(s, capacity()); + if (s == capacity()) { + return GrowAndPushBack(t); + } + DCHECK_LT(s, capacity()); + + if (allocated()) { + ConstructAllocated(allocated_space() + s, t); + } else { + ConstructInlined(inlined_space() + s, t); + } + + set_size_internal(s + 1); + } + + void pop_back() { + DCHECK(!empty()); + size_t s = size(); + if (allocated()) { + DestroyAllocated(allocated_space() + s - 1, allocated_space() + s); + } else { + DestroyInlined(inlined_space() + s - 1, inlined_space() + s); + } + set_size_internal(s - 1); + } + + // Resizes the vector to contain "n" elements. + // If "n" is smaller than the initial size, extra elements are destroyed. + // If "n" is larger than the initial size, enough copies of "elem" + // are appended to increase the size to "n". If "elem" is omitted, + // new elements are value-initialized. + void resize(size_t n); + void resize(size_t n, const value_type& elem); + + iterator begin() { return mutable_array(); } + const_iterator begin() const { return array(); } + + iterator end() { return mutable_array() + size(); } + const_iterator end() const { return array() + size(); } + + iterator insert(iterator pos, const value_type& v); + + iterator erase(iterator pos) { + DCHECK_LT(pos, end()); + DCHECK_GE(pos, begin()); + std::copy(pos + 1, end(), pos); + pop_back(); + return pos; + } + + iterator erase(iterator first, iterator last); + + // Enlarges the underlying representation so it can hold at least + // "n" elements without reallocation. + // Does not change size() or the actual contents of the vector. + void reserve(size_t n) { + if (n > capacity()) { + // Make room for new elements + EnlargeBy(n - size()); + } + } + + // Swap the contents of *this with other. + // REQUIRES: value_type is swappable and copyable. + void swap(InlinedVector& other); + + allocator_type get_allocator() const { return allocator(); } + + private: + struct AllocatorTraits { + typedef typename allocator_type::value_type value_type; + typedef typename allocator_type::pointer pointer; + typedef typename allocator_type::size_type size_type; + + static void construct(allocator_type& a, // NOLINT(runtime/references) + pointer p) { + // Tricky: do we support non-copyable types, or support allocators + // that do special things with construct()? Non-copyable types are + // needed today, so they are more important. When we sort out the + // Android NDK C++11 problem, we will be able to use the proper + // std::allocator_traits<A>::construct(p, ...). + // + // a.construct(p, value_type()); + new (p) value_type(); + } + static void construct(allocator_type& a, // NOLINT(runtime/references) + pointer p, const value_type& t) { + a.construct(p, t); + } + static void destroy(allocator_type& a, // NOLINT(runtime/references) + pointer p) { + a.destroy(p); + } + static pointer allocate(allocator_type& a, // NOLINT(runtime/references) + size_type n) { + return a.allocate(n); + } + static void deallocate(allocator_type& a, // NOLINT(runtime/references) + pointer p, size_type n) { + a.deallocate(p, n); + } + }; + + // If the vector is inlined, holds the size of the vector. + // If the vector is allocated, holds the special value kAllocated, + // and the size is stored in the vector's Allocation. + class Tag { + public: + Tag() : size_(0) {} + size_t size() const { return size_; } + void set_size(size_t n) { size_ = n; } + bool allocated() const { return size_ == kAllocated; } + void set_allocated() { size_ = kAllocated; } + + private: + static const size_t kAllocated = -1; + size_t size_; + }; + + // Derives from allocator_type to use the empty base class optimization. + // If the allocator_type is stateless, we can 'store' + // our instance of it for free. + class AllocatorAndTag : private allocator_type { + public: + explicit AllocatorAndTag(const allocator_type& a, Tag t = Tag()) + : allocator_type(a), tag_(t) {} + Tag& tag() { return tag_; } + const Tag& tag() const { return tag_; } + allocator_type& allocator() { return *this; } + const allocator_type& allocator() const { return *this; } + + private: + Tag tag_; + }; + + class Allocation { + public: + Allocation(allocator_type& a, // NOLINT(runtime/references) + size_t capacity) + : size_(0), + capacity_(capacity), + buffer_(AllocatorTraits::allocate(a, capacity_)) {} + + void Dealloc(allocator_type& a) { // NOLINT(runtime/references) + AllocatorTraits::deallocate(a, buffer(), capacity()); + } + + size_t size() const { return size_; } + void set_size(size_t s) { size_ = s; } + size_t capacity() const { return capacity_; } + const value_type* buffer() const { return buffer_; } + value_type* buffer() { return buffer_; } + + private: + size_t size_; + size_t capacity_; + value_type* buffer_; + }; + + const Tag& tag() const { return allocator_and_tag_.tag(); } + Tag& tag() { return allocator_and_tag_.tag(); } + + Allocation& allocation() { return *rep_.allocation_storage.allocation.get(); } + const Allocation& allocation() const { + return *rep_.allocation_storage.allocation.get(); + } + void init_allocation(const Allocation& allocation) { + rep_.allocation_storage.allocation.Init(allocation); + } + + value_type* inlined_space() { return rep_.inlined_storage.inlined[0].get(); } + const value_type* inlined_space() const { + return rep_.inlined_storage.inlined[0].get(); + } + + value_type* allocated_space() { return allocation().buffer(); } + const value_type* allocated_space() const { return allocation().buffer(); } + + const allocator_type& allocator() const { + return allocator_and_tag_.allocator(); + } + allocator_type& allocator() { return allocator_and_tag_.allocator(); } + + bool allocated() const { return tag().allocated(); } + void set_allocated() { return tag().set_allocated(); } + + void set_size_internal(size_t n) { + if (allocated()) { + allocation().set_size(n); + } else { + tag().set_size(n); + } + } + + // Enlarge the underlying representation so we can store size_ + delta elems. + // The size is not changed, and any newly added memory is not initialized. + void EnlargeBy(size_t delta); + + void ResetAllocation(Allocation new_allocation) { + if (allocated()) { + DestroyAllocated(allocated_space(), allocated_space() + size()); + DCHECK_EQ(begin(), allocated_space()); + allocation().Dealloc(allocator()); + allocation() = new_allocation; + } else { + DestroyInlined(inlined_space(), inlined_space() + size()); + init_allocation(new_allocation); // bug: only init once + set_allocated(); + } + } + + void GrowAndPushBack(const value_type& t) { + DCHECK_EQ(size(), capacity()); + const size_t s = size(); + + Allocation new_allocation(allocator(), 2 * capacity()); + new_allocation.set_size(s + 1); + + UninitializedCopyAllocated(array(), array() + s, new_allocation.buffer()); + ConstructAllocated(new_allocation.buffer() + s, t); + + ResetAllocation(new_allocation); + } + + void InitAssign(size_t n); + void InitAssign(size_t n, const value_type& t); + + void ConstructInlined(pointer p) { new (p) value_type(); } + + void ConstructInlined(pointer p, const value_type& t) { + new (p) value_type(t); + } + + void ConstructAllocated(pointer p) { + AllocatorTraits::construct(allocator(), p); + } + void ConstructAllocated(pointer p, const value_type& t) { + AllocatorTraits::construct(allocator(), p, t); + } + + template <typename Iter> + void UninitializedCopyInlined(Iter src, Iter src_last, value_type* dst) { + std::uninitialized_copy(src, src_last, dst); + } + + template <typename Iter> + void UninitializedCopyAllocated(Iter src, Iter src_last, value_type* dst) { + for (; src != src_last; ++dst, ++src) ConstructAllocated(dst, *src); + } + + void UninitializedFillInlined(value_type* dst, value_type* dst_last) { + for (; dst != dst_last; ++dst) ConstructInlined(dst); + } + void UninitializedFillInlined(value_type* dst, value_type* dst_last, + const value_type& t) { + std::uninitialized_fill(dst, dst_last, t); + } + + void UninitializedFillAllocated(value_type* dst, value_type* dst_last) { + for (; dst != dst_last; ++dst) ConstructAllocated(dst); + } + void UninitializedFillAllocated(value_type* dst, value_type* dst_last, + const value_type& t) { + for (; dst != dst_last; ++dst) ConstructAllocated(dst, t); + } + + // Destroy [ptr, ptr_last) in place. + void DestroyInlined(value_type* ptr, value_type* ptr_last); + void DestroyAllocated(value_type* ptr, value_type* ptr_last); + + template <typename Iter> + void AppendRange(Iter first, Iter last, std::input_iterator_tag); + + // Faster path for forward iterators. + template <typename Iter> + void AppendRange(Iter first, Iter last, std::forward_iterator_tag); + + template <typename Iter> + void AppendRange(Iter first, Iter last); + + AllocatorAndTag allocator_and_tag_; + + // Either the inlined or allocated representation + union Rep { + // Use struct to perform indirection that solves a bizarre compilation + // error on Visual Studio (all known versions). + struct { + tensorflow::ManualConstructor<value_type> inlined[N]; + } inlined_storage; + struct { + tensorflow::ManualConstructor<Allocation> allocation; + } allocation_storage; + } rep_; +}; + +template <typename T, int N, typename A> +const size_t InlinedVector<T, N, A>::Tag::kAllocated; + +template <typename T, int N, typename A> +inline void swap(InlinedVector<T, N, A>& a, InlinedVector<T, N, A>& b) { + a.swap(b); +} + +template <typename T, int N, typename A> +inline bool operator==(const InlinedVector<T, N, A>& a, + const InlinedVector<T, N, A>& b) { + return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin()); +} + +template <typename T, int N, typename A> +inline bool operator!=(const InlinedVector<T, N, A>& a, + const InlinedVector<T, N, A>& b) { + return !(a == b); +} + +template <typename T, int N, typename A> +inline bool operator<(const InlinedVector<T, N, A>& a, + const InlinedVector<T, N, A>& b) { + return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); +} + +template <typename T, int N, typename A> +inline bool operator>(const InlinedVector<T, N, A>& a, + const InlinedVector<T, N, A>& b) { + return b < a; +} + +template <typename T, int N, typename A> +inline bool operator<=(const InlinedVector<T, N, A>& a, + const InlinedVector<T, N, A>& b) { + return !(b < a); +} + +template <typename T, int N, typename A> +inline bool operator>=(const InlinedVector<T, N, A>& a, + const InlinedVector<T, N, A>& b) { + return !(a < b); +} + +// ======================================== +// Implementation + +template <typename T, int N, typename A> +inline InlinedVector<T, N, A>::InlinedVector() + : allocator_and_tag_(allocator_type()) {} + +template <typename T, int N, typename A> +inline InlinedVector<T, N, A>::InlinedVector(const allocator_type& alloc) + : allocator_and_tag_(alloc) {} + +template <typename T, int N, typename A> +inline InlinedVector<T, N, A>::InlinedVector(size_t n) + : allocator_and_tag_(allocator_type()) { + InitAssign(n); +} + +template <typename T, int N, typename A> +inline InlinedVector<T, N, A>::InlinedVector(size_t n, const value_type& elem, + const allocator_type& alloc) + : allocator_and_tag_(alloc) { + InitAssign(n, elem); +} + +template <typename T, int N, typename A> +inline InlinedVector<T, N, A>::InlinedVector(const InlinedVector& v) + : allocator_and_tag_(v.allocator()) { + reserve(v.size()); + if (allocated()) { + UninitializedCopyAllocated(v.begin(), v.end(), allocated_space()); + } else { + UninitializedCopyInlined(v.begin(), v.end(), inlined_space()); + } + set_size_internal(v.size()); +} + +template <typename T, int N, typename A> +inline void InlinedVector<T, N, A>::InitAssign(size_t n, const value_type& t) { + if (n > static_cast<size_t>(N)) { + Allocation new_allocation(allocator(), n); + init_allocation(new_allocation); + set_allocated(); + UninitializedFillAllocated(allocated_space(), allocated_space() + n, t); + } else { + UninitializedFillInlined(inlined_space(), inlined_space() + n, t); + } + set_size_internal(n); +} + +template <typename T, int N, typename A> +inline void InlinedVector<T, N, A>::InitAssign(size_t n) { + if (n > static_cast<size_t>(N)) { + Allocation new_allocation(allocator(), n); + init_allocation(new_allocation); + set_allocated(); + UninitializedFillAllocated(allocated_space(), allocated_space() + n); + } else { + UninitializedFillInlined(inlined_space(), inlined_space() + n); + } + set_size_internal(n); +} + +template <typename T, int N, typename A> +inline void InlinedVector<T, N, A>::resize(size_t n) { + size_t s = size(); + if (n < s) { + erase(begin() + n, end()); + return; + } + reserve(n); + DCHECK_GE(capacity(), n); + + // Fill new space with elements constructed in-place. + if (allocated()) { + UninitializedFillAllocated(allocated_space() + s, allocated_space() + n); + } else { + UninitializedFillInlined(inlined_space() + s, inlined_space() + n); + } + set_size_internal(n); +} + +template <typename T, int N, typename A> +inline void InlinedVector<T, N, A>::resize(size_t n, const value_type& elem) { + size_t s = size(); + if (n < s) { + erase(begin() + n, end()); + return; + } + reserve(n); + DCHECK_GE(capacity(), n); + + // Fill new space with copies of 'elem'. + if (allocated()) { + UninitializedFillAllocated(allocated_space() + s, allocated_space() + n, + elem); + } else { + UninitializedFillInlined(inlined_space() + s, inlined_space() + n, elem); + } + set_size_internal(n); +} + +template <typename T, int N, typename A> +typename InlinedVector<T, N, A>::iterator InlinedVector<T, N, A>::insert( + iterator pos, const value_type& v) { + DCHECK_GE(pos, begin()); + DCHECK_LE(pos, end()); + if (pos == end()) { + push_back(v); + return end() - 1; + } + size_t s = size(); + size_t idx = std::distance(begin(), pos); + if (s == capacity()) { + EnlargeBy(1); + } + CHECK_LT(s, capacity()); + pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator. + + if (allocated()) { + ConstructAllocated(allocated_space() + s, *(allocated_space() + s - 1)); + std::copy_backward(pos, allocated_space() + s - 1, allocated_space() + s); + } else { + ConstructInlined(inlined_space() + s, *(inlined_space() + s - 1)); + std::copy_backward(pos, inlined_space() + s - 1, inlined_space() + s); + } + + *pos = v; + + set_size_internal(s + 1); + return pos; +} + +template <typename T, int N, typename A> +typename InlinedVector<T, N, A>::iterator InlinedVector<T, N, A>::erase( + iterator first, iterator last) { + DCHECK_LE(begin(), first); + DCHECK_LE(first, last); + DCHECK_LE(last, end()); + + size_t s = size(); + ptrdiff_t erase_gap = std::distance(first, last); + + if (allocated()) { + std::copy(last, allocated_space() + s, first); + DestroyAllocated(allocated_space() + s - erase_gap, allocated_space() + s); + } else { + std::copy(last, inlined_space() + s, first); + DestroyInlined(inlined_space() + s - erase_gap, inlined_space() + s); + } + + set_size_internal(size() - erase_gap); + + return first; +} + +template <typename T, int N, typename A> +void InlinedVector<T, N, A>::swap(InlinedVector& other) { + using std::swap; // Augment ADL with std::swap. + if (&other == this) { + return; + } + if (allocated() && other.allocated()) { + // Both out of line, so just swap the tag, allocation, and allocator. + swap(tag(), other.tag()); + swap(allocation(), other.allocation()); + swap(allocator(), other.allocator()); + return; + } + if (!allocated() && !other.allocated()) { + // Both inlined: swap up to smaller size, then move remaining elements. + InlinedVector* a = this; + InlinedVector* b = &other; + if (size() < other.size()) { + swap(a, b); + } + + const size_t a_size = a->size(); + const size_t b_size = b->size(); + DCHECK_GE(a_size, b_size); + // 'a' is larger. Swap the elements up to the smaller array size. + std::swap_ranges(a->inlined_space(), a->inlined_space() + b_size, + b->inlined_space()); + + // Move the remaining elements: A[b_size,a_size) -> B[b_size,a_size) + b->UninitializedCopyInlined(a->inlined_space() + b_size, + a->inlined_space() + a_size, + b->inlined_space() + b_size); + a->DestroyInlined(a->inlined_space() + b_size, a->inlined_space() + a_size); + + swap(a->tag(), b->tag()); + swap(a->allocator(), b->allocator()); + DCHECK_EQ(b->size(), a_size); + DCHECK_EQ(a->size(), b_size); + return; + } + // One is out of line, one is inline. + // We first move the elements from the inlined vector into the + // inlined space in the other vector. We then put the other vector's + // pointer/capacity into the originally inlined vector and swap + // the tags. + InlinedVector* a = this; + InlinedVector* b = &other; + if (a->allocated()) { + swap(a, b); + } + DCHECK(!a->allocated()); + DCHECK(b->allocated()); + const size_t a_size = a->size(); + const size_t b_size = b->size(); + + // Made Local copies of size(), don't need tag() accurate anymore + swap(a->tag(), b->tag()); + + // Copy b_allocation out before b's union gets clobbered by inline_space. + Allocation b_allocation = b->allocation(); + + b->UninitializedCopyInlined(a->inlined_space(), a->inlined_space() + a_size, + b->inlined_space()); + a->DestroyInlined(a->inlined_space(), a->inlined_space() + a_size); + + a->allocation() = b_allocation; + + if (a->allocator() != b->allocator()) { + swap(a->allocator(), b->allocator()); + } + + DCHECK_EQ(b->size(), a_size); + DCHECK_EQ(a->size(), b_size); +} + +template <typename T, int N, typename A> +void InlinedVector<T, N, A>::EnlargeBy(size_t delta) { + const size_t s = size(); + DCHECK_LE(s, capacity()); + + size_t target = std::max(static_cast<size_t>(N), s + delta); + + // Compute new capacity by repeatedly doubling current capacity + // TODO(psrc): Check and avoid overflow? + size_t new_capacity = capacity(); + while (new_capacity < target) { + new_capacity <<= 1; + } + + Allocation new_allocation(allocator(), new_capacity); + new_allocation.set_size(s); + + UninitializedCopyAllocated(array(), array() + s, new_allocation.buffer()); + + ResetAllocation(new_allocation); +} + +template <typename T, int N, typename A> +inline void InlinedVector<T, N, A>::DestroyInlined(value_type* ptr, + value_type* ptr_last) { + for (value_type* p = ptr; p != ptr_last; ++p) { + p->~value_type(); + } + +// Overwrite unused memory with 0xab so we can catch uninitialized usage. +// Cast to void* to tell the compiler that we don't care that we might be +// scribbling on a vtable pointer. +#ifndef NDEBUG + if (ptr != ptr_last) { + memset(reinterpret_cast<void*>(ptr), 0xab, sizeof(*ptr) * (ptr_last - ptr)); + } +#endif +} + +template <typename T, int N, typename A> +inline void InlinedVector<T, N, A>::DestroyAllocated(value_type* ptr, + value_type* ptr_last) { + for (value_type* p = ptr; p != ptr_last; ++p) { + AllocatorTraits::destroy(allocator(), p); + } + +// Overwrite unused memory with 0xab so we can catch uninitialized usage. +// Cast to void* to tell the compiler that we don't care that we might be +// scribbling on a vtable pointer. +#ifndef NDEBUG + if (ptr != ptr_last) { + memset(reinterpret_cast<void*>(ptr), 0xab, sizeof(*ptr) * (ptr_last - ptr)); + } +#endif +} + +template <typename T, int N, typename A> +template <typename Iter> +inline void InlinedVector<T, N, A>::AppendRange(Iter first, Iter last, + std::input_iterator_tag) { + std::copy(first, last, std::back_inserter(*this)); +} + +template <typename T, int N, typename A> +template <typename Iter> +inline void InlinedVector<T, N, A>::AppendRange(Iter first, Iter last, + std::forward_iterator_tag) { + typedef typename std::iterator_traits<Iter>::difference_type Length; + Length length = std::distance(first, last); + reserve(size() + length); + if (allocated()) { + UninitializedCopyAllocated(first, last, allocated_space() + size()); + } else { + UninitializedCopyInlined(first, last, inlined_space() + size()); + } + set_size_internal(size() + length); +} + +template <typename T, int N, typename A> +template <typename Iter> +inline void InlinedVector<T, N, A>::AppendRange(Iter first, Iter last) { + typedef typename std::iterator_traits<Iter>::iterator_category IterTag; + AppendRange(first, last, IterTag()); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc new file mode 100644 index 0000000000..ec5fe1eaa8 --- /dev/null +++ b/tensorflow/core/lib/gtl/inlined_vector_test.cc @@ -0,0 +1,905 @@ +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +#include <list> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +typedef tensorflow::gtl::InlinedVector<int, 8> IntVec; + +// A type that counts number of live occurrences of the type +static int64 instances = 0; +class Instance { + public: + int value_; + explicit Instance(int x) : value_(x) { instances++; } + Instance(const Instance& x) : value_(x.value_) { instances++; } + ~Instance() { instances--; } + + friend inline void swap(Instance& a, Instance& b) { + using std::swap; + swap(a.value_, b.value_); + } + + friend std::ostream& operator<<(std::ostream& o, const Instance& v) { + return o << "[value:" << v.value_ << "]"; + } +}; + +typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec; + +// A simple reference counted class to make sure that the proper elements are +// destroyed in the erase(begin, end) test. +class RefCounted { + public: + RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); } + + RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) { + VLOG(5) << "[RefCounted: copy" + << " from count @" << v.count_ << "]"; + Ref(); + } + + ~RefCounted() { + Unref(); + count_ = NULL; + } + + friend void swap(RefCounted& a, RefCounted& b) { + using std::swap; + swap(a.value_, b.value_); + swap(a.count_, b.count_); + } + + RefCounted& operator=(RefCounted v) { + using std::swap; + swap(*this, v); + return *this; + } + + void Ref() const { + CHECK(count_ != NULL); + ++(*count_); + VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]"; + } + + void Unref() const { + --(*count_); + CHECK_GE(*count_, 0); + VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]"; + } + + int count() const { return *count_; } + + friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) { + return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]"; + } + + int value_; + int* count_; +}; + +typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec; + +// A class with a vtable pointer +class Dynamic { + public: + virtual ~Dynamic() {} + + friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) { + return o << "[Dynamic]"; + } +}; + +typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec; + +// Append 0..len-1 to *v +static void Fill(IntVec* v, int len, int offset = 0) { + for (int i = 0; i < len; i++) { + v->push_back(i + offset); + } +} + +static IntVec Fill(int len, int offset = 0) { + IntVec v; + Fill(&v, len, offset); + return v; +} + +TEST(IntVec, SimpleOps) { + for (int len = 0; len < 20; len++) { + IntVec v; + const IntVec& cv = v; // const alias + + Fill(&v, len); + EXPECT_EQ(len, v.size()); + EXPECT_LE(len, v.capacity()); + + for (int i = 0; i < len; i++) { + EXPECT_EQ(i, v[i]); + } + EXPECT_EQ(v.begin(), v.array()); + EXPECT_EQ(v.begin(), v.mutable_array()); + + EXPECT_EQ(v.begin(), v.data()); + EXPECT_EQ(cv.begin(), cv.data()); + + int counter = 0; + for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) { + EXPECT_EQ(counter, *iter); + counter++; + } + EXPECT_EQ(counter, len); + + counter = 0; + for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) { + EXPECT_EQ(counter, *iter); + counter++; + } + EXPECT_EQ(counter, len); + + if (len > 0) { + EXPECT_EQ(0, v.front()); + EXPECT_EQ(len - 1, v.back()); + v.pop_back(); + EXPECT_EQ(len - 1, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(i, v[i]); + } + } + } +} + +TEST(IntVec, Erase) { + for (int len = 1; len < 20; len++) { + for (int i = 0; i < len; ++i) { + IntVec v; + Fill(&v, len); + v.erase(v.begin() + i); + EXPECT_EQ(len - 1, v.size()); + for (int j = 0; j < i; ++j) { + EXPECT_EQ(j, v[j]); + } + for (int j = i; j < len - 1; ++j) { + EXPECT_EQ(j + 1, v[j]); + } + } + } +} + +// At the end of this test loop, the elements between [erase_begin, erase_end) +// should have reference counts == 0, and all others elements should have +// reference counts == 1. +TEST(RefCountedVec, EraseBeginEnd) { + for (int len = 1; len < 20; ++len) { + for (int erase_begin = 0; erase_begin < len; ++erase_begin) { + for (int erase_end = erase_begin; erase_end <= len; ++erase_end) { + std::vector<int> counts(len, 0); + RefCountedVec v; + for (int i = 0; i < len; ++i) { + v.push_back(RefCounted(i, &counts[i])); + } + + int erase_len = erase_end - erase_begin; + + v.erase(v.begin() + erase_begin, v.begin() + erase_end); + + EXPECT_EQ(len - erase_len, v.size()); + + // Check the elements before the first element erased. + for (int i = 0; i < erase_begin; ++i) { + EXPECT_EQ(i, v[i].value_); + } + + // Check the elements after the first element erased. + for (size_t i = erase_begin; i < v.size(); ++i) { + EXPECT_EQ(i + erase_len, v[i].value_); + } + + // Check that the elements at the beginning are preserved. + for (int i = 0; i < erase_begin; ++i) { + EXPECT_EQ(1, counts[i]); + } + + // Check that the erased elements are destroyed + for (int i = erase_begin; i < erase_end; ++i) { + EXPECT_EQ(0, counts[i]); + } + + // Check that the elements at the end are preserved. + for (int i = erase_end; i < len; ++i) { + EXPECT_EQ(1, counts[i]); + } + } + } + } +} + +struct NoDefaultCtor { + explicit NoDefaultCtor(int /* x */) {} +}; +struct NoCopy { + NoCopy() {} + NoCopy(const NoCopy& /* x */) = delete; +}; +struct NoAssign { + NoAssign() {} + NoAssign& operator=(const NoAssign& /* x */) = delete; +}; +TEST(InlinedVectorTest, NoDefaultCtor) { + tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2)); + (void)v; +} +TEST(InlinedVectorTest, NoCopy) { + tensorflow::gtl::InlinedVector<NoCopy, 1> v(10); + (void)v; +} +TEST(InlinedVectorTest, NoAssign) { + tensorflow::gtl::InlinedVector<NoAssign, 1> v(10); + (void)v; +} + +TEST(IntVec, Insert) { + for (int len = 0; len < 20; len++) { + for (int pos = 0; pos <= len; pos++) { + IntVec v; + Fill(&v, len); + v.insert(v.begin() + pos, 9999); + EXPECT_EQ(v.size(), len + 1); + for (int i = 0; i < pos; i++) { + EXPECT_EQ(v[i], i); + } + EXPECT_EQ(v[pos], 9999); + for (size_t i = pos + 1; i < v.size(); i++) { + EXPECT_EQ(v[i], i - 1); + } + } + } +} + +TEST(RefCountedVec, InsertConstructorDestructor) { + // Make sure the proper construction/destruction happen during insert + // operations. + for (int len = 0; len < 20; len++) { + SCOPED_TRACE(len); + for (int pos = 0; pos <= len; pos++) { + SCOPED_TRACE(pos); + std::vector<int> counts(len, 0); + RefCountedVec v; + for (int i = 0; i < len; ++i) { + SCOPED_TRACE(i); + v.push_back(RefCounted(i, &counts[i])); + } + + for (auto elem : counts) { + EXPECT_EQ(1, elem); + } + + int inserted_count = 0; + RefCounted insert_element(9999, &inserted_count); + EXPECT_EQ(1, inserted_count); + v.insert(v.begin() + pos, insert_element); + EXPECT_EQ(2, inserted_count); + // Check that the elements at the end are preserved. + for (auto elem : counts) { + EXPECT_EQ(1, elem); + } + EXPECT_EQ(2, inserted_count); + } + } +} + +TEST(IntVec, Resize) { + for (int len = 0; len < 20; len++) { + IntVec v; + Fill(&v, len); + + // Try resizing up and down by k elements + static const int kResizeElem = 1000000; + for (int k = 0; k < 10; k++) { + // Enlarging resize + v.resize(len + k, kResizeElem); + EXPECT_EQ(len + k, v.size()); + EXPECT_LE(len + k, v.capacity()); + for (int i = 0; i < len + k; i++) { + if (i < len) { + EXPECT_EQ(i, v[i]); + } else { + EXPECT_EQ(kResizeElem, v[i]); + } + } + + // Shrinking resize + v.resize(len, kResizeElem); + EXPECT_EQ(len, v.size()); + EXPECT_LE(len, v.capacity()); + for (int i = 0; i < len; i++) { + EXPECT_EQ(i, v[i]); + } + } + } +} + +TEST(IntVec, InitWithLength) { + for (int len = 0; len < 20; len++) { + IntVec v(len, 7); + EXPECT_EQ(len, v.size()); + EXPECT_LE(len, v.capacity()); + for (int i = 0; i < len; i++) { + EXPECT_EQ(7, v[i]); + } + } +} + +TEST(IntVec, CopyConstructorAndAssignment) { + for (int len = 0; len < 20; len++) { + IntVec v; + Fill(&v, len); + EXPECT_EQ(len, v.size()); + EXPECT_LE(len, v.capacity()); + + IntVec v2(v); + EXPECT_EQ(v, v2); + + for (int start_len = 0; start_len < 20; start_len++) { + IntVec v3; + Fill(&v3, start_len, 99); // Add dummy elements that should go away + v3 = v; + EXPECT_EQ(v, v3); + } + } +} + +TEST(OverheadTest, Storage) { + // Check for size overhead. + // In particular, ensure that std::allocator doesn't cost anything to store. + // The union should be absorbing some of the allocation bookkeeping overhead + // in the larger vectors, leaving only the size_ field as overhead. + using tensorflow::gtl::InlinedVector; + EXPECT_EQ(3 * sizeof(int*), + sizeof(InlinedVector<int*, 1>) - 1 * sizeof(int*)); + EXPECT_EQ(2 * sizeof(int*), + sizeof(InlinedVector<int*, 2>) - 2 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector<int*, 3>) - 3 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector<int*, 4>) - 4 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector<int*, 5>) - 5 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector<int*, 6>) - 6 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector<int*, 7>) - 7 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector<int*, 8>) - 8 * sizeof(int*)); +} + +TEST(IntVec, Clear) { + for (int len = 0; len < 20; len++) { + SCOPED_TRACE(len); + IntVec v; + Fill(&v, len); + v.clear(); + EXPECT_EQ(0, v.size()); + EXPECT_EQ(v.begin(), v.end()); + } +} + +TEST(IntVec, Reserve) { + for (size_t len = 0; len < 20; len++) { + IntVec v; + Fill(&v, len); + + for (size_t newlen = 0; newlen < 100; newlen++) { + const int* start_rep = v.array(); + v.reserve(newlen); + const int* final_rep = v.array(); + if (newlen <= len) { + EXPECT_EQ(start_rep, final_rep); + } + EXPECT_LE(newlen, v.capacity()); + + // Filling up to newlen should not change rep + while (v.size() < newlen) { + v.push_back(0); + } + EXPECT_EQ(final_rep, v.array()); + } + } +} + +template <typename T> +static std::vector<typename T::value_type> Vec(const T& src) { + std::vector<typename T::value_type> result; + for (const auto& elem : src) { + result.push_back(elem); + } + return result; +} + +TEST(IntVec, SelfRefPushBack) { + std::vector<string> std_v; + tensorflow::gtl::InlinedVector<string, 4> v; + const string s = "A very long string to ensure heap."; + std_v.push_back(s); + v.push_back(s); + for (int i = 0; i < 20; ++i) { + EXPECT_EQ(std_v, Vec(v)); + + v.push_back(v.back()); + std_v.push_back(std_v.back()); + } + EXPECT_EQ(std_v, Vec(v)); +} + +TEST(IntVec, Swap) { + for (int l1 = 0; l1 < 20; l1++) { + SCOPED_TRACE(l1); + for (int l2 = 0; l2 < 20; l2++) { + SCOPED_TRACE(l2); + IntVec a = Fill(l1, 0); + IntVec b = Fill(l2, 100); + { + using std::swap; + swap(a, b); + } + EXPECT_EQ(l1, b.size()); + EXPECT_EQ(l2, a.size()); + for (int i = 0; i < l1; i++) { + SCOPED_TRACE(i); + EXPECT_EQ(i, b[i]); + } + for (int i = 0; i < l2; i++) { + SCOPED_TRACE(i); + EXPECT_EQ(100 + i, a[i]); + } + } + } +} + +TEST(InstanceVec, Swap) { + for (int l1 = 0; l1 < 20; l1++) { + for (int l2 = 0; l2 < 20; l2++) { + InstanceVec a, b; + for (int i = 0; i < l1; i++) a.push_back(Instance(i)); + for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i)); + EXPECT_EQ(l1 + l2, instances); + { + using std::swap; + swap(a, b); + } + EXPECT_EQ(l1 + l2, instances); + EXPECT_EQ(l1, b.size()); + EXPECT_EQ(l2, a.size()); + for (int i = 0; i < l1; i++) { + EXPECT_EQ(i, b[i].value_); + } + for (int i = 0; i < l2; i++) { + EXPECT_EQ(100 + i, a[i].value_); + } + } + } +} + +TEST(IntVec, EqualAndNotEqual) { + IntVec a, b; + EXPECT_TRUE(a == b); + EXPECT_FALSE(a != b); + + a.push_back(3); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + + b.push_back(3); + EXPECT_TRUE(a == b); + EXPECT_FALSE(a != b); + + b.push_back(7); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + + a.push_back(6); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + + a.clear(); + b.clear(); + for (int i = 0; i < 100; i++) { + a.push_back(i); + b.push_back(i); + EXPECT_TRUE(a == b); + EXPECT_FALSE(a != b); + + b[i] = b[i] + 1; + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + + b[i] = b[i] - 1; // Back to before + EXPECT_TRUE(a == b); + EXPECT_FALSE(a != b); + } +} + +TEST(IntVec, RelationalOps) { + IntVec a, b; + EXPECT_FALSE(a < b); + EXPECT_FALSE(b < a); + EXPECT_FALSE(a > b); + EXPECT_FALSE(b > a); + EXPECT_TRUE(a <= b); + EXPECT_TRUE(b <= a); + EXPECT_TRUE(a >= b); + EXPECT_TRUE(b >= a); + b.push_back(3); + EXPECT_TRUE(a < b); + EXPECT_FALSE(b < a); + EXPECT_FALSE(a > b); + EXPECT_TRUE(b > a); + EXPECT_TRUE(a <= b); + EXPECT_FALSE(b <= a); + EXPECT_FALSE(a >= b); + EXPECT_TRUE(b >= a); +} + +TEST(InstanceVec, CountConstructorsDestructors) { + const int start = instances; + for (int len = 0; len < 20; len++) { + InstanceVec v; + for (int i = 0; i < len; i++) { + v.push_back(Instance(i)); + } + EXPECT_EQ(start + len, instances); + + { // Copy constructor should create 'len' more instances. + InstanceVec v_copy(v); + EXPECT_EQ(start + len + len, instances); + } + EXPECT_EQ(start + len, instances); + + // Enlarging resize() must construct some objects + v.resize(len + 10, Instance(100)); + EXPECT_EQ(start + len + 10, instances); + + // Shrinking resize() must destroy some objects + v.resize(len, Instance(100)); + EXPECT_EQ(start + len, instances); + + // reserve() must not increase the number of initialized objects + v.reserve(len + 1000); + EXPECT_EQ(start + len, instances); + + // pop_back() and erase() must destroy one object + if (len > 0) { + v.pop_back(); + EXPECT_EQ(start + len - 1, instances); + if (!v.empty()) { + v.erase(v.begin()); + EXPECT_EQ(start + len - 2, instances); + } + } + } + EXPECT_EQ(start, instances); +} + +TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) { + const int start = instances; + for (int len = 0; len < 20; len++) { + for (int longorshort = 0; longorshort <= 1; ++longorshort) { + InstanceVec longer, shorter; + for (int i = 0; i < len; i++) { + longer.push_back(Instance(i)); + shorter.push_back(Instance(i)); + } + longer.push_back(Instance(len)); + EXPECT_EQ(start + len + len + 1, instances); + + if (longorshort) { + shorter = longer; + EXPECT_EQ(start + (len + 1) + (len + 1), instances); + } else { + longer = shorter; + EXPECT_EQ(start + len + len, instances); + } + } + } + EXPECT_EQ(start, instances); +} + +TEST(RangedConstructor, SimpleType) { + std::vector<int> source_v = {4, 5, 6}; + // First try to fit in inline backing + tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end()); + EXPECT_EQ(3, v.size()); + EXPECT_EQ(4, v.capacity()); // Indication that we're still on inlined storage + EXPECT_EQ(4, v[0]); + EXPECT_EQ(5, v[1]); + EXPECT_EQ(6, v[2]); + + // Now, force a re-allocate + tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(), + source_v.end()); + EXPECT_EQ(3, realloc_v.size()); + EXPECT_LT(2, realloc_v.capacity()); + EXPECT_EQ(4, realloc_v[0]); + EXPECT_EQ(5, realloc_v[1]); + EXPECT_EQ(6, realloc_v[2]); +} + +TEST(RangedConstructor, ComplexType) { + // We also use a list here to pass a different flavor of iterator (e.g. not + // random-access). + std::list<Instance> source_v = {Instance(0)}; + + // First try to fit in inline backing + tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(), + source_v.end()); + EXPECT_EQ(1, v.size()); + EXPECT_EQ(1, v.capacity()); // Indication that we're still on inlined storage + EXPECT_EQ(0, v[0].value_); + + std::list<Instance> source_v2 = {Instance(0), Instance(1)}; + // Now, force a re-allocate + tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(), + source_v2.end()); + EXPECT_EQ(2, realloc_v.size()); + EXPECT_LT(1, realloc_v.capacity()); + EXPECT_EQ(0, realloc_v[0].value_); + EXPECT_EQ(1, realloc_v[1].value_); +} + +TEST(RangedConstructor, ElementsAreConstructed) { + std::vector<string> source_v = {"cat", "dog"}; + + // Force expansion and re-allocation of v. Ensures that when the vector is + // expanded that new elements are constructed. + tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end()); + EXPECT_EQ("cat", v[0]); + EXPECT_EQ("dog", v[1]); +} + +TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) { + auto vec = tensorflow::gtl::InlinedVector<int, 4>{4, 5, 6}; + EXPECT_EQ(3, vec.size()); + EXPECT_EQ(4, vec.capacity()); + EXPECT_EQ(4, vec[0]); + EXPECT_EQ(5, vec[1]); + EXPECT_EQ(6, vec[2]); +} + +TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) { + auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6}; + EXPECT_EQ(3, vec.size()); + EXPECT_LE(3, vec.capacity()); + EXPECT_EQ(4, vec[0]); + EXPECT_EQ(5, vec[1]); + EXPECT_EQ(6, vec[2]); +} + +TEST(InitializerListConstructor, DisparateTypesInList) { + EXPECT_EQ((std::vector<int>{-7, 8}), + Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL})); + + EXPECT_EQ( + (std::vector<string>{"foo", "bar"}), + Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")})); +} + +TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) { + auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)}; + EXPECT_EQ(1, vec.size()); + EXPECT_EQ(1, vec.capacity()); + EXPECT_EQ(0, vec[0].value_); +} + +TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) { + auto vec = + tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)}; + EXPECT_EQ(2, vec.size()); + EXPECT_LE(2, vec.capacity()); + EXPECT_EQ(0, vec[0].value_); + EXPECT_EQ(1, vec[1].value_); +} + +TEST(DynamicVec, DynamicVecCompiles) { + DynamicVec v; + (void)v; +} + +#ifdef INLINED_VECTOR_HAS_ALLOC +TEST(AllocatorSupportTest, Constructors) { + typedef STLCountingAllocator<int> MyAlloc; + typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec; + const int ia[] = {0, 1, 2, 3, 4, 5, 6, 7}; + int64 allocated = 0; + MyAlloc alloc(&allocated); + { AllocVec TF_ATTRIBUTE_UNUSED v; } + { AllocVec TF_ATTRIBUTE_UNUSED v(alloc); } + { AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + arraysize(ia), alloc); } +#ifdef LANG_CXX11 + { AllocVec TF_ATTRIBUTE_UNUSED v({1, 2, 3}, alloc); } +#endif // LANG_CXX11 +} + +TEST(AllocatorSupportTest, CountAllocations) { + typedef STLCountingAllocator<int> MyAlloc; + typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec; + const int ia[] = {0, 1, 2, 3, 4, 5, 6, 7}; + int64 allocated = 0; + MyAlloc alloc(&allocated); + { + AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + 4, alloc); + EXPECT_THAT(allocated, 0); + } + EXPECT_THAT(allocated, 0); + { + AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + arraysize(ia), alloc); + EXPECT_THAT(allocated, v.size() * sizeof(int)); + } + EXPECT_THAT(allocated, 0); +} + +TEST(AllocatorSupportTest, SwapBothAllocated) { + typedef STLCountingAllocator<int> MyAlloc; + typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec; + int64 allocated1 = 0; + int64 allocated2 = 0; + { + const std::vector<int> ia1 = {0, 1, 2, 3, 4, 5, 6, 7}; + const std::vector<int> ia2 = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + MyAlloc a1(&allocated1); + MyAlloc a2(&allocated2); + AllocVec v1(ia1.data(), ia1.data() + ia1.size(), a1); + AllocVec v2(ia2.data(), ia2.data() + ia2.size(), a2); + EXPECT_LT(v1.capacity(), v2.capacity()); + EXPECT_THAT(allocated1, v1.capacity() * sizeof(int)); + EXPECT_THAT(allocated2, v2.capacity() * sizeof(int)); + v1.swap(v2); + EXPECT_EQ(ia2, Vec(v1)); + EXPECT_EQ(ia1, Vec(v2)); + EXPECT_THAT(allocated1, v2.capacity() * sizeof(int)); + EXPECT_THAT(allocated2, v1.capacity() * sizeof(int)); + } + EXPECT_THAT(allocated1, 0); + EXPECT_THAT(allocated2, 0); +} + +TEST(AllocatorSupportTest, SwapOneAllocated) { + typedef STLCountingAllocator<int> MyAlloc; + typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec; + int64 allocated1 = 0; + int64 allocated2 = 0; + { + const std::vector<int> ia1 = {0, 1, 2, 3, 4, 5, 6, 7}; + const std::vector<int> ia2 = {0, 1, 2, 3}; + MyAlloc a1(&allocated1); + MyAlloc a2(&allocated2); + AllocVec v1(ia1.data(), ia1.data() + ia1.size(), a1); + AllocVec v2(ia2.data(), ia2.data() + ia2.size(), a2); + EXPECT_THAT(allocated1, v1.capacity() * sizeof(int)); + EXPECT_THAT(allocated2, 0); + v1.swap(v2); + EXPECT_EQ(ia2, Vec(v1)); + EXPECT_EQ(ia1, Vec(v2)); + EXPECT_THAT(allocated1, v2.capacity() * sizeof(int)); + EXPECT_THAT(allocated2, 0); + EXPECT_TRUE(v2.get_allocator() == a1); + EXPECT_TRUE(v1.get_allocator() == a2); + } + EXPECT_THAT(allocated1, 0); + EXPECT_THAT(allocated2, 0); +} +#endif // INLINED_VECTOR_HAS_ALLOC + +static void BM_InlinedVectorFill(int iters, int len) { + for (int i = 0; i < iters; i++) { + IntVec v; + for (int j = 0; j < len; j++) { + v.push_back(j); + } + } + testing::BytesProcessed((static_cast<int64>(iters) * len) * sizeof(int)); +} +BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024); + +static void BM_InlinedVectorFillRange(int iters, int len) { + std::unique_ptr<int[]> ia(new int[len]); + for (int j = 0; j < len; j++) { + ia[j] = j; + } + for (int i = 0; i < iters; i++) { + IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len); + } + testing::BytesProcessed((static_cast<int64>(iters) * len) * sizeof(int)); +} +BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024); + +static void BM_StdVectorFill(int iters, int len) { + for (int i = 0; i < iters; i++) { + std::vector<int> v; + for (int j = 0; j < len; j++) { + v.push_back(j); + } + } + testing::BytesProcessed((static_cast<int64>(iters) * len) * sizeof(int)); +} +BENCHMARK(BM_StdVectorFill)->Range(0, 1024); + +namespace { +struct Buffer { // some arbitrary structure for benchmarking. + char* base; + int length; + int capacity; + void* user_data; +}; +} // anonymous namespace + +static void BM_InlinedVectorTenAssignments(int iters, int len) { + typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec; + + BufferVec src; + src.resize(len); + + iters *= 10; + BufferVec dst; + for (int i = 0; i < iters; i++) { + dst = src; + } +} +BENCHMARK(BM_InlinedVectorTenAssignments) + ->Arg(0) + ->Arg(1) + ->Arg(2) + ->Arg(3) + ->Arg(4) + ->Arg(20); + +static void BM_CreateFromInitializerList(int iters) { + for (; iters > 0; iters--) { + tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3}; + (void)x[0]; + } +} +BENCHMARK(BM_CreateFromInitializerList); + +namespace { + +struct LargeSwappable { + LargeSwappable() : d_(1024, 17) {} + ~LargeSwappable() {} + LargeSwappable(const LargeSwappable& o) : d_(o.d_) {} + + friend void swap(LargeSwappable& a, LargeSwappable& b) { + using std::swap; + swap(a.d_, b.d_); + } + + LargeSwappable& operator=(LargeSwappable o) { + using std::swap; + swap(*this, o); + return *this; + } + + std::vector<int> d_; +}; + +} // namespace + +static void BM_LargeSwappableElements(int iters, int len) { + typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec; + Vec a(len); + Vec b; + while (--iters >= 0) { + using std::swap; + swap(a, b); + } +} +BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024); + +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/int_type.h b/tensorflow/core/lib/gtl/int_type.h new file mode 100644 index 0000000000..d3fcb08d38 --- /dev/null +++ b/tensorflow/core/lib/gtl/int_type.h @@ -0,0 +1,343 @@ +// #status: LEGACY +// #category: Miscellaneous +// #summary: Integral types; prefer util/intops/strong_int.h +// #bugs: Infrastructure > C++ Library Team > util +// +// IntType is a simple template class mechanism for defining "logical" +// integer-like class types that support many of the same functionalities +// as native integer types, but which prevent assignment, construction, and +// other operations from other similar integer-like types. Essentially, the +// template class IntType<IntTypeName, ValueType> (where ValueType assumes +// valid scalar types such as int, uint, int32, etc) has the additional +// property that it cannot be assigned to or constructed from other IntTypes +// or native integer types of equal or implicitly convertible type. +// +// The class is useful for preventing mingling of integer variables with +// different logical roles or units. Unfortunately, C++ provides relatively +// good type-safety for user-defined classes but not for integer types. It is +// essentially up to the user to use nice variable names and comments to prevent +// accidental mismatches, such as confusing a user-index with a group-index or a +// time-in-milliseconds with a time-in-seconds. The use of typedefs are limited +// in that regard as they do not enforce type-safety. +// +// USAGE ----------------------------------------------------------------------- +// +// DEFINE_INT_TYPE(IntTypeName, ValueType); +// +// where: +// IntTypeName: is the desired (unique) name for the "logical" integer type. +// ValueType: is one of the integral types as defined by base::is_integral +// (see base/type_traits.h). +// +// DISALLOWED OPERATIONS / TYPE-SAFETY ENFORCEMENT ----------------------------- +// +// Consider these definitions and variable declarations: +// DEFINE_INT_TYPE(GlobalDocID, int64); +// DEFINE_INT_TYPE(LocalDocID, int64); +// GlobalDocID global; +// LocalDocID local; +// +// The class IntType prevents: +// +// 1) Assignments of other IntTypes with different IntTypeNames. +// +// global = local; <-- Fails to compile! +// local = global; <-- Fails to compile! +// +// 2) Explicit/implicit conversion from an IntType to another IntType. +// +// LocalDocID l(global); <-- Fails to compile! +// LocalDocID l = global; <-- Fails to compile! +// +// void GetGlobalDoc(GlobalDocID global) { } +// GetGlobalDoc(global); <-- Compiles fine, types match! +// GetGlobalDoc(local); <-- Fails to compile! +// +// 3) Implicit conversion from an IntType to a native integer type. +// +// void GetGlobalDoc(int64 global) { ... +// GetGlobalDoc(global); <-- Fails to compile! +// GetGlobalDoc(local); <-- Fails to compile! +// +// void GetLocalDoc(int32 local) { ... +// GetLocalDoc(global); <-- Fails to compile! +// GetLocalDoc(local); <-- Fails to compile! +// +// +// SUPPORTED OPERATIONS -------------------------------------------------------- +// +// The following operators are supported: unary: ++ (both prefix and postfix), +// +, -, ! (logical not), ~ (one's complement); comparison: ==, !=, <, <=, >, +// >=; numerical: +, -, *, /; assignment: =, +=, -=, /=, *=; stream: <<. Each +// operator allows the same IntTypeName and the ValueType to be used on +// both left- and right-hand sides. +// +// It also supports an accessor value() returning the stored value as ValueType, +// and a templatized accessor value<T>() method that serves as syntactic sugar +// for static_cast<T>(var.value()). These accessors are useful when assigning +// the stored value into protocol buffer fields and using it as printf args. +// +// The class also defines a hash functor that allows the IntType to be used +// as key to hashable containers such as std::unordered_map and +// std::unordered_set. +// +// We suggest using the IntTypeIndexedContainer wrapper around FixedArray and +// STL vector (see int-type-indexed-container.h) if an IntType is intended to +// be used as an index into these containers. These wrappers are indexed in a +// type-safe manner using IntTypes to ensure type-safety. +// +// NB: this implementation does not attempt to abide by or enforce dimensional +// analysis on these scalar types. +// +// EXAMPLES -------------------------------------------------------------------- +// +// DEFINE_INT_TYPE(GlobalDocID, int64); +// GlobalDocID global = 3; +// cout << global; <-- Prints 3 to stdout. +// +// for (GlobalDocID i(0); i < global; ++i) { +// cout << i; +// } <-- Print(ln)s 0 1 2 to stdout +// +// DEFINE_INT_TYPE(LocalDocID, int64); +// LocalDocID local; +// cout << local; <-- Prints 0 to stdout it default +// initializes the value to 0. +// +// local = 5; +// local *= 2; +// LocalDocID l(local); +// cout << l + local; <-- Prints 20 to stdout. +// +// GenericSearchRequest request; +// request.set_doc_id(global.value()); <-- Uses value() to extract the value +// from the IntType class. +// +// REMARKS --------------------------------------------------------------------- +// +// The following bad usage is permissible although discouraged. Essentially, it +// involves using the value*() accessors to extract the native integer type out +// of the IntType class. Keep in mind that the primary reason for the IntType +// class is to prevent *accidental* mingling of similar logical integer types -- +// and not type casting from one type to another. +// +// DEFINE_INT_TYPE(GlobalDocID, int64); +// DEFINE_INT_TYPE(LocalDocID, int64); +// GlobalDocID global; +// LocalDocID local; +// +// global = local.value(); <-- Compiles fine. +// +// void GetGlobalDoc(GlobalDocID global) { ... +// GetGlobalDoc(local.value()); <-- Compiles fine. +// +// void GetGlobalDoc(int64 global) { ... +// GetGlobalDoc(local.value()); <-- Compiles fine. + +#ifndef TENSORFLOW_LIB_GTL_INT_TYPE_H_ +#define TENSORFLOW_LIB_GTL_INT_TYPE_H_ + +#include <stddef.h> +#include <functional> +#include <iosfwd> +#include <ostream> // NOLINT +#include <unordered_map> + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace gtl { + +template <typename IntTypeName, typename _ValueType> +class IntType; + +// Defines the IntType using value_type and typedefs it to int_type_name. +// The struct int_type_name ## _tag_ trickery is needed to ensure that a new +// type is created per int_type_name. +#define TF_LIB_GTL_DEFINE_INT_TYPE(int_type_name, value_type) \ + struct int_type_name##_tag_ {}; \ + typedef ::tensorflow::gtl::IntType<int_type_name##_tag_, value_type> \ + int_type_name; + +// Holds an integer value (of type ValueType) and behaves as a ValueType by +// exposing assignment, unary, comparison, and arithmetic operators. +// +// The template parameter IntTypeName defines the name for the int type and must +// be unique within a binary (the convenient DEFINE_INT_TYPE macro at the end of +// the file generates a unique IntTypeName). The parameter ValueType defines +// the integer type value (see supported list above). +// +// This class is NOT thread-safe. +template <typename IntTypeName, typename _ValueType> +class IntType { + public: + typedef _ValueType ValueType; // for non-member operators + typedef IntType<IntTypeName, ValueType> ThisType; // Syntactic sugar. + + // Note that this may change from time to time without notice. + struct Hasher { + size_t operator()(const IntType& arg) const { + return static_cast<size_t>(arg.value()); + } + }; + + public: + // Default c'tor initializing value_ to 0. + constexpr IntType() : value_(0) {} + // C'tor explicitly initializing from a ValueType. + constexpr explicit IntType(ValueType value) : value_(value) {} + + // IntType uses the default copy constructor, destructor and assign operator. + // The defaults are sufficient and omitting them allows the compiler to add + // the move constructor/assignment. + + // -- ACCESSORS -------------------------------------------------------------- + // The class provides a value() accessor returning the stored ValueType value_ + // as well as a templatized accessor that is just a syntactic sugar for + // static_cast<T>(var.value()); + constexpr ValueType value() const { return value_; } + + template <typename ValType> + constexpr ValType value() const { + return static_cast<ValType>(value_); + } + + // -- UNARY OPERATORS -------------------------------------------------------- + ThisType& operator++() { // prefix ++ + ++value_; + return *this; + } + const ThisType operator++(int v) { // postfix ++ + ThisType temp(*this); + ++value_; + return temp; + } + ThisType& operator--() { // prefix -- + --value_; + return *this; + } + const ThisType operator--(int v) { // postfix -- + ThisType temp(*this); + --value_; + return temp; + } + + constexpr bool operator!() const { return value_ == 0; } + constexpr const ThisType operator+() const { return ThisType(value_); } + constexpr const ThisType operator-() const { return ThisType(-value_); } + constexpr const ThisType operator~() const { return ThisType(~value_); } + +// -- ASSIGNMENT OPERATORS --------------------------------------------------- +// We support the following assignment operators: =, +=, -=, *=, /=, <<=, >>= +// and %= for both ThisType and ValueType. +#define INT_TYPE_ASSIGNMENT_OP(op) \ + ThisType& operator op(const ThisType& arg_value) { \ + value_ op arg_value.value(); \ + return *this; \ + } \ + ThisType& operator op(ValueType arg_value) { \ + value_ op arg_value; \ + return *this; \ + } + INT_TYPE_ASSIGNMENT_OP(+= ); + INT_TYPE_ASSIGNMENT_OP(-= ); + INT_TYPE_ASSIGNMENT_OP(*= ); + INT_TYPE_ASSIGNMENT_OP(/= ); + INT_TYPE_ASSIGNMENT_OP(<<= ); // NOLINT + INT_TYPE_ASSIGNMENT_OP(>>= ); // NOLINT + INT_TYPE_ASSIGNMENT_OP(%= ); +#undef INT_TYPE_ASSIGNMENT_OP + + ThisType& operator=(ValueType arg_value) { + value_ = arg_value; + return *this; + } + + private: + // The integer value of type ValueType. + ValueType value_; + + static_assert(std::is_integral<ValueType>::value, "invalid integer type"); +} TF_PACKED; + +// -- NON-MEMBER STREAM OPERATORS ---------------------------------------------- +// We provide the << operator, primarily for logging purposes. Currently, there +// seems to be no need for an >> operator. +template <typename IntTypeName, typename ValueType> +std::ostream& operator<<(std::ostream& os, // NOLINT + IntType<IntTypeName, ValueType> arg) { + return os << arg.value(); +} + +// -- NON-MEMBER ARITHMETIC OPERATORS ------------------------------------------ +// We support only the +, -, *, and / operators with the same IntType and +// ValueType types. The reason is to allow simple manipulation on these IDs +// when used as indices in vectors and arrays. +// +// NB: Although it is possible to do IntType * IntType and IntType / IntType, +// it is probably non-sensical from a dimensionality analysis perspective. +#define INT_TYPE_ARITHMETIC_OP(op) \ + template <typename IntTypeName, typename ValueType> \ + static inline constexpr IntType<IntTypeName, ValueType> operator op( \ + IntType<IntTypeName, ValueType> id_1, \ + IntType<IntTypeName, ValueType> id_2) { \ + return IntType<IntTypeName, ValueType>(id_1.value() op id_2.value()); \ + } \ + template <typename IntTypeName, typename ValueType> \ + static inline constexpr IntType<IntTypeName, ValueType> operator op( \ + IntType<IntTypeName, ValueType> id, \ + typename IntType<IntTypeName, ValueType>::ValueType arg_val) { \ + return IntType<IntTypeName, ValueType>(id.value() op arg_val); \ + } \ + template <typename IntTypeName, typename ValueType> \ + static inline constexpr IntType<IntTypeName, ValueType> operator op( \ + typename IntType<IntTypeName, ValueType>::ValueType arg_val, \ + IntType<IntTypeName, ValueType> id) { \ + return IntType<IntTypeName, ValueType>(arg_val op id.value()); \ + } +INT_TYPE_ARITHMETIC_OP(+); +INT_TYPE_ARITHMETIC_OP(-); +INT_TYPE_ARITHMETIC_OP(*); +INT_TYPE_ARITHMETIC_OP(/ ); +INT_TYPE_ARITHMETIC_OP(<< ); // NOLINT +INT_TYPE_ARITHMETIC_OP(>> ); // NOLINT +INT_TYPE_ARITHMETIC_OP(% ); +#undef INT_TYPE_ARITHMETIC_OP + +// -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------ +// Static inline comparison operators. We allow all comparison operators among +// the following types (OP \in [==, !=, <, <=, >, >=]: +// IntType<IntTypeName, ValueType> OP IntType<IntTypeName, ValueType> +// IntType<IntTypeName, ValueType> OP ValueType +// ValueType OP IntType<IntTypeName, ValueType> +#define INT_TYPE_COMPARISON_OP(op) \ + template <typename IntTypeName, typename ValueType> \ + static inline constexpr bool operator op( \ + IntType<IntTypeName, ValueType> id_1, \ + IntType<IntTypeName, ValueType> id_2) { \ + return id_1.value() op id_2.value(); \ + } \ + template <typename IntTypeName, typename ValueType> \ + static inline constexpr bool operator op( \ + IntType<IntTypeName, ValueType> id, \ + typename IntType<IntTypeName, ValueType>::ValueType val) { \ + return id.value() op val; \ + } \ + template <typename IntTypeName, typename ValueType> \ + static inline constexpr bool operator op( \ + typename IntType<IntTypeName, ValueType>::ValueType val, \ + IntType<IntTypeName, ValueType> id) { \ + return val op id.value(); \ + } +INT_TYPE_COMPARISON_OP(== ); // NOLINT +INT_TYPE_COMPARISON_OP(!= ); // NOLINT +INT_TYPE_COMPARISON_OP(< ); // NOLINT +INT_TYPE_COMPARISON_OP(<= ); // NOLINT +INT_TYPE_COMPARISON_OP(> ); // NOLINT +INT_TYPE_COMPARISON_OP(>= ); // NOLINT +#undef INT_TYPE_COMPARISON_OP + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_INT_TYPE_H_ diff --git a/tensorflow/core/lib/gtl/int_type_test.cc b/tensorflow/core/lib/gtl/int_type_test.cc new file mode 100644 index 0000000000..694886d345 --- /dev/null +++ b/tensorflow/core/lib/gtl/int_type_test.cc @@ -0,0 +1,282 @@ +// Unit test cases for IntType. + +#include <memory> +#include <unordered_map> + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/gtl/int_type.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +TF_LIB_GTL_DEFINE_INT_TYPE(Int8_IT, int8); +TF_LIB_GTL_DEFINE_INT_TYPE(UInt8_IT, uint8); +TF_LIB_GTL_DEFINE_INT_TYPE(Int16_IT, int16); +TF_LIB_GTL_DEFINE_INT_TYPE(UInt16_IT, uint16); +TF_LIB_GTL_DEFINE_INT_TYPE(Int32_IT, int32); +TF_LIB_GTL_DEFINE_INT_TYPE(Int64_IT, int64); +TF_LIB_GTL_DEFINE_INT_TYPE(UInt32_IT, uint32); +TF_LIB_GTL_DEFINE_INT_TYPE(UInt64_IT, uint64); +TF_LIB_GTL_DEFINE_INT_TYPE(Long_IT, long); // NOLINT + +template <typename IntType_Type> +class IntTypeTest : public ::testing::Test { + public: + typedef IntType_Type T; +}; + +// All tests below will be executed on all supported IntTypes. +typedef ::testing::Types<Int8_IT, UInt8_IT, Int16_IT, UInt16_IT, Int32_IT, + Int64_IT, UInt64_IT, Long_IT> SupportedIntTypes; + +TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes); + +TYPED_TEST(IntTypeTest, TestInitialization) { + constexpr typename TestFixture::T a; + constexpr typename TestFixture::T b(1); + constexpr typename TestFixture::T c(b); + EXPECT_EQ(0, a); // default initialization to 0 + EXPECT_EQ(1, b); + EXPECT_EQ(1, c); +} + +TYPED_TEST(IntTypeTest, TestOperators) { + typename TestFixture::T a(0); + typename TestFixture::T b(1); + typename TestFixture::T c(2); + constexpr typename TestFixture::T d(3); + constexpr typename TestFixture::T e(4); + + // On all EXPECT_EQ below, we use the accessor value() as to not invoke the + // comparison operators which must themselves be tested. + + // -- UNARY OPERATORS -------------------------------------------------------- + EXPECT_EQ(0, (a++).value()); + EXPECT_EQ(2, (++a).value()); + EXPECT_EQ(2, (a--).value()); + EXPECT_EQ(0, (--a).value()); + + EXPECT_EQ(true, !a); + EXPECT_EQ(false, !b); + static_assert(!d == false, "Unary operator! failed"); + + EXPECT_EQ(a.value(), +a); + static_assert(+d == d.value(), "Unary operator+ failed"); + EXPECT_EQ(-a.value(), -a); + static_assert(-d == -d.value(), "Unary operator- failed"); + EXPECT_EQ(~a.value(), ~a); // ~zero + EXPECT_EQ(~b.value(), ~b); // ~non-zero + static_assert(~d == ~d.value(), "Unary operator~ failed"); + + // -- ASSIGNMENT OPERATORS --------------------------------------------------- + // We test all assignment operators using IntType and constant as arguments. + // We also test the return from the operators. + // From same IntType + c = a = b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + // From constant + c = b = 2; + EXPECT_EQ(2, b.value()); + EXPECT_EQ(2, c.value()); + // From same IntType + c = a += b; + EXPECT_EQ(3, a.value()); + EXPECT_EQ(3, c.value()); + c = a -= b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a *= b; + EXPECT_EQ(2, a.value()); + EXPECT_EQ(2, c.value()); + c = a /= b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a <<= b; + EXPECT_EQ(4, a.value()); + EXPECT_EQ(4, c.value()); + c = a >>= b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a %= b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + // From constant + c = a += 2; + EXPECT_EQ(3, a.value()); + EXPECT_EQ(3, c.value()); + c = a -= 2; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a *= 2; + EXPECT_EQ(2, a.value()); + EXPECT_EQ(2, c.value()); + c = a /= 2; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a <<= 2; + EXPECT_EQ(4, a.value()); + EXPECT_EQ(4, c.value()); + c = a >>= 2; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a %= 2; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + + // -- COMPARISON OPERATORS --------------------------------------------------- + a = 0; + b = 1; + + EXPECT_FALSE(a == b); + EXPECT_TRUE(a == 0); // NOLINT + EXPECT_FALSE(1 == a); // NOLINT + static_assert(d == d, "operator== failed"); + static_assert(d == 3, "operator== failed"); + static_assert(3 == d, "operator== failed"); + EXPECT_TRUE(a != b); + EXPECT_TRUE(a != 1); // NOLINT + EXPECT_FALSE(0 != a); // NOLINT + static_assert(d != e, "operator!= failed"); + static_assert(d != 4, "operator!= failed"); + static_assert(4 != d, "operator!= failed"); + EXPECT_TRUE(a < b); + EXPECT_TRUE(a < 1); // NOLINT + EXPECT_FALSE(0 < a); // NOLINT + static_assert(d < e, "operator< failed"); + static_assert(d < 4, "operator< failed"); + static_assert(3 < e, "operator< failed"); + EXPECT_TRUE(a <= b); + EXPECT_TRUE(a <= 1); // NOLINT + EXPECT_TRUE(0 <= a); // NOLINT + static_assert(d <= e, "operator<= failed"); + static_assert(d <= 4, "operator<= failed"); + static_assert(3 <= e, "operator<= failed"); + EXPECT_FALSE(a > b); + EXPECT_FALSE(a > 1); // NOLINT + EXPECT_FALSE(0 > a); // NOLINT + static_assert(e > d, "operator> failed"); + static_assert(e > 3, "operator> failed"); + static_assert(4 > d, "operator> failed"); + EXPECT_FALSE(a >= b); + EXPECT_FALSE(a >= 1); // NOLINT + EXPECT_TRUE(0 >= a); // NOLINT + static_assert(e >= d, "operator>= failed"); + static_assert(e >= 3, "operator>= failed"); + static_assert(4 >= d, "operator>= failed"); + + // -- BINARY OPERATORS ------------------------------------------------------- + a = 1; + b = 3; + EXPECT_EQ(4, (a + b).value()); + EXPECT_EQ(4, (a + 3).value()); + EXPECT_EQ(4, (1 + b).value()); + static_assert((d + e).value() == 7, "Binary operator+ failed"); + static_assert((d + 4).value() == 7, "Binary operator+ failed"); + static_assert((3 + e).value() == 7, "Binary operator+ failed"); + EXPECT_EQ(2, (b - a).value()); + EXPECT_EQ(2, (b - 1).value()); + EXPECT_EQ(2, (3 - a).value()); + static_assert((e - d).value() == 1, "Binary operator- failed"); + static_assert((e - 3).value() == 1, "Binary operator- failed"); + static_assert((4 - d).value() == 1, "Binary operator- failed"); + EXPECT_EQ(3, (a * b).value()); + EXPECT_EQ(3, (a * 3).value()); + EXPECT_EQ(3, (1 * b).value()); + static_assert((d * e).value() == 12, "Binary operator* failed"); + static_assert((d * 4).value() == 12, "Binary operator* failed"); + static_assert((3 * e).value() == 12, "Binary operator* failed"); + EXPECT_EQ(0, (a / b).value()); + EXPECT_EQ(0, (a / 3).value()); + EXPECT_EQ(0, (1 / b).value()); + static_assert((d / e).value() == 0, "Binary operator/ failed"); + static_assert((d / 4).value() == 0, "Binary operator/ failed"); + static_assert((3 / e).value() == 0, "Binary operator/ failed"); + EXPECT_EQ(8, (a << b).value()); + EXPECT_EQ(8, (a << 3).value()); + EXPECT_EQ(8, (1 << b).value()); + static_assert((d << e).value() == 48, "Binary operator<< failed"); + static_assert((d << 4).value() == 48, "Binary operator<< failed"); + static_assert((3 << e).value() == 48, "Binary operator<< failed"); + b = 8; + EXPECT_EQ(4, (b >> a).value()); + EXPECT_EQ(4, (b >> 1).value()); + EXPECT_EQ(4, (8 >> a).value()); + static_assert((d >> e).value() == 0, "Binary operator>> failed"); + static_assert((d >> 4).value() == 0, "Binary operator>> failed"); + static_assert((3 >> e).value() == 0, "Binary operator>> failed"); + b = 3; + a = 2; + EXPECT_EQ(1, (b % a).value()); + EXPECT_EQ(1, (b % 2).value()); + EXPECT_EQ(1, (3 % a).value()); + static_assert((e % d).value() == 1, "Binary operator% failed"); + static_assert((e % 3).value() == 1, "Binary operator% failed"); + static_assert((4 % d).value() == 1, "Binary operator% failed"); +} + +TYPED_TEST(IntTypeTest, TestHashFunctor) { + std::unordered_map<typename TestFixture::T, char, + typename TestFixture::T::Hasher> map; + typename TestFixture::T a(0); + map[a] = 'c'; + EXPECT_EQ('c', map[a]); + map[++a] = 'o'; + EXPECT_EQ('o', map[a]); + + typename TestFixture::T b(a); + EXPECT_EQ(typename TestFixture::T::Hasher()(a), + typename TestFixture::T::Hasher()(b)); +} + +// Tests the use of the templatized value accessor that performs static_casts. +// We use -1 to force casting in unsigned integers. +TYPED_TEST(IntTypeTest, TestValueAccessor) { + constexpr typename TestFixture::T::ValueType i = -1; + constexpr typename TestFixture::T int_type(i); + EXPECT_EQ(i, int_type.value()); + static_assert(int_type.value() == i, "value() failed"); + // The use of the keyword 'template' (suggested by Clang) is only necessary + // as this code is part of a template class. Weird syntax though. Good news + // is that only int_type.value<int>() is needed in most code. + EXPECT_EQ(static_cast<int>(i), int_type.template value<int>()); + EXPECT_EQ(static_cast<int8>(i), int_type.template value<int8>()); + EXPECT_EQ(static_cast<int16>(i), int_type.template value<int16>()); + EXPECT_EQ(static_cast<int32>(i), int_type.template value<int32>()); + EXPECT_EQ(static_cast<uint32>(i), int_type.template value<uint32>()); + EXPECT_EQ(static_cast<int64>(i), int_type.template value<int64>()); + EXPECT_EQ(static_cast<uint64>(i), int_type.template value<uint64>()); + EXPECT_EQ(static_cast<long>(i), int_type.template value<long>()); // NOLINT + static_assert(int_type.template value<int>() == static_cast<int>(i), + "value<Value>() failed"); +} + +TYPED_TEST(IntTypeTest, TestMove) { + // Check that the int types have move constructor/assignment. + // We do this by composing a struct with an int type and a unique_ptr. This + // struct can't be copied due to the unique_ptr, so it must be moved. + // If this compiles, it means that the int types have move operators. + struct NotCopyable { + typename TestFixture::T inttype; + std::unique_ptr<int> ptr; + + static NotCopyable Make(int i) { + NotCopyable f; + f.inttype = typename TestFixture::T(i); + f.ptr.reset(new int(i)); + return f; + } + }; + + // Test move constructor. + NotCopyable foo = NotCopyable::Make(123); + EXPECT_EQ(123, foo.inttype); + EXPECT_EQ(123, *foo.ptr); + + // Test move assignment. + foo = NotCopyable::Make(321); + EXPECT_EQ(321, foo.inttype); + EXPECT_EQ(321, *foo.ptr); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/iterator_range.h b/tensorflow/core/lib/gtl/iterator_range.h new file mode 100644 index 0000000000..baec85c40a --- /dev/null +++ b/tensorflow/core/lib/gtl/iterator_range.h @@ -0,0 +1,49 @@ +// This provides a very simple, boring adaptor for a begin and end iterator +// into a range type. This should be used to build range views that work well +// with range based for loops and range based constructors. +// +// Note that code here follows more standards-based coding conventions as it +// is mirroring proposed interfaces for standardization. +// +// Converted from chandlerc@'s code to Google style by joshl@. + +#ifndef TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_ +#define TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_ + +#include <utility> + +namespace tensorflow { +namespace gtl { + +// A range adaptor for a pair of iterators. +// +// This just wraps two iterators into a range-compatible interface. Nothing +// fancy at all. +template <typename IteratorT> +class iterator_range { + public: + iterator_range() : begin_iterator_(), end_iterator_() {} + iterator_range(IteratorT begin_iterator, IteratorT end_iterator) + : begin_iterator_(std::move(begin_iterator)), + end_iterator_(std::move(end_iterator)) {} + + IteratorT begin() const { return begin_iterator_; } + IteratorT end() const { return end_iterator_; } + + private: + IteratorT begin_iterator_, end_iterator_; +}; + +// Convenience function for iterating over sub-ranges. +// +// This provides a bit of syntactic sugar to make using sub-ranges +// in for loops a bit easier. Analogous to std::make_pair(). +template <class T> +iterator_range<T> make_range(T x, T y) { + return iterator_range<T>(std::move(x), std::move(y)); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_ diff --git a/tensorflow/core/lib/gtl/iterator_range_test.cc b/tensorflow/core/lib/gtl/iterator_range_test.cc new file mode 100644 index 0000000000..328be4ecbc --- /dev/null +++ b/tensorflow/core/lib/gtl/iterator_range_test.cc @@ -0,0 +1,60 @@ +#include "tensorflow/core/lib/gtl/iterator_range.h" + +#include <vector> +#include "tensorflow/core/platform/port.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace gtl { +namespace { + +TEST(IteratorRange, WholeVector) { + std::vector<int> v = {2, 3, 5, 7, 11, 13}; + iterator_range<std::vector<int>::iterator> range(v.begin(), v.end()); + int index = 0; + for (int prime : range) { + ASSERT_LT(index, v.size()); + EXPECT_EQ(v[index], prime); + ++index; + } + EXPECT_EQ(v.size(), index); +} + +TEST(IteratorRange, VectorMakeRange) { + std::vector<int> v = {2, 3, 5, 7, 11, 13}; + auto range = make_range(v.begin(), v.end()); + int index = 0; + for (int prime : range) { + ASSERT_LT(index, v.size()); + EXPECT_EQ(v[index], prime); + ++index; + } + EXPECT_EQ(v.size(), index); +} + +TEST(IteratorRange, PartArray) { + int v[] = {2, 3, 5, 7, 11, 13}; + iterator_range<int*> range(&v[1], &v[4]); // 3, 5, 7 + int index = 1; + for (int prime : range) { + ASSERT_LT(index, TF_ARRAYSIZE(v)); + EXPECT_EQ(v[index], prime); + ++index; + } + EXPECT_EQ(4, index); +} + +TEST(IteratorRange, ArrayMakeRange) { + int v[] = {2, 3, 5, 7, 11, 13}; + auto range = make_range(&v[1], &v[4]); // 3, 5, 7 + int index = 1; + for (int prime : range) { + ASSERT_LT(index, TF_ARRAYSIZE(v)); + EXPECT_EQ(v[index], prime); + ++index; + } + EXPECT_EQ(4, index); +} +} // namespace +} // namespace gtl +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/manual_constructor.h b/tensorflow/core/lib/gtl/manual_constructor.h new file mode 100644 index 0000000000..39f029ed4a --- /dev/null +++ b/tensorflow/core/lib/gtl/manual_constructor.h @@ -0,0 +1,230 @@ +// ManualConstructor statically-allocates space in which to store some +// object, but does not initialize it. You can then call the constructor +// and destructor for the object yourself as you see fit. This is useful +// for memory management optimizations, where you want to initialize and +// destroy an object multiple times but only allocate it once. +// +// (When I say ManualConstructor statically allocates space, I mean that +// the ManualConstructor object itself is forced to be the right size.) + +#ifndef TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_ +#define TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_ + +#include <stddef.h> +#include <new> +#include <utility> + +#include "tensorflow/core/platform/port.h" // For aligned_malloc/aligned_free + +namespace tensorflow { +namespace gtl { +namespace internal { + +// +// Provides a char array with the exact same alignment as another type. The +// first parameter must be a complete type, the second parameter is how many +// of that type to provide space for. +// +// TF_LIB_GTL_ALIGNED_CHAR_ARRAY(struct stat, 16) storage_; +// +// Because MSVC and older GCCs require that the argument to their alignment +// construct to be a literal constant integer, we use a template instantiated +// at all the possible powers of two. +#ifndef SWIG +template <int alignment, int size> +struct AlignType {}; +template <int size> +struct AlignType<0, size> { + typedef char result[size]; +}; +#if defined(COMPILER_MSVC) +#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __declspec(align(X)) +#define TF_LIB_GTL_ALIGN_OF(T) __alignof(T) +#elif defined(COMPILER_GCC3) || __GNUC__ >= 3 || defined(__APPLE__) || \ + defined(COMPILER_ICC) || defined(OS_NACL) || defined(__clang__) +#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __attribute__((aligned(X))) +#define TF_LIB_GTL_ALIGN_OF(T) __alignof__(T) +#endif + +#if defined(TF_LIB_GTL_ALIGN_ATTRIBUTE) + +#define TF_LIB_GTL_ALIGNTYPE_TEMPLATE(X) \ + template <int size> \ + struct AlignType<X, size> { \ + typedef TF_LIB_GTL_ALIGN_ATTRIBUTE(X) char result[size]; \ + } + +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(16); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(32); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(64); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(128); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(256); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(512); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1024); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2048); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4096); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8192); +// Any larger and MSVC++ will complain. + +#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \ + typename tensorflow::gtl::internal::AlignType<TF_LIB_GTL_ALIGN_OF(T), \ + sizeof(T) * Size>::result + +#undef TF_LIB_GTL_ALIGNTYPE_TEMPLATE +#undef TF_LIB_GTL_ALIGN_ATTRIBUTE + +#else // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE) +#error "You must define TF_LIB_GTL_ALIGNED_CHAR_ARRAY for your compiler." +#endif // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE) + +#else // !SWIG + +// SWIG can't represent alignment and doesn't care about alignment on data +// members (it works fine without it). +template <typename Size> +struct AlignType { + typedef char result[Size]; +}; +#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \ + tensorflow::gtl::internal::AlignType<Size * sizeof(T)>::result + +// Enough to parse with SWIG, will never be used by running code. +#define TF_LIB_GTL_ALIGN_OF(Type) 16 + +#endif // !SWIG + +} // namespace internal +} // namespace gtl + +template <typename Type> +class ManualConstructor { + public: + // No constructor or destructor because one of the most useful uses of + // this class is as part of a union, and members of a union cannot have + // constructors or destructors. And, anyway, the whole point of this + // class is to bypass these. + + // Support users creating arrays of ManualConstructor<>s. This ensures that + // the array itself has the correct alignment. + static void* operator new[](size_t size) { + return port::aligned_malloc(size, TF_LIB_GTL_ALIGN_OF(Type)); + } + static void operator delete[](void* mem) { port::aligned_free(mem); } + + inline Type* get() { return reinterpret_cast<Type*>(space_); } + inline const Type* get() const { + return reinterpret_cast<const Type*>(space_); + } + + inline Type* operator->() { return get(); } + inline const Type* operator->() const { return get(); } + + inline Type& operator*() { return *get(); } + inline const Type& operator*() const { return *get(); } + + inline void Init() { new (space_) Type; } + +// Init() constructs the Type instance using the given arguments +// (which are forwarded to Type's constructor). In C++11, Init() can +// take any number of arguments of any type, and forwards them perfectly. +// On pre-C++11 platforms, it can take up to 11 arguments, and may not be +// able to forward certain kinds of arguments. +// +// Note that Init() with no arguments performs default-initialization, +// not zero-initialization (i.e it behaves the same as "new Type;", not +// "new Type();"), so it will leave non-class types uninitialized. +#ifdef LANG_CXX11 + template <typename... Ts> + inline void Init(Ts&&... args) { // NOLINT + new (space_) Type(std::forward<Ts>(args)...); // NOLINT + } +#else // !defined(LANG_CXX11) + template <typename T1> + inline void Init(const T1& p1) { + new (space_) Type(p1); + } + + template <typename T1, typename T2> + inline void Init(const T1& p1, const T2& p2) { + new (space_) Type(p1, p2); + } + + template <typename T1, typename T2, typename T3> + inline void Init(const T1& p1, const T2& p2, const T3& p3) { + new (space_) Type(p1, p2, p3); + } + + template <typename T1, typename T2, typename T3, typename T4> + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4) { + new (space_) Type(p1, p2, p3, p4); + } + + template <typename T1, typename T2, typename T3, typename T4, typename T5> + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5) { + new (space_) Type(p1, p2, p3, p4, p5); + } + + template <typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6> + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6) { + new (space_) Type(p1, p2, p3, p4, p5, p6); + } + + template <typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6, typename T7> + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7); + } + + template <typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6, typename T7, typename T8> + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8); + } + + template <typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6, typename T7, typename T8, typename T9> + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8, + const T9& p9) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9); + } + + template <typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6, typename T7, typename T8, typename T9, typename T10> + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8, + const T9& p9, const T10& p10) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10); + } + + template <typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6, typename T7, typename T8, typename T9, typename T10, + typename T11> + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8, + const T9& p9, const T10& p10, const T11& p11) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11); + } +#endif // LANG_CXX11 + + inline void Destroy() { get()->~Type(); } + + private: + TF_LIB_GTL_ALIGNED_CHAR_ARRAY(Type, 1) space_; +}; + +#undef TF_LIB_GTL_ALIGNED_CHAR_ARRAY +#undef TF_LIB_GTL_ALIGN_OF + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_ diff --git a/tensorflow/core/lib/gtl/manual_constructor_test.cc b/tensorflow/core/lib/gtl/manual_constructor_test.cc new file mode 100644 index 0000000000..a929591be2 --- /dev/null +++ b/tensorflow/core/lib/gtl/manual_constructor_test.cc @@ -0,0 +1,113 @@ +#include "tensorflow/core/lib/gtl/manual_constructor.h" + +#include <stdint.h> + +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +static int constructor_count_ = 0; + +template <int kSize> +struct TestN { + TestN() { ++constructor_count_; } + ~TestN() { --constructor_count_; } + char a[kSize]; +}; + +typedef TestN<1> Test1; +typedef TestN<2> Test2; +typedef TestN<3> Test3; +typedef TestN<4> Test4; +typedef TestN<5> Test5; +typedef TestN<9> Test9; +typedef TestN<15> Test15; + +} // namespace + +namespace { + +TEST(ManualConstructorTest, Sizeof) { + CHECK_EQ(sizeof(ManualConstructor<Test1>), sizeof(Test1)); + CHECK_EQ(sizeof(ManualConstructor<Test2>), sizeof(Test2)); + CHECK_EQ(sizeof(ManualConstructor<Test3>), sizeof(Test3)); + CHECK_EQ(sizeof(ManualConstructor<Test4>), sizeof(Test4)); + CHECK_EQ(sizeof(ManualConstructor<Test5>), sizeof(Test5)); + CHECK_EQ(sizeof(ManualConstructor<Test9>), sizeof(Test9)); + CHECK_EQ(sizeof(ManualConstructor<Test15>), sizeof(Test15)); + + CHECK_EQ(constructor_count_, 0); + ManualConstructor<Test1> mt[4]; + CHECK_EQ(sizeof(mt), 4); + CHECK_EQ(constructor_count_, 0); + mt[0].Init(); + CHECK_EQ(constructor_count_, 1); + mt[0].Destroy(); +} + +TEST(ManualConstructorTest, Alignment) { + // We want to make sure that ManualConstructor aligns its memory properly + // on a word barrier. Otherwise, it might be unexpectedly slow, since + // memory access will be unaligned. + + struct { + char a; + ManualConstructor<void*> b; + } test1; + struct { + char a; + void* b; + } control1; + + // TODO(bww): Make these tests more direct with C++11 alignment_of<T>::value. + EXPECT_EQ(reinterpret_cast<char*>(test1.b.get()) - &test1.a, + reinterpret_cast<char*>(&control1.b) - &control1.a); + EXPECT_EQ(reinterpret_cast<intptr_t>(test1.b.get()) % sizeof(control1.b), 0); + + struct { + char a; + ManualConstructor<long double> b; + } test2; + struct { + char a; + long double b; + } control2; + + EXPECT_EQ(reinterpret_cast<char*>(test2.b.get()) - &test2.a, + reinterpret_cast<char*>(&control2.b) - &control2.a); +#ifdef ARCH_K8 + EXPECT_EQ(reinterpret_cast<intptr_t>(test2.b.get()) % 16, 0); +#endif +#ifdef ARCH_PIII + EXPECT_EQ(reinterpret_cast<intptr_t>(test2.b.get()) % 4, 0); +#endif +} + +TEST(ManualConstructorTest, DefaultInitialize) { + struct X { + X() : x(123) {} + int x; + }; + union { + ManualConstructor<X> x; + ManualConstructor<int> y; + } u; + *u.y = -1; + u.x.Init(); // should default-initialize u.x + EXPECT_EQ(123, u.x->x); +} + +TEST(ManualConstructorTest, ZeroInitializePOD) { + union { + ManualConstructor<int> x; + ManualConstructor<int> y; + } u; + *u.y = -1; + u.x.Init(); // should not zero-initialize u.x + EXPECT_EQ(-1, *u.y); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/map_util.h b/tensorflow/core/lib/gtl/map_util.h new file mode 100644 index 0000000000..c953de57c7 --- /dev/null +++ b/tensorflow/core/lib/gtl/map_util.h @@ -0,0 +1,123 @@ +// This file provides utility functions for use with STL map-like data +// structures, such as std::map and hash_map. Some functions will also work with +// sets, such as ContainsKey(). + +#ifndef TENSORFLOW_LIB_GTL_MAP_UTIL_H_ +#define TENSORFLOW_LIB_GTL_MAP_UTIL_H_ + +#include <stddef.h> +#include <iterator> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +namespace tensorflow { +namespace gtl { + +// Returns a pointer to the const value associated with the given key if it +// exists, or NULL otherwise. +template <class Collection> +const typename Collection::value_type::second_type* FindOrNull( + const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Same as above but returns a pointer to the non-const value. +template <class Collection> +typename Collection::value_type::second_type* FindOrNull( + Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Returns the pointer value associated with the given key. If none is found, +// NULL is returned. The function is designed to be used with a map of keys to +// pointers. +// +// This function does not distinguish between a missing key and a key mapped +// to a NULL value. +template <class Collection> +typename Collection::value_type::second_type FindPtrOrNull( + const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return typename Collection::value_type::second_type(); + } + return it->second; +} + +// Returns a const reference to the value associated with the given key if it +// exists, otherwise returns a const reference to the provided default value. +// +// WARNING: If a temporary object is passed as the default "value," +// this function will return a reference to that temporary object, +// which will be destroyed at the end of the statement. A common +// example: if you have a map with string values, and you pass a char* +// as the default "value," either use the returned value immediately +// or store it in a string (not string&). +template <class Collection> +const typename Collection::value_type::second_type& FindWithDefault( + const Collection& collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return value; + } + return it->second; +} + +// Inserts the given key and value into the given collection if and only if the +// given key did NOT already exist in the collection. If the key previously +// existed in the collection, the value is not changed. Returns true if the +// key-value pair was inserted; returns false if the key was already present. +template <class Collection> +bool InsertIfNotPresent(Collection* const collection, + const typename Collection::value_type& vt) { + return collection->insert(vt).second; +} + +// Same as above except the key and value are passed separately. +template <class Collection> +bool InsertIfNotPresent( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return InsertIfNotPresent(collection, + typename Collection::value_type(key, value)); +} + +// Looks up a given key and value pair in a collection and inserts the key-value +// pair if it's not already present. Returns a reference to the value associated +// with the key. +template <class Collection> +typename Collection::value_type::second_type& LookupOrInsert( + Collection* const collection, const typename Collection::value_type& vt) { + return collection->insert(vt).first->second; +} + +// Same as above except the key-value are passed separately. +template <class Collection> +typename Collection::value_type::second_type& LookupOrInsert( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return LookupOrInsert(collection, + typename Collection::value_type(key, value)); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_MAP_UTIL_H_ diff --git a/tensorflow/core/lib/gtl/map_util_test.cc b/tensorflow/core/lib/gtl/map_util_test.cc new file mode 100644 index 0000000000..356f987337 --- /dev/null +++ b/tensorflow/core/lib/gtl/map_util_test.cc @@ -0,0 +1,47 @@ +#include "tensorflow/core/lib/gtl/map_util.h" + +#include <map> +#include <set> +#include <string> +#include "tensorflow/core/platform/port.h" + +#include <gtest/gtest.h> + +namespace tensorflow { + +TEST(MapUtil, Find) { + typedef std::map<string, string> Map; + Map m; + + // Check that I can use a type that's implicitly convertible to the + // key or value type, such as const char* -> string. + EXPECT_EQ("", gtl::FindWithDefault(m, "foo", "")); + m["foo"] = "bar"; + EXPECT_EQ("bar", gtl::FindWithDefault(m, "foo", "")); + EXPECT_EQ("bar", *gtl::FindOrNull(m, "foo")); + string str; + EXPECT_TRUE(m.count("foo") > 0); + EXPECT_EQ(m["foo"], "bar"); +} + +TEST(MapUtil, LookupOrInsert) { + typedef std::map<string, string> Map; + Map m; + + // Check that I can use a type that's implicitly convertible to the + // key or value type, such as const char* -> string. + EXPECT_EQ("xyz", gtl::LookupOrInsert(&m, "foo", "xyz")); + EXPECT_EQ("xyz", gtl::LookupOrInsert(&m, "foo", "abc")); +} + +TEST(MapUtil, InsertIfNotPresent) { + // Set operations + typedef std::set<int> Set; + Set s; + EXPECT_TRUE(gtl::InsertIfNotPresent(&s, 0)); + EXPECT_EQ(s.count(0), 1); + EXPECT_FALSE(gtl::InsertIfNotPresent(&s, 0)); + EXPECT_EQ(s.count(0), 1); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/stl_util.h b/tensorflow/core/lib/gtl/stl_util.h new file mode 100644 index 0000000000..83abcd6b55 --- /dev/null +++ b/tensorflow/core/lib/gtl/stl_util.h @@ -0,0 +1,130 @@ +// This file provides utility functions for use with STL + +#ifndef TENSORFLOW_LIB_GTL_STL_UTIL_H_ +#define TENSORFLOW_LIB_GTL_STL_UTIL_H_ + +#include <stddef.h> +#include <algorithm> +#include <iterator> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +namespace tensorflow { +namespace gtl { + +// Returns a mutable char* pointing to a string's internal buffer, which may not +// be null-terminated. Returns NULL for an empty string. If not non-null, +// writing through this pointer will modify the string. +// +// string_as_array(&str)[i] is valid for 0 <= i < str.size() until the +// next call to a string method that invalidates iterators. +// +// In C++11 you may simply use &str[0] to get a mutable char*. +// +// Prior to C++11, there was no standard-blessed way of getting a mutable +// reference to a string's internal buffer. The requirement that string be +// contiguous is officially part of the C++11 standard [string.require]/5. +// According to Matt Austern, this should already work on all current C++98 +// implementations. +inline char* string_as_array(string* str) { + return str->empty() ? NULL : &*str->begin(); +} + +// Returns the T* array for the given vector, or NULL if the vector was empty. +// +// Note: If you know the array will never be empty, you can use &*v.begin() +// directly, but that is may dump core if v is empty. This function is the most +// efficient code that will work, taking into account how our STL is actually +// implemented. THIS IS NON-PORTABLE CODE, so use this function instead of +// repeating the nonportable code everywhere. If our STL implementation changes, +// we will need to change this as well. +template <typename T, typename Allocator> +inline T* vector_as_array(std::vector<T, Allocator>* v) { +#if defined NDEBUG && !defined _GLIBCXX_DEBUG + return &*v->begin(); +#else + return v->empty() ? NULL : &*v->begin(); +#endif +} +// vector_as_array overload for const std::vector<>. +template <typename T, typename Allocator> +inline const T* vector_as_array(const std::vector<T, Allocator>* v) { +#if defined NDEBUG && !defined _GLIBCXX_DEBUG + return &*v->begin(); +#else + return v->empty() ? NULL : &*v->begin(); +#endif +} + +// Like str->resize(new_size), except any new characters added to "*str" as a +// result of resizing may be left uninitialized, rather than being filled with +// '0' bytes. Typically used when code is then going to overwrite the backing +// store of the string with known data. Uses a Google extension to ::string. +inline void STLStringResizeUninitialized(string* s, size_t new_size) { +#if __google_stl_resize_uninitialized_string + s->resize_uninitialized(new_size); +#else + s->resize(new_size); +#endif +} + +// Calls delete (non-array version) on the SECOND item (pointer) in each pair in +// the range [begin, end). +// +// Note: If you're calling this on an entire container, you probably want to +// call STLDeleteValues(&container) instead, or use ValueDeleter. +template <typename ForwardIterator> +void STLDeleteContainerPairSecondPointers(ForwardIterator begin, + ForwardIterator end) { + while (begin != end) { + ForwardIterator temp = begin; + ++begin; + delete temp->second; + } +} + +// Deletes all the elements in an STL container and clears the container. This +// function is suitable for use with a vector, set, hash_set, or any other STL +// container which defines sensible begin(), end(), and clear() methods. +// +// If container is NULL, this function is a no-op. +template <typename T> +void STLDeleteElements(T* container) { + if (!container) return; + auto it = container->begin(); + while (it != container->end()) { + auto temp = it; + ++it; + delete *temp; + } + container->clear(); +} + +// Given an STL container consisting of (key, value) pairs, STLDeleteValues +// deletes all the "value" components and clears the container. Does nothing in +// the case it's given a NULL pointer. +template <typename T> +void STLDeleteValues(T* container) { + if (!container) return; + auto it = container->begin(); + while (it != container->end()) { + auto temp = it; + ++it; + delete temp->second; + } + container->clear(); +} + +// Sorts and removes duplicates from a sequence container. +template <typename T> +inline void STLSortAndRemoveDuplicates(T* v) { + std::sort(v->begin(), v->end()); + v->erase(std::unique(v->begin(), v->end()), v->end()); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_STL_UTIL_H_ diff --git a/tensorflow/core/lib/gtl/top_n.h b/tensorflow/core/lib/gtl/top_n.h new file mode 100644 index 0000000000..b95b998c21 --- /dev/null +++ b/tensorflow/core/lib/gtl/top_n.h @@ -0,0 +1,324 @@ +// This simple class finds the top n elements of an incrementally provided set +// of elements which you push one at a time. If the number of elements exceeds +// n, the lowest elements are incrementally dropped. At the end you get +// a vector of the top elements sorted in descending order (through Extract() or +// ExtractNondestructive()), or a vector of the top elements but not sorted +// (through ExtractUnsorted() or ExtractUnsortedNondestructive()). +// +// The value n is specified in the constructor. If there are p elements pushed +// altogether: +// The total storage requirements are O(min(n, p)) elements +// The running time is O(p * log(min(n, p))) comparisons +// If n is a constant, the total storage required is a constant and the running +// time is linear in p. +// +// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p) +// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements, +// discarding the lowest n elements whenever the buffer is full using a linear- +// time median algorithm. This may have better performance when the input +// sequence is partially sorted. +// +// NOTE(zhifengc): This class should be redesigned to avoid reallocating a +// vector for each Extract. + +#ifndef TENSORFLOW_LIB_GTL_TOP_N_H_ +#define TENSORFLOW_LIB_GTL_TOP_N_H_ + +#include <stddef.h> +#include <algorithm> +#include <functional> +#include <string> +#include <vector> + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace gtl { + +// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate, +// not the more commonly used "less" predicate. +// +// If you use a "less" predicate here, the TopN will pick out the bottom N +// elements out of the ones passed to it, and it will return them sorted in +// ascending order. +// +// TopN is rule-of-zero copyable and movable if its members are. +template <class T, class Cmp = std::greater<T> > +class TopN { + public: + // The TopN is in one of the three states: + // + // o UNORDERED: this is the state an instance is originally in, + // where the elements are completely orderless. + // + // o BOTTOM_KNOWN: in this state, we keep the invariant that there + // is at least one element in it, and the lowest element is at + // position 0. The elements in other positions remain + // unsorted. This state is reached if the state was originally + // UNORDERED and a peek_bottom() function call is invoked. + // + // o HEAP_SORTED: in this state, the array is kept as a heap and + // there are exactly (limit_+1) elements in the array. This + // state is reached when at least (limit_+1) elements are + // pushed in. + // + // The state transition graph is at follows: + // + // peek_bottom() (limit_+1) elements + // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED + // | ^ + // | (limit_+1) elements | + // +-----------------------------------------------------------+ + + enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED }; + using UnsortedIterator = typename std::vector<T>::const_iterator; + + // 'limit' is the maximum number of top results to return. + explicit TopN(size_t limit) : TopN(limit, Cmp()) {} + TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {} + + size_t limit() const { return limit_; } + + // Number of elements currently held by this TopN object. This + // will be no greater than 'limit' passed to the constructor. + size_t size() const { return std::min(elements_.size(), limit_); } + + bool empty() const { return size() == 0; } + + // If you know how many elements you will push at the time you create the + // TopN object, you can call reserve to preallocate the memory that TopN + // will need to process all 'n' pushes. Calling this method is optional. + void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); } + + // Push 'v'. If the maximum number of elements was exceeded, drop the + // lowest element and return it in 'dropped' (if given). If the maximum is not + // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or + // nullptr, in which case it is not filled in. + // Requires: T is CopyAssignable, Swappable + void push(const T &v) { push(v, nullptr); } + void push(const T &v, T *dropped) { PushInternal(v, dropped); } + + // Move overloads of push. + // Requires: T is MoveAssignable, Swappable + void push(T &&v) { // NOLINT(build/c++11) + push(std::move(v), nullptr); + } + void push(T &&v, T *dropped) { // NOLINT(build/c++11) + PushInternal(std::move(v), dropped); + } + + // Peeks the bottom result without calling Extract() + const T &peek_bottom(); + + // Extract the elements as a vector sorted in descending order. The caller + // assumes ownership of the vector and must delete it when done. This is a + // destructive operation. The only method that can be called immediately + // after Extract() is Reset(). + std::vector<T> *Extract(); + + // Similar to Extract(), but makes no guarantees the elements are in sorted + // order. As with Extract(), the caller assumes ownership of the vector and + // must delete it when done. This is a destructive operation. The only + // method that can be called immediately after ExtractUnsorted() is Reset(). + std::vector<T> *ExtractUnsorted(); + + // A non-destructive version of Extract(). Copy the elements in a new vector + // sorted in descending order and return it. The caller assumes ownership of + // the new vector and must delete it when done. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + std::vector<T> *ExtractNondestructive() const; + + // A non-destructive version of Extract(). Copy the elements to a given + // vector sorted in descending order. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractNondestructive(std::vector<T> *output) const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements in a new + // vector and return it, with no guarantees the elements are in sorted order. + // The caller assumes ownership of the new vector and must delete it when + // done. After calling ExtractUnsortedNondestructive(), the caller can + // continue to push() new elements. + std::vector<T> *ExtractUnsortedNondestructive() const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements into + // a given vector, with no guarantees the elements are in sorted order. + // After calling ExtractUnsortedNondestructive(), the caller can continue + // to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractUnsortedNondestructive(std::vector<T> *output) const; + + // Return an iterator to the beginning (end) of the container, + // with no guarantees about the order of iteration. These iterators are + // invalidated by mutation of the data structure. + UnsortedIterator unsorted_begin() const { return elements_.begin(); } + UnsortedIterator unsorted_end() const { return elements_.begin() + size(); } + + // Accessor for comparator template argument. + Cmp *comparator() { return &cmp_; } + + // This removes all elements. If Extract() or ExtractUnsorted() have been + // called, this will put it back in an empty but useable state. + void Reset(); + + private: + template <typename U> + void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11) + + // elements_ can be in one of two states: + // elements_.size() <= limit_: elements_ is an unsorted vector of elements + // pushed so far. + // elements_.size() > limit_: The last element of elements_ is unused; + // the other elements of elements_ are an stl heap whose size is exactly + // limit_. In this case elements_.size() is exactly one greater than + // limit_, but don't use "elements_.size() == limit_ + 1" to check for + // that because you'll get a false positive if limit_ == size_t(-1). + std::vector<T> elements_; + size_t limit_; // Maximum number of elements to find + Cmp cmp_; // Greater-than comparison function + State state_ = UNORDERED; +}; + +// ---------------------------------------------------------------------- +// Implementations of non-inline functions + +template <class T, class Cmp> +template <typename U> +void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11) + if (limit_ == 0) { + if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11) + return; + } + if (state_ != HEAP_SORTED) { + elements_.push_back(std::forward<U>(v)); // NOLINT(build/c++11) + if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) { + // Easy case: we just pushed the new element back + } else { + // To maintain the BOTTOM_KNOWN state, we need to make sure that + // the element at position 0 is always the smallest. So we put + // the new element at position 0 and push the original bottom + // element in the back. + // Warning: this code is subtle. + using std::swap; + swap(elements_.front(), elements_.back()); + } + if (elements_.size() == limit_ + 1) { + // Transition from unsorted vector to a heap. + std::make_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + state_ = HEAP_SORTED; + } + } else { + // Only insert the new element if it is greater than the least element. + if (cmp_(v, elements_.front())) { + elements_.back() = std::forward<U>(v); // NOLINT(build/c++11) + std::push_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + } else { + if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11) + } + } +} + +template <class T, class Cmp> +const T &TopN<T, Cmp>::peek_bottom() { + CHECK(!empty()); + if (state_ == UNORDERED) { + // We need to do a linear scan to find out the bottom element + int min_candidate = 0; + for (size_t i = 1; i < elements_.size(); ++i) { + if (cmp_(elements_[min_candidate], elements_[i])) { + min_candidate = i; + } + } + // By swapping the element at position 0 and the minimal + // element, we transition to the BOTTOM_KNOWN state + if (min_candidate != 0) { + using std::swap; + swap(elements_[0], elements_[min_candidate]); + } + state_ = BOTTOM_KNOWN; + } + return elements_.front(); +} + +template <class T, class Cmp> +std::vector<T> *TopN<T, Cmp>::Extract() { + auto out = new std::vector<T>; + out->swap(elements_); + if (state_ != HEAP_SORTED) { + std::sort(out->begin(), out->end(), cmp_); + } else { + out->pop_back(); + std::sort_heap(out->begin(), out->end(), cmp_); + } + return out; +} + +template <class T, class Cmp> +std::vector<T> *TopN<T, Cmp>::ExtractUnsorted() { + auto out = new std::vector<T>; + out->swap(elements_); + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + out->pop_back(); + } + return out; +} + +template <class T, class Cmp> +std::vector<T> *TopN<T, Cmp>::ExtractNondestructive() const { + auto out = new std::vector<T>; + ExtractNondestructive(out); + return out; +} + +template <class T, class Cmp> +void TopN<T, Cmp>::ExtractNondestructive(std::vector<T> *output) const { + CHECK(output); + *output = elements_; + if (state_ != HEAP_SORTED) { + std::sort(output->begin(), output->end(), cmp_); + } else { + output->pop_back(); + std::sort_heap(output->begin(), output->end(), cmp_); + } +} + +template <class T, class Cmp> +std::vector<T> *TopN<T, Cmp>::ExtractUnsortedNondestructive() const { + auto elements = new std::vector<T>; + ExtractUnsortedNondestructive(elements); + return elements; +} + +template <class T, class Cmp> +void TopN<T, Cmp>::ExtractUnsortedNondestructive(std::vector<T> *output) const { + CHECK(output); + *output = elements_; + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + output->pop_back(); + } +} + +template <class T, class Cmp> +void TopN<T, Cmp>::Reset() { + elements_.clear(); + state_ = UNORDERED; +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_TOP_N_H_ diff --git a/tensorflow/core/lib/gtl/top_n_test.cc b/tensorflow/core/lib/gtl/top_n_test.cc new file mode 100644 index 0000000000..1812a1bd3f --- /dev/null +++ b/tensorflow/core/lib/gtl/top_n_test.cc @@ -0,0 +1,249 @@ +// Unit test for TopN. + +#include "tensorflow/core/lib/gtl/top_n.h" + +#include <string> +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace { + +using tensorflow::gtl::TopN; +using tensorflow::random::PhiloxRandom; +using tensorflow::random::SimplePhilox; +using tensorflow::string; + +// Move the contents from an owned raw pointer, returning by value. +// Objects are easier to manage by value. +template <class T> +T ConsumeRawPtr(T *p) { + T tmp = std::move(*p); + delete p; + return tmp; +} + +template <class Cmp> +void TestIntTopNHelper(size_t limit, size_t n_elements, const Cmp &cmp, + SimplePhilox *random, bool test_peek, + bool test_extract_unsorted) { + LOG(INFO) << "Testing limit=" << limit << ", n_elements=" << n_elements + << ", test_peek=" << test_peek + << ", test_extract_unsorted=" << test_extract_unsorted; + TopN<int, Cmp> top(limit, cmp); + std::vector<int> shadow(n_elements); + for (int i = 0; i != n_elements; ++i) shadow[i] = random->Uniform(limit); + for (int e : shadow) top.push(e); + std::sort(shadow.begin(), shadow.end(), cmp); + size_t top_size = std::min(limit, n_elements); + EXPECT_EQ(top_size, top.size()); + if (test_peek && top_size != 0) { + EXPECT_EQ(shadow[top_size - 1], top.peek_bottom()); + } + std::vector<int> v; + if (test_extract_unsorted) { + v = ConsumeRawPtr(top.ExtractUnsorted()); + std::sort(v.begin(), v.end(), cmp); + } else { + v = ConsumeRawPtr(top.Extract()); + } + EXPECT_EQ(top_size, v.size()); + for (int i = 0; i != top_size; ++i) { + VLOG(1) << "Top element " << v[i]; + EXPECT_EQ(shadow[i], v[i]); + } +} + +template <class Cmp> +void TestIntTopN(size_t limit, size_t n_elements, const Cmp &cmp, + SimplePhilox *random) { + // Test peek_bottom() and Extract() + TestIntTopNHelper(limit, n_elements, cmp, random, true, false); + // Test Extract() + TestIntTopNHelper(limit, n_elements, cmp, random, false, false); + // Test peek_bottom() and ExtractUnsorted() + TestIntTopNHelper(limit, n_elements, cmp, random, true, true); + // Test ExtractUnsorted() + TestIntTopNHelper(limit, n_elements, cmp, random, false, true); +} + +TEST(TopNTest, Misc) { + PhiloxRandom philox(1, 1); + SimplePhilox random(&philox); + + TestIntTopN(0, 5, std::greater<int>(), &random); + TestIntTopN(32, 0, std::greater<int>(), &random); + TestIntTopN(6, 6, std::greater<int>(), &random); + TestIntTopN(6, 6, std::less<int>(), &random); + TestIntTopN(1000, 999, std::greater<int>(), &random); + TestIntTopN(1000, 1000, std::greater<int>(), &random); + TestIntTopN(1000, 1001, std::greater<int>(), &random); + TestIntTopN(2300, 28393, std::less<int>(), &random); + TestIntTopN(30, 100, std::greater<int>(), &random); + TestIntTopN(100, 30, std::less<int>(), &random); + TestIntTopN(size_t(-1), 3, std::greater<int>(), &random); + TestIntTopN(size_t(-1), 0, std::greater<int>(), &random); + TestIntTopN(0, 5, std::greater<int>(), &random); +} + +TEST(TopNTest, String) { + LOG(INFO) << "Testing strings"; + + TopN<string> top(3); + EXPECT_TRUE(top.empty()); + top.push("abracadabra"); + top.push("waldemar"); + EXPECT_EQ(2, top.size()); + EXPECT_EQ("abracadabra", top.peek_bottom()); + top.push(""); + EXPECT_EQ(3, top.size()); + EXPECT_EQ("", top.peek_bottom()); + top.push("top"); + EXPECT_EQ(3, top.size()); + EXPECT_EQ("abracadabra", top.peek_bottom()); + top.push("Google"); + top.push("test"); + EXPECT_EQ(3, top.size()); + EXPECT_EQ("test", top.peek_bottom()); + TopN<string> top2(top); + TopN<string> top3(5); + top3 = top; + EXPECT_EQ("test", top3.peek_bottom()); + { + std::vector<string> s = ConsumeRawPtr(top.Extract()); + EXPECT_EQ(s[0], "waldemar"); + EXPECT_EQ(s[1], "top"); + EXPECT_EQ(s[2], "test"); + } + + top2.push("zero"); + EXPECT_EQ(top2.peek_bottom(), "top"); + + { + std::vector<string> s = ConsumeRawPtr(top2.Extract()); + EXPECT_EQ(s[0], "zero"); + EXPECT_EQ(s[1], "waldemar"); + EXPECT_EQ(s[2], "top"); + } + { + std::vector<string> s = ConsumeRawPtr(top3.Extract()); + EXPECT_EQ(s[0], "waldemar"); + EXPECT_EQ(s[1], "top"); + EXPECT_EQ(s[2], "test"); + } + + TopN<string> top4(3); + // Run this test twice to check Reset(): + for (int i = 0; i < 2; ++i) { + top4.push("abcd"); + top4.push("ijkl"); + top4.push("efgh"); + top4.push("mnop"); + std::vector<string> s = ConsumeRawPtr(top4.Extract()); + EXPECT_EQ(s[0], "mnop"); + EXPECT_EQ(s[1], "ijkl"); + EXPECT_EQ(s[2], "efgh"); + top4.Reset(); + } +} + +// Test that pointers aren't leaked from a TopN if we use the 2-argument version +// of push(). +TEST(TopNTest, Ptr) { + LOG(INFO) << "Testing 2-argument push()"; + TopN<string *> topn(3); + for (int i = 0; i < 8; ++i) { + string *dropped = NULL; + topn.push(new string(std::to_string(i)), &dropped); + delete dropped; + } + + for (int i = 8; i > 0; --i) { + string *dropped = NULL; + topn.push(new string(std::to_string(i)), &dropped); + delete dropped; + } + + std::vector<string *> extract = ConsumeRawPtr(topn.Extract()); + tensorflow::gtl::STLDeleteElements(&extract); +} + +struct PointeeGreater { + template <typename T> + bool operator()(const T &a, const T &b) const { + return *a > *b; + } +}; + +TEST(TopNTest, MoveOnly) { + using StrPtr = std::unique_ptr<string>; + TopN<StrPtr, PointeeGreater> topn(3); + for (int i = 0; i < 8; ++i) topn.push(StrPtr(new string(std::to_string(i)))); + for (int i = 8; i > 0; --i) topn.push(StrPtr(new string(std::to_string(i)))); + + std::vector<StrPtr> extract = ConsumeRawPtr(topn.Extract()); + EXPECT_EQ(extract.size(), 3); + EXPECT_EQ(*(extract[0]), "8"); + EXPECT_EQ(*(extract[1]), "7"); + EXPECT_EQ(*(extract[2]), "7"); +} + +// Test that Nondestructive extracts do not need a Reset() afterwards, +// and that pointers aren't leaked from a TopN after calling them. +TEST(TopNTest, Nondestructive) { + LOG(INFO) << "Testing Nondestructive extracts"; + TopN<int> top4(4); + for (int i = 0; i < 8; ++i) { + top4.push(i); + std::vector<int> v = ConsumeRawPtr(top4.ExtractNondestructive()); + EXPECT_EQ(std::min(i + 1, 4), v.size()); + for (size_t j = 0; j < v.size(); ++j) EXPECT_EQ(i - j, v[j]); + } + + TopN<int> top3(3); + for (int i = 0; i < 8; ++i) { + top3.push(i); + std::vector<int> v = ConsumeRawPtr(top3.ExtractUnsortedNondestructive()); + std::sort(v.begin(), v.end(), std::greater<int>()); + EXPECT_EQ(std::min(i + 1, 3), v.size()); + for (size_t j = 0; j < v.size(); ++j) EXPECT_EQ(i - j, v[j]); + } +} + +struct ForbiddenCmp { + bool operator()(int lhs, int rhs) const { + LOG(FATAL) << "ForbiddenCmp called " << lhs << " " << rhs; + } +}; + +TEST(TopNTest, ZeroLimit) { + TopN<int, ForbiddenCmp> top(0); + top.push(1); + top.push(2); + + int dropped = -1; + top.push(1, &dropped); + top.push(2, &dropped); + + std::vector<int> v; + top.ExtractNondestructive(&v); + EXPECT_EQ(0, v.size()); +} + +TEST(TopNTest, Iteration) { + TopN<int> top(4); + for (int i = 0; i < 8; ++i) top.push(i); + std::vector<int> actual(top.unsorted_begin(), top.unsorted_end()); + // Check that we have 4,5,6,7 as the top 4 (in some order, so we sort) + sort(actual.begin(), actual.end()); + EXPECT_EQ(actual.size(), 4); + EXPECT_EQ(actual[0], 4); + EXPECT_EQ(actual[1], 5); + EXPECT_EQ(actual[2], 6); + EXPECT_EQ(actual[3], 7); +} +} // namespace diff --git a/tensorflow/core/lib/hash/crc32c.cc b/tensorflow/core/lib/hash/crc32c.cc new file mode 100644 index 0000000000..3bef1cf78d --- /dev/null +++ b/tensorflow/core/lib/hash/crc32c.cc @@ -0,0 +1,244 @@ +// A portable implementation of crc32c, optimized to handle +// four bytes at a time. + +#include "tensorflow/core/lib/hash/crc32c.h" + +#include <stdint.h> +#include "tensorflow/core/lib/core/coding.h" + +namespace tensorflow { +namespace crc32c { + +static const uint32 table0_[256] = { + 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c, + 0x26a1e7e8, 0xd4ca64eb, 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, + 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 0x105ec76f, 0xe235446c, + 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, + 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc, + 0xbc267848, 0x4e4dfb4b, 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, + 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 0xaa64d611, 0x580f5512, + 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, + 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad, + 0x1642ae59, 0xe4292d5a, 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, + 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 0x417b1dbc, 0xb3109ebf, + 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, + 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f, + 0xed03a29b, 0x1f682198, 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, + 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 0xdbfc821c, 0x2997011f, + 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, + 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e, + 0x4767748a, 0xb50cf789, 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, + 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 0x7198540d, 0x83f3d70e, + 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, + 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de, + 0xdde0eb2a, 0x2f8b6829, 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, + 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 0x082f63b7, 0xfa44e0b4, + 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, + 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b, + 0xb4091bff, 0x466298fc, 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, + 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 0xa24bb5a6, 0x502036a5, + 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, + 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975, + 0x0e330a81, 0xfc588982, 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, + 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 0x38cc2a06, 0xcaa7a905, + 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, + 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8, + 0xe52cc12c, 0x1747422f, 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, + 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 0xd3d3e1ab, 0x21b862a8, + 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, + 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78, + 0x7fab5e8c, 0x8dc0dd8f, 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, + 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 0x69e9f0d5, 0x9b8273d6, + 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, + 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69, + 0xd5cf889d, 0x27a40b9e, 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, + 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351}; +static const uint32 table1_[256] = { + 0x00000000, 0x13a29877, 0x274530ee, 0x34e7a899, 0x4e8a61dc, 0x5d28f9ab, + 0x69cf5132, 0x7a6dc945, 0x9d14c3b8, 0x8eb65bcf, 0xba51f356, 0xa9f36b21, + 0xd39ea264, 0xc03c3a13, 0xf4db928a, 0xe7790afd, 0x3fc5f181, 0x2c6769f6, + 0x1880c16f, 0x0b225918, 0x714f905d, 0x62ed082a, 0x560aa0b3, 0x45a838c4, + 0xa2d13239, 0xb173aa4e, 0x859402d7, 0x96369aa0, 0xec5b53e5, 0xfff9cb92, + 0xcb1e630b, 0xd8bcfb7c, 0x7f8be302, 0x6c297b75, 0x58ced3ec, 0x4b6c4b9b, + 0x310182de, 0x22a31aa9, 0x1644b230, 0x05e62a47, 0xe29f20ba, 0xf13db8cd, + 0xc5da1054, 0xd6788823, 0xac154166, 0xbfb7d911, 0x8b507188, 0x98f2e9ff, + 0x404e1283, 0x53ec8af4, 0x670b226d, 0x74a9ba1a, 0x0ec4735f, 0x1d66eb28, + 0x298143b1, 0x3a23dbc6, 0xdd5ad13b, 0xcef8494c, 0xfa1fe1d5, 0xe9bd79a2, + 0x93d0b0e7, 0x80722890, 0xb4958009, 0xa737187e, 0xff17c604, 0xecb55e73, + 0xd852f6ea, 0xcbf06e9d, 0xb19da7d8, 0xa23f3faf, 0x96d89736, 0x857a0f41, + 0x620305bc, 0x71a19dcb, 0x45463552, 0x56e4ad25, 0x2c896460, 0x3f2bfc17, + 0x0bcc548e, 0x186eccf9, 0xc0d23785, 0xd370aff2, 0xe797076b, 0xf4359f1c, + 0x8e585659, 0x9dface2e, 0xa91d66b7, 0xbabffec0, 0x5dc6f43d, 0x4e646c4a, + 0x7a83c4d3, 0x69215ca4, 0x134c95e1, 0x00ee0d96, 0x3409a50f, 0x27ab3d78, + 0x809c2506, 0x933ebd71, 0xa7d915e8, 0xb47b8d9f, 0xce1644da, 0xddb4dcad, + 0xe9537434, 0xfaf1ec43, 0x1d88e6be, 0x0e2a7ec9, 0x3acdd650, 0x296f4e27, + 0x53028762, 0x40a01f15, 0x7447b78c, 0x67e52ffb, 0xbf59d487, 0xacfb4cf0, + 0x981ce469, 0x8bbe7c1e, 0xf1d3b55b, 0xe2712d2c, 0xd69685b5, 0xc5341dc2, + 0x224d173f, 0x31ef8f48, 0x050827d1, 0x16aabfa6, 0x6cc776e3, 0x7f65ee94, + 0x4b82460d, 0x5820de7a, 0xfbc3faf9, 0xe861628e, 0xdc86ca17, 0xcf245260, + 0xb5499b25, 0xa6eb0352, 0x920cabcb, 0x81ae33bc, 0x66d73941, 0x7575a136, + 0x419209af, 0x523091d8, 0x285d589d, 0x3bffc0ea, 0x0f186873, 0x1cbaf004, + 0xc4060b78, 0xd7a4930f, 0xe3433b96, 0xf0e1a3e1, 0x8a8c6aa4, 0x992ef2d3, + 0xadc95a4a, 0xbe6bc23d, 0x5912c8c0, 0x4ab050b7, 0x7e57f82e, 0x6df56059, + 0x1798a91c, 0x043a316b, 0x30dd99f2, 0x237f0185, 0x844819fb, 0x97ea818c, + 0xa30d2915, 0xb0afb162, 0xcac27827, 0xd960e050, 0xed8748c9, 0xfe25d0be, + 0x195cda43, 0x0afe4234, 0x3e19eaad, 0x2dbb72da, 0x57d6bb9f, 0x447423e8, + 0x70938b71, 0x63311306, 0xbb8de87a, 0xa82f700d, 0x9cc8d894, 0x8f6a40e3, + 0xf50789a6, 0xe6a511d1, 0xd242b948, 0xc1e0213f, 0x26992bc2, 0x353bb3b5, + 0x01dc1b2c, 0x127e835b, 0x68134a1e, 0x7bb1d269, 0x4f567af0, 0x5cf4e287, + 0x04d43cfd, 0x1776a48a, 0x23910c13, 0x30339464, 0x4a5e5d21, 0x59fcc556, + 0x6d1b6dcf, 0x7eb9f5b8, 0x99c0ff45, 0x8a626732, 0xbe85cfab, 0xad2757dc, + 0xd74a9e99, 0xc4e806ee, 0xf00fae77, 0xe3ad3600, 0x3b11cd7c, 0x28b3550b, + 0x1c54fd92, 0x0ff665e5, 0x759baca0, 0x663934d7, 0x52de9c4e, 0x417c0439, + 0xa6050ec4, 0xb5a796b3, 0x81403e2a, 0x92e2a65d, 0xe88f6f18, 0xfb2df76f, + 0xcfca5ff6, 0xdc68c781, 0x7b5fdfff, 0x68fd4788, 0x5c1aef11, 0x4fb87766, + 0x35d5be23, 0x26772654, 0x12908ecd, 0x013216ba, 0xe64b1c47, 0xf5e98430, + 0xc10e2ca9, 0xd2acb4de, 0xa8c17d9b, 0xbb63e5ec, 0x8f844d75, 0x9c26d502, + 0x449a2e7e, 0x5738b609, 0x63df1e90, 0x707d86e7, 0x0a104fa2, 0x19b2d7d5, + 0x2d557f4c, 0x3ef7e73b, 0xd98eedc6, 0xca2c75b1, 0xfecbdd28, 0xed69455f, + 0x97048c1a, 0x84a6146d, 0xb041bcf4, 0xa3e32483}; +static const uint32 table2_[256] = { + 0x00000000, 0xa541927e, 0x4f6f520d, 0xea2ec073, 0x9edea41a, 0x3b9f3664, + 0xd1b1f617, 0x74f06469, 0x38513ec5, 0x9d10acbb, 0x773e6cc8, 0xd27ffeb6, + 0xa68f9adf, 0x03ce08a1, 0xe9e0c8d2, 0x4ca15aac, 0x70a27d8a, 0xd5e3eff4, + 0x3fcd2f87, 0x9a8cbdf9, 0xee7cd990, 0x4b3d4bee, 0xa1138b9d, 0x045219e3, + 0x48f3434f, 0xedb2d131, 0x079c1142, 0xa2dd833c, 0xd62de755, 0x736c752b, + 0x9942b558, 0x3c032726, 0xe144fb14, 0x4405696a, 0xae2ba919, 0x0b6a3b67, + 0x7f9a5f0e, 0xdadbcd70, 0x30f50d03, 0x95b49f7d, 0xd915c5d1, 0x7c5457af, + 0x967a97dc, 0x333b05a2, 0x47cb61cb, 0xe28af3b5, 0x08a433c6, 0xade5a1b8, + 0x91e6869e, 0x34a714e0, 0xde89d493, 0x7bc846ed, 0x0f382284, 0xaa79b0fa, + 0x40577089, 0xe516e2f7, 0xa9b7b85b, 0x0cf62a25, 0xe6d8ea56, 0x43997828, + 0x37691c41, 0x92288e3f, 0x78064e4c, 0xdd47dc32, 0xc76580d9, 0x622412a7, + 0x880ad2d4, 0x2d4b40aa, 0x59bb24c3, 0xfcfab6bd, 0x16d476ce, 0xb395e4b0, + 0xff34be1c, 0x5a752c62, 0xb05bec11, 0x151a7e6f, 0x61ea1a06, 0xc4ab8878, + 0x2e85480b, 0x8bc4da75, 0xb7c7fd53, 0x12866f2d, 0xf8a8af5e, 0x5de93d20, + 0x29195949, 0x8c58cb37, 0x66760b44, 0xc337993a, 0x8f96c396, 0x2ad751e8, + 0xc0f9919b, 0x65b803e5, 0x1148678c, 0xb409f5f2, 0x5e273581, 0xfb66a7ff, + 0x26217bcd, 0x8360e9b3, 0x694e29c0, 0xcc0fbbbe, 0xb8ffdfd7, 0x1dbe4da9, + 0xf7908dda, 0x52d11fa4, 0x1e704508, 0xbb31d776, 0x511f1705, 0xf45e857b, + 0x80aee112, 0x25ef736c, 0xcfc1b31f, 0x6a802161, 0x56830647, 0xf3c29439, + 0x19ec544a, 0xbcadc634, 0xc85da25d, 0x6d1c3023, 0x8732f050, 0x2273622e, + 0x6ed23882, 0xcb93aafc, 0x21bd6a8f, 0x84fcf8f1, 0xf00c9c98, 0x554d0ee6, + 0xbf63ce95, 0x1a225ceb, 0x8b277743, 0x2e66e53d, 0xc448254e, 0x6109b730, + 0x15f9d359, 0xb0b84127, 0x5a968154, 0xffd7132a, 0xb3764986, 0x1637dbf8, + 0xfc191b8b, 0x595889f5, 0x2da8ed9c, 0x88e97fe2, 0x62c7bf91, 0xc7862def, + 0xfb850ac9, 0x5ec498b7, 0xb4ea58c4, 0x11abcaba, 0x655baed3, 0xc01a3cad, + 0x2a34fcde, 0x8f756ea0, 0xc3d4340c, 0x6695a672, 0x8cbb6601, 0x29faf47f, + 0x5d0a9016, 0xf84b0268, 0x1265c21b, 0xb7245065, 0x6a638c57, 0xcf221e29, + 0x250cde5a, 0x804d4c24, 0xf4bd284d, 0x51fcba33, 0xbbd27a40, 0x1e93e83e, + 0x5232b292, 0xf77320ec, 0x1d5de09f, 0xb81c72e1, 0xccec1688, 0x69ad84f6, + 0x83834485, 0x26c2d6fb, 0x1ac1f1dd, 0xbf8063a3, 0x55aea3d0, 0xf0ef31ae, + 0x841f55c7, 0x215ec7b9, 0xcb7007ca, 0x6e3195b4, 0x2290cf18, 0x87d15d66, + 0x6dff9d15, 0xc8be0f6b, 0xbc4e6b02, 0x190ff97c, 0xf321390f, 0x5660ab71, + 0x4c42f79a, 0xe90365e4, 0x032da597, 0xa66c37e9, 0xd29c5380, 0x77ddc1fe, + 0x9df3018d, 0x38b293f3, 0x7413c95f, 0xd1525b21, 0x3b7c9b52, 0x9e3d092c, + 0xeacd6d45, 0x4f8cff3b, 0xa5a23f48, 0x00e3ad36, 0x3ce08a10, 0x99a1186e, + 0x738fd81d, 0xd6ce4a63, 0xa23e2e0a, 0x077fbc74, 0xed517c07, 0x4810ee79, + 0x04b1b4d5, 0xa1f026ab, 0x4bdee6d8, 0xee9f74a6, 0x9a6f10cf, 0x3f2e82b1, + 0xd50042c2, 0x7041d0bc, 0xad060c8e, 0x08479ef0, 0xe2695e83, 0x4728ccfd, + 0x33d8a894, 0x96993aea, 0x7cb7fa99, 0xd9f668e7, 0x9557324b, 0x3016a035, + 0xda386046, 0x7f79f238, 0x0b899651, 0xaec8042f, 0x44e6c45c, 0xe1a75622, + 0xdda47104, 0x78e5e37a, 0x92cb2309, 0x378ab177, 0x437ad51e, 0xe63b4760, + 0x0c158713, 0xa954156d, 0xe5f54fc1, 0x40b4ddbf, 0xaa9a1dcc, 0x0fdb8fb2, + 0x7b2bebdb, 0xde6a79a5, 0x3444b9d6, 0x91052ba8}; +static const uint32 table3_[256] = { + 0x00000000, 0xdd45aab8, 0xbf672381, 0x62228939, 0x7b2231f3, 0xa6679b4b, + 0xc4451272, 0x1900b8ca, 0xf64463e6, 0x2b01c95e, 0x49234067, 0x9466eadf, + 0x8d665215, 0x5023f8ad, 0x32017194, 0xef44db2c, 0xe964b13d, 0x34211b85, + 0x560392bc, 0x8b463804, 0x924680ce, 0x4f032a76, 0x2d21a34f, 0xf06409f7, + 0x1f20d2db, 0xc2657863, 0xa047f15a, 0x7d025be2, 0x6402e328, 0xb9474990, + 0xdb65c0a9, 0x06206a11, 0xd725148b, 0x0a60be33, 0x6842370a, 0xb5079db2, + 0xac072578, 0x71428fc0, 0x136006f9, 0xce25ac41, 0x2161776d, 0xfc24ddd5, + 0x9e0654ec, 0x4343fe54, 0x5a43469e, 0x8706ec26, 0xe524651f, 0x3861cfa7, + 0x3e41a5b6, 0xe3040f0e, 0x81268637, 0x5c632c8f, 0x45639445, 0x98263efd, + 0xfa04b7c4, 0x27411d7c, 0xc805c650, 0x15406ce8, 0x7762e5d1, 0xaa274f69, + 0xb327f7a3, 0x6e625d1b, 0x0c40d422, 0xd1057e9a, 0xaba65fe7, 0x76e3f55f, + 0x14c17c66, 0xc984d6de, 0xd0846e14, 0x0dc1c4ac, 0x6fe34d95, 0xb2a6e72d, + 0x5de23c01, 0x80a796b9, 0xe2851f80, 0x3fc0b538, 0x26c00df2, 0xfb85a74a, + 0x99a72e73, 0x44e284cb, 0x42c2eeda, 0x9f874462, 0xfda5cd5b, 0x20e067e3, + 0x39e0df29, 0xe4a57591, 0x8687fca8, 0x5bc25610, 0xb4868d3c, 0x69c32784, + 0x0be1aebd, 0xd6a40405, 0xcfa4bccf, 0x12e11677, 0x70c39f4e, 0xad8635f6, + 0x7c834b6c, 0xa1c6e1d4, 0xc3e468ed, 0x1ea1c255, 0x07a17a9f, 0xdae4d027, + 0xb8c6591e, 0x6583f3a6, 0x8ac7288a, 0x57828232, 0x35a00b0b, 0xe8e5a1b3, + 0xf1e51979, 0x2ca0b3c1, 0x4e823af8, 0x93c79040, 0x95e7fa51, 0x48a250e9, + 0x2a80d9d0, 0xf7c57368, 0xeec5cba2, 0x3380611a, 0x51a2e823, 0x8ce7429b, + 0x63a399b7, 0xbee6330f, 0xdcc4ba36, 0x0181108e, 0x1881a844, 0xc5c402fc, + 0xa7e68bc5, 0x7aa3217d, 0x52a0c93f, 0x8fe56387, 0xedc7eabe, 0x30824006, + 0x2982f8cc, 0xf4c75274, 0x96e5db4d, 0x4ba071f5, 0xa4e4aad9, 0x79a10061, + 0x1b838958, 0xc6c623e0, 0xdfc69b2a, 0x02833192, 0x60a1b8ab, 0xbde41213, + 0xbbc47802, 0x6681d2ba, 0x04a35b83, 0xd9e6f13b, 0xc0e649f1, 0x1da3e349, + 0x7f816a70, 0xa2c4c0c8, 0x4d801be4, 0x90c5b15c, 0xf2e73865, 0x2fa292dd, + 0x36a22a17, 0xebe780af, 0x89c50996, 0x5480a32e, 0x8585ddb4, 0x58c0770c, + 0x3ae2fe35, 0xe7a7548d, 0xfea7ec47, 0x23e246ff, 0x41c0cfc6, 0x9c85657e, + 0x73c1be52, 0xae8414ea, 0xcca69dd3, 0x11e3376b, 0x08e38fa1, 0xd5a62519, + 0xb784ac20, 0x6ac10698, 0x6ce16c89, 0xb1a4c631, 0xd3864f08, 0x0ec3e5b0, + 0x17c35d7a, 0xca86f7c2, 0xa8a47efb, 0x75e1d443, 0x9aa50f6f, 0x47e0a5d7, + 0x25c22cee, 0xf8878656, 0xe1873e9c, 0x3cc29424, 0x5ee01d1d, 0x83a5b7a5, + 0xf90696d8, 0x24433c60, 0x4661b559, 0x9b241fe1, 0x8224a72b, 0x5f610d93, + 0x3d4384aa, 0xe0062e12, 0x0f42f53e, 0xd2075f86, 0xb025d6bf, 0x6d607c07, + 0x7460c4cd, 0xa9256e75, 0xcb07e74c, 0x16424df4, 0x106227e5, 0xcd278d5d, + 0xaf050464, 0x7240aedc, 0x6b401616, 0xb605bcae, 0xd4273597, 0x09629f2f, + 0xe6264403, 0x3b63eebb, 0x59416782, 0x8404cd3a, 0x9d0475f0, 0x4041df48, + 0x22635671, 0xff26fcc9, 0x2e238253, 0xf36628eb, 0x9144a1d2, 0x4c010b6a, + 0x5501b3a0, 0x88441918, 0xea669021, 0x37233a99, 0xd867e1b5, 0x05224b0d, + 0x6700c234, 0xba45688c, 0xa345d046, 0x7e007afe, 0x1c22f3c7, 0xc167597f, + 0xc747336e, 0x1a0299d6, 0x782010ef, 0xa565ba57, 0xbc65029d, 0x6120a825, + 0x0302211c, 0xde478ba4, 0x31035088, 0xec46fa30, 0x8e647309, 0x5321d9b1, + 0x4a21617b, 0x9764cbc3, 0xf54642fa, 0x2803e842}; + +// Used to fetch a naturally-aligned 32-bit word in little endian byte-order +static inline uint32_t LE_LOAD32(const uint8_t *p) { + return core::DecodeFixed32(reinterpret_cast<const char *>(p)); +} + +uint32 Extend(uint32 crc, const char *buf, size_t size) { + const uint8 *p = reinterpret_cast<const uint8 *>(buf); + const uint8 *e = p + size; + uint32 l = crc ^ 0xffffffffu; + +#define STEP1 \ + do { \ + int c = (l & 0xff) ^ *p++; \ + l = table0_[c] ^ (l >> 8); \ + } while (0) + +#define STEP4 \ + do { \ + uint32 c = l ^ LE_LOAD32(p); \ + p += 4; \ + l = table3_[c & 0xff] ^ table2_[(c >> 8) & 0xff] ^ \ + table1_[(c >> 16) & 0xff] ^ table0_[c >> 24]; \ + } while (0) + + // Point x at first 4-byte aligned byte in string. This might be + // just past the end of the string. + const uintptr_t pval = reinterpret_cast<uintptr_t>(p); + const uint8 *x = reinterpret_cast<const uint8 *>(((pval + 3) >> 2) << 2); + if (x <= e) { + // Process bytes until finished or p is 4-byte aligned + while (p != x) { + STEP1; + } + } + // Process bytes 16 at a time + while ((e - p) >= 16) { + STEP4; + STEP4; + STEP4; + STEP4; + } + // Process bytes 4 at a time + while ((e - p) >= 4) { + STEP4; + } + // Process the last few bytes + while (p != e) { + STEP1; + } +#undef STEP4 +#undef STEP1 + return l ^ 0xffffffffu; +} + +} // namespace crc32c +} // namespace tensorflow diff --git a/tensorflow/core/lib/hash/crc32c.h b/tensorflow/core/lib/hash/crc32c.h new file mode 100644 index 0000000000..f728b6f5e7 --- /dev/null +++ b/tensorflow/core/lib/hash/crc32c.h @@ -0,0 +1,39 @@ +#ifndef TENSORFLOW_LIB_HASH_CRC32C_H_ +#define TENSORFLOW_LIB_HASH_CRC32C_H_ + +#include <stddef.h> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace crc32c { + +// Return the crc32c of concat(A, data[0,n-1]) where init_crc is the +// crc32c of some string A. Extend() is often used to maintain the +// crc32c of a stream of data. +extern uint32 Extend(uint32 init_crc, const char* data, size_t n); + +// Return the crc32c of data[0,n-1] +inline uint32 Value(const char* data, size_t n) { return Extend(0, data, n); } + +static const uint32 kMaskDelta = 0xa282ead8ul; + +// Return a masked representation of crc. +// +// Motivation: it is problematic to compute the CRC of a string that +// contains embedded CRCs. Therefore we recommend that CRCs stored +// somewhere (e.g., in files) should be masked before being stored. +inline uint32 Mask(uint32 crc) { + // Rotate right by 15 bits and add a constant. + return ((crc >> 15) | (crc << 17)) + kMaskDelta; +} + +// Return the crc whose masked representation is masked_crc. +inline uint32 Unmask(uint32 masked_crc) { + uint32 rot = masked_crc - kMaskDelta; + return ((rot >> 17) | (rot << 15)); +} + +} // namespace crc32c +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_HASH_CRC32C_H_ diff --git a/tensorflow/core/lib/hash/crc32c_test.cc b/tensorflow/core/lib/hash/crc32c_test.cc new file mode 100644 index 0000000000..54aced3186 --- /dev/null +++ b/tensorflow/core/lib/hash/crc32c_test.cc @@ -0,0 +1,51 @@ +#include "tensorflow/core/lib/hash/crc32c.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace crc32c { + +TEST(CRC, StandardResults) { + // From rfc3720 section B.4. + char buf[32]; + + memset(buf, 0, sizeof(buf)); + ASSERT_EQ(0x8a9136aa, Value(buf, sizeof(buf))); + + memset(buf, 0xff, sizeof(buf)); + ASSERT_EQ(0x62a8ab43, Value(buf, sizeof(buf))); + + for (int i = 0; i < 32; i++) { + buf[i] = i; + } + ASSERT_EQ(0x46dd794e, Value(buf, sizeof(buf))); + + for (int i = 0; i < 32; i++) { + buf[i] = 31 - i; + } + ASSERT_EQ(0x113fdb5c, Value(buf, sizeof(buf))); + + unsigned char data[48] = { + 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + ASSERT_EQ(0xd9963a56, Value(reinterpret_cast<char*>(data), sizeof(data))); +} + +TEST(CRC, Values) { ASSERT_NE(Value("a", 1), Value("foo", 3)); } + +TEST(CRC, Extend) { + ASSERT_EQ(Value("hello world", 11), Extend(Value("hello ", 6), "world", 5)); +} + +TEST(CRC, Mask) { + uint32 crc = Value("foo", 3); + ASSERT_NE(crc, Mask(crc)); + ASSERT_NE(crc, Mask(Mask(crc))); + ASSERT_EQ(crc, Unmask(Mask(crc))); + ASSERT_EQ(crc, Unmask(Unmask(Mask(Mask(crc))))); +} + +} // namespace crc32c +} // namespace tensorflow diff --git a/tensorflow/core/lib/hash/hash.cc b/tensorflow/core/lib/hash/hash.cc new file mode 100644 index 0000000000..075d252412 --- /dev/null +++ b/tensorflow/core/lib/hash/hash.cc @@ -0,0 +1,113 @@ +#include "tensorflow/core/lib/hash/hash.h" + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/raw_coding.h" + +#include <string.h> + +namespace tensorflow { + +// 0xff is in case char is signed. +static inline uint32 ByteAs32(char c) { return static_cast<uint32>(c) & 0xff; } +static inline uint64 ByteAs64(char c) { return static_cast<uint64>(c) & 0xff; } + +uint32 Hash32(const char* data, size_t n, uint32 seed) { + // 'm' and 'r' are mixing constants generated offline. + // They're not really 'magic', they just happen to work well. + + const uint32 m = 0x5bd1e995; + const int r = 24; + + // Initialize the hash to a 'random' value + uint32 h = seed ^ n; + + // Mix 4 bytes at a time into the hash + while (n >= 4) { + uint32 k = core::DecodeFixed32(data); + + k *= m; + k ^= k >> r; + k *= m; + + h *= m; + h ^= k; + + data += 4; + n -= 4; + } + + // Handle the last few bytes of the input array + + switch (n) { + case 3: + h ^= ByteAs32(data[2]) << 16; + TF_FALLTHROUGH_INTENDED; + case 2: + h ^= ByteAs32(data[1]) << 8; + TF_FALLTHROUGH_INTENDED; + case 1: + h ^= ByteAs32(data[0]); + h *= m; + } + + // Do a few final mixes of the hash to ensure the last few + // bytes are well-incorporated. + + h ^= h >> 13; + h *= m; + h ^= h >> 15; + + return h; +} + +uint64 Hash64(const char* data, size_t n, uint64 seed) { + const uint64 m = 0xc6a4a7935bd1e995; + const int r = 47; + + uint64 h = seed ^ (n * m); + + while (n >= 8) { + uint64 k = core::DecodeFixed64(data); + data += 8; + n -= 8; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + switch (n) { + case 7: + h ^= ByteAs64(data[6]) << 48; + TF_FALLTHROUGH_INTENDED; + case 6: + h ^= ByteAs64(data[5]) << 40; + TF_FALLTHROUGH_INTENDED; + case 5: + h ^= ByteAs64(data[4]) << 32; + TF_FALLTHROUGH_INTENDED; + case 4: + h ^= ByteAs64(data[3]) << 24; + TF_FALLTHROUGH_INTENDED; + case 3: + h ^= ByteAs64(data[2]) << 16; + TF_FALLTHROUGH_INTENDED; + case 2: + h ^= ByteAs64(data[1]) << 8; + TF_FALLTHROUGH_INTENDED; + case 1: + h ^= ByteAs64(data[0]); + h *= m; + } + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h new file mode 100644 index 0000000000..af56218fed --- /dev/null +++ b/tensorflow/core/lib/hash/hash.h @@ -0,0 +1,28 @@ +// Simple hash functions used for internal data structures + +#ifndef TENSORFLOW_LIB_HASH_HASH_H_ +#define TENSORFLOW_LIB_HASH_HASH_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <string> + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +extern uint32 Hash32(const char* data, size_t n, uint32 seed); +extern uint64 Hash64(const char* data, size_t n, uint64 seed); + +inline uint64 Hash64(const char* data, size_t n) { + return Hash64(data, n, 0xDECAFCAFFE); +} + +inline uint64 Hash64(const string& str) { + return Hash64(str.data(), str.size()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_HASH_HASH_H_ diff --git a/tensorflow/core/lib/hash/hash_test.cc b/tensorflow/core/lib/hash/hash_test.cc new file mode 100644 index 0000000000..9d3b970f3b --- /dev/null +++ b/tensorflow/core/lib/hash/hash_test.cc @@ -0,0 +1,64 @@ +#include <vector> + +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +TEST(Hash, SignedUnsignedIssue) { + const unsigned char d1[1] = {0x62}; + const unsigned char d2[2] = {0xc3, 0x97}; + const unsigned char d3[3] = {0xe2, 0x99, 0xa5}; + const unsigned char d4[4] = {0xe1, 0x80, 0xb9, 0x32}; + const unsigned char d5[48] = { + 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + + struct Case { + uint32 hash32; + uint64 hash64; + const unsigned char* data; + size_t size; + uint32 seed; + }; + + for (Case c : std::vector<Case>{ + {0x471a8188u, 0x4c61ea3eeda4cb87ull, nullptr, 0, 0xbc9f1d34}, + {0xd615eba5u, 0x091309f7ef916c8aull, d1, sizeof(d1), 0xbc9f1d34}, + {0x0c3cccdau, 0xa815bcdf1d1af01cull, d2, sizeof(d2), 0xbc9f1d34}, + {0x3ba37e0eu, 0x02167564e4d06430ull, d3, sizeof(d3), 0xbc9f1d34}, + {0x16174eb3u, 0x8f7ed82ffc21071full, d4, sizeof(d4), 0xbc9f1d34}, + {0x98b1926cu, 0xce196580c97aff1eull, d5, sizeof(d5), 0x12345678}, + }) { + EXPECT_EQ(c.hash32, + Hash32(reinterpret_cast<const char*>(c.data), c.size, c.seed)); + EXPECT_EQ(c.hash64, + Hash64(reinterpret_cast<const char*>(c.data), c.size, c.seed)); + + // Check hashes with inputs aligned differently. + for (int align = 1; align <= 7; align++) { + std::string input(align, 'x'); + input.append(reinterpret_cast<const char*>(c.data), c.size); + EXPECT_EQ(c.hash32, Hash32(&input[align], c.size, c.seed)); + EXPECT_EQ(c.hash64, Hash64(&input[align], c.size, c.seed)); + } + } +} + +static void BM_Hash32(int iters, int len) { + std::string input(len, 'x'); + uint32 h = 0; + for (int i = 0; i < iters; i++) { + h = Hash32(input.data(), len, 1); + } + testing::BytesProcessed(static_cast<int64>(iters) * len); + VLOG(1) << h; +} +BENCHMARK(BM_Hash32)->Range(1, 1024); + +} // namespace tensorflow diff --git a/tensorflow/core/lib/histogram/histogram.cc b/tensorflow/core/lib/histogram/histogram.cc new file mode 100644 index 0000000000..4c29d687b7 --- /dev/null +++ b/tensorflow/core/lib/histogram/histogram.cc @@ -0,0 +1,247 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/histogram/histogram.h" +#include <float.h> +#include <math.h> +#include "tensorflow/core/framework/summary.pb.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +namespace tensorflow { +namespace histogram { + +static std::vector<double>* InitDefaultBucketsInner() { + std::vector<double> buckets; + std::vector<double> neg_buckets; + // Make buckets whose range grows by 10% starting at 1.0e-12 up to 1.0e20 + double v = 1.0e-12; + while (v < 1.0e20) { + buckets.push_back(v); + neg_buckets.push_back(-v); + v *= 1.1; + } + buckets.push_back(DBL_MAX); + neg_buckets.push_back(-DBL_MAX); + std::reverse(neg_buckets.begin(), neg_buckets.end()); + std::vector<double>* result = new std::vector<double>; + result->insert(result->end(), neg_buckets.begin(), neg_buckets.end()); + result->push_back(0.0); + result->insert(result->end(), buckets.begin(), buckets.end()); + return result; +} + +static gtl::ArraySlice<double> InitDefaultBuckets() { + static std::vector<double>* default_bucket_limits = InitDefaultBucketsInner(); + return *default_bucket_limits; +} + +Histogram::Histogram() : bucket_limits_(InitDefaultBuckets()) { Clear(); } + +// Create a histogram with a custom set of bucket limits, +// specified in "custom_buckets[0..custom_buckets.size()-1]" +Histogram::Histogram(gtl::ArraySlice<double> custom_bucket_limits) + : custom_bucket_limits_(custom_bucket_limits.begin(), + custom_bucket_limits.end()), + bucket_limits_(custom_bucket_limits_) { +#ifndef NDEBUG + DCHECK_GT(bucket_limits_.size(), 0); + // Verify that the bucket boundaries are strictly increasing + for (size_t i = 1; i < bucket_limits_.size(); i++) { + DCHECK_GT(bucket_limits_[i], bucket_limits_[i - 1]); + } +#endif + Clear(); +} + +bool Histogram::DecodeFromProto(const HistogramProto& proto) { + if ((proto.bucket_size() != proto.bucket_limit_size()) || + (proto.bucket_size() == 0)) { + return false; + } + min_ = proto.min(); + max_ = proto.max(); + num_ = proto.num(); + sum_ = proto.sum(); + sum_squares_ = proto.sum_squares(); + custom_bucket_limits_.clear(); + custom_bucket_limits_.insert(custom_bucket_limits_.end(), + proto.bucket_limit().begin(), + proto.bucket_limit().end()); + bucket_limits_ = custom_bucket_limits_; + buckets_.clear(); + buckets_.insert(buckets_.end(), proto.bucket().begin(), proto.bucket().end()); + return true; +} + +void Histogram::Clear() { + min_ = bucket_limits_[bucket_limits_.size() - 1]; + max_ = -DBL_MAX; + num_ = 0; + sum_ = 0; + sum_squares_ = 0; + buckets_.resize(bucket_limits_.size()); + for (size_t i = 0; i < bucket_limits_.size(); i++) { + buckets_[i] = 0; + } +} + +void Histogram::Add(double value) { + int b = + std::upper_bound(bucket_limits_.begin(), bucket_limits_.end(), value) - + bucket_limits_.begin(); + + buckets_[b] += 1.0; + if (min_ > value) min_ = value; + if (max_ < value) max_ = value; + num_++; + sum_ += value; + sum_squares_ += (value * value); +} + +double Histogram::Median() const { return Percentile(50.0); } + +double Histogram::Percentile(double p) const { + if (num_ == 0.0) return 0.0; + double threshold = num_ * (p / 100.0); + double sum = 0; + for (size_t b = 0; b < buckets_.size(); b++) { + sum += buckets_[b]; + if (sum >= threshold) { + // Scale linearly within this bucket + double left_point = (b == 0) ? min_ : bucket_limits_[b - 1]; + double right_point = bucket_limits_[b]; + double left_sum = sum - buckets_[b]; + double right_sum = sum; + double pos = (threshold - left_sum) / (right_sum - left_sum); + double r = left_point + (right_point - left_point) * pos; + if (r < min_) r = min_; + if (r > max_) r = max_; + return r; + } + } + return max_; +} + +double Histogram::Average() const { + if (num_ == 0.0) return 0; + return sum_ / num_; +} + +double Histogram::StandardDeviation() const { + if (num_ == 0.0) return 0; + double variance = (sum_squares_ * num_ - sum_ * sum_) / (num_ * num_); + return sqrt(variance); +} + +std::string Histogram::ToString() const { + std::string r; + char buf[200]; + snprintf(buf, sizeof(buf), "Count: %.0f Average: %.4f StdDev: %.2f\n", num_, + Average(), StandardDeviation()); + r.append(buf); + snprintf(buf, sizeof(buf), "Min: %.4f Median: %.4f Max: %.4f\n", + (num_ == 0.0 ? 0.0 : min_), Median(), max_); + r.append(buf); + r.append("------------------------------------------------------\n"); + const double mult = num_ > 0 ? 100.0 / num_ : 0.0; + double sum = 0; + for (size_t b = 0; b < buckets_.size(); b++) { + if (buckets_[b] <= 0.0) continue; + sum += buckets_[b]; + snprintf(buf, sizeof(buf), "[ %10.2g, %10.2g ) %7.0f %7.3f%% %7.3f%% ", + ((b == 0) ? -DBL_MAX : bucket_limits_[b - 1]), // left + bucket_limits_[b], // right + buckets_[b], // count + mult * buckets_[b], // percentage + mult * sum); // cum percentage + r.append(buf); + + // Add hash marks based on percentage; 20 marks for 100%. + int marks = static_cast<int>(20 * (buckets_[b] / num_) + 0.5); + r.append(marks, '#'); + r.push_back('\n'); + } + return r; +} + +void Histogram::EncodeToProto(HistogramProto* proto, + bool preserve_zero_buckets) const { + proto->Clear(); + proto->set_min(min_); + proto->set_max(max_); + proto->set_num(num_); + proto->set_sum(sum_); + proto->set_sum_squares(sum_squares_); + for (size_t i = 0; i < buckets_.size();) { + double end = bucket_limits_[i]; + double count = buckets_[i]; + i++; + if (!preserve_zero_buckets && count <= 0.0) { + // Find run of empty buckets and collapse them into one + while (i < buckets_.size() && buckets_[i] <= 0.0) { + end = bucket_limits_[i]; + count = buckets_[i]; + i++; + } + } + proto->add_bucket_limit(end); + proto->add_bucket(count); + } + if (proto->bucket_size() == 0.0) { + // It's easier when we restore if we always have at least one bucket entry + proto->add_bucket_limit(DBL_MAX); + proto->add_bucket(0.0); + } +} + +// ThreadSafeHistogram implementation. +bool ThreadSafeHistogram::DecodeFromProto(const HistogramProto& proto) { + mutex_lock l(mu_); + return histogram_.DecodeFromProto(proto); +} + +void ThreadSafeHistogram::Clear() { + mutex_lock l(mu_); + histogram_.Clear(); +} + +void ThreadSafeHistogram::Add(double value) { + mutex_lock l(mu_); + histogram_.Add(value); +} + +void ThreadSafeHistogram::EncodeToProto(HistogramProto* proto, + bool preserve_zero_buckets) const { + mutex_lock l(mu_); + histogram_.EncodeToProto(proto, preserve_zero_buckets); +} + +double ThreadSafeHistogram::Median() const { + mutex_lock l(mu_); + return histogram_.Median(); +} + +double ThreadSafeHistogram::Percentile(double p) const { + mutex_lock l(mu_); + return histogram_.Percentile(p); +} + +double ThreadSafeHistogram::Average() const { + mutex_lock l(mu_); + return histogram_.Average(); +} + +double ThreadSafeHistogram::StandardDeviation() const { + mutex_lock l(mu_); + return histogram_.StandardDeviation(); +} + +std::string ThreadSafeHistogram::ToString() const { + mutex_lock l(mu_); + return histogram_.ToString(); +} + +} // namespace histogram +} // namespace tensorflow diff --git a/tensorflow/core/lib/histogram/histogram.h b/tensorflow/core/lib/histogram/histogram.h new file mode 100644 index 0000000000..9b655f3acb --- /dev/null +++ b/tensorflow/core/lib/histogram/histogram.h @@ -0,0 +1,119 @@ +#ifndef TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_ +#define TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_ + +#include <string> +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +class HistogramProto; + +namespace histogram { + +class Histogram { + public: + // Create a histogram with a default set of bucket boundaries. + // Buckets near zero cover very small ranges (e.g. 10^-12), and each + // bucket range grows by ~10% as we head away from zero. The + // buckets cover the range from -DBL_MAX to DBL_MAX. + Histogram(); + + // Create a histogram with a custom set of bucket boundaries, + // specified in "custom_bucket_limits[0..custom_bucket_limits.size()-1]" + // REQUIRES: custom_bucket_limits[i] values are monotonically increasing. + // REQUIRES: custom_bucket_limits is not empty() + explicit Histogram(gtl::ArraySlice<double> custom_bucket_limits); + + // Restore the state of a histogram that was previously encoded + // via Histogram::EncodeToProto. Note that only the bucket boundaries + // generated by EncodeToProto will be restored. + bool DecodeFromProto(const HistogramProto& proto); + + ~Histogram() {} + + void Clear(); + void Add(double value); + + // Save the current state of the histogram to "*proto". If + // "preserve_zero_buckets" is false, only non-zero bucket values and + // ranges are saved, and the bucket boundaries of zero-valued buckets + // are lost. + void EncodeToProto(HistogramProto* proto, bool preserve_zero_buckets) const; + + // Return the median of the values in the histogram + double Median() const; + + // Return the "p"th percentile [0.0..100.0] of the values in the + // distribution + double Percentile(double p) const; + + // Return the average value of the distribution + double Average() const; + + // Return the standard deviation of values in the distribution + double StandardDeviation() const; + + // Returns a multi-line human-readable string representing the histogram + // contents. Example output: + // Count: 4 Average: 251.7475 StdDev: 432.02 + // Min: -3.0000 Median: 5.0000 Max: 1000.0000 + // ------------------------------------------------------ + // [ -5, 0 ) 1 25.000% 25.000% ##### + // [ 0, 5 ) 1 25.000% 50.000% ##### + // [ 5, 10 ) 1 25.000% 75.000% ##### + // [ 1000, 10000 ) 1 25.000% 100.000% ##### + std::string ToString() const; + + private: + double min_; + double max_; + double num_; + double sum_; + double sum_squares_; + + std::vector<double> custom_bucket_limits_; + gtl::ArraySlice<double> bucket_limits_; + std::vector<double> buckets_; + + TF_DISALLOW_COPY_AND_ASSIGN(Histogram); +}; + +// Wrapper around a Histogram object that is thread safe. +// +// All methods hold a lock while delegating to a Histogram object owned by the +// ThreadSafeHistogram instance. +// +// See Histogram for documentation of the methods. +class ThreadSafeHistogram { + public: + ThreadSafeHistogram() {} + explicit ThreadSafeHistogram(gtl::ArraySlice<double> custom_bucket_limits) + : histogram_(custom_bucket_limits) {} + bool DecodeFromProto(const HistogramProto& proto); + + ~ThreadSafeHistogram() {} + + void Clear(); + + // TODO(mdevin): It might be a good idea to provide a AddN(<many values>) + // method to avoid grabbing/releasing the lock when adding many values. + void Add(double value); + + void EncodeToProto(HistogramProto* proto, bool preserve_zero_buckets) const; + double Median() const; + double Percentile(double p) const; + double Average() const; + double StandardDeviation() const; + std::string ToString() const; + + private: + mutable mutex mu_; + Histogram histogram_ GUARDED_BY(mu_); +}; + +} // namespace histogram +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_ diff --git a/tensorflow/core/lib/histogram/histogram_test.cc b/tensorflow/core/lib/histogram/histogram_test.cc new file mode 100644 index 0000000000..ede44fe85b --- /dev/null +++ b/tensorflow/core/lib/histogram/histogram_test.cc @@ -0,0 +1,112 @@ +#include "tensorflow/core/lib/histogram/histogram.h" +#include <float.h> +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/framework/summary.pb.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace histogram { + +static void Validate(const Histogram& h) { + string s1 = h.ToString(); + LOG(ERROR) << s1; + + HistogramProto proto_with_zeroes; + h.EncodeToProto(&proto_with_zeroes, true); + Histogram h2; + EXPECT_TRUE(h2.DecodeFromProto(proto_with_zeroes)); + string s2 = h2.ToString(); + LOG(ERROR) << s2; + + EXPECT_EQ(s1, s2); + + HistogramProto proto_no_zeroes; + h.EncodeToProto(&proto_no_zeroes, false); + LOG(ERROR) << proto_no_zeroes.DebugString(); + Histogram h3; + EXPECT_TRUE(h3.DecodeFromProto(proto_no_zeroes)); + string s3 = h3.ToString(); + LOG(ERROR) << s3; + + EXPECT_EQ(s1, s3); +} + +TEST(Histogram, Empty) { + Histogram h; + Validate(h); +} + +TEST(Histogram, SingleValue) { + Histogram h; + h.Add(-3.0); + Validate(h); +} + +TEST(Histogram, CustomBuckets) { + Histogram h({-10, -5, 0, 5, 10, 100, 1000, 10000, DBL_MAX}); + h.Add(-3.0); + h.Add(4.99); + h.Add(5.0); + h.Add(1000.0); + Validate(h); +} + +TEST(Histogram, Percentile) { + Histogram h({0, 10, 100, DBL_MAX}); + h.Add(-2); + h.Add(-2); + h.Add(0); + double median = h.Percentile(50.0); + EXPECT_EQ(median, -0.5); +} + +TEST(Histogram, Basic) { + Histogram h; + for (int i = 0; i < 100; i++) { + h.Add(i); + } + for (int i = 1000; i < 100000; i += 1000) { + h.Add(i); + } + Validate(h); +} + +TEST(ThreadSafeHistogram, Basic) { + // Fill a normal histogram. + Histogram h; + for (int i = 0; i < 100; i++) { + h.Add(i); + } + + // Fill a thread-safe histogram with the same values. + ThreadSafeHistogram tsh; + for (int i = 0; i < 100; i++) { + tsh.Add(i); + } + + for (int i = 0; i < 2; ++i) { + bool preserve_zero_buckets = (i == 0); + HistogramProto h_proto; + h.EncodeToProto(&h_proto, preserve_zero_buckets); + HistogramProto tsh_proto; + tsh.EncodeToProto(&tsh_proto, preserve_zero_buckets); + + // Let's decode from the proto of the other histogram type. + Histogram h2; + EXPECT_TRUE(h2.DecodeFromProto(tsh_proto)); + ThreadSafeHistogram tsh2; + EXPECT_TRUE(tsh2.DecodeFromProto(h_proto)); + + // Now let's reencode and check they match. + EXPECT_EQ(h2.ToString(), tsh2.ToString()); + } + + EXPECT_EQ(h.Median(), tsh.Median()); + EXPECT_EQ(h.Percentile(40.0), tsh.Percentile(40.0)); + EXPECT_EQ(h.Average(), tsh.Average()); + EXPECT_EQ(h.StandardDeviation(), tsh.StandardDeviation()); + EXPECT_EQ(h.ToString(), tsh.ToString()); +} + +} // namespace histogram +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block.cc b/tensorflow/core/lib/io/block.cc new file mode 100644 index 0000000000..1ddaa2eb78 --- /dev/null +++ b/tensorflow/core/lib/io/block.cc @@ -0,0 +1,236 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// Decodes the blocks generated by block_builder.cc. + +#include "tensorflow/core/lib/io/block.h" + +#include <vector> +#include <algorithm> +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +inline uint32 Block::NumRestarts() const { + assert(size_ >= sizeof(uint32)); + return core::DecodeFixed32(data_ + size_ - sizeof(uint32)); +} + +Block::Block(const BlockContents& contents) + : data_(contents.data.data()), + size_(contents.data.size()), + owned_(contents.heap_allocated) { + if (size_ < sizeof(uint32)) { + size_ = 0; // Error marker + } else { + size_t max_restarts_allowed = (size_ - sizeof(uint32)) / sizeof(uint32); + if (NumRestarts() > max_restarts_allowed) { + // The size is too small for NumRestarts() + size_ = 0; + } else { + restart_offset_ = size_ - (1 + NumRestarts()) * sizeof(uint32); + } + } +} + +Block::~Block() { + if (owned_) { + delete[] data_; + } +} + +// Helper routine: decode the next block entry starting at "p", +// storing the number of shared key bytes, non_shared key bytes, +// and the length of the value in "*shared", "*non_shared", and +// "*value_length", respectively. Will not dereference past "limit". +// +// If any errors are detected, returns NULL. Otherwise, returns a +// pointer to the key delta (just past the three decoded values). +static inline const char* DecodeEntry(const char* p, const char* limit, + uint32* shared, uint32* non_shared, + uint32* value_length) { + if (limit - p < 3) return NULL; + *shared = reinterpret_cast<const unsigned char*>(p)[0]; + *non_shared = reinterpret_cast<const unsigned char*>(p)[1]; + *value_length = reinterpret_cast<const unsigned char*>(p)[2]; + if ((*shared | *non_shared | *value_length) < 128) { + // Fast path: all three values are encoded in one byte each + p += 3; + } else { + if ((p = core::GetVarint32Ptr(p, limit, shared)) == NULL) return NULL; + if ((p = core::GetVarint32Ptr(p, limit, non_shared)) == NULL) return NULL; + if ((p = core::GetVarint32Ptr(p, limit, value_length)) == NULL) return NULL; + } + + if (static_cast<uint32>(limit - p) < (*non_shared + *value_length)) { + return NULL; + } + return p; +} + +class Block::Iter : public Iterator { + private: + const char* const data_; // underlying block contents + uint32 const restarts_; // Offset of restart array (list of fixed32) + uint32 const num_restarts_; // Number of uint32 entries in restart array + + // current_ is offset in data_ of current entry. >= restarts_ if !Valid + uint32 current_; + uint32 restart_index_; // Index of restart block in which current_ falls + string key_; + StringPiece value_; + Status status_; + + inline int Compare(const StringPiece& a, const StringPiece& b) const { + return a.compare(b); + } + + // Return the offset in data_ just past the end of the current entry. + inline uint32 NextEntryOffset() const { + return (value_.data() + value_.size()) - data_; + } + + uint32 GetRestartPoint(uint32 index) { + assert(index < num_restarts_); + return core::DecodeFixed32(data_ + restarts_ + index * sizeof(uint32)); + } + + void SeekToRestartPoint(uint32 index) { + key_.clear(); + restart_index_ = index; + // current_ will be fixed by ParseNextKey(); + + // ParseNextKey() starts at the end of value_, so set value_ accordingly + uint32 offset = GetRestartPoint(index); + value_ = StringPiece(data_ + offset, 0); + } + + public: + Iter(const char* data, uint32 restarts, uint32 num_restarts) + : data_(data), + restarts_(restarts), + num_restarts_(num_restarts), + current_(restarts_), + restart_index_(num_restarts_) { + assert(num_restarts_ > 0); + } + + virtual bool Valid() const { return current_ < restarts_; } + virtual Status status() const { return status_; } + virtual StringPiece key() const { + assert(Valid()); + return key_; + } + virtual StringPiece value() const { + assert(Valid()); + return value_; + } + + virtual void Next() { + assert(Valid()); + ParseNextKey(); + } + + virtual void Seek(const StringPiece& target) { + // Binary search in restart array to find the last restart point + // with a key < target + uint32 left = 0; + uint32 right = num_restarts_ - 1; + while (left < right) { + uint32 mid = (left + right + 1) / 2; + uint32 region_offset = GetRestartPoint(mid); + uint32 shared, non_shared, value_length; + const char* key_ptr = + DecodeEntry(data_ + region_offset, data_ + restarts_, &shared, + &non_shared, &value_length); + if (key_ptr == NULL || (shared != 0)) { + CorruptionError(); + return; + } + StringPiece mid_key(key_ptr, non_shared); + if (Compare(mid_key, target) < 0) { + // Key at "mid" is smaller than "target". Therefore all + // blocks before "mid" are uninteresting. + left = mid; + } else { + // Key at "mid" is >= "target". Therefore all blocks at or + // after "mid" are uninteresting. + right = mid - 1; + } + } + + // Linear search (within restart block) for first key >= target + SeekToRestartPoint(left); + while (true) { + if (!ParseNextKey()) { + return; + } + if (Compare(key_, target) >= 0) { + return; + } + } + } + + virtual void SeekToFirst() { + SeekToRestartPoint(0); + ParseNextKey(); + } + + private: + void CorruptionError() { + current_ = restarts_; + restart_index_ = num_restarts_; + status_ = errors::DataLoss("bad entry in block"); + key_.clear(); + value_.clear(); + } + + bool ParseNextKey() { + current_ = NextEntryOffset(); + const char* p = data_ + current_; + const char* limit = data_ + restarts_; // Restarts come right after data + if (p >= limit) { + // No more entries to return. Mark as invalid. + current_ = restarts_; + restart_index_ = num_restarts_; + return false; + } + + // Decode next entry + uint32 shared, non_shared, value_length; + p = DecodeEntry(p, limit, &shared, &non_shared, &value_length); + if (p == NULL || key_.size() < shared) { + CorruptionError(); + return false; + } else { + key_.resize(shared); + key_.append(p, non_shared); + value_ = StringPiece(p + non_shared, value_length); + while (restart_index_ + 1 < num_restarts_ && + GetRestartPoint(restart_index_ + 1) < current_) { + ++restart_index_; + } + return true; + } + } +}; + +Iterator* Block::NewIterator() { + if (size_ < sizeof(uint32)) { + return NewErrorIterator(errors::DataLoss("bad block contents")); + } + const uint32 num_restarts = NumRestarts(); + if (num_restarts == 0) { + return NewEmptyIterator(); + } else { + return new Iter(data_, restart_offset_, num_restarts); + } +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block.h b/tensorflow/core/lib/io/block.h new file mode 100644 index 0000000000..bf53245b8d --- /dev/null +++ b/tensorflow/core/lib/io/block.h @@ -0,0 +1,45 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_BLOCK_H_ +#define TENSORFLOW_LIB_IO_BLOCK_H_ + +#include <stddef.h> +#include <stdint.h> +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +struct BlockContents; + +class Block { + public: + // Initialize the block with the specified contents. + explicit Block(const BlockContents& contents); + + ~Block(); + + size_t size() const { return size_; } + Iterator* NewIterator(); + + private: + uint32 NumRestarts() const; + + const char* data_; + size_t size_; + uint32 restart_offset_; // Offset in data_ of restart array + bool owned_; // Block owns data_[] + + // No copying allowed + Block(const Block&); + void operator=(const Block&); + + class Iter; +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_BLOCK_H_ diff --git a/tensorflow/core/lib/io/block_builder.cc b/tensorflow/core/lib/io/block_builder.cc new file mode 100644 index 0000000000..d94048d744 --- /dev/null +++ b/tensorflow/core/lib/io/block_builder.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// BlockBuilder generates blocks where keys are prefix-compressed: +// +// When we store a key, we drop the prefix shared with the previous +// string. This helps reduce the space requirement significantly. +// Furthermore, once every K keys, we do not apply the prefix +// compression and store the entire key. We call this a "restart +// point". The tail end of the block stores the offsets of all of the +// restart points, and can be used to do a binary search when looking +// for a particular key. Values are stored as-is (without compression) +// immediately following the corresponding key. +// +// An entry for a particular key-value pair has the form: +// shared_bytes: varint32 +// unshared_bytes: varint32 +// value_length: varint32 +// key_delta: char[unshared_bytes] +// value: char[value_length] +// shared_bytes == 0 for restart points. +// +// The trailer of the block has the form: +// restarts: uint32[num_restarts] +// num_restarts: uint32 +// restarts[i] contains the offset within the block of the ith restart point. + +#include "tensorflow/core/lib/io/block_builder.h" + +#include <algorithm> +#include <assert.h> +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/core/coding.h" + +namespace tensorflow { +namespace table { + +BlockBuilder::BlockBuilder(const Options* options) + : options_(options), restarts_(), counter_(0), finished_(false) { + assert(options->block_restart_interval >= 1); + restarts_.push_back(0); // First restart point is at offset 0 +} + +void BlockBuilder::Reset() { + buffer_.clear(); + restarts_.clear(); + restarts_.push_back(0); // First restart point is at offset 0 + counter_ = 0; + finished_ = false; + last_key_.clear(); +} + +size_t BlockBuilder::CurrentSizeEstimate() const { + return (buffer_.size() + // Raw data buffer + restarts_.size() * sizeof(uint32) + // Restart array + sizeof(uint32)); // Restart array length +} + +StringPiece BlockBuilder::Finish() { + // Append restart array + for (size_t i = 0; i < restarts_.size(); i++) { + core::PutFixed32(&buffer_, restarts_[i]); + } + core::PutFixed32(&buffer_, restarts_.size()); + finished_ = true; + return StringPiece(buffer_); +} + +void BlockBuilder::Add(const StringPiece& key, const StringPiece& value) { + StringPiece last_key_piece(last_key_); + assert(!finished_); + assert(counter_ <= options_->block_restart_interval); + assert(buffer_.empty() // No values yet? + || key.compare(last_key_piece) > 0); + size_t shared = 0; + if (counter_ < options_->block_restart_interval) { + // See how much sharing to do with previous string + const size_t min_length = std::min(last_key_piece.size(), key.size()); + while ((shared < min_length) && (last_key_piece[shared] == key[shared])) { + shared++; + } + } else { + // Restart compression + restarts_.push_back(buffer_.size()); + counter_ = 0; + } + const size_t non_shared = key.size() - shared; + + // Add "<shared><non_shared><value_size>" to buffer_ + core::PutVarint32(&buffer_, shared); + core::PutVarint32(&buffer_, non_shared); + core::PutVarint32(&buffer_, value.size()); + + // Add string delta to buffer_ followed by value + buffer_.append(key.data() + shared, non_shared); + buffer_.append(value.data(), value.size()); + + // Update state + last_key_.resize(shared); + last_key_.append(key.data() + shared, non_shared); + assert(StringPiece(last_key_) == key); + counter_++; +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h new file mode 100644 index 0000000000..e07a647805 --- /dev/null +++ b/tensorflow/core/lib/io/block_builder.h @@ -0,0 +1,57 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ +#define TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ + +#include <vector> + +#include <stdint.h> +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace table { + +struct Options; + +class BlockBuilder { + public: + explicit BlockBuilder(const Options* options); + + // Reset the contents as if the BlockBuilder was just constructed. + void Reset(); + + // REQUIRES: Finish() has not been called since the last call to Reset(). + // REQUIRES: key is larger than any previously added key + void Add(const StringPiece& key, const StringPiece& value); + + // Finish building the block and return a slice that refers to the + // block contents. The returned slice will remain valid for the + // lifetime of this builder or until Reset() is called. + StringPiece Finish(); + + // Returns an estimate of the current (uncompressed) size of the block + // we are building. + size_t CurrentSizeEstimate() const; + + // Return true iff no entries have been added since the last Reset() + bool empty() const { return buffer_.empty(); } + + private: + const Options* options_; + string buffer_; // Destination buffer + std::vector<uint32> restarts_; // Restart points + int counter_; // Number of entries emitted since restart + bool finished_; // Has Finish() been called? + string last_key_; + + // No copying allowed + BlockBuilder(const BlockBuilder&); + void operator=(const BlockBuilder&); +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ diff --git a/tensorflow/core/lib/io/format.cc b/tensorflow/core/lib/io/format.cc new file mode 100644 index 0000000000..259cfc13dc --- /dev/null +++ b/tensorflow/core/lib/io/format.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/format.h" + +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +void BlockHandle::EncodeTo(string* dst) const { + // Sanity check that all fields have been set + assert(offset_ != ~static_cast<uint64>(0)); + assert(size_ != ~static_cast<uint64>(0)); + core::PutVarint64(dst, offset_); + core::PutVarint64(dst, size_); +} + +Status BlockHandle::DecodeFrom(StringPiece* input) { + if (core::GetVarint64(input, &offset_) && core::GetVarint64(input, &size_)) { + return Status::OK(); + } else { + return errors::DataLoss("bad block handle"); + } +} + +void Footer::EncodeTo(string* dst) const { +#ifndef NDEBUG + const size_t original_size = dst->size(); +#endif + metaindex_handle_.EncodeTo(dst); + index_handle_.EncodeTo(dst); + dst->resize(2 * BlockHandle::kMaxEncodedLength); // Padding + core::PutFixed32(dst, static_cast<uint32>(kTableMagicNumber & 0xffffffffu)); + core::PutFixed32(dst, static_cast<uint32>(kTableMagicNumber >> 32)); + assert(dst->size() == original_size + kEncodedLength); +} + +Status Footer::DecodeFrom(StringPiece* input) { + const char* magic_ptr = input->data() + kEncodedLength - 8; + const uint32 magic_lo = core::DecodeFixed32(magic_ptr); + const uint32 magic_hi = core::DecodeFixed32(magic_ptr + 4); + const uint64 magic = + ((static_cast<uint64>(magic_hi) << 32) | (static_cast<uint64>(magic_lo))); + if (magic != kTableMagicNumber) { + return errors::DataLoss("not an sstable (bad magic number)"); + } + + Status result = metaindex_handle_.DecodeFrom(input); + if (result.ok()) { + result = index_handle_.DecodeFrom(input); + } + if (result.ok()) { + // We skip over any leftover data (just padding for now) in "input" + const char* end = magic_ptr + 8; + *input = StringPiece(end, input->data() + input->size() - end); + } + return result; +} + +Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result) { + result->data = StringPiece(); + result->cachable = false; + result->heap_allocated = false; + + // Read the block contents as well as the type/crc footer. + // See table_builder.cc for the code that built this structure. + size_t n = static_cast<size_t>(handle.size()); + char* buf = new char[n + kBlockTrailerSize]; + StringPiece contents; + Status s = + file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); + if (!s.ok()) { + delete[] buf; + return s; + } + if (contents.size() != n + kBlockTrailerSize) { + delete[] buf; + return errors::DataLoss("truncated block read"); + } + + // Check the crc of the type and the block contents + const char* data = contents.data(); // Pointer to where Read put the data + // This checksum verification is optional. We leave it on for now + const bool verify_checksum = true; + if (verify_checksum) { + const uint32 crc = crc32c::Unmask(core::DecodeFixed32(data + n + 1)); + const uint32 actual = crc32c::Value(data, n + 1); + if (actual != crc) { + delete[] buf; + s = errors::DataLoss("block checksum mismatch"); + return s; + } + } + + switch (data[n]) { + case kNoCompression: + if (data != buf) { + // File implementation gave us pointer to some other data. + // Use it directly under the assumption that it will be live + // while the file is open. + delete[] buf; + result->data = StringPiece(data, n); + result->heap_allocated = false; + result->cachable = false; // Do not double-cache + } else { + result->data = StringPiece(buf, n); + result->heap_allocated = true; + result->cachable = true; + } + + // Ok + break; + case kSnappyCompression: { + size_t ulength = 0; + if (!port::Snappy_GetUncompressedLength(data, n, &ulength)) { + delete[] buf; + return errors::DataLoss("corrupted compressed block contents"); + } + char* ubuf = new char[ulength]; + if (!port::Snappy_Uncompress(data, n, ubuf)) { + delete[] buf; + delete[] ubuf; + return errors::DataLoss("corrupted compressed block contents"); + } + delete[] buf; + result->data = StringPiece(ubuf, ulength); + result->heap_allocated = true; + result->cachable = true; + break; + } + default: + delete[] buf; + return errors::DataLoss("bad block type"); + } + + return Status::OK(); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/format.h b/tensorflow/core/lib/io/format.h new file mode 100644 index 0000000000..3121c41bb8 --- /dev/null +++ b/tensorflow/core/lib/io/format.h @@ -0,0 +1,99 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_FORMAT_H_ +#define TENSORFLOW_LIB_IO_FORMAT_H_ + +#include <string> +#include <stdint.h> +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/io/table_builder.h" + +namespace tensorflow { +class RandomAccessFile; +namespace table { + +class Block; + +// BlockHandle is a pointer to the extent of a file that stores a data +// block or a meta block. +class BlockHandle { + public: + BlockHandle(); + + // The offset of the block in the file. + uint64 offset() const { return offset_; } + void set_offset(uint64 offset) { offset_ = offset; } + + // The size of the stored block + uint64 size() const { return size_; } + void set_size(uint64 size) { size_ = size; } + + void EncodeTo(string* dst) const; + Status DecodeFrom(StringPiece* input); + + // Maximum encoding length of a BlockHandle + enum { kMaxEncodedLength = 10 + 10 }; + + private: + uint64 offset_; + uint64 size_; +}; + +// Footer encapsulates the fixed information stored at the tail +// end of every table file. +class Footer { + public: + Footer() {} + + // The block handle for the metaindex block of the table + const BlockHandle& metaindex_handle() const { return metaindex_handle_; } + void set_metaindex_handle(const BlockHandle& h) { metaindex_handle_ = h; } + + // The block handle for the index block of the table + const BlockHandle& index_handle() const { return index_handle_; } + void set_index_handle(const BlockHandle& h) { index_handle_ = h; } + + void EncodeTo(string* dst) const; + Status DecodeFrom(StringPiece* input); + + // Encoded length of a Footer. Note that the serialization of a + // Footer will always occupy exactly this many bytes. It consists + // of two block handles and a magic number. + enum { kEncodedLength = 2 * BlockHandle::kMaxEncodedLength + 8 }; + + private: + BlockHandle metaindex_handle_; + BlockHandle index_handle_; +}; + +// kTableMagicNumber was picked by running +// echo http://code.google.com/p/leveldb/ | sha1sum +// and taking the leading 64 bits. +static const uint64 kTableMagicNumber = 0xdb4775248b80fb57ull; + +// 1-byte type + 32-bit crc +static const size_t kBlockTrailerSize = 5; + +struct BlockContents { + StringPiece data; // Actual contents of data + bool cachable; // True iff data can be cached + bool heap_allocated; // True iff caller should delete[] data.data() +}; + +// Read the block identified by "handle" from "file". On failure +// return non-OK. On success fill *result and return OK. +extern Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result); + +// Implementation details follow. Clients should ignore, + +inline BlockHandle::BlockHandle() + : offset_(~static_cast<uint64>(0)), size_(~static_cast<uint64>(0)) {} + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_FORMAT_H_ diff --git a/tensorflow/core/lib/io/inputbuffer.cc b/tensorflow/core/lib/io/inputbuffer.cc new file mode 100644 index 0000000000..8fa245a546 --- /dev/null +++ b/tensorflow/core/lib/io/inputbuffer.cc @@ -0,0 +1,112 @@ +#include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace io { + +InputBuffer::InputBuffer(RandomAccessFile* file, size_t buffer_bytes) + : file_(file), + file_pos_(0), + size_(buffer_bytes), + buf_(new char[size_]), + pos_(buf_), + limit_(buf_) {} + +InputBuffer::~InputBuffer() { + delete file_; + delete[] buf_; +} + +Status InputBuffer::FillBuffer() { + StringPiece data; + Status s = file_->Read(file_pos_, size_, &data, buf_); + if (data.data() != buf_) { + memmove(buf_, data.data(), data.size()); + } + pos_ = buf_; + limit_ = pos_ + data.size(); + file_pos_ += data.size(); + return s; +} + +Status InputBuffer::ReadLine(string* result) { + result->clear(); + int i; + Status s; + for (i = 0;; i++) { + if (pos_ == limit_) { + // Get more data into buffer + s = FillBuffer(); + if (limit_ == buf_) { + break; + } + } + char c = *pos_++; + if (c == '\n') { + // We don't append the '\n' to *result + return Status::OK(); + } + *result += c; + } + if (errors::IsOutOfRange(s) && !result->empty()) { + return Status::OK(); + } + return s; +} + +Status InputBuffer::ReadNBytes(int64 bytes_to_read, string* result) { + result->clear(); + if (bytes_to_read < 0) { + return errors::InvalidArgument("Can't read a negative number of bytes: ", + bytes_to_read); + } + result->reserve(bytes_to_read); + Status s; + while (result->size() < static_cast<size_t>(bytes_to_read)) { + if (pos_ == limit_) { + // Get more data into buffer + s = FillBuffer(); + if (limit_ == buf_) { + break; + } + } + const int64 bytes_to_copy = + std::min<int64>(limit_ - pos_, bytes_to_read - result->size()); + result->insert(result->size(), pos_, bytes_to_copy); + pos_ += bytes_to_copy; + } + if (errors::IsOutOfRange(s) && + (result->size() == static_cast<size_t>(bytes_to_read))) { + return Status::OK(); + } + return s; +} + +Status InputBuffer::SkipNBytes(int64 bytes_to_skip) { + if (bytes_to_skip < 0) { + return errors::InvalidArgument("Can only skip forward, not ", + bytes_to_skip); + } + int64 bytes_skipped = 0; + Status s; + while (bytes_skipped < bytes_to_skip) { + if (pos_ == limit_) { + // Get more data into buffer + s = FillBuffer(); + if (limit_ == buf_) { + break; + } + } + const int64 bytes_to_advance = + std::min<int64>(limit_ - pos_, bytes_to_skip - bytes_skipped); + bytes_skipped += bytes_to_advance; + pos_ += bytes_to_advance; + } + if (errors::IsOutOfRange(s) && bytes_skipped == bytes_to_skip) { + return Status::OK(); + } + return s; +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/inputbuffer.h b/tensorflow/core/lib/io/inputbuffer.h new file mode 100644 index 0000000000..6879f30567 --- /dev/null +++ b/tensorflow/core/lib/io/inputbuffer.h @@ -0,0 +1,62 @@ +#ifndef TENSORFLOW_LIB_IO_INPUTBUFFER_H_ +#define TENSORFLOW_LIB_IO_INPUTBUFFER_H_ + +#include <string> +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace io { + +// An InputBuffer provides a buffer on top of a RandomAccessFile. +// A given instance of an InputBuffer is NOT safe for concurrent use +// by multiple threads +class InputBuffer { + public: + // Create an InputBuffer for "file" with a buffer size of + // "buffer_bytes" bytes. Takes ownership of "file" and will + // delete it when the InputBuffer is destroyed. + InputBuffer(RandomAccessFile* file, size_t buffer_bytes); + ~InputBuffer(); + + // Read one text line of data into "*result" until end-of-file or a + // \n is read. (The \n is not included in the result.) Overwrites + // any existing data in *result. + // + // If successful, returns OK. If we are already at the end of the + // file, we return an OUT_OF_RANGE error. Otherwise, we return + // some other non-OK status. + Status ReadLine(string* result); + + // Reads bytes_to_read bytes into *result, overwriting *result. + // + // If successful, returns OK. If we there are not enough bytes to + // read before the end of the file, we return an OUT_OF_RANGE error. + // Otherwise, we return some other non-OK status. + Status ReadNBytes(int64 bytes_to_read, string* result); + + // Like ReadNBytes() without returning the bytes read. + Status SkipNBytes(int64 bytes_to_skip); + + // Returns the position in the file. + int64 Tell() const { return file_pos_ - (limit_ - pos_); } + + private: + Status FillBuffer(); + + RandomAccessFile* file_; // Owned + int64 file_pos_; // Next position to read from in "file_" + size_t size_; // Size of "buf_" + char* buf_; // The buffer itself + // [pos_,limit_) hold the "limit_ - pos_" bytes just before "file_pos_" + char* pos_; // Current position in "buf" + char* limit_; // Just past end of valid data in "buf" + + TF_DISALLOW_COPY_AND_ASSIGN(InputBuffer); +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_INPUTBUFFER_H_ diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc new file mode 100644 index 0000000000..34094f018c --- /dev/null +++ b/tensorflow/core/lib/io/inputbuffer_test.cc @@ -0,0 +1,174 @@ +#include "tensorflow/core/lib/io/inputbuffer.h" + +#include "tensorflow/core/public/env.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include <gtest/gtest.h> +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +static std::vector<int> BufferSizes() { + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 65536}; +} + +TEST(InputBuffer, ReadLine_Empty) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, ""); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadLine1) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\nline two\nline three\n"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadLine_NoTrailingNewLine) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\nline two\nline three"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadLine_EmptyLines) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\n\n\nline two\nline three"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, ""); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, ""); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadNBytes) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "0123456789"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string read; + io::InputBuffer in(file, buf_size); + EXPECT_EQ(0, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(3, &read)); + EXPECT_EQ(read, "012"); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(4, &read)); + EXPECT_EQ(read, "3456"); + EXPECT_EQ(7, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + EXPECT_EQ(7, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read))); + EXPECT_EQ(read, "789"); + EXPECT_EQ(10, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read))); + EXPECT_EQ(read, ""); + EXPECT_EQ(10, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + EXPECT_EQ(10, in.Tell()); + } +} + +TEST(InputBuffer, SkipNBytes) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "0123456789"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string read; + io::InputBuffer in(file, buf_size); + EXPECT_EQ(0, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(3)); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(0)); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(2, &read)); + EXPECT_EQ(read, "34"); + EXPECT_EQ(5, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(0)); + EXPECT_EQ(5, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(2)); + EXPECT_EQ(7, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(1, &read)); + EXPECT_EQ(read, "7"); + EXPECT_EQ(8, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5))); + EXPECT_EQ(10, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5))); + EXPECT_EQ(10, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read))); + EXPECT_EQ(read, ""); + EXPECT_EQ(10, in.Tell()); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/iterator.cc b/tensorflow/core/lib/io/iterator.cc new file mode 100644 index 0000000000..878e93a911 --- /dev/null +++ b/tensorflow/core/lib/io/iterator.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +Iterator::Iterator() { + cleanup_.function = NULL; + cleanup_.next = NULL; +} + +Iterator::~Iterator() { + if (cleanup_.function != NULL) { + (*cleanup_.function)(cleanup_.arg1, cleanup_.arg2); + for (Cleanup* c = cleanup_.next; c != NULL;) { + (*c->function)(c->arg1, c->arg2); + Cleanup* next = c->next; + delete c; + c = next; + } + } +} + +void Iterator::RegisterCleanup(CleanupFunction func, void* arg1, void* arg2) { + assert(func != NULL); + Cleanup* c; + if (cleanup_.function == NULL) { + c = &cleanup_; + } else { + c = new Cleanup; + c->next = cleanup_.next; + cleanup_.next = c; + } + c->function = func; + c->arg1 = arg1; + c->arg2 = arg2; +} + +namespace { +class EmptyIterator : public Iterator { + public: + EmptyIterator(const Status& s) : status_(s) {} + virtual bool Valid() const { return false; } + virtual void Seek(const StringPiece& target) {} + virtual void SeekToFirst() {} + virtual void Next() { assert(false); } + StringPiece key() const { + assert(false); + return StringPiece(); + } + StringPiece value() const { + assert(false); + return StringPiece(); + } + virtual Status status() const { return status_; } + + private: + Status status_; +}; +} // namespace + +Iterator* NewEmptyIterator() { return new EmptyIterator(Status::OK()); } + +Iterator* NewErrorIterator(const Status& status) { + return new EmptyIterator(status); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/iterator.h b/tensorflow/core/lib/io/iterator.h new file mode 100644 index 0000000000..603a2f95fe --- /dev/null +++ b/tensorflow/core/lib/io/iterator.h @@ -0,0 +1,93 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// An iterator yields a sequence of key/value pairs from a source. +// The following class defines the interface. Multiple implementations +// are provided by this library. In particular, iterators are provided +// to access the contents of a Table or a DB. +// +// Multiple threads can invoke const methods on an Iterator without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same Iterator must use +// external synchronization. + +#ifndef TENSORFLOW_LIB_IO_ITERATOR_H_ +#define TENSORFLOW_LIB_IO_ITERATOR_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace table { + +class Iterator { + public: + Iterator(); + virtual ~Iterator(); + + // An iterator is either positioned at a key/value pair, or + // not valid. This method returns true iff the iterator is valid. + virtual bool Valid() const = 0; + + // Position at the first key in the source. The iterator is Valid() + // after this call iff the source is not empty. + virtual void SeekToFirst() = 0; + + // Position at the first key in the source that is at or past target. + // The iterator is Valid() after this call iff the source contains + // an entry that comes at or past target. + virtual void Seek(const StringPiece& target) = 0; + + // Moves to the next entry in the source. After this call, Valid() is + // true iff the iterator was not positioned at the last entry in the source. + // REQUIRES: Valid() + virtual void Next() = 0; + + // Return the key for the current entry. The underlying storage for + // the returned slice is valid only until the next modification of + // the iterator. + // REQUIRES: Valid() + virtual StringPiece key() const = 0; + + // Return the value for the current entry. The underlying storage for + // the returned slice is valid only until the next modification of + // the iterator. + // REQUIRES: Valid() + virtual StringPiece value() const = 0; + + // If an error has occurred, return it. Else return an ok status. + virtual Status status() const = 0; + + // Clients are allowed to register function/arg1/arg2 triples that + // will be invoked when this iterator is destroyed. + // + // Note that unlike all of the preceding methods, this method is + // not abstract and therefore clients should not override it. + typedef void (*CleanupFunction)(void* arg1, void* arg2); + void RegisterCleanup(CleanupFunction function, void* arg1, void* arg2); + + private: + struct Cleanup { + CleanupFunction function; + void* arg1; + void* arg2; + Cleanup* next; + }; + Cleanup cleanup_; + + // No copying allowed + Iterator(const Iterator&); + void operator=(const Iterator&); +}; + +// Return an empty iterator (yields nothing). +extern Iterator* NewEmptyIterator(); + +// Return an empty iterator with the specified status. +extern Iterator* NewErrorIterator(const Status& status); + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_ITERATOR_H_ diff --git a/tensorflow/core/lib/io/match.cc b/tensorflow/core/lib/io/match.cc new file mode 100644 index 0000000000..1563642d0b --- /dev/null +++ b/tensorflow/core/lib/io/match.cc @@ -0,0 +1,31 @@ +#include "tensorflow/core/lib/io/match.h" +#include <fnmatch.h> +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace io { + +Status GetMatchingFiles(Env* env, const string& pattern, + std::vector<string>* results) { + results->clear(); + std::vector<string> all_files; + string dir = Dirname(pattern).ToString(); + if (dir.empty()) dir = "."; + string basename_pattern = Basename(pattern).ToString(); + Status s = env->GetChildren(dir, &all_files); + if (!s.ok()) { + return s; + } + for (const auto& f : all_files) { + int flags = 0; + if (fnmatch(basename_pattern.c_str(), Basename(f).ToString().c_str(), + flags) == 0) { + results->push_back(JoinPath(dir, f)); + } + } + return Status::OK(); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/match.h b/tensorflow/core/lib/io/match.h new file mode 100644 index 0000000000..fd194178e7 --- /dev/null +++ b/tensorflow/core/lib/io/match.h @@ -0,0 +1,24 @@ +#ifndef TENSORFLOW_LIB_IO_MATCH_H_ +#define TENSORFLOW_LIB_IO_MATCH_H_ + +#include <vector> +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +class Env; +namespace io { + +// Given a pattern, return the set of files that match the pattern. +// Note that this routine only supports wildcard characters in the +// basename portion of the pattern, not in the directory portion. If +// successful, return Status::OK and store the matching files in +// "*results". Otherwise, return a non-OK status. +Status GetMatchingFiles(Env* env, const string& pattern, + std::vector<string>* results); + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_MATCH_H_ diff --git a/tensorflow/core/lib/io/match_test.cc b/tensorflow/core/lib/io/match_test.cc new file mode 100644 index 0000000000..aaa56e4e7e --- /dev/null +++ b/tensorflow/core/lib/io/match_test.cc @@ -0,0 +1,51 @@ +#include <algorithm> +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/match.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/env.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace io { + +static string Match(Env* env, const string& suffix_pattern) { + std::vector<string> results; + Status s = GetMatchingFiles(env, JoinPath(testing::TmpDir(), suffix_pattern), + &results); + if (!s.ok()) { + return s.ToString(); + } else { + string r; + std::sort(results.begin(), results.end()); + for (size_t i = 0; i < results.size(); i++) { + strings::StrAppend(&r, (i > 0) ? "," : "", Basename(results[i])); + } + return r; + } +} +TEST(GetMatchingFiles, Simple) { + Env* env = Env::Default(); + EXPECT_EQ(Match(env, "thereisnosuchfile"), ""); + EXPECT_EQ(Match(env, "thereisnosuchfile*"), ""); + + // Populate a few files + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-00"), "")); + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-0a"), "")); + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-01"), "")); + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-aaa"), "")); + + EXPECT_EQ(Match(env, "match-*"), "match-00,match-01,match-0a,match-aaa"); + EXPECT_EQ(Match(env, "match-0[0-9]"), "match-00,match-01"); + EXPECT_EQ(Match(env, "match-?[0-9]"), "match-00,match-01"); + EXPECT_EQ(Match(env, "match-?a*"), "match-0a,match-aaa"); + EXPECT_EQ(Match(env, "match-??"), "match-00,match-01,match-0a"); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc new file mode 100644 index 0000000000..1359ded0f0 --- /dev/null +++ b/tensorflow/core/lib/io/path.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace io { + +string JoinPath(StringPiece part1, StringPiece part2) { + string result; + + StringPiece paths[2] = {part1, part2}; + for (StringPiece path : paths) { + if (path.empty()) continue; + + if (result.empty()) { + result = path.ToString(); + continue; + } + + if (result[result.size() - 1] == '/') { + if (IsAbsolutePath(path)) { + strings::StrAppend(&result, path.substr(1)); + } else { + strings::StrAppend(&result, path); + } + } else { + if (IsAbsolutePath(path)) { + strings::StrAppend(&result, path); + } else { + strings::StrAppend(&result, "/", path); + } + } + } + + return result; +} + +namespace internal { + +// Return the parts of the path, split on the final "/". If there is no +// "/" in the path, the first part of the output is empty and the second +// is the input. If the only "/" in the path is the first character, it is +// the first part of the output. +std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) { + auto pos = path.rfind('/'); + + // Handle the case with no '/' in 'path'. + if (pos == StringPiece::npos) + return std::make_pair(StringPiece(path.data(), 0), path); + + // Handle the case with a single leading '/' in 'path'. + if (pos == 0) + return std::make_pair(StringPiece(path.data(), 1), + StringPiece(path.data() + 1, path.size() - 1)); + + return std::make_pair( + StringPiece(path.data(), pos), + StringPiece(path.data() + pos + 1, path.size() - (pos + 1))); +} + +// Return the parts of the basename of path, split on the final ".". +// If there is no "." in the basename or "." is the final character in the +// basename, the second value will be empty. +std::pair<StringPiece, StringPiece> SplitBasename(StringPiece path) { + path = Basename(path); + + auto pos = path.rfind('.'); + if (pos == StringPiece::npos) + return std::make_pair(path, StringPiece(path.data() + path.size(), 0)); + return std::make_pair( + StringPiece(path.data(), pos), + StringPiece(path.data() + pos + 1, path.size() - (pos + 1))); +} +} // namespace internal + +bool IsAbsolutePath(StringPiece path) { + return !path.empty() && path[0] == '/'; +} + +StringPiece Dirname(StringPiece path) { + return internal::SplitPath(path).first; +} + +StringPiece Basename(StringPiece path) { + return internal::SplitPath(path).second; +} + +StringPiece Extension(StringPiece path) { + return internal::SplitBasename(path).second; +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h new file mode 100644 index 0000000000..01483f1702 --- /dev/null +++ b/tensorflow/core/lib/io/path.h @@ -0,0 +1,47 @@ +#ifndef TENSORFLOW_LIB_IO_PATH_H_ +#define TENSORFLOW_LIB_IO_PATH_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +class StringPiece; +namespace io { + +// Utility routines for processing filenames + +// Join multiple paths together, without introducing unnecessary path +// separators. +// For example: +// +// Arguments | JoinPath +// ---------------------------+---------- +// '/foo', 'bar' | /foo/bar +// '/foo/', 'bar' | /foo/bar +// '/foo', '/bar' | /foo/bar +// +// Usage: +// string path = io::JoinPath("/mydir", filename); +// string path = io::JoinPath(FLAGS_test_srcdir, filename); +string JoinPath(StringPiece part1, StringPiece part2); + +// Return true if path is absolute. +bool IsAbsolutePath(StringPiece path); + +// Returns the part of the path before the final "/". If there is a single +// leading "/" in the path, the result will be the leading "/". If there is +// no "/" in the path, the result is the empty prefix of the input. +StringPiece Dirname(StringPiece path); + +// Returns the part of the path after the final "/". If there is no +// "/" in the path, the result is the same as the input. +StringPiece Basename(StringPiece path); + +// Returns the part of the basename of path after the final ".". If +// there is no "." in the basename, the result is empty. +StringPiece Extension(StringPiece path); + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_PATH_H_ diff --git a/tensorflow/core/lib/io/path_test.cc b/tensorflow/core/lib/io/path_test.cc new file mode 100644 index 0000000000..b670e44f1f --- /dev/null +++ b/tensorflow/core/lib/io/path_test.cc @@ -0,0 +1,65 @@ +#include "tensorflow/core/lib/io/path.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace io { + +TEST(PathTest, JoinPath) { + EXPECT_EQ("/foo/bar", JoinPath("/foo", "bar")); + EXPECT_EQ("foo/bar", JoinPath("foo", "bar")); + EXPECT_EQ("foo/bar", JoinPath("foo", "/bar")); + EXPECT_EQ("/foo/bar", JoinPath("/foo", "/bar")); + + EXPECT_EQ("/bar", JoinPath("", "/bar")); + EXPECT_EQ("bar", JoinPath("", "bar")); + EXPECT_EQ("/foo", JoinPath("/foo", "")); + + EXPECT_EQ("/foo/bar/baz/blah/blink/biz", + JoinPath("/foo/bar/baz/", "/blah/blink/biz")); +} + +TEST(PathTest, IsAbsolutePath) { + EXPECT_FALSE(IsAbsolutePath("")); + EXPECT_FALSE(IsAbsolutePath("../foo")); + EXPECT_FALSE(IsAbsolutePath("foo")); + EXPECT_FALSE(IsAbsolutePath("./foo")); + EXPECT_FALSE(IsAbsolutePath("foo/bar/baz/")); + EXPECT_TRUE(IsAbsolutePath("/foo")); + EXPECT_TRUE(IsAbsolutePath("/foo/bar/../baz")); +} + +TEST(PathTest, Dirname) { + EXPECT_EQ("/hello", Dirname("/hello/")); + EXPECT_EQ("/", Dirname("/hello")); + EXPECT_EQ("hello", Dirname("hello/world")); + EXPECT_EQ("hello", Dirname("hello/")); + EXPECT_EQ("", Dirname("world")); + EXPECT_EQ("/", Dirname("/")); + EXPECT_EQ("", Dirname("")); +} + +TEST(PathTest, Basename) { + EXPECT_EQ("", Basename("/hello/")); + EXPECT_EQ("hello", Basename("/hello")); + EXPECT_EQ("world", Basename("hello/world")); + EXPECT_EQ("", Basename("hello/")); + EXPECT_EQ("world", Basename("world")); + EXPECT_EQ("", Basename("/")); + EXPECT_EQ("", Basename("")); +} + +TEST(PathTest, Extension) { + EXPECT_EQ("gif", Extension("foo.gif")); + EXPECT_EQ("", Extension("foo.")); + EXPECT_EQ("", Extension("")); + EXPECT_EQ("", Extension("/")); + EXPECT_EQ("", Extension("foo")); + EXPECT_EQ("", Extension("foo/")); + EXPECT_EQ("gif", Extension("/a/path/to/foo.gif")); + EXPECT_EQ("html", Extension("/a/path.bar/to/foo.html")); + EXPECT_EQ("", Extension("/a/path.bar/to/foo")); + EXPECT_EQ("baz", Extension("/a/path.bar/to/foo.bar.baz")); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc new file mode 100644 index 0000000000..2f0fabff63 --- /dev/null +++ b/tensorflow/core/lib/io/record_reader.cc @@ -0,0 +1,80 @@ +#include "tensorflow/core/lib/io/record_reader.h" + +#include <limits.h> +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace io { + +RecordReader::RecordReader(RandomAccessFile* file) : src_(file) {} + +RecordReader::~RecordReader() {} + +// Read n+4 bytes from file, verify that checksum of first n bytes is +// stored in the last 4 bytes and store the first n bytes in *result. +// May use *storage as backing store. +static Status ReadChecksummed(RandomAccessFile* file, uint64 offset, + size_t n, StringPiece* result, + string* storage) { + if (n >= SIZE_MAX - sizeof(uint32)) { + return errors::DataLoss("record size too large"); + } + + const size_t expected = n + sizeof(uint32); + storage->resize(expected); + StringPiece data; + Status s = file->Read(offset, expected, &data, &(*storage)[0]); + if (!s.ok()) { + return s; + } + if (data.size() != expected) { + if (data.size() == 0) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } + } + uint32 masked_crc = core::DecodeFixed32(data.data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(data.data(), n); + return Status::OK(); +} + +Status RecordReader::ReadRecord(uint64* offset, string* record) { + static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); + static const size_t kFooterSize = sizeof(uint32); + + // Read length + StringPiece lbuf; + Status s = ReadChecksummed(src_, *offset, sizeof(uint64), &lbuf, record); + if (!s.ok()) { + return s; + } + const uint64 length = core::DecodeFixed64(lbuf.data()); + + // Read data + StringPiece data; + s = ReadChecksummed(src_, *offset + kHeaderSize, length, &data, record); + if (!s.ok()) { + if (errors::IsOutOfRange(s)) { + s = errors::DataLoss("truncated record at ", *offset); + } + return s; + } + if (record->data() != data.data()) { + // RandomAccessFile placed the data in some other location. + memmove(&(*record)[0], data.data(), data.size()); + } + + record->resize(data.size()); + *offset += kHeaderSize + length + kFooterSize; + return Status::OK(); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h new file mode 100644 index 0000000000..a8c1b0dd5d --- /dev/null +++ b/tensorflow/core/lib/io/record_reader.h @@ -0,0 +1,36 @@ +#ifndef TENSORFLOW_LIB_IO_RECORD_READER_H_ +#define TENSORFLOW_LIB_IO_RECORD_READER_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class RandomAccessFile; + +namespace io { + +class RecordReader { + public: + // Create a reader that will return log records from "*file". + // "*file" must remain live while this Reader is in use. + explicit RecordReader(RandomAccessFile* file); + + ~RecordReader(); + + // Read the record at "*offset" into *record and update *offset to + // point to the offset of the next record. Returns OK on success, + // OUT_OF_RANGE for end of file, or something else for an error. + Status ReadRecord(uint64* offset, string* record); + + private: + RandomAccessFile* src_; + + TF_DISALLOW_COPY_AND_ASSIGN(RecordReader); +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_RECORD_READER_H_ diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc new file mode 100644 index 0000000000..3d7f1509ab --- /dev/null +++ b/tensorflow/core/lib/io/record_writer.cc @@ -0,0 +1,42 @@ +#include "tensorflow/core/lib/io/record_writer.h" + +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" + +namespace tensorflow { +namespace io { + +RecordWriter::RecordWriter(WritableFile* dest) : dest_(dest) {} + +RecordWriter::~RecordWriter() {} + +static uint32 MaskedCrc(const char* data, size_t n) { + return crc32c::Mask(crc32c::Value(data, n)); +} + +Status RecordWriter::WriteRecord(StringPiece data) { + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + char header[sizeof(uint64) + sizeof(uint32)]; + core::EncodeFixed64(header + 0, data.size()); + core::EncodeFixed32(header + sizeof(uint64), + MaskedCrc(header, sizeof(uint64))); + Status s = dest_->Append(StringPiece(header, sizeof(header))); + if (!s.ok()) { + return s; + } + s = dest_->Append(data); + if (!s.ok()) { + return s; + } + char footer[sizeof(uint32)]; + core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size())); + return dest_->Append(StringPiece(footer, sizeof(footer))); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h new file mode 100644 index 0000000000..c7af00e5ae --- /dev/null +++ b/tensorflow/core/lib/io/record_writer.h @@ -0,0 +1,34 @@ +#ifndef TENSORFLOW_LIB_IO_RECORD_WRITER_H_ +#define TENSORFLOW_LIB_IO_RECORD_WRITER_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class WritableFile; + +namespace io { + +class RecordWriter { + public: + // Create a writer that will append data to "*dest". + // "*dest" must be initially empty. + // "*dest" must remain live while this Writer is in use. + explicit RecordWriter(WritableFile* dest); + + ~RecordWriter(); + + Status WriteRecord(StringPiece slice); + + private: + WritableFile* const dest_; + + TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter); +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_RECORD_WRITER_H_ diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc new file mode 100644 index 0000000000..3e9c816443 --- /dev/null +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -0,0 +1,245 @@ +#include "tensorflow/core/lib/io/record_reader.h" +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace io { + +// Construct a string of the specified length made out of the supplied +// partial string. +static string BigString(const string& partial_string, size_t n) { + string result; + while (result.size() < n) { + result.append(partial_string); + } + result.resize(n); + return result; +} + +// Construct a string from a number +static string NumberString(int n) { + char buf[50]; + snprintf(buf, sizeof(buf), "%d.", n); + return string(buf); +} + +// Return a skewed potentially long string +static string RandomSkewedString(int i, random::SimplePhilox* rnd) { + return BigString(NumberString(i), rnd->Skewed(17)); +} + +class RecordioTest : public testing::Test { + private: + class StringDest : public WritableFile { + public: + string contents_; + + Status Close() override { return Status::OK(); } + Status Flush() override { return Status::OK(); } + Status Sync() override { return Status::OK(); } + Status Append(const StringPiece& slice) override { + contents_.append(slice.data(), slice.size()); + return Status::OK(); + } + }; + + class StringSource : public RandomAccessFile { + public: + StringPiece contents_; + mutable bool force_error_; + mutable bool returned_partial_; + StringSource() : force_error_(false), returned_partial_(false) {} + + Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { + EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error"; + + if (force_error_) { + force_error_ = false; + returned_partial_ = true; + return errors::DataLoss("read error"); + } + + if (offset >= contents_.size()) { + return errors::OutOfRange("end of file"); + } + + if (contents_.size() < offset + n) { + n = contents_.size() - offset; + returned_partial_ = true; + } + *result = StringPiece(contents_.data() + offset, n); + return Status::OK(); + } + }; + + StringDest dest_; + StringSource source_; + bool reading_; + uint64 readpos_; + RecordWriter* writer_; + RecordReader* reader_; + + public: + RecordioTest() + : reading_(false), + readpos_(0), + writer_(new RecordWriter(&dest_)), + reader_(new RecordReader(&source_)) {} + + ~RecordioTest() override { + delete writer_; + delete reader_; + } + + void Write(const string& msg) { + ASSERT_TRUE(!reading_) << "Write() after starting to read"; + ASSERT_OK(writer_->WriteRecord(StringPiece(msg))); + } + + size_t WrittenBytes() const { return dest_.contents_.size(); } + + string Read() { + if (!reading_) { + reading_ = true; + source_.contents_ = StringPiece(dest_.contents_); + } + string record; + Status s = reader_->ReadRecord(&readpos_, &record); + if (s.ok()) { + return record; + } else if (errors::IsOutOfRange(s)) { + return "EOF"; + } else { + return s.ToString(); + } + } + + void IncrementByte(int offset, int delta) { + dest_.contents_[offset] += delta; + } + + void SetByte(int offset, char new_byte) { + dest_.contents_[offset] = new_byte; + } + + void ShrinkSize(int bytes) { + dest_.contents_.resize(dest_.contents_.size() - bytes); + } + + void FixChecksum(int header_offset, int len) { + // Compute crc of type/len/data + uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len); + crc = crc32c::Mask(crc); + core::EncodeFixed32(&dest_.contents_[header_offset], crc); + } + + void ForceError() { source_.force_error_ = true; } + + void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; } + + void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) { + Write("foo"); + Write("bar"); + Write(BigString("x", 10000)); + reading_ = true; + source_.contents_ = StringPiece(dest_.contents_); + uint64 offset = WrittenBytes() + offset_past_end; + string record; + Status s = reader_->ReadRecord(&offset, &record); + ASSERT_TRUE(errors::IsOutOfRange(s)) << s; + } +}; + +TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); } + +TEST_F(RecordioTest, ReadWrite) { + Write("foo"); + Write("bar"); + Write(""); + Write("xxxx"); + ASSERT_EQ("foo", Read()); + ASSERT_EQ("bar", Read()); + ASSERT_EQ("", Read()); + ASSERT_EQ("xxxx", Read()); + ASSERT_EQ("EOF", Read()); + ASSERT_EQ("EOF", Read()); // Make sure reads at eof work +} + +TEST_F(RecordioTest, ManyRecords) { + for (int i = 0; i < 100000; i++) { + Write(NumberString(i)); + } + for (int i = 0; i < 100000; i++) { + ASSERT_EQ(NumberString(i), Read()); + } + ASSERT_EQ("EOF", Read()); +} + +TEST_F(RecordioTest, RandomRead) { + const int N = 500; + { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int i = 0; i < N; i++) { + Write(RandomSkewedString(i, &rnd)); + } + } + { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int i = 0; i < N; i++) { + ASSERT_EQ(RandomSkewedString(i, &rnd), Read()); + } + } + ASSERT_EQ("EOF", Read()); +} + +// Tests of all the error paths in log_reader.cc follow: +static void AssertHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " + << expected; +} + +TEST_F(RecordioTest, ReadError) { + Write("foo"); + ForceError(); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptLength) { + Write("foo"); + IncrementByte(6, 100); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptLengthCrc) { + Write("foo"); + IncrementByte(10, 100); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptData) { + Write("foo"); + IncrementByte(14, 10); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptDataCrc) { + Write("foo"); + IncrementByte(WrittenBytes() - 1, 10); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); } + +TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); } + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table.cc b/tensorflow/core/lib/io/table.cc new file mode 100644 index 0000000000..769d7e72a5 --- /dev/null +++ b/tensorflow/core/lib/io/table.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/table.h" + +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/lib/io/two_level_iterator.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace table { + +struct Table::Rep { + ~Rep() { delete index_block; } + + Options options; + Status status; + RandomAccessFile* file; + // XXX uint64 cache_id; + + BlockHandle metaindex_handle; // Handle to metaindex_block: saved from footer + Block* index_block; +}; + +Status Table::Open(const Options& options, RandomAccessFile* file, + uint64 size, Table** table) { + *table = NULL; + if (size < Footer::kEncodedLength) { + return errors::DataLoss("file is too short to be an sstable"); + } + + char footer_space[Footer::kEncodedLength]; + StringPiece footer_input; + Status s = + file->Read(size - Footer::kEncodedLength, Footer::kEncodedLength, + &footer_input, footer_space); + if (!s.ok()) return s; + + Footer footer; + s = footer.DecodeFrom(&footer_input); + if (!s.ok()) return s; + + // Read the index block + BlockContents contents; + Block* index_block = NULL; + if (s.ok()) { + s = ReadBlock(file, footer.index_handle(), &contents); + if (s.ok()) { + index_block = new Block(contents); + } + } + + if (s.ok()) { + // We've successfully read the footer and the index block: we're + // ready to serve requests. + Rep* rep = new Table::Rep; + rep->options = options; + rep->file = file; + rep->metaindex_handle = footer.metaindex_handle(); + rep->index_block = index_block; + // XXX rep->cache_id = (options.block_cache ? + // options.block_cache->NewId() : 0); + *table = new Table(rep); + } else { + if (index_block) delete index_block; + } + + return s; +} + +Table::~Table() { delete rep_; } + +static void DeleteBlock(void* arg, void* ignored) { + delete reinterpret_cast<Block*>(arg); +} + +// Convert an index iterator value (i.e., an encoded BlockHandle) +// into an iterator over the contents of the corresponding block. +Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) { + Table* table = reinterpret_cast<Table*>(arg); + // Cache* block_cache = table->rep_->options.block_cache; + Block* block = NULL; + // Cache::Handle* cache_handle = NULL; + + BlockHandle handle; + StringPiece input = index_value; + Status s = handle.DecodeFrom(&input); + // We intentionally allow extra stuff in index_value so that we + // can add more features in the future. + + if (s.ok()) { + BlockContents contents; + s = ReadBlock(table->rep_->file, handle, &contents); + if (s.ok()) { + block = new Block(contents); + } + } + + Iterator* iter; + if (block != NULL) { + iter = block->NewIterator(); + iter->RegisterCleanup(&DeleteBlock, block, NULL); + } else { + iter = NewErrorIterator(s); + } + return iter; +} + +Iterator* Table::NewIterator() const { + return NewTwoLevelIterator(rep_->index_block->NewIterator(), + &Table::BlockReader, const_cast<Table*>(this)); +} + +Status Table::InternalGet(const StringPiece& k, void* arg, + void (*saver)(void*, const StringPiece&, + const StringPiece&)) { + Status s; + Iterator* iiter = rep_->index_block->NewIterator(); + iiter->Seek(k); + if (iiter->Valid()) { + BlockHandle handle; + Iterator* block_iter = BlockReader(this, iiter->value()); + block_iter->Seek(k); + if (block_iter->Valid()) { + (*saver)(arg, block_iter->key(), block_iter->value()); + } + s = block_iter->status(); + delete block_iter; + } + if (s.ok()) { + s = iiter->status(); + } + delete iiter; + return s; +} + +uint64 Table::ApproximateOffsetOf(const StringPiece& key) const { + Iterator* index_iter = rep_->index_block->NewIterator(); + index_iter->Seek(key); + uint64 result; + if (index_iter->Valid()) { + BlockHandle handle; + StringPiece input = index_iter->value(); + Status s = handle.DecodeFrom(&input); + if (s.ok()) { + result = handle.offset(); + } else { + // Strange: we can't decode the block handle in the index block. + // We'll just return the offset of the metaindex block, which is + // close to the whole file size for this case. + result = rep_->metaindex_handle.offset(); + } + } else { + // key is past the last key in the file. Approximate the offset + // by returning the offset of the metaindex block (which is + // right near the end of the file). + result = rep_->metaindex_handle.offset(); + } + delete index_iter; + return result; +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h new file mode 100644 index 0000000000..230dded2d4 --- /dev/null +++ b/tensorflow/core/lib/io/table.h @@ -0,0 +1,76 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_TABLE_H_ +#define TENSORFLOW_LIB_IO_TABLE_H_ + +#include <stdint.h> +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +class RandomAccessFile; + +namespace table { + +class Block; +class BlockHandle; +class Footer; +struct Options; + +// A Table is a sorted map from strings to strings. Tables are +// immutable and persistent. A Table may be safely accessed from +// multiple threads without external synchronization. +class Table { + public: + // Attempt to open the table that is stored in bytes [0..file_size) + // of "file", and read the metadata entries necessary to allow + // retrieving data from the table. + // + // If successful, returns ok and sets "*table" to the newly opened + // table. The client should delete "*table" when no longer needed. + // If there was an error while initializing the table, sets "*table" + // to NULL and returns a non-ok status. Does not take ownership of + // "*file", but the client must ensure that "file" remains live + // for the duration of the returned table's lifetime. + static Status Open(const Options& options, RandomAccessFile* file, + uint64 file_size, Table** table); + + ~Table(); + + // Returns a new iterator over the table contents. + // The result of NewIterator() is initially invalid (caller must + // call one of the Seek methods on the iterator before using it). + Iterator* NewIterator() const; + + // Given a key, return an approximate byte offset in the file where + // the data for that key begins (or would begin if the key were + // present in the file). The returned value is in terms of file + // bytes, and so includes effects like compression of the underlying data. + // E.g., the approximate offset of the last key in the table will + // be close to the file length. + uint64 ApproximateOffsetOf(const StringPiece& key) const; + + private: + struct Rep; + Rep* rep_; + + explicit Table(Rep* rep) { rep_ = rep; } + static Iterator* BlockReader(void*, const StringPiece&); + + // Calls (*handle_result)(arg, ...) with the entry found after a call + // to Seek(key). May not make such a call if filter policy says + // that key is not present. + Status InternalGet(const StringPiece& key, void* arg, + void (*handle_result)(void* arg, const StringPiece& k, + const StringPiece& v)); + + // No copying allowed + Table(const Table&); + void operator=(const Table&); +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TABLE_H_ diff --git a/tensorflow/core/lib/io/table_builder.cc b/tensorflow/core/lib/io/table_builder.cc new file mode 100644 index 0000000000..b786888b30 --- /dev/null +++ b/tensorflow/core/lib/io/table_builder.cc @@ -0,0 +1,263 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/table_builder.h" + +#include <assert.h> +#include "tensorflow/core/lib/io/block_builder.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +namespace { + +void FindShortestSeparator(string* start, const StringPiece& limit) { + // Find length of common prefix + size_t min_length = std::min(start->size(), limit.size()); + size_t diff_index = 0; + while ((diff_index < min_length) && + ((*start)[diff_index] == limit[diff_index])) { + diff_index++; + } + + if (diff_index >= min_length) { + // Do not shorten if one string is a prefix of the other + } else { + uint8 diff_byte = static_cast<uint8>((*start)[diff_index]); + if (diff_byte < static_cast<uint8>(0xff) && + diff_byte + 1 < static_cast<uint8>(limit[diff_index])) { + (*start)[diff_index]++; + start->resize(diff_index + 1); + assert(StringPiece(*start).compare(limit) < 0); + } + } +} + +void FindShortSuccessor(string* key) { + // Find first character that can be incremented + size_t n = key->size(); + for (size_t i = 0; i < n; i++) { + const uint8 byte = (*key)[i]; + if (byte != static_cast<uint8>(0xff)) { + (*key)[i] = byte + 1; + key->resize(i + 1); + return; + } + } + // *key is a run of 0xffs. Leave it alone. +} +} // namespace + +struct TableBuilder::Rep { + Options options; + Options index_block_options; + WritableFile* file; + uint64 offset; + Status status; + BlockBuilder data_block; + BlockBuilder index_block; + string last_key; + int64 num_entries; + bool closed; // Either Finish() or Abandon() has been called. + + // We do not emit the index entry for a block until we have seen the + // first key for the next data block. This allows us to use shorter + // keys in the index block. For example, consider a block boundary + // between the keys "the quick brown fox" and "the who". We can use + // "the r" as the key for the index block entry since it is >= all + // entries in the first block and < all entries in subsequent + // blocks. + // + // Invariant: r->pending_index_entry is true only if data_block is empty. + bool pending_index_entry; + BlockHandle pending_handle; // Handle to add to index block + + string compressed_output; + + Rep(const Options& opt, WritableFile* f) + : options(opt), + index_block_options(opt), + file(f), + offset(0), + data_block(&options), + index_block(&index_block_options), + num_entries(0), + closed(false), + pending_index_entry(false) { + index_block_options.block_restart_interval = 1; + } +}; + +TableBuilder::TableBuilder(const Options& options, WritableFile* file) + : rep_(new Rep(options, file)) {} + +TableBuilder::~TableBuilder() { + assert(rep_->closed); // Catch errors where caller forgot to call Finish() + delete rep_; +} + +void TableBuilder::Add(const StringPiece& key, const StringPiece& value) { + Rep* r = rep_; + assert(!r->closed); + if (!ok()) return; + if (r->num_entries > 0) { + assert(key.compare(StringPiece(r->last_key)) > 0); + // See if this key+value would make our current block overly large. If + // so, emit the current block before adding this key/value + const int kOverlyLargeBlockRatio = 2; + const size_t this_entry_bytes = key.size() + value.size(); + if (this_entry_bytes >= kOverlyLargeBlockRatio * r->options.block_size) { + Flush(); + } + } + + if (r->pending_index_entry) { + assert(r->data_block.empty()); + FindShortestSeparator(&r->last_key, key); + string handle_encoding; + r->pending_handle.EncodeTo(&handle_encoding); + r->index_block.Add(r->last_key, StringPiece(handle_encoding)); + r->pending_index_entry = false; + } + + r->last_key.assign(key.data(), key.size()); + r->num_entries++; + r->data_block.Add(key, value); + + const size_t estimated_block_size = r->data_block.CurrentSizeEstimate(); + if (estimated_block_size >= r->options.block_size) { + Flush(); + } +} + +void TableBuilder::Flush() { + Rep* r = rep_; + assert(!r->closed); + if (!ok()) return; + if (r->data_block.empty()) return; + assert(!r->pending_index_entry); + WriteBlock(&r->data_block, &r->pending_handle); + if (ok()) { + r->pending_index_entry = true; + r->status = r->file->Flush(); + } +} + +void TableBuilder::WriteBlock(BlockBuilder* block, BlockHandle* handle) { + // File format contains a sequence of blocks where each block has: + // block_data: uint8[n] + // type: uint8 + // crc: uint32 + assert(ok()); + Rep* r = rep_; + StringPiece raw = block->Finish(); + + StringPiece block_contents; + CompressionType type = r->options.compression; + // TODO(postrelease): Support more compression options: zlib? + switch (type) { + case kNoCompression: + block_contents = raw; + break; + + case kSnappyCompression: { + string* compressed = &r->compressed_output; + if (port::Snappy_Compress(raw.data(), raw.size(), compressed) && + compressed->size() < raw.size() - (raw.size() / 8u)) { + block_contents = *compressed; + } else { + // Snappy not supported, or compressed less than 12.5%, so just + // store uncompressed form + block_contents = raw; + type = kNoCompression; + } + break; + } + } + WriteRawBlock(block_contents, type, handle); + r->compressed_output.clear(); + block->Reset(); +} + +void TableBuilder::WriteRawBlock(const StringPiece& block_contents, + CompressionType type, BlockHandle* handle) { + Rep* r = rep_; + handle->set_offset(r->offset); + handle->set_size(block_contents.size()); + r->status = r->file->Append(block_contents); + if (r->status.ok()) { + char trailer[kBlockTrailerSize]; + trailer[0] = type; + uint32 crc = crc32c::Value(block_contents.data(), block_contents.size()); + crc = crc32c::Extend(crc, trailer, 1); // Extend crc to cover block type + core::EncodeFixed32(trailer + 1, crc32c::Mask(crc)); + r->status = r->file->Append(StringPiece(trailer, kBlockTrailerSize)); + if (r->status.ok()) { + r->offset += block_contents.size() + kBlockTrailerSize; + } + } +} + +Status TableBuilder::status() const { return rep_->status; } + +Status TableBuilder::Finish() { + Rep* r = rep_; + Flush(); + assert(!r->closed); + r->closed = true; + + BlockHandle metaindex_block_handle, index_block_handle; + + // Write metaindex block + if (ok()) { + BlockBuilder meta_index_block(&r->options); + // TODO(postrelease): Add stats and other meta blocks + WriteBlock(&meta_index_block, &metaindex_block_handle); + } + + // Write index block + if (ok()) { + if (r->pending_index_entry) { + FindShortSuccessor(&r->last_key); + string handle_encoding; + r->pending_handle.EncodeTo(&handle_encoding); + r->index_block.Add(r->last_key, StringPiece(handle_encoding)); + r->pending_index_entry = false; + } + WriteBlock(&r->index_block, &index_block_handle); + } + + // Write footer + if (ok()) { + Footer footer; + footer.set_metaindex_handle(metaindex_block_handle); + footer.set_index_handle(index_block_handle); + string footer_encoding; + footer.EncodeTo(&footer_encoding); + r->status = r->file->Append(footer_encoding); + if (r->status.ok()) { + r->offset += footer_encoding.size(); + } + } + return r->status; +} + +void TableBuilder::Abandon() { + Rep* r = rep_; + assert(!r->closed); + r->closed = true; +} + +uint64 TableBuilder::NumEntries() const { return rep_->num_entries; } + +uint64 TableBuilder::FileSize() const { return rep_->offset; } + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h new file mode 100644 index 0000000000..cebf4d8e0c --- /dev/null +++ b/tensorflow/core/lib/io/table_builder.h @@ -0,0 +1,87 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// TableBuilder provides the interface used to build a Table +// (an immutable and sorted map from keys to values). +// +// Multiple threads can invoke const methods on a TableBuilder without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same TableBuilder must use +// external synchronization. + +#ifndef TENSORFLOW_LIB_IO_TABLE_BUILDER_H_ +#define TENSORFLOW_LIB_IO_TABLE_BUILDER_H_ + +#include <stdint.h> +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +class WritableFile; +namespace table { + +class BlockBuilder; +class BlockHandle; + +class TableBuilder { + public: + // Create a builder that will store the contents of the table it is + // building in *file. Does not close the file. It is up to the + // caller to close the file after calling Finish(). + TableBuilder(const Options& options, WritableFile* file); + + // REQUIRES: Either Finish() or Abandon() has been called. + ~TableBuilder(); + + // Add key,value to the table being constructed. + // REQUIRES: key is after any previously added key in lexicographic order. + // REQUIRES: Finish(), Abandon() have not been called + void Add(const StringPiece& key, const StringPiece& value); + + // Advanced operation: flush any buffered key/value pairs to file. + // Can be used to ensure that two adjacent entries never live in + // the same data block. Most clients should not need to use this method. + // REQUIRES: Finish(), Abandon() have not been called + void Flush(); + + // Return non-ok iff some error has been detected. + Status status() const; + + // Finish building the table. Stops using the file passed to the + // constructor after this function returns. + // REQUIRES: Finish(), Abandon() have not been called + Status Finish(); + + // Indicate that the contents of this builder should be abandoned. Stops + // using the file passed to the constructor after this function returns. + // If the caller is not going to call Finish(), it must call Abandon() + // before destroying this builder. + // REQUIRES: Finish(), Abandon() have not been called + void Abandon(); + + // Number of calls to Add() so far. + uint64 NumEntries() const; + + // Size of the file generated so far. If invoked after a successful + // Finish() call, returns the size of the final generated file. + uint64 FileSize() const; + + private: + bool ok() const { return status().ok(); } + void WriteBlock(BlockBuilder* block, BlockHandle* handle); + void WriteRawBlock(const StringPiece& data, CompressionType, + BlockHandle* handle); + + struct Rep; + Rep* rep_; + + // No copying allowed + TableBuilder(const TableBuilder&); + void operator=(const TableBuilder&); +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TABLE_BUILDER_H_ diff --git a/tensorflow/core/lib/io/table_format.txt b/tensorflow/core/lib/io/table_format.txt new file mode 100644 index 0000000000..7edb9fb121 --- /dev/null +++ b/tensorflow/core/lib/io/table_format.txt @@ -0,0 +1,8 @@ +File format +=========== + +The table format is heavily based on the table format for the LevelDB +open source key/value store, with the exception that our tables +do not support "filter" meta blocks (Bloom Filters). See: + +https://code.google.com/p/leveldb/source/browse/doc/table_format.txt diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h new file mode 100644 index 0000000000..45b061b03b --- /dev/null +++ b/tensorflow/core/lib/io/table_options.h @@ -0,0 +1,53 @@ +#ifndef TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_ +#define TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_ + +#include <stddef.h> + +namespace tensorflow { +namespace table { + +// DB contents are stored in a set of blocks, each of which holds a +// sequence of key,value pairs. Each block may be compressed before +// being stored in a file. The following enum describes which +// compression method (if any) is used to compress a block. +enum CompressionType { + // NOTE: do not change the values of existing entries, as these are + // part of the persistent format on disk. + kNoCompression = 0x0, + kSnappyCompression = 0x1 +}; + +// Options to control the behavior of a table (passed to Table::Open) +struct Options { + // Approximate size of user data packed per block. Note that the + // block size specified here corresponds to uncompressed data. The + // actual size of the unit read from disk may be smaller if + // compression is enabled. This parameter can be changed dynamically. + size_t block_size = 262144; + + // Number of keys between restart points for delta encoding of keys. + // This parameter can be changed dynamically. Most clients should + // leave this parameter alone. + int block_restart_interval = 16; + + // Compress blocks using the specified compression algorithm. This + // parameter can be changed dynamically. + // + // Default: kSnappyCompression, which gives lightweight but fast + // compression. + // + // Typical speeds of kSnappyCompression on an Intel(R) Core(TM)2 2.4GHz: + // ~200-500MB/s compression + // ~400-800MB/s decompression + // Note that these speeds are significantly faster than most + // persistent storage speeds, and therefore it is typically never + // worth switching to kNoCompression. Even if the input data is + // incompressible, the kSnappyCompression implementation will + // efficiently detect that and will switch to uncompressed mode. + CompressionType compression = kSnappyCompression; +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_ diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc new file mode 100644 index 0000000000..66e90ac64e --- /dev/null +++ b/tensorflow/core/lib/io/table_test.cc @@ -0,0 +1,601 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/table.h" + +#include <map> +#include <string> +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/block_builder.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/iterator.h" +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace table { + +namespace test { +static StringPiece RandomString(random::SimplePhilox* rnd, int len, + string* dst) { + dst->resize(len); + for (int i = 0; i < len; i++) { + (*dst)[i] = static_cast<char>(' ' + rnd->Uniform(95)); // ' ' .. '~' + } + return StringPiece(*dst); +} +static string RandomKey(random::SimplePhilox* rnd, int len) { + // Make sure to generate a wide variety of characters so we + // test the boundary conditions for short-key optimizations. + static const char kTestChars[] = {'\0', '\1', 'a', 'b', 'c', + 'd', 'e', '\xfd', '\xfe', '\xff'}; + string result; + for (int i = 0; i < len; i++) { + result += kTestChars[rnd->Uniform(sizeof(kTestChars))]; + } + return result; +} +static StringPiece CompressibleString(random::SimplePhilox* rnd, + double compressed_fraction, size_t len, + string* dst) { + int raw = static_cast<int>(len * compressed_fraction); + if (raw < 1) raw = 1; + string raw_data; + RandomString(rnd, raw, &raw_data); + + // Duplicate the random data until we have filled "len" bytes + dst->clear(); + while (dst->size() < len) { + dst->append(raw_data); + } + dst->resize(len); + return StringPiece(*dst); +} +} + +static void Increment(string* key) { key->push_back('\0'); } + +// An STL comparator that compares two StringPieces +namespace { +struct STLLessThan { + STLLessThan() {} + bool operator()(const string& a, const string& b) const { + return StringPiece(a).compare(StringPiece(b)) < 0; + } +}; +} // namespace + +class StringSink : public WritableFile { + public: + ~StringSink() {} + + const string& contents() const { return contents_; } + + virtual Status Close() { return Status::OK(); } + virtual Status Flush() { return Status::OK(); } + virtual Status Sync() { return Status::OK(); } + + virtual Status Append(const StringPiece& data) { + contents_.append(data.data(), data.size()); + return Status::OK(); + } + + private: + string contents_; +}; + +class StringSource : public RandomAccessFile { + public: + StringSource(const StringPiece& contents) + : contents_(contents.data(), contents.size()), bytes_read_(0) {} + + virtual ~StringSource() {} + + uint64 Size() const { return contents_.size(); } + + virtual Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const { + if (offset > contents_.size()) { + return errors::InvalidArgument("invalid Read offset"); + } + if (offset + n > contents_.size()) { + n = contents_.size() - offset; + } + memcpy(scratch, &contents_[offset], n); + *result = StringPiece(scratch, n); + bytes_read_ += n; + return Status::OK(); + } + + uint64 BytesRead() const { return bytes_read_; } + + private: + string contents_; + mutable uint64 bytes_read_; +}; + +typedef std::map<string, string, STLLessThan> KVMap; + +// Helper class for tests to unify the interface between +// BlockBuilder/TableBuilder and Block/Table. +class Constructor { + public: + explicit Constructor() : data_(STLLessThan()) {} + virtual ~Constructor() {} + + void Add(const string& key, const StringPiece& value) { + data_[key] = value.ToString(); + } + + // Finish constructing the data structure with all the keys that have + // been added so far. Returns the keys in sorted order in "*keys" + // and stores the key/value pairs in "*kvmap" + void Finish(const Options& options, std::vector<string>* keys, KVMap* kvmap) { + *kvmap = data_; + keys->clear(); + for (KVMap::const_iterator it = data_.begin(); it != data_.end(); ++it) { + keys->push_back(it->first); + } + data_.clear(); + Status s = FinishImpl(options, *kvmap); + ASSERT_TRUE(s.ok()) << s.ToString(); + } + + // Construct the data structure from the data in "data" + virtual Status FinishImpl(const Options& options, const KVMap& data) = 0; + + virtual Iterator* NewIterator() const = 0; + + virtual const KVMap& data() { return data_; } + + private: + KVMap data_; +}; + +class BlockConstructor : public Constructor { + public: + BlockConstructor() : block_(NULL) {} + ~BlockConstructor() { delete block_; } + virtual Status FinishImpl(const Options& options, const KVMap& data) { + delete block_; + block_ = NULL; + BlockBuilder builder(&options); + + for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) { + builder.Add(it->first, it->second); + } + // Open the block + data_ = builder.Finish().ToString(); + BlockContents contents; + contents.data = data_; + contents.cachable = false; + contents.heap_allocated = false; + block_ = new Block(contents); + return Status::OK(); + } + virtual Iterator* NewIterator() const { return block_->NewIterator(); } + + private: + string data_; + Block* block_; +}; + +class TableConstructor : public Constructor { + public: + TableConstructor() : source_(NULL), table_(NULL) {} + ~TableConstructor() { Reset(); } + virtual Status FinishImpl(const Options& options, const KVMap& data) { + Reset(); + StringSink sink; + TableBuilder builder(options, &sink); + + for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) { + builder.Add(it->first, it->second); + TF_CHECK_OK(builder.status()); + } + Status s = builder.Finish(); + TF_CHECK_OK(s) << s.ToString(); + + CHECK_EQ(sink.contents().size(), builder.FileSize()); + + // Open the table + source_ = new StringSource(sink.contents()); + Options table_options; + return Table::Open(table_options, source_, sink.contents().size(), &table_); + } + + virtual Iterator* NewIterator() const { return table_->NewIterator(); } + + uint64 ApproximateOffsetOf(const StringPiece& key) const { + return table_->ApproximateOffsetOf(key); + } + + uint64 BytesRead() const { return source_->BytesRead(); } + + private: + void Reset() { + delete table_; + delete source_; + table_ = NULL; + source_ = NULL; + } + + StringSource* source_; + Table* table_; +}; + +enum TestType { TABLE_TEST, BLOCK_TEST }; + +struct TestArgs { + TestType type; + int restart_interval; +}; + +static const TestArgs kTestArgList[] = { + {TABLE_TEST, 16}, {TABLE_TEST, 1}, {TABLE_TEST, 1024}, + {BLOCK_TEST, 16}, {BLOCK_TEST, 1}, {BLOCK_TEST, 1024}, +}; +static const int kNumTestArgs = sizeof(kTestArgList) / sizeof(kTestArgList[0]); + +class Harness : public ::testing::Test { + public: + Harness() : constructor_(NULL) {} + + void Init(const TestArgs& args) { + delete constructor_; + constructor_ = NULL; + options_ = Options(); + + options_.block_restart_interval = args.restart_interval; + // Use shorter block size for tests to exercise block boundary + // conditions more. + options_.block_size = 256; + switch (args.type) { + case TABLE_TEST: + constructor_ = new TableConstructor(); + break; + case BLOCK_TEST: + constructor_ = new BlockConstructor(); + break; + } + } + + ~Harness() { delete constructor_; } + + void Add(const string& key, const string& value) { + constructor_->Add(key, value); + } + + void Test(random::SimplePhilox* rnd) { + std::vector<string> keys; + KVMap data; + constructor_->Finish(options_, &keys, &data); + + TestForwardScan(keys, data); + TestRandomAccess(rnd, keys, data); + } + + void TestForwardScan(const std::vector<string>& keys, const KVMap& data) { + Iterator* iter = constructor_->NewIterator(); + ASSERT_TRUE(!iter->Valid()); + iter->SeekToFirst(); + for (KVMap::const_iterator model_iter = data.begin(); + model_iter != data.end(); ++model_iter) { + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + iter->Next(); + } + ASSERT_TRUE(!iter->Valid()); + delete iter; + } + + void TestRandomAccess(random::SimplePhilox* rnd, + const std::vector<string>& keys, const KVMap& data) { + static const bool kVerbose = false; + Iterator* iter = constructor_->NewIterator(); + ASSERT_TRUE(!iter->Valid()); + KVMap::const_iterator model_iter = data.begin(); + if (kVerbose) fprintf(stderr, "---\n"); + for (int i = 0; i < 200; i++) { + const int toss = rnd->Uniform(3); + switch (toss) { + case 0: { + if (iter->Valid()) { + if (kVerbose) fprintf(stderr, "Next\n"); + iter->Next(); + ++model_iter; + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + } + break; + } + + case 1: { + if (kVerbose) fprintf(stderr, "SeekToFirst\n"); + iter->SeekToFirst(); + model_iter = data.begin(); + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + break; + } + + case 2: { + string key = PickRandomKey(rnd, keys); + model_iter = data.lower_bound(key); + if (kVerbose) + fprintf(stderr, "Seek '%s'\n", str_util::CEscape(key).c_str()); + iter->Seek(StringPiece(key)); + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + break; + } + } + } + delete iter; + } + + string ToString(const KVMap& data, const KVMap::const_iterator& it) { + if (it == data.end()) { + return "END"; + } else { + return "'" + it->first + "->" + it->second + "'"; + } + } + + string ToString(const KVMap& data, const KVMap::const_reverse_iterator& it) { + if (it == data.rend()) { + return "END"; + } else { + return "'" + it->first + "->" + it->second + "'"; + } + } + + string ToString(const Iterator* it) { + if (!it->Valid()) { + return "END"; + } else { + return "'" + it->key().ToString() + "->" + it->value().ToString() + "'"; + } + } + + string PickRandomKey(random::SimplePhilox* rnd, + const std::vector<string>& keys) { + if (keys.empty()) { + return "foo"; + } else { + const int index = rnd->Uniform(keys.size()); + string result = keys[index]; + switch (rnd->Uniform(3)) { + case 0: + // Return an existing key + break; + case 1: { + // Attempt to return something smaller than an existing key + if (result.size() > 0 && result[result.size() - 1] > '\0') { + result[result.size() - 1]--; + } + break; + } + case 2: { + // Return something larger than an existing key + Increment(&result); + break; + } + } + return result; + } + } + + private: + Options options_; + Constructor* constructor_; +}; + +// Test empty table/block. +TEST_F(Harness, Empty) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 1, 17); + random::SimplePhilox rnd(&philox); + Test(&rnd); + } +} + +// Special test for a block with no restart entries. The C++ leveldb +// code never generates such blocks, but the Java version of leveldb +// seems to. +TEST_F(Harness, ZeroRestartPointsInBlock) { + char data[sizeof(uint32)]; + memset(data, 0, sizeof(data)); + BlockContents contents; + contents.data = StringPiece(data, sizeof(data)); + contents.cachable = false; + contents.heap_allocated = false; + Block block(contents); + Iterator* iter = block.NewIterator(); + iter->SeekToFirst(); + ASSERT_TRUE(!iter->Valid()); + iter->Seek("foo"); + ASSERT_TRUE(!iter->Valid()); + delete iter; +} + +// Test the empty key +TEST_F(Harness, SimpleEmptyKey) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 1, 17); + random::SimplePhilox rnd(&philox); + Add("", "v"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleSingle) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 2, 17); + random::SimplePhilox rnd(&philox); + Add("abc", "v"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleMulti) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 3, 17); + random::SimplePhilox rnd(&philox); + Add("abc", "v"); + Add("abcd", "v"); + Add("ac", "v2"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleMultiBigValues) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 3, 17); + random::SimplePhilox rnd(&philox); + Add("ainitial", "tiny"); + Add("anext", string(10000000, 'a')); + Add("anext2", string(10000000, 'b')); + Add("azz", "tiny"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleSpecialKey) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 4, 17); + random::SimplePhilox rnd(&philox); + Add("\xff\xff", "v3"); + Test(&rnd); + } +} + +TEST_F(Harness, Randomized) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 5, 17); + random::SimplePhilox rnd(&philox); + for (int num_entries = 0; num_entries < 2000; + num_entries += (num_entries < 50 ? 1 : 200)) { + if ((num_entries % 10) == 0) { + fprintf(stderr, "case %d of %d: num_entries = %d\n", (i + 1), + int(kNumTestArgs), num_entries); + } + for (int e = 0; e < num_entries; e++) { + string v; + Add(test::RandomKey(&rnd, rnd.Skewed(4)), + test::RandomString(&rnd, rnd.Skewed(5), &v).ToString()); + } + Test(&rnd); + } + } +} + +static bool Between(uint64 val, uint64 low, uint64 high) { + bool result = (val >= low) && (val <= high); + if (!result) { + fprintf(stderr, "Value %llu is not in range [%llu, %llu]\n", + (unsigned long long)(val), (unsigned long long)(low), + (unsigned long long)(high)); + } + return result; +} + +class TableTest {}; + +TEST(TableTest, ApproximateOffsetOfPlain) { + TableConstructor c; + c.Add("k01", "hello"); + c.Add("k02", "hello2"); + c.Add("k03", string(10000, 'x')); + c.Add("k04", string(200000, 'x')); + c.Add("k05", string(300000, 'x')); + c.Add("k06", "hello3"); + c.Add("k07", string(100000, 'x')); + std::vector<string> keys; + KVMap kvmap; + Options options; + options.block_size = 1024; + options.compression = kNoCompression; + c.Finish(options, &keys, &kvmap); + + ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01a"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 10, 500)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 10000, 11000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04a"), 210000, 211000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k05"), 210000, 211000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k06"), 510000, 511000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k07"), 510000, 511000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 610000, 612000)); +} + +static bool SnappyCompressionSupported() { + string out; + StringPiece in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; + return port::Snappy_Compress(in.data(), in.size(), &out); +} + +TEST(TableTest, ApproximateOffsetOfCompressed) { + if (!SnappyCompressionSupported()) { + fprintf(stderr, "skipping compression tests\n"); + return; + } + + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + TableConstructor c; + string tmp; + c.Add("k01", "hello"); + c.Add("k02", test::CompressibleString(&rnd, 0.25, 10000, &tmp)); + c.Add("k03", "hello3"); + c.Add("k04", test::CompressibleString(&rnd, 0.25, 10000, &tmp)); + std::vector<string> keys; + KVMap kvmap; + Options options; + options.block_size = 1024; + options.compression = kSnappyCompression; + c.Finish(options, &keys, &kvmap); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 10, 100)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 2000, 3000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 2000, 3000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 4000, 6000)); +} + +TEST(TableTest, SeekToFirstKeyDoesNotReadTooMuch) { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + string tmp; + TableConstructor c; + c.Add("k01", "firstvalue"); + c.Add("k03", test::CompressibleString(&rnd, 0.25, 1000000, &tmp)); + c.Add("k04", "abc"); + std::vector<string> keys; + KVMap kvmap; + Options options; + options.block_size = 1024; + options.compression = kNoCompression; + c.Finish(options, &keys, &kvmap); + + Iterator* iter = c.NewIterator(); + iter->Seek("k01"); + delete iter; + // Make sure we don't read the big second block when just trying to + // retrieve the data in the first key + EXPECT_LT(c.BytesRead(), 200); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/two_level_iterator.cc b/tensorflow/core/lib/io/two_level_iterator.cc new file mode 100644 index 0000000000..409baade6d --- /dev/null +++ b/tensorflow/core/lib/io/two_level_iterator.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/two_level_iterator.h" + +#include "tensorflow/core/lib/io/table.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +namespace { + +typedef Iterator* (*BlockFunction)(void*, const StringPiece&); + +class TwoLevelIterator : public Iterator { + public: + TwoLevelIterator(Iterator* index_iter, BlockFunction block_function, + void* arg); + + virtual ~TwoLevelIterator(); + + virtual void Seek(const StringPiece& target); + virtual void SeekToFirst(); + virtual void Next(); + + virtual bool Valid() const { + return (data_iter_ == nullptr) ? false : data_iter_->Valid(); + } + virtual StringPiece key() const { + assert(Valid()); + return data_iter_->key(); + } + virtual StringPiece value() const { + assert(Valid()); + return data_iter_->value(); + } + virtual Status status() const { + // It'd be nice if status() returned a const Status& instead of a + // Status + if (!index_iter_->status().ok()) { + return index_iter_->status(); + } else if (data_iter_ != NULL && !data_iter_->status().ok()) { + return data_iter_->status(); + } else { + return status_; + } + } + + private: + void SaveError(const Status& s) { + if (status_.ok() && !s.ok()) status_ = s; + } + void SkipEmptyDataBlocksForward(); + void SetDataIterator(Iterator* data_iter); + void InitDataBlock(); + + BlockFunction block_function_; + void* arg_; + Status status_; + Iterator* index_iter_; + Iterator* data_iter_; // May be NULL + // If data_iter_ is non-NULL, then "data_block_handle_" holds the + // "index_value" passed to block_function_ to create the data_iter_. + string data_block_handle_; +}; + +TwoLevelIterator::TwoLevelIterator(Iterator* index_iter, + BlockFunction block_function, void* arg) + : block_function_(block_function), + arg_(arg), + index_iter_(index_iter), + data_iter_(NULL) {} + +TwoLevelIterator::~TwoLevelIterator() { + delete index_iter_; + delete data_iter_; +} + +void TwoLevelIterator::Seek(const StringPiece& target) { + index_iter_->Seek(target); + InitDataBlock(); + if (data_iter_ != NULL) data_iter_->Seek(target); + SkipEmptyDataBlocksForward(); +} + +void TwoLevelIterator::SeekToFirst() { + index_iter_->SeekToFirst(); + InitDataBlock(); + if (data_iter_ != NULL) data_iter_->SeekToFirst(); + SkipEmptyDataBlocksForward(); +} + +void TwoLevelIterator::Next() { + assert(Valid()); + data_iter_->Next(); + SkipEmptyDataBlocksForward(); +} + +void TwoLevelIterator::SkipEmptyDataBlocksForward() { + while (data_iter_ == NULL || !data_iter_->Valid()) { + // Move to next block + if (!index_iter_->Valid()) { + SetDataIterator(NULL); + return; + } + index_iter_->Next(); + InitDataBlock(); + if (data_iter_ != NULL) data_iter_->SeekToFirst(); + } +} + +void TwoLevelIterator::SetDataIterator(Iterator* data_iter) { + if (data_iter_ != NULL) { + SaveError(data_iter_->status()); + delete data_iter_; + } + data_iter_ = data_iter; +} + +void TwoLevelIterator::InitDataBlock() { + if (!index_iter_->Valid()) { + SetDataIterator(NULL); + } else { + StringPiece handle = index_iter_->value(); + if (data_iter_ != NULL && handle.compare(data_block_handle_) == 0) { + // data_iter_ is already constructed with this iterator, so + // no need to change anything + } else { + Iterator* iter = (*block_function_)(arg_, handle); + data_block_handle_.assign(handle.data(), handle.size()); + SetDataIterator(iter); + } + } +} + +} // namespace + +Iterator* NewTwoLevelIterator(Iterator* index_iter, + BlockFunction block_function, void* arg) { + return new TwoLevelIterator(index_iter, block_function, arg); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/two_level_iterator.h b/tensorflow/core/lib/io/two_level_iterator.h new file mode 100644 index 0000000000..1cc5d2f921 --- /dev/null +++ b/tensorflow/core/lib/io/two_level_iterator.h @@ -0,0 +1,30 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#define TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ + +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +// Return a new two level iterator. A two-level iterator contains an +// index iterator whose values point to a sequence of blocks where +// each block is itself a sequence of key,value pairs. The returned +// two-level iterator yields the concatenation of all key/value pairs +// in the sequence of blocks. Takes ownership of "index_iter" and +// will delete it when no longer needed. +// +// Uses a supplied function to convert an index_iter value into +// an iterator over the contents of the corresponding block. +extern Iterator* NewTwoLevelIterator( + Iterator* index_iter, + Iterator* (*block_function)(void* arg, const StringPiece& index_value), + void* arg); + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.cc b/tensorflow/core/lib/jpeg/jpeg_handle.cc new file mode 100644 index 0000000000..4521be0afb --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_handle.cc @@ -0,0 +1,162 @@ +// This file implements a memory destination for libjpeg +// The design is very similar to jdatadst.c in libjpeg +// These functions are not meant to be used directly, see jpeg_mem.h instead. +// We are filling out stubs required by jpeglib, those stubs are private to +// the implementation, we are just making available JPGMemSrc, JPGMemDest + +#include "tensorflow/core/lib/jpeg/jpeg_handle.h" + +#include <setjmp.h> +#include <stddef.h> + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace jpeg { + +void CatchError(j_common_ptr cinfo) { + (*cinfo->err->output_message)(cinfo); + jmp_buf *jpeg_jmpbuf = reinterpret_cast<jmp_buf *>(cinfo->client_data); + jpeg_destroy(cinfo); + longjmp(*jpeg_jmpbuf, 1); +} + +// ***************************************************************************** +// ***************************************************************************** +// ***************************************************************************** +// Destination functions + +// ----------------------------------------------------------------------------- +void MemInitDestination(j_compress_ptr cinfo) { + MemDestMgr *dest = reinterpret_cast<MemDestMgr *>(cinfo->dest); + VLOG(1) << "Initializing buffer=" << dest->bufsize << " bytes"; + dest->pub.next_output_byte = dest->buffer; + dest->pub.free_in_buffer = dest->bufsize; + dest->datacount = 0; + if (dest->dest) { + dest->dest->clear(); + } +} + +// ----------------------------------------------------------------------------- +boolean MemEmptyOutputBuffer(j_compress_ptr cinfo) { + MemDestMgr *dest = reinterpret_cast<MemDestMgr *>(cinfo->dest); + VLOG(1) << "Writing " << dest->bufsize << " bytes"; + if (dest->dest) { + dest->dest->append(reinterpret_cast<char *>(dest->buffer), dest->bufsize); + } + dest->pub.next_output_byte = dest->buffer; + dest->pub.free_in_buffer = dest->bufsize; + return TRUE; +} + +// ----------------------------------------------------------------------------- +void MemTermDestination(j_compress_ptr cinfo) { + MemDestMgr *dest = reinterpret_cast<MemDestMgr *>(cinfo->dest); + VLOG(1) << "Writing " << dest->bufsize - dest->pub.free_in_buffer << " bytes"; + if (dest->dest) { + dest->dest->append(reinterpret_cast<char *>(dest->buffer), + dest->bufsize - dest->pub.free_in_buffer); + VLOG(1) << "Total size= " << dest->dest->size(); + } + dest->datacount = dest->bufsize - dest->pub.free_in_buffer; +} + +// ----------------------------------------------------------------------------- +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize) { + SetDest(cinfo, buffer, bufsize, NULL); +} + +// ----------------------------------------------------------------------------- +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize, + string *destination) { + MemDestMgr *dest; + if (cinfo->dest == NULL) { + cinfo->dest = reinterpret_cast<struct jpeg_destination_mgr *>( + (*cinfo->mem->alloc_small)(reinterpret_cast<j_common_ptr>(cinfo), + JPOOL_PERMANENT, sizeof(MemDestMgr))); + } + + dest = reinterpret_cast<MemDestMgr *>(cinfo->dest); + dest->bufsize = bufsize; + dest->buffer = static_cast<JOCTET *>(buffer); + dest->dest = destination; + dest->pub.init_destination = MemInitDestination; + dest->pub.empty_output_buffer = MemEmptyOutputBuffer; + dest->pub.term_destination = MemTermDestination; +} + +// ***************************************************************************** +// ***************************************************************************** +// ***************************************************************************** +// Source functions + +// ----------------------------------------------------------------------------- +void MemInitSource(j_decompress_ptr cinfo) { + MemSourceMgr *src = reinterpret_cast<MemSourceMgr *>(cinfo->src); + src->pub.next_input_byte = src->data; + src->pub.bytes_in_buffer = src->datasize; +} + +// ----------------------------------------------------------------------------- +// We emulate the same error-handling as fill_input_buffer() from jdatasrc.c, +// for coherency's sake. +boolean MemFillInputBuffer(j_decompress_ptr cinfo) { + static const JOCTET kEOIBuffer[2] = {0xff, JPEG_EOI}; + MemSourceMgr *src = reinterpret_cast<MemSourceMgr *>(cinfo->src); + if (src->pub.bytes_in_buffer == 0 && src->pub.next_input_byte == src->data) { + // empty file -> treated as an error. + ERREXIT(cinfo, JERR_INPUT_EMPTY); + return FALSE; + } else if (src->pub.bytes_in_buffer) { + // if there's still some data left, it's probably corrupted + return src->try_recover_truncated_jpeg ? TRUE : FALSE; + } else if (src->pub.next_input_byte != kEOIBuffer && + src->try_recover_truncated_jpeg) { + // In an attempt to recover truncated files, we insert a fake EOI + WARNMS(cinfo, JWRN_JPEG_EOF); + src->pub.next_input_byte = kEOIBuffer; + src->pub.bytes_in_buffer = 2; + return TRUE; + } else { + // We already inserted a fake EOI and it wasn't enough, so this time + // it's really an error. + ERREXIT(cinfo, JERR_FILE_READ); + return FALSE; + } +} + +// ----------------------------------------------------------------------------- +void MemTermSource(j_decompress_ptr cinfo) {} + +// ----------------------------------------------------------------------------- +void MemSkipInputData(j_decompress_ptr cinfo, long jump) { + MemSourceMgr *src = reinterpret_cast<MemSourceMgr *>(cinfo->src); + src->pub.bytes_in_buffer -= jump; + src->pub.next_input_byte += jump; +} + +// ----------------------------------------------------------------------------- +void SetSrc(j_decompress_ptr cinfo, const void *data, + unsigned long int datasize, bool try_recover_truncated_jpeg) { + MemSourceMgr *src; + + cinfo->src = reinterpret_cast<struct jpeg_source_mgr *>( + (*cinfo->mem->alloc_small)(reinterpret_cast<j_common_ptr>(cinfo), + JPOOL_PERMANENT, sizeof(MemSourceMgr))); + + src = reinterpret_cast<MemSourceMgr *>(cinfo->src); + src->pub.init_source = MemInitSource; + src->pub.fill_input_buffer = MemFillInputBuffer; + src->pub.skip_input_data = MemSkipInputData; + src->pub.resync_to_restart = jpeg_resync_to_restart; + src->pub.term_source = MemTermSource; + src->data = reinterpret_cast<const unsigned char *>(data); + src->datasize = datasize; + src->pub.bytes_in_buffer = 0; + src->pub.next_input_byte = NULL; + src->try_recover_truncated_jpeg = try_recover_truncated_jpeg; +} + +} // namespace jpeg +} // namespace tensorflow diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.h b/tensorflow/core/lib/jpeg/jpeg_handle.h new file mode 100644 index 0000000000..58f7f6f666 --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_handle.h @@ -0,0 +1,51 @@ +// This file declares the functions and structures for memory I/O with libjpeg +// These functions are not meant to be used directly, see jpeg_mem.h isntead. + +#ifndef TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_ +#define TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_ + +extern "C" { +#include "external/jpeg_archive/jpeg-9a/jinclude.h" +#include "external/jpeg_archive/jpeg-9a/jpeglib.h" +#include "external/jpeg_archive/jpeg-9a/jerror.h" +#include "external/jpeg_archive/jpeg-9a/transupp.h" // for rotations +} + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace jpeg { + +// Handler for fatal JPEG library errors: clean up & return +void CatchError(j_common_ptr cinfo); + +typedef struct { + struct jpeg_destination_mgr pub; + JOCTET *buffer; + int bufsize; + int datacount; + string *dest; +} MemDestMgr; + +typedef struct { + struct jpeg_source_mgr pub; + const unsigned char *data; + unsigned long int datasize; + bool try_recover_truncated_jpeg; +} MemSourceMgr; + +void SetSrc(j_decompress_ptr cinfo, const void *data, + unsigned long int datasize, bool try_recover_truncated_jpeg); + +// JPEG destination: we will store all the data in a buffer "buffer" of total +// size "bufsize", if the buffer overflows, we will be in trouble. +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize); +// Same as above, except that buffer is only used as a temporary structure and +// is emptied into "destination" as soon as it fills up. +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize, + string *destination); + +} // namespace jpeg +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_ diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc new file mode 100644 index 0000000000..556f13e388 --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc @@ -0,0 +1,557 @@ +// This file defines functions to compress and uncompress JPEG data +// to and from memory, as well as some direct manipulations of JPEG string + +#include "tensorflow/core/lib/jpeg/jpeg_mem.h" + +#include <setjmp.h> +#include <string.h> +#include <algorithm> +#include <memory> +#include <string> + +#include "tensorflow/core/lib/jpeg/jpeg_handle.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace jpeg { + +// ----------------------------------------------------------------------------- +// Decompression + +namespace { + +enum JPEGErrors { + JPEGERRORS_OK, + JPEGERRORS_UNEXPECTED_END_OF_DATA, + JPEGERRORS_BAD_PARAM +}; + +// Prevent bad compiler behaviour in ASAN mode by wrapping most of the +// arguments in a struct struct. +class FewerArgsForCompiler { + public: + FewerArgsForCompiler(int datasize, const UncompressFlags& flags, int* nwarn, + std::function<uint8*(int, int, int)> allocate_output) + : datasize_(datasize), + flags_(flags), + pnwarn_(nwarn), + allocate_output_(allocate_output), + fraction_read_(0.), + height_(0), + stride_(0) { + if (pnwarn_ != nullptr) *pnwarn_ = 0; + } + + const int datasize_; + const UncompressFlags flags_; + int* const pnwarn_; + std::function<uint8*(int, int, int)> allocate_output_; + float fraction_read_; // fraction of scanline lines successfully read + int height_; + int stride_; +}; + +uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { + // unpack the argball + const int datasize = argball->datasize_; + const auto& flags = argball->flags_; + const int ratio = flags.ratio; + int components = flags.components; + int stride = flags.stride; // may be 0 + int* const nwarn = argball->pnwarn_; // may be NULL + + // can't decode if the ratio is not recognized by libjpeg + if ((ratio != 1) && (ratio != 2) && (ratio != 4) && (ratio != 8)) { + return nullptr; + } + + // if empty image, return + if (datasize == 0 || srcdata == NULL) return nullptr; + + // Declare temporary buffer pointer here so that we can free on error paths + JSAMPLE* tempdata = nullptr; + + // Initialize libjpeg structures to have a memory source + // Modify the usual jpeg error manager to catch fatal errors. + JPEGErrors error = JPEGERRORS_OK; + struct jpeg_decompress_struct cinfo; + struct jpeg_error_mgr jerr; + cinfo.err = jpeg_std_error(&jerr); + jmp_buf jpeg_jmpbuf; + cinfo.client_data = &jpeg_jmpbuf; + jerr.error_exit = CatchError; + if (setjmp(jpeg_jmpbuf)) { + return nullptr; + } + + jpeg_create_decompress(&cinfo); + SetSrc(&cinfo, srcdata, datasize, flags.try_recover_truncated_jpeg); + jpeg_read_header(&cinfo, TRUE); + + // Set components automatically if desired + if (components == 0) components = cinfo.num_components; + + // set grayscale and ratio parameters + switch (components) { + case 1: + cinfo.out_color_space = JCS_GRAYSCALE; + break; + case 3: + case 4: + if (cinfo.jpeg_color_space == JCS_CMYK || + cinfo.jpeg_color_space == JCS_YCCK) { + // always use cmyk for output in a 4 channel jpeg. libjpeg has a builtin + // decoder. + cinfo.out_color_space = JCS_CMYK; + } else { + cinfo.out_color_space = JCS_RGB; + } + break; + default: + LOG(ERROR) << " Invalid components value " << components << std::endl; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + cinfo.do_fancy_upsampling = boolean(flags.fancy_upscaling); + cinfo.scale_num = 1; + cinfo.scale_denom = ratio; + // Activating this has a quality/speed trade-off implication: + // cinfo.dct_method = JDCT_IFAST; + + jpeg_start_decompress(&cinfo); + + // check for compatible stride + const int min_stride = cinfo.output_width * components * sizeof(JSAMPLE); + if (stride == 0) { + stride = min_stride; + } else if (stride < min_stride) { + LOG(ERROR) << "Incompatible stride: " << stride << " < " << min_stride; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + + // Remember stride and height for use in Uncompress + argball->height_ = cinfo.output_height; + argball->stride_ = stride; + + uint8* const dstdata = argball->allocate_output_( + cinfo.output_width, cinfo.output_height, components); + if (dstdata == nullptr) { + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + JSAMPLE* output_line = static_cast<JSAMPLE*>(dstdata); + + // Temporary buffer used for CMYK -> RGB conversion. + const bool use_cmyk = (cinfo.out_color_space == JCS_CMYK); + tempdata = use_cmyk ? new JSAMPLE[cinfo.output_width * 4] : NULL; + + // If there is an error reading a line, this aborts the reading. + // Save the fraction of the image that has been read. + argball->fraction_read_ = 1.0; + while (cinfo.output_scanline < cinfo.output_height) { + int num_lines_read = 0; + if (cinfo.out_color_space == JCS_CMYK) { + num_lines_read = jpeg_read_scanlines(&cinfo, &tempdata, 1); + // Convert CMYK to RGB + for (size_t i = 0; i < cinfo.output_width; ++i) { + int c = tempdata[4 * i + 0]; + int m = tempdata[4 * i + 1]; + int y = tempdata[4 * i + 2]; + int k = tempdata[4 * i + 3]; + int r, g, b; + if (cinfo.saw_Adobe_marker) { + r = (k * c) / 255; + g = (k * m) / 255; + b = (k * y) / 255; + } else { + r = (255 - k) * (255 - c) / 255; + g = (255 - k) * (255 - m) / 255; + b = (255 - k) * (255 - y) / 255; + } + output_line[3 * i + 0] = r; + output_line[3 * i + 1] = g; + output_line[3 * i + 2] = b; + } + } else { + num_lines_read = jpeg_read_scanlines(&cinfo, &output_line, 1); + } + // Handle error cases + if (num_lines_read == 0) { + LOG(ERROR) << "Premature end of JPEG data. Stopped at line " + << cinfo.output_scanline << "/" << cinfo.output_height; + if (!flags.try_recover_truncated_jpeg) { + argball->fraction_read_ = + static_cast<float>(cinfo.output_scanline) / cinfo.output_height; + error = JPEGERRORS_UNEXPECTED_END_OF_DATA; + } else { + for (size_t line = cinfo.output_scanline; line < cinfo.output_height; + ++line) { + if (line == 0) { + // If even the first line is missing, fill with black color + memset(output_line, 0, min_stride); + } else { + // else, just replicate the line above. + memcpy(output_line, output_line - stride, min_stride); + } + output_line += stride; + } + argball->fraction_read_ = 1.0; // consider all lines as read + // prevent error-on-exit in libjpeg: + cinfo.output_scanline = cinfo.output_height; + } + break; + } + DCHECK_EQ(num_lines_read, 1); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_line, min_stride); + output_line += stride; + } + delete[] tempdata; + + // Convert the RGB data to RGBA, with alpha set to 0xFF to indicate + // opacity. + // RGBRGBRGB... --> RGBARGBARGBA... + if (components == 4) { + // Start on the last line. + JSAMPLE* scanlineptr = + static_cast<JSAMPLE*>(dstdata + (cinfo.output_height - 1) * stride); + const JSAMPLE kOpaque = -1; // All ones appropriate for JSAMPLE. + const int right_rgb = (cinfo.output_width - 1) * 3; + const int right_rgba = (cinfo.output_width - 1) * 4; + + for (int y = cinfo.output_height; y-- > 0;) { + // We do all the transformations in place, going backwards for each row. + const JSAMPLE* rgb_pixel = scanlineptr + right_rgb; + JSAMPLE* rgba_pixel = scanlineptr + right_rgba; + scanlineptr -= stride; + for (int x = cinfo.output_width; x-- > 0; + rgba_pixel -= 4, rgb_pixel -= 3) { + // We copy the 3 bytes at rgb_pixel into the 4 bytes at rgba_pixel + // The "a" channel is set to be opaque. + rgba_pixel[3] = kOpaque; + rgba_pixel[2] = rgb_pixel[2]; + rgba_pixel[1] = rgb_pixel[1]; + rgba_pixel[0] = rgb_pixel[0]; + } + } + } + + switch (components) { + case 1: + if (cinfo.output_components != 1) { + error = JPEGERRORS_BAD_PARAM; + } + break; + case 3: + case 4: + if (cinfo.out_color_space == JCS_CMYK) { + if (cinfo.output_components != 4) { + error = JPEGERRORS_BAD_PARAM; + } + } else { + if (cinfo.output_components != 3) { + error = JPEGERRORS_BAD_PARAM; + } + } + break; + default: + // will never happen, should be catched by the previous switch + LOG(ERROR) << "Invalid components value " << components << std::endl; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + + // save number of warnings if requested + if (nwarn != nullptr) { + *nwarn = cinfo.err->num_warnings; + } + + // Handle errors in JPEG + switch (error) { + case JPEGERRORS_OK: + jpeg_finish_decompress(&cinfo); + break; + case JPEGERRORS_UNEXPECTED_END_OF_DATA: + case JPEGERRORS_BAD_PARAM: + jpeg_abort(reinterpret_cast<j_common_ptr>(&cinfo)); + break; + default: + LOG(ERROR) << "Unhandled case " << error; + break; + } + jpeg_destroy_decompress(&cinfo); + + return dstdata; +} + +} // anonymous namespace + +// ----------------------------------------------------------------------------- +// We do the apparently silly thing of packing 5 of the arguments +// into a structure that is then passed to another routine +// that does all the work. The reason is that we want to catch +// fatal JPEG library errors with setjmp/longjmp, and g++ and +// associated libraries aren't good enough to guarantee that 7 +// parameters won't get clobbered by the longjmp. So we help +// it out a little. +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int* nwarn, + std::function<uint8*(int, int, int)> allocate_output) { + FewerArgsForCompiler argball(datasize, flags, nwarn, allocate_output); + uint8* const dstdata = UncompressLow(srcdata, &argball); + const float fraction_read = argball.fraction_read_; + if (dstdata == NULL || + fraction_read < std::min(1.0f, flags.min_acceptable_fraction)) { + // Major failure, none or too-partial read returned; get out + return NULL; + } + + // If there was an error in reading the jpeg data, + // set the unread pixels to black + if (fraction_read < 1.0) { + const int first_bad_line = + static_cast<int>(fraction_read * argball.height_); + uint8* start = dstdata + first_bad_line * argball.stride_; + const int nbytes = (argball.height_ - first_bad_line) * argball.stride_; + memset(static_cast<void*>(start), 0, nbytes); + } + + return dstdata; +} + +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int* pwidth, int* pheight, + int* pcomponents, int* nwarn) { + uint8* buffer = NULL; + uint8* result = + Uncompress(srcdata, datasize, flags, nwarn, + [=, &buffer](int width, int height, int components) { + if (pwidth != nullptr) *pwidth = width; + if (pheight != nullptr) *pheight = height; + if (pcomponents != nullptr) *pcomponents = components; + buffer = new uint8[height * width * components]; + return buffer; + }); + if (!result) delete[] buffer; + return result; +} + +// ---------------------------------------------------------------------------- +// Computes image information from jpeg header. +// Returns true on success; false on failure. +bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height, + int* components) { + // Init in case of failure + if (width) *width = 0; + if (height) *height = 0; + if (components) *components = 0; + + // If empty image, return + if (datasize == 0 || srcdata == NULL) return false; + + // Initialize libjpeg structures to have a memory source + // Modify the usual jpeg error manager to catch fatal errors. + struct jpeg_decompress_struct cinfo; + struct jpeg_error_mgr jerr; + jmp_buf jpeg_jmpbuf; + cinfo.err = jpeg_std_error(&jerr); + cinfo.client_data = &jpeg_jmpbuf; + jerr.error_exit = CatchError; + if (setjmp(jpeg_jmpbuf)) { + return false; + } + + // set up, read header, set image parameters, save size + jpeg_create_decompress(&cinfo); + SetSrc(&cinfo, srcdata, datasize, false); + + jpeg_read_header(&cinfo, TRUE); + jpeg_start_decompress(&cinfo); // required to transfer image size to cinfo + if (width) *width = cinfo.output_width; + if (height) *height = cinfo.output_height; + if (components) *components = cinfo.output_components; + + jpeg_destroy_decompress(&cinfo); + + return true; +} + +// ----------------------------------------------------------------------------- +// Compression + +namespace { +bool CompressInternal(const uint8* srcdata, int width, int height, + const CompressFlags& flags, string* output) { + output->clear(); + const int components = (static_cast<int>(flags.format) & 0xff); + int in_stride = flags.stride; + if (in_stride == 0) { + in_stride = width * (static_cast<int>(flags.format) & 0xff); + } else if (in_stride < width * components) { + LOG(ERROR) << "Incompatible input stride"; + return false; + } + + JOCTET* buffer = 0; + + // NOTE: for broader use xmp_metadata should be made a unicode string + CHECK(srcdata != nullptr); + CHECK(output != nullptr); + // This struct contains the JPEG compression parameters and pointers to + // working space + struct jpeg_compress_struct cinfo; + // This struct represents a JPEG error handler. + struct jpeg_error_mgr jerr; + jmp_buf jpeg_jmpbuf; // recovery point in case of error + + // Step 1: allocate and initialize JPEG compression object + // Use the usual jpeg error manager. + cinfo.err = jpeg_std_error(&jerr); + cinfo.client_data = &jpeg_jmpbuf; + jerr.error_exit = CatchError; + if (setjmp(jpeg_jmpbuf)) { + output->clear(); + delete[] buffer; + return false; + } + + jpeg_create_compress(&cinfo); + + // Step 2: specify data destination + // We allocate a buffer of reasonable size. If we have a small image, just + // estimate the size of the output using the number of bytes of the input. + // If this is getting too big, we will append to the string by chunks of 1MB. + // This seems like a reasonable compromise between performance and memory. + int bufsize = std::min(width * height * components, 1 << 20); + buffer = new JOCTET[bufsize]; + SetDest(&cinfo, buffer, bufsize, output); + + // Step 3: set parameters for compression + cinfo.image_width = width; + cinfo.image_height = height; + switch (components) { + case 1: + cinfo.input_components = 1; + cinfo.in_color_space = JCS_GRAYSCALE; + break; + case 3: + case 4: + cinfo.input_components = 3; + cinfo.in_color_space = JCS_RGB; + break; + default: + LOG(ERROR) << " Invalid components value " << components << std::endl; + output->clear(); + delete[] buffer; + return false; + } + jpeg_set_defaults(&cinfo); + if (flags.optimize_jpeg_size) cinfo.optimize_coding = TRUE; + + cinfo.density_unit = flags.density_unit; // JFIF code for pixel size units: + // 1 = in, 2 = cm + cinfo.X_density = flags.x_density; // Horizontal pixel density + cinfo.Y_density = flags.y_density; // Vertical pixel density + jpeg_set_quality(&cinfo, flags.quality, TRUE); + + if (flags.progressive) { + jpeg_simple_progression(&cinfo); + } + + if (!flags.chroma_downsampling) { + // Turn off chroma subsampling (it is on by default). For more details on + // chroma subsampling, see http://en.wikipedia.org/wiki/Chroma_subsampling. + for (int i = 0; i < cinfo.num_components; ++i) { + cinfo.comp_info[i].h_samp_factor = 1; + cinfo.comp_info[i].v_samp_factor = 1; + } + } + + jpeg_start_compress(&cinfo, TRUE); + + // Embed XMP metadata if any + if (!flags.xmp_metadata.empty()) { + // XMP metadata is embedded in the APP1 tag of JPEG and requires this + // namespace header string (null-terminated) + const string name_space = "http://ns.adobe.com/xap/1.0/"; + const int name_space_length = name_space.size(); + const int metadata_length = flags.xmp_metadata.size(); + const int packet_length = metadata_length + name_space_length + 1; + std::unique_ptr<JOCTET[]> joctet_packet(new JOCTET[packet_length]); + + for (int i = 0; i < name_space_length; i++) { + // Conversion char --> JOCTET + joctet_packet[i] = name_space[i]; + } + joctet_packet[name_space_length] = 0; // null-terminate namespace string + + for (int i = 0; i < metadata_length; i++) { + // Conversion char --> JOCTET + joctet_packet[i + name_space_length + 1] = flags.xmp_metadata[i]; + } + jpeg_write_marker(&cinfo, JPEG_APP0 + 1, joctet_packet.get(), + packet_length); + } + + // JSAMPLEs per row in image_buffer + std::unique_ptr<JSAMPLE[]> row_temp( + new JSAMPLE[width * cinfo.input_components]); + while (cinfo.next_scanline < cinfo.image_height) { + JSAMPROW row_pointer[1]; // pointer to JSAMPLE row[s] + const uint8* r = &srcdata[cinfo.next_scanline * in_stride]; + uint8* p = static_cast<uint8*>(row_temp.get()); + switch (flags.format) { + case FORMAT_RGBA: { + for (int i = 0; i < width; ++i, p += 3, r += 4) { + p[0] = r[0]; + p[1] = r[1]; + p[2] = r[2]; + } + row_pointer[0] = row_temp.get(); + break; + } + case FORMAT_ABGR: { + for (int i = 0; i < width; ++i, p += 3, r += 4) { + p[0] = r[3]; + p[1] = r[2]; + p[2] = r[1]; + } + row_pointer[0] = row_temp.get(); + break; + } + default: { + row_pointer[0] = reinterpret_cast<JSAMPLE*>(const_cast<JSAMPLE*>(r)); + } + } + CHECK_EQ(jpeg_write_scanlines(&cinfo, row_pointer, 1), 1); + } + jpeg_finish_compress(&cinfo); + + // release JPEG compression object + jpeg_destroy_compress(&cinfo); + delete[] buffer; + return true; +} + +} // anonymous namespace + +// ----------------------------------------------------------------------------- + +bool Compress(const void* srcdata, int width, int height, + const CompressFlags& flags, string* output) { + return CompressInternal(static_cast<const uint8*>(srcdata), width, height, + flags, output); +} + +string Compress(const void* srcdata, int width, int height, + const CompressFlags& flags) { + string temp; + CompressInternal(static_cast<const uint8*>(srcdata), width, height, flags, + &temp); + // If CompressInternal fails, temp will be empty. + return temp; +} + +} // namespace jpeg +} // namespace tensorflow diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.h b/tensorflow/core/lib/jpeg/jpeg_mem.h new file mode 100644 index 0000000000..19ba7d4acf --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_mem.h @@ -0,0 +1,130 @@ +// This file defines functions to compress and uncompress JPEG files +// to and from memory. It provides interfaces for raw images +// (data array and size fields). +// Direct manipulation of JPEG strings are supplied: Flip, Rotate, Crop.. + +#ifndef TENSORFLOW_LIB_JPEG_JPEG_MEM_H_ +#define TENSORFLOW_LIB_JPEG_JPEG_MEM_H_ + +#include <functional> +#include <string> +#include <vector> + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace jpeg { + +// Flags for Uncompress +struct UncompressFlags { + // ratio can be 1, 2, 4, or 8 and represent the denominator for the scaling + // factor (eg ratio = 4 means that the resulting image will be at 1/4 original + // size in both directions). + int ratio = 1; + + // The number of bytes per pixel (1, 3 or 4), or 0 for autodetect. + int components = 0; + + // If true, decoder will use a slower but nicer upscaling of the chroma + // planes (yuv420/422 only). + bool fancy_upscaling = true; + + // If true, will attempt to fill in missing lines of truncated files + bool try_recover_truncated_jpeg = false; + + // The minimum required fraction of lines read before the image is accepted. + float min_acceptable_fraction = 1.0; + + // The distance in bytes from one scanline to the other. Should be at least + // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride + // used will be this minimal value. + int stride = 0; +}; + +// Uncompress some raw JPEG data given by the pointer srcdata and the length +// datasize. +// - width and height are the address where to store the size of the +// uncompressed image in pixels. May be nullptr. +// - components is the address where the number of read components are +// stored. This is *output only*: to request a specific number of +// components use flags.components. May be nullptr. +// - nwarn is the address in which to store the number of warnings. +// May be nullptr. +// The function returns a pointer to the raw uncompressed data or NULL if +// there was an error. The caller of the function is responsible for +// freeing the memory (using delete []). +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int* width, int* height, + int* components, // Output only: useful with autodetect + int* nwarn); + +// Version of Uncompress that allocates memory via a callback. The callback +// arguments are (width, height, components). If the size is known ahead of +// time this function can return an existing buffer; passing a callback allows +// the buffer to be shaped based on the JPEG header. The caller is responsible +// for freeing the memory *even along error paths*. +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int* nwarn, + std::function<uint8*(int, int, int)> allocate_output); + +// Read jpeg header and get image information. Returns true on success. +// The width, height, and components points may be null. +bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height, + int* components); + +// Note: (format & 0xff) = number of components (<=> bytes per pixels) +enum Format { + FORMAT_GRAYSCALE = 0x001, // 1 byte/pixel + FORMAT_RGB = 0x003, // 3 bytes/pixel RGBRGBRGBRGB... + FORMAT_RGBA = 0x004, // 4 bytes/pixel RGBARGBARGBARGBA... + FORMAT_ABGR = 0x104 // 4 bytes/pixel ABGRABGRABGR... +}; + +// Flags for compression +struct CompressFlags { + // Encoding of the input data for compression + Format format; + + // Quality of the compression from 0-100 + int quality = 95; + + // If true, create a jpeg image that loads progressively + bool progressive = false; + + // If true, reduce jpeg size without changing quality (at the cost of CPU/RAM) + bool optimize_jpeg_size = false; + + // See http://en.wikipedia.org/wiki/Chroma_subsampling + bool chroma_downsampling = true; + + // Resolution + int density_unit = 1; // 1 = in, 2 = cm + int x_density = 300; + int y_density = 300; + + // If not empty, embed this XMP metadata in the image header + StringPiece xmp_metadata; + + // The distance in bytes from one scanline to the other. Should be at least + // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride + // used will be this minimal value. + int stride = 0; +}; + +// Compress some raw image given in srcdata, the data is a 2D array of size +// stride*height with one of the formats enumerated above. +// The encoded data is returned as a string. +// If not empty, XMP metadata can be embedded in the image header +// On error, returns the empty string (which is never a valid jpeg). +string Compress(const void* srcdata, int width, int height, + const CompressFlags& flags); + +// On error, returns false and sets output to empty. +bool Compress(const void* srcdata, int width, int height, + const CompressFlags& flags, string* output); + +} // namespace jpeg +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_JPEG_JPEG_MEM_H_ diff --git a/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc new file mode 100644 index 0000000000..23e72f9d57 --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc @@ -0,0 +1,304 @@ +#include "tensorflow/core/lib/jpeg/jpeg_mem.h" + +#include <setjmp.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include <memory> + +#include "tensorflow/core/lib/jpeg/jpeg_handle.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" +#include <gtest/gtest.h> + +#include "tensorflow/core/lib/core/casts.h" + +namespace tensorflow { +namespace jpeg { +namespace { + +const char kTestData[] = "tensorflow/core/lib/jpeg/testdata/"; + +int ComputeSumAbsoluteDifference(const uint8* a, const uint8* b, int width, + int height, int a_stride, int b_stride) { + int totalerr = 0; + for (int i = 0; i < height; i++) { + const uint8* const pa = a + i * a_stride; + const uint8* const pb = b + i * b_stride; + for (int j = 0; j < 3 * width; j++) { + totalerr += abs(static_cast<int>(pa[j]) - static_cast<int>(pb[j])); + } + } + return totalerr; +} + +// Reads the contents of the file into output +void ReadFileToStringOrDie(Env* env, const string& filename, string* output) { + TF_CHECK_OK(ReadFileToString(env, filename, output)); +} + +void TestJPEG(Env* env, const string& jpegfile) { + // Read the data from the jpeg file into memory + string jpeg; + ReadFileToStringOrDie(Env::Default(), jpegfile, &jpeg); + const int fsize = jpeg.size(); + const uint8* const temp = bit_cast<const uint8*>(jpeg.data()); + + // try partial decoding (half of the data) + int w, h, c; + std::unique_ptr<uint8[]> imgdata; + + UncompressFlags flags; + flags.components = 3; + + // set min_acceptable_fraction to something insufficient + flags.min_acceptable_fraction = 0.8; + imgdata.reset(Uncompress(temp, fsize / 2, flags, &w, &h, &c, NULL)); + CHECK(imgdata.get() == NULL); + + // now, use a value that makes fsize/2 be enough for a black-filling + flags.min_acceptable_fraction = 0.01; + imgdata.reset(Uncompress(temp, fsize / 2, flags, &w, &h, &c, NULL)); + CHECK(imgdata.get() != NULL); + + // finally, uncompress the whole data + flags.min_acceptable_fraction = 1.0; + imgdata.reset(Uncompress(temp, fsize, flags, &w, &h, &c, NULL)); + CHECK(imgdata.get() != NULL); + + // Uncompress the data to RGBA, too + flags.min_acceptable_fraction = 1.0; + flags.components = 4; + imgdata.reset(Uncompress(temp, fsize, flags, &w, &h, &c, NULL)); + CHECK(imgdata.get() != NULL); +} + +TEST(JpegMemTest, Jpeg) { + Env* env = Env::Default(); + const string data_path = kTestData; + + // Name of a valid jpeg file on the disk + TestJPEG(env, data_path + "jpeg_merge_test1.jpg"); + + // Exercise CMYK machinery as well + TestJPEG(env, data_path + "jpeg_merge_test1_cmyk.jpg"); +} + +TEST(JpegMemTest, Jpeg2) { + // create known data, for size in_w x in_h + const int in_w = 256; + const int in_h = 256; + const int stride1 = 3 * in_w; + const std::unique_ptr<uint8[]> refdata1(new uint8[stride1 * in_h]); + for (int i = 0; i < in_h; i++) { + for (int j = 0; j < in_w; j++) { + const int offset = i * stride1 + 3 * j; + refdata1[offset + 0] = i; + refdata1[offset + 1] = j; + refdata1[offset + 2] = static_cast<uint8>((i + j) >> 1); + } + } + + // duplicate with weird input stride + const int stride2 = 3 * 357; + const std::unique_ptr<uint8[]> refdata2(new uint8[stride2 * in_h]); + for (int i = 0; i < in_h; i++) { + memcpy(&refdata2[i * stride2], &refdata1[i * stride1], 3 * in_w); + } + + // Test compression + string cpdata1, cpdata2; + { + const string kXMP = "XMP_TEST_123"; + + // Compress it to JPEG + CompressFlags flags; + flags.format = FORMAT_RGB; + flags.quality = 97; + flags.xmp_metadata = kXMP; + cpdata1 = Compress(refdata1.get(), in_w, in_h, flags); + flags.stride = stride2; + cpdata2 = Compress(refdata2.get(), in_w, in_h, flags); + // Different input stride shouldn't change the output + CHECK_EQ(cpdata1, cpdata2); + + // Verify valid XMP. + CHECK_NE(string::npos, cpdata1.find(kXMP)); + + // Test the other API, where a storage string is supplied + string cptest; + flags.stride = 0; + Compress(refdata1.get(), in_w, in_h, flags, &cptest); + CHECK_EQ(cptest, cpdata1); + flags.stride = stride2; + Compress(refdata2.get(), in_w, in_h, flags, &cptest); + CHECK_EQ(cptest, cpdata2); + } + + // Uncompress twice: once with 3 components and once with autodetect + std::unique_ptr<uint8[]> imgdata1; + for (const int components : {0, 3}) { + // Uncompress it + UncompressFlags flags; + flags.components = components; + int w, h, c; + imgdata1.reset( + Uncompress(cpdata1.c_str(), cpdata1.length(), flags, &w, &h, &c, NULL)); + + // Check obvious formatting stuff + CHECK_EQ(w, in_w); + CHECK_EQ(h, in_h); + CHECK_EQ(c, 3); + CHECK(imgdata1.get()); + + // Compare the two images + const int totalerr = ComputeSumAbsoluteDifference( + imgdata1.get(), refdata1.get(), in_w, in_h, stride1, stride1); + CHECK_LE(totalerr, 85000); + } + + // check the second image too. Should be bitwise identical to the first. + // uncompress using a weird stride + { + UncompressFlags flags; + flags.stride = 3 * 411; + const std::unique_ptr<uint8[]> imgdata2(new uint8[flags.stride * in_h]); + CHECK(imgdata2.get() == Uncompress(cpdata2.c_str(), cpdata2.length(), flags, + NULL, [&imgdata2](int w, int h, int c) { + CHECK_EQ(w, in_w); + CHECK_EQ(h, in_h); + CHECK_EQ(c, 3); + return imgdata2.get(); + })); + const int totalerr = ComputeSumAbsoluteDifference( + imgdata1.get(), imgdata2.get(), in_w, in_h, stride1, flags.stride); + CHECK_EQ(totalerr, 0); + } +} + +// Takes JPEG data and reads its headers to determine whether or not the JPEG +// was chroma downsampled. +bool IsChromaDownsampled(const string& jpegdata) { + // Initialize libjpeg structures to have a memory source + // Modify the usual jpeg error manager to catch fatal errors. + struct jpeg_decompress_struct cinfo; + struct jpeg_error_mgr jerr; + jmp_buf jpeg_jmpbuf; + cinfo.err = jpeg_std_error(&jerr); + cinfo.client_data = &jpeg_jmpbuf; + jerr.error_exit = CatchError; + if (setjmp(jpeg_jmpbuf)) return false; + + // set up, read header, set image parameters, save size + jpeg_create_decompress(&cinfo); + SetSrc(&cinfo, jpegdata.c_str(), jpegdata.size(), false); + + jpeg_read_header(&cinfo, TRUE); + jpeg_start_decompress(&cinfo); // required to transfer image size to cinfo + const int components = cinfo.output_components; + if (components == 1) return false; + + // Check validity + CHECK_EQ(3, components); + CHECK_EQ(cinfo.comp_info[1].h_samp_factor, cinfo.comp_info[2].h_samp_factor) + << "The h sampling factors should be the same."; + CHECK_EQ(cinfo.comp_info[1].v_samp_factor, cinfo.comp_info[2].v_samp_factor) + << "The v sampling factors should be the same."; + for (int i = 0; i < components; ++i) { + CHECK_GT(cinfo.comp_info[i].h_samp_factor, 0) << "Invalid sampling factor."; + CHECK_EQ(cinfo.comp_info[i].h_samp_factor, cinfo.comp_info[i].v_samp_factor) + << "The sampling factor should be the same in both directions."; + } + + // We're downsampled if we use fewer samples for color than for brightness. + // Do this before deallocating cinfo. + const bool downsampled = + cinfo.comp_info[1].h_samp_factor < cinfo.comp_info[0].h_samp_factor; + + jpeg_destroy_decompress(&cinfo); + return downsampled; +} + +TEST(JpegMemTest, ChromaDownsampling) { + // Read the data from a test jpeg file into memory + const string jpegfile = string(kTestData) + "jpeg_merge_test1.jpg"; + string jpeg; + ReadFileToStringOrDie(Env::Default(), jpegfile, &jpeg); + + // Verify that compressing the JPEG with chroma downsampling works. + // + // First, uncompress the JPEG. + UncompressFlags unflags; + unflags.components = 3; + int w, h, c, num_warnings; + std::unique_ptr<uint8[]> uncompressed(Uncompress( + jpeg.c_str(), jpeg.size(), unflags, &w, &h, &c, &num_warnings)); + CHECK(uncompressed.get() != NULL); + CHECK_EQ(num_warnings, 0); + + // Recompress the JPEG with and without chroma downsampling + for (const bool downsample : {false, true}) { + CompressFlags flags; + flags.format = FORMAT_RGB; + flags.quality = 85; + flags.chroma_downsampling = downsample; + string recompressed; + Compress(uncompressed.get(), w, h, flags, &recompressed); + CHECK(!recompressed.empty()); + CHECK_EQ(IsChromaDownsampled(recompressed), downsample); + } +} + +void TestBadJPEG(Env* env, const string& bad_jpeg_file, int expected_width, + int expected_height, const string& reference_RGB_file, + const bool try_recover_truncated_jpeg) { + string jpeg; + ReadFileToStringOrDie(env, bad_jpeg_file, &jpeg); + + UncompressFlags flags; + flags.components = 3; + flags.try_recover_truncated_jpeg = try_recover_truncated_jpeg; + + int width, height, components; + std::unique_ptr<uint8[]> imgdata; + imgdata.reset(Uncompress(jpeg.c_str(), jpeg.size(), flags, &width, &height, + &components, NULL)); + if (expected_width > 0) { // we expect the file to decode into 'something' + CHECK_EQ(width, expected_width); + CHECK_EQ(height, expected_height); + CHECK_EQ(components, 3); + CHECK(imgdata.get()); + if (!reference_RGB_file.empty()) { + string ref; + ReadFileToStringOrDie(env, reference_RGB_file, &ref); + CHECK(!memcmp(ref.data(), imgdata.get(), ref.size())); + } + } else { // no decodable + CHECK(!imgdata.get()) << "file:" << bad_jpeg_file; + } +} + +TEST(JpegMemTest, BadJpeg) { + Env* env = Env::Default(); + const string data_path = kTestData; + + // Test corrupt file + TestBadJPEG(env, data_path + "bad_huffman.jpg", 1024, 768, "", false); + TestBadJPEG(env, data_path + "corrupt.jpg", 0 /*120*/, 90, "", false); + + // Truncated files, undecodable because of missing lines: + TestBadJPEG(env, data_path + "corrupt34_2.jpg", 0, 3300, "", false); + TestBadJPEG(env, data_path + "corrupt34_3.jpg", 0, 3300, "", false); + TestBadJPEG(env, data_path + "corrupt34_4.jpg", 0, 3300, "", false); + + // Try in 'recover' mode now: + TestBadJPEG(env, data_path + "corrupt34_2.jpg", 2544, 3300, "", true); + TestBadJPEG(env, data_path + "corrupt34_3.jpg", 2544, 3300, "", true); + TestBadJPEG(env, data_path + "corrupt34_4.jpg", 2544, 3300, "", true); +} + +} // namespace +} // namespace jpeg +} // namespace tensorflow diff --git a/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg b/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg Binary files differnew file mode 100644 index 0000000000..ef5b6f12c5 --- /dev/null +++ b/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt.jpg Binary files differnew file mode 100644 index 0000000000..5e2fe6c56f --- /dev/null +++ b/tensorflow/core/lib/jpeg/testdata/corrupt.jpg diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg Binary files differnew file mode 100644 index 0000000000..4211155c45 --- /dev/null +++ b/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg Binary files differnew file mode 100644 index 0000000000..c1c2a9d1e1 --- /dev/null +++ b/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg Binary files differnew file mode 100644 index 0000000000..b8e7308ba0 --- /dev/null +++ b/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg diff --git a/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg Binary files differnew file mode 100644 index 0000000000..5e348a12fd --- /dev/null +++ b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg diff --git a/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg Binary files differnew file mode 100644 index 0000000000..15f895960d --- /dev/null +++ b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc new file mode 100644 index 0000000000..43b84e41e0 --- /dev/null +++ b/tensorflow/core/lib/png/png_io.cc @@ -0,0 +1,385 @@ +// Functions to read and write images in PNG format. + +#include <string.h> +#include <sys/types.h> +#include <string> +#include <utility> +#include <vector> +// NOTE(skal): we don't '#include <setjmp.h>' before png/png.h as it otherwise +// provokes a compile error. We instead let png.h include what is needed. + +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/png/png_io.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" // endian +#include "external/png_archive/libpng-1.2.53/png.h" + +namespace tensorflow { +namespace png { + +//////////////////////////////////////////////////////////////////////////////// +// Encode an 8- or 16-bit rgb/grayscale image to PNG string +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +#define PTR_INC(type, ptr, del) (ptr = \ + reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del))) +#define CPTR_INC(type, ptr, del) (ptr = \ + reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + (del))) + +// Convert from 8 bit components to 16. This works in-place. +static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes, + int width, int height, uint16* p16, + int p16_row_bytes) { + // Adjust pointers to copy backwards + width *= num_comps; + CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes + + (width - 1) * sizeof(*p8)); + PTR_INC(uint16, p16, (height - 1) * p16_row_bytes + + (width - 1) * sizeof(*p16)); + int bump8 = width * sizeof(*p8) - p8_row_bytes; + int bump16 = width * sizeof(*p16) - p16_row_bytes; + for (; height-- != 0; + CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) { + for (int w = width; w-- != 0; --p8, --p16) { + uint pix = *p8; + pix |= pix << 8; + *p16 = static_cast<uint16>(pix); + } + } +} + +#undef PTR_INC +#undef CPTR_INC + +void ErrorHandler(png_structp png_ptr, png_const_charp msg) { + DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr)); + ctx->error_condition = true; + // To prevent log spam, errors are logged as VLOG(1) instead of ERROR. + VLOG(1) << "PNG error: " << msg; + longjmp(png_jmpbuf(png_ptr), 1); +} + +void WarningHandler(png_structp png_ptr, png_const_charp msg) { + LOG(WARNING) << "PNG warning: " << msg; +} + +void StringReader(png_structp png_ptr, + png_bytep data, png_size_t length) { + DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr)); + if (static_cast<png_size_t>(ctx->data_left) < length) { + if (!ctx->error_condition) { + VLOG(1) << "PNG read decoding error"; + ctx->error_condition = true; + } + memset(data, 0, length); + } else { + memcpy(data, ctx->data, length); + ctx->data += length; + ctx->data_left -= length; + } +} + +void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) { + string* const s = bit_cast<string*>(png_get_io_ptr(png_ptr)); + s->append(bit_cast<const char*>(data), length); +} + +void StringWriterFlush(png_structp png_ptr) { +} + +char* check_metadata_string(const string& s) { + const char* const c_string = s.c_str(); + const size_t length = s.size(); + if (strlen(c_string) != length) { + LOG(WARNING) << "Warning! Metadata contains \\0 character(s)."; + } + return const_cast<char*>(c_string); +} + +} // namespace + +// We move CommonInitDecode() and CommonFinishDecode() +// out of the CommonDecode() template to save code space. +void CommonFreeDecode(DecodeContext* context) { + if (context->png_ptr) { + png_destroy_read_struct(&context->png_ptr, + context->info_ptr ? &context->info_ptr : NULL, 0); + context->png_ptr = nullptr; + context->info_ptr = nullptr; + } +} + +bool DecodeHeader(StringPiece png_string, int* width, int* height, + int* components, int* channel_bit_depth, + std::vector<std::pair<string, string> >* metadata) { + DecodeContext context; + // Ask for 16 bits even if there may be fewer. This assures that sniffing + // the metadata will succeed in all cases. + // + // TODO(skal): CommonInitDecode() mixes the operation of sniffing the + // metadata with setting up the data conversions. These should be separated. + constexpr int kDesiredNumChannels = 1; + constexpr int kDesiredChannelBits = 16; + if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits, + &context)) { + return false; + } + CHECK_NOTNULL(width); + *width = static_cast<int>(context.width); + CHECK_NOTNULL(height); + *height = static_cast<int>(context.height); + if (components != NULL) { + switch (context.color_type) { + case PNG_COLOR_TYPE_PALETTE: + *components = (context.info_ptr->valid & PNG_INFO_tRNS) ? 4 : 3; + break; + case PNG_COLOR_TYPE_GRAY: + *components = 1; + break; + case PNG_COLOR_TYPE_GRAY_ALPHA: + *components = 2; + break; + case PNG_COLOR_TYPE_RGB: + *components = 3; + break; + case PNG_COLOR_TYPE_RGB_ALPHA: + *components = 4; + break; + default: + *components = 0; + break; + } + } + if (channel_bit_depth != NULL) { + *channel_bit_depth = context.bit_depth; + } + if (metadata != NULL) { + metadata->clear(); + for (int i = 0; i < context.info_ptr->num_text; i++) { + const png_text& text = context.info_ptr->text[i]; + metadata->push_back(std::make_pair(text.key, text.text)); + } + } + CommonFreeDecode(&context); + return true; +} + +bool CommonInitDecode(StringPiece png_string, int desired_channels, + int desired_channel_bits, DecodeContext* context) { + CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) + << "desired_channel_bits = " << desired_channel_bits; + CHECK(0 <= desired_channels && desired_channels <= 4) << "desired_channels = " + << desired_channels; + context->error_condition = false; + context->channels = desired_channels; + context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context, + ErrorHandler, WarningHandler); + if (!context->png_ptr) { + VLOG(1) << ": DecodePNG <- png_create_read_struct failed"; + return false; + } + if (setjmp(png_jmpbuf(context->png_ptr))) { + VLOG(1) << ": DecodePNG error trapped."; + CommonFreeDecode(context); + return false; + } + context->info_ptr = png_create_info_struct(context->png_ptr); + if (!context->info_ptr || context->error_condition) { + VLOG(1) << ": DecodePNG <- png_create_info_struct failed"; + CommonFreeDecode(context); + return false; + } + context->data = bit_cast<const uint8*>(png_string.data()); + context->data_left = png_string.size(); + png_set_read_fn(context->png_ptr, context, StringReader); + png_read_info(context->png_ptr, context->info_ptr); + png_get_IHDR(context->png_ptr, context->info_ptr, + &context->width, &context->height, + &context->bit_depth, &context->color_type, + 0, 0, 0); + if (context->error_condition) { + VLOG(1) << ": DecodePNG <- error during header parsing."; + CommonFreeDecode(context); + return false; + } + if (context->width <= 0 || context->height <= 0) { + VLOG(1) << ": DecodePNG <- invalid dimensions"; + CommonFreeDecode(context); + return false; + } + if (context->channels == 0) { // Autodetect number of channels + context->channels = context->info_ptr->channels; + } + const bool has_tRNS = (context->info_ptr->valid & PNG_INFO_tRNS) != 0; + const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0; + if ((context->channels & 1) == 0) { // We desire alpha + if (has_alpha) { // There is alpha + } else if (has_tRNS) { + png_set_tRNS_to_alpha(context->png_ptr); // Convert transparency to alpha + } else { + png_set_add_alpha( + context->png_ptr, (1 << context->bit_depth) - 1, PNG_FILLER_AFTER); + } + } else { // We don't want alpha + if (has_alpha || has_tRNS) { // There is alpha + png_set_strip_alpha(context->png_ptr); // Strip alpha + } + } + + // If we only want 8 bits, but are given 16, strip off the LS 8 bits + if (context->bit_depth > 8 && desired_channel_bits <= 8) + png_set_strip_16(context->png_ptr); + + context->need_to_synthesize_16 = + (context->bit_depth <= 8 && desired_channel_bits == 16); + + png_set_packing(context->png_ptr); + context->num_passes = png_set_interlace_handling(context->png_ptr); + png_read_update_info(context->png_ptr, context->info_ptr); + +#ifdef IS_LITTLE_ENDIAN + if (desired_channel_bits > 8) + png_set_swap(context->png_ptr); +#endif // IS_LITTLE_ENDIAN + + // convert palette to rgb(a) if needs be. + if (context->color_type == PNG_COLOR_TYPE_PALETTE) + png_set_palette_to_rgb(context->png_ptr); + + // handle grayscale case for source or destination + const bool want_gray = (context->channels < 3); + const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR); + if (is_gray) { // upconvert gray to 8-bit if needed. + if (context->bit_depth < 8) + png_set_gray_1_2_4_to_8(context->png_ptr); + } + if (want_gray) { // output is grayscale + if (!is_gray) + png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587); // 601, JPG + } else { // output is rgb(a) + if (is_gray) + png_set_gray_to_rgb(context->png_ptr); // Enable gray -> RGB conversion + } + return true; +} + +bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) { + CHECK_NOTNULL(data); + + // we need to re-set the jump point so that we trap the errors + // within *this* function (and not CommonInitDecode()) + if (setjmp(png_jmpbuf(context->png_ptr))) { + VLOG(1) << ": DecodePNG error trapped."; + CommonFreeDecode(context); + return false; + } + // png_read_row() takes care of offsetting the pointer based on interlacing + for (int p = 0; p < context->num_passes; ++p) { + png_bytep row = data; + for (int h = context->height; h-- != 0; row += row_bytes) { + png_read_row(context->png_ptr, row, NULL); + } + } + + context->info_ptr->valid |= PNG_INFO_IDAT; + png_read_end(context->png_ptr, context->info_ptr); + + // Clean up. + const bool ok = !context->error_condition; + CommonFreeDecode(context); + + // Synthesize 16 bits from 8 if requested. + if (context->need_to_synthesize_16) + Convert8to16(bit_cast<uint8*>(data), context->channels, row_bytes, + context->width, context->height, bit_cast<uint16*>(data), + row_bytes); + return ok; +} + +bool WriteImageToBuffer( + const void* image, int width, int height, int row_bytes, int num_channels, + int channel_bits, int compression, string* png_string, + const std::vector<std::pair<string, string> >* metadata) { + CHECK_NOTNULL(image); + CHECK_NOTNULL(png_string); + // Although this case is checked inside png.cc and issues an error message, + // that error causes memory corruption. + if (width == 0 || height == 0) + return false; + + png_string->resize(0); + png_infop info_ptr = NULL; + png_structp png_ptr = + png_create_write_struct(PNG_LIBPNG_VER_STRING, + NULL, ErrorHandler, WarningHandler); + if (png_ptr == NULL) return false; + if (setjmp(png_jmpbuf(png_ptr))) { + png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : NULL); + return false; + } + info_ptr = png_create_info_struct(png_ptr); + if (info_ptr == NULL) { + png_destroy_write_struct(&png_ptr, NULL); + return false; + } + + int color_type = -1; + switch (num_channels) { + case 1: + color_type = PNG_COLOR_TYPE_GRAY; + break; + case 2: + color_type = PNG_COLOR_TYPE_GRAY_ALPHA; + break; + case 3: + color_type = PNG_COLOR_TYPE_RGB; + break; + case 4: + color_type = PNG_COLOR_TYPE_RGB_ALPHA; + break; + default: + png_destroy_write_struct(&png_ptr, &info_ptr); + return false; + } + + png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush); + if (compression < 0) compression = Z_DEFAULT_COMPRESSION; + png_set_compression_level(png_ptr, compression); + png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL); + // There used to be a call to png_set_filter here turning off filtering + // entirely, but it produced pessimal compression ratios. I'm not sure + // why it was there. + png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type, + PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT, + PNG_FILTER_TYPE_DEFAULT); + // If we have metadata write to it. + if (metadata && !metadata->empty()) { + std::vector<png_text> text; + for (const auto& pair : *metadata) { + png_text txt; + txt.compression = PNG_TEXT_COMPRESSION_NONE; + txt.key = check_metadata_string(pair.first); + txt.text = check_metadata_string(pair.second); + text.push_back(txt); + } + png_set_text(png_ptr, info_ptr, &text[0], text.size()); + } + + png_write_info(png_ptr, info_ptr); +#ifdef IS_LITTLE_ENDIAN + if (channel_bits > 8) + png_set_swap(png_ptr); +#endif // IS_LITTLE_ENDIAN + + png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image)); + for (; height--; row += row_bytes) png_write_row(png_ptr, row); + png_write_end(png_ptr, NULL); + + png_destroy_write_struct(&png_ptr, &info_ptr); + return true; +} + +} // namespace png +} // namespace tensorflow diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h new file mode 100644 index 0000000000..df9bff7be8 --- /dev/null +++ b/tensorflow/core/lib/png/png_io.h @@ -0,0 +1,88 @@ +// Functions to read and write images in PNG format. +// +// The advantage over image/codec/png{enc,dec}ocder.h is that this library +// supports both 8 and 16 bit images. +// +// The decoding routine accepts binary image data as a StringPiece. These are +// implicitly constructed from strings or char* so they're completely +// transparent to the caller. They're also very cheap to construct so this +// doesn't introduce any additional overhead. +// +// The primary benefit of StringPieces being, in this case, that APIs already +// returning StringPieces (e.g., Bigtable Scanner) or Cords (e.g., IOBuffer; +// only when they're flat, though) or protocol buffer fields typed to either of +// these can be decoded without copying the data into a C++ string. + +#ifndef TENSORFLOW_LIB_PNG_PNG_IO_H_ +#define TENSORFLOW_LIB_PNG_PNG_IO_H_ + +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "external/png_archive/libpng-1.2.53/png.h" + +namespace tensorflow { +namespace png { + +// Handy container for decoding informations and struct pointers +struct DecodeContext { + const uint8* data; + int data_left; + png_structp png_ptr; + png_infop info_ptr; + png_uint_32 width, height; + int num_passes; + int color_type; + int bit_depth; + int channels; + bool need_to_synthesize_16; + bool error_condition; + DecodeContext() : png_ptr(NULL), info_ptr(NULL) {} +}; + +bool DecodeHeader(StringPiece png_string, int* width, int* height, + int* components, int* channel_bit_depth, + std::vector<std::pair<string, string> >* metadata); + +// Sample usage for reading PNG: +// +// string png_string; /* fill with input PNG format data */ +// DecodeContext context; +// CHECK(CommonInitDecode(png_string, 3 /*RGB*/, 8 /*uint8*/, &context)); +// char* image_buffer = new char[3*context.width*context.height]; +// CHECK(CommonFinishDecode(bit_cast<png_byte*>(image_buffer), +// 3*context.width /*stride*/, &context)); +// +// desired_channels may be 0 to detected it from the input. + +bool CommonInitDecode(StringPiece png_string, int desired_channels, + int desired_channel_bits, DecodeContext* context); + +bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context); + +// Normally called automatically from CommonFinishDecode. If CommonInitDecode +// is called but not CommonFinishDecode, call this to clean up. Safe to call +// extra times. +void CommonFreeDecode(DecodeContext* context); + +// Sample usage for writing PNG: +// +// uint16* image_buffer = new uint16[width*height]; /* fill with pixels */ +// string png_string; +// CHECK(WriteImageToBuffer(image_buffer, width, height, 2*width /*stride*/, +// 1 /*gray*/, 16 /*uint16*/, &png_string, NULL)); +// +// compression is in [-1,9], where 0 is fast and weak compression, 9 is slow +// and strong, and -1 is the zlib default. + +bool WriteImageToBuffer( + const void* image, int width, int height, int row_bytes, int num_channels, + int channel_bits, int compression, string* png_string, + const std::vector<std::pair<string, string> >* metadata); + +} // namespace png +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_PNG_PNG_IO_H_ diff --git a/tensorflow/core/lib/png/testdata/lena_gray.png b/tensorflow/core/lib/png/testdata/lena_gray.png Binary files differnew file mode 100644 index 0000000000..8bc73159b0 --- /dev/null +++ b/tensorflow/core/lib/png/testdata/lena_gray.png diff --git a/tensorflow/core/lib/png/testdata/lena_rgba.png b/tensorflow/core/lib/png/testdata/lena_rgba.png Binary files differnew file mode 100644 index 0000000000..79f1f84a62 --- /dev/null +++ b/tensorflow/core/lib/png/testdata/lena_rgba.png diff --git a/tensorflow/core/lib/random/distribution_sampler.cc b/tensorflow/core/lib/random/distribution_sampler.cc new file mode 100644 index 0000000000..341f1bd595 --- /dev/null +++ b/tensorflow/core/lib/random/distribution_sampler.cc @@ -0,0 +1,80 @@ +#include "tensorflow/core/lib/random/distribution_sampler.h" + +#include <memory> +#include <vector> + +namespace tensorflow { +namespace random { + +DistributionSampler::DistributionSampler( + const gtl::ArraySlice<float>& weights) { + DCHECK(!weights.empty()); + int n = weights.size(); + num_ = n; + data_.reset(new std::pair<float, int>[n]); + + std::unique_ptr<double[]> pr(new double[n]); + + double sum = 0.0; + for (int i = 0; i < n; i++) { + sum += weights[i]; + set_alt(i, -1); + } + + // These are long/short items - called high/low because of reserved keywords. + std::vector<int> high; + high.reserve(n); + std::vector<int> low; + low.reserve(n); + + // compute propotional weights + for (int i = 0; i < n; i++) { + double p = (weights[i] * n) / sum; + pr[i] = p; + if (p < 1.0) { + low.push_back(i); + } else { + high.push_back(i); + } + } + + // Now pair high with low. + while (!high.empty() && !low.empty()) { + int l = low.back(); + low.pop_back(); + int h = high.back(); + high.pop_back(); + + set_alt(l, h); + DCHECK_GE(pr[h], 1.0); + double remaining = pr[h] - (1.0 - pr[l]); + pr[h] = remaining; + + if (remaining < 1.0) { + low.push_back(h); + } else { + high.push_back(h); + } + } + // Transfer pr to prob with rounding errors. + for (int i = 0; i < n; i++) { + set_prob(i, pr[i]); + } + // Because of rounding errors, both high and low may have elements, that are + // close to 1.0 prob. + for (size_t i = 0; i < high.size(); i++) { + int idx = high[i]; + set_prob(idx, 1.0); + // set alt to self to prevent rounding errors returning 0 + set_alt(idx, idx); + } + for (size_t i = 0; i < low.size(); i++) { + int idx = low[i]; + set_prob(idx, 1.0); + // set alt to self to prevent rounding errors returning 0 + set_alt(idx, idx); + } +} + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h new file mode 100644 index 0000000000..ab9598a205 --- /dev/null +++ b/tensorflow/core/lib/random/distribution_sampler.h @@ -0,0 +1,79 @@ +// DistributionSampler allows generating a discrete random variable with a given +// distribution. +// The values taken by the variable are [0, N) and relative weights for each +// value are specified using a vector of size N. +// +// The Algorithm takes O(N) time to precompute data at construction time and +// takes O(1) time (2 random number generation, 2 lookups) for each sample. +// The data structure takes O(N) memory. +// +// In contrast, util/random/weighted-picker.h provides O(lg N) sampling. +// The advantage of that implementation is that weights can be adjusted +// dynamically, while DistributionSampler doesn't allow weight adjustment. +// +// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2. + +#ifndef TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#define TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ + +#include <memory> +#include <utility> +#include <vector> + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace random { + +class DistributionSampler { + public: + explicit DistributionSampler(const gtl::ArraySlice<float>& weights); + + ~DistributionSampler() {} + + int Sample(SimplePhilox* rand) const { + float r = rand->RandFloat(); + // Since n is typically low, we don't bother with UnbiasedUniform. + int idx = rand->Uniform(num_); + if (r < prob(idx)) return idx; + // else pick alt from that bucket. + DCHECK_NE(-1, alt(idx)); + return alt(idx); + } + + int num() const { return num_; } + + private: + float prob(int idx) const { + DCHECK_LT(idx, num_); + return data_[idx].first; + } + + int alt(int idx) const { + DCHECK_LT(idx, num_); + return data_[idx].second; + } + + void set_prob(int idx, float f) { + DCHECK_LT(idx, num_); + data_[idx].first = f; + } + + void set_alt(int idx, int val) { + DCHECK_LT(idx, num_); + data_[idx].second = val; + } + + int num_; + std::unique_ptr<std::pair<float, int>[]> data_; + + TF_DISALLOW_COPY_AND_ASSIGN(DistributionSampler); +}; + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ diff --git a/tensorflow/core/lib/random/distribution_sampler_test.cc b/tensorflow/core/lib/random/distribution_sampler_test.cc new file mode 100644 index 0000000000..d61a8daa0f --- /dev/null +++ b/tensorflow/core/lib/random/distribution_sampler_test.cc @@ -0,0 +1,90 @@ +#include "tensorflow/core/lib/random/distribution_sampler.h" + +#include <string.h> +#include <memory> +#include <vector> + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace random { + +class DistributionSamplerTest : public ::testing::Test { + protected: + // Returns the Chi-Squared statistic for the two distributions. + float TestWeights(const std::vector<float>& weights, int trials_per_bin) { + int iters = weights.size() * trials_per_bin; + std::unique_ptr<float[]> counts(new float[weights.size()]); + memset(counts.get(), 0, sizeof(float) * weights.size()); + DistributionSampler sampler(weights); + PhiloxRandom philox(testing::RandomSeed(), 17); + SimplePhilox random(&philox); + for (int i = 0; i < iters; i++) { + int r = sampler.Sample(&random); + EXPECT_LT(r, weights.size()); + EXPECT_GE(r, 0); + counts[r] += 1.0; + } + float chi2 = 0.0; + for (size_t i = 0; i < weights.size(); i++) { + counts[i] /= iters; + float err = (counts[i] - weights[i]); + chi2 += (err * err) / weights[i]; + } + return chi2; + } + + void TestDistribution(float* arr, int n) { + std::vector<float> w; + w.reserve(n); + for (int i = 0; i < n; i++) { + w.push_back(arr[i]); + } + float var = TestWeights(w, 1000); + if (var < 0.001) return; + // Maybe a statistical skew. Let's try more iterations. + var = TestWeights(w, 100000); + if (var < 0.001) return; + EXPECT_TRUE(false) << "Chi2 is " << var << " in " << n * 100000 + << "iterations"; + } +}; + +TEST_F(DistributionSamplerTest, KnownDistribution) { + float kEven2[] = {0.5, 0.5}; + float kEven3[] = {0.33333333, 0.33333333, 0.33333333}; + float kEven4[] = {0.25, 0.25, 0.25, 0.25}; + + float kDist1[] = {0.8, 0.15, 0.05}; + + TestDistribution(kEven2, TF_ARRAYSIZE(kEven2)); + TestDistribution(kEven3, TF_ARRAYSIZE(kEven3)); + TestDistribution(kEven4, TF_ARRAYSIZE(kEven4)); + TestDistribution(kDist1, TF_ARRAYSIZE(kDist1)); +} + +static void BM_DistributionSampler(int iters, int n) { + testing::StopTiming(); + PhiloxRandom philox(173, 371); + SimplePhilox rand(&philox); + std::vector<float> weights(n, 0); + for (int i = 0; i < n; i++) { + weights[i] = rand.Uniform(100); + } + DistributionSampler picker(weights); + testing::StartTiming(); + int r = 0; + for (int i = 0; i < iters; i++) { + r |= picker.Sample(&rand); + } + CHECK_NE(r, kint32max); +} + +BENCHMARK(BM_DistributionSampler)->Arg(10)->Arg(100)->Arg(1000); + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/exact_uniform_int.h b/tensorflow/core/lib/random/exact_uniform_int.h new file mode 100644 index 0000000000..616354cc5c --- /dev/null +++ b/tensorflow/core/lib/random/exact_uniform_int.h @@ -0,0 +1,68 @@ +// Exact uniform integers using rejection sampling + +#ifndef TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_ +#define TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_ + +#include <type_traits> + +namespace tensorflow { +namespace random { + +template <typename UintType, typename RandomBits> +UintType ExactUniformInt(const UintType n, const RandomBits& random) { + static_assert(std::is_unsigned<UintType>::value, + "UintType must be an unsigned int"); + static_assert(std::is_same<UintType, decltype(random())>::value, + "random() should return UintType"); + if (n == 0) { + // Consume a value anyway + // TODO(irving): Assert n != 0, since this case makes no sense. + return random() * n; + } else if (0 == (n & (n - 1))) { + // N is a power of two, so just mask off the lower bits. + return random() & (n - 1); + } else { + // Reject all numbers that skew the distribution towards 0. + + // random's output is uniform in the half-open interval [0, 2^{bits}). + // For any interval [m,n), the number of elements in it is n-m. + + const UintType range = ~static_cast<UintType>(0); + const UintType rem = (range % n) + 1; + UintType rnd; + + // rem = ((2^bits-1) \bmod n) + 1 + // 1 <= rem <= n + + // NB: rem == n is impossible, since n is not a power of 2 (from + // earlier check). + + do { + rnd = random(); // rnd uniform over [0, 2^{bits}) + } while (rnd < rem); // reject [0, rem) + // rnd is uniform over [rem, 2^{bits}) + // + // The number of elements in the half-open interval is + // + // 2^{bits} - rem = 2^{bits} - ((2^{bits}-1) \bmod n) - 1 + // = 2^{bits}-1 - ((2^{bits}-1) \bmod n) + // = n \cdot \lfloor (2^{bits}-1)/n \rfloor + // + // therefore n evenly divides the number of integers in the + // interval. + // + // The function v \rightarrow v % n takes values from [bias, + // 2^{bits}) to [0, n). Each integer in the range interval [0, n) + // will have exactly \lfloor (2^{bits}-1)/n \rfloor preimages from + // the domain interval. + // + // Therefore, v % n is uniform over [0, n). QED. + + return rnd % n; + } +} + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_ diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h new file mode 100644 index 0000000000..2c3cd0c4b9 --- /dev/null +++ b/tensorflow/core/lib/random/philox_random.h @@ -0,0 +1,232 @@ +// Implement the Philox algorithm to generate random numbers in parallel. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf + +#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_ +#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_ + +#include <stdlib.h> + +#include "tensorflow/core/platform/port.h" + +// Function qualifiers that need to work on both CPU and GPU. +#ifdef __CUDA_ARCH__ +// For nvcc. +#define PHILOX_DEVICE_FUNC __host__ __device__ +#define PHILOX_INLINE __inline__ +#else +// For non-nvcc. +#define PHILOX_DEVICE_FUNC +#define PHILOX_INLINE inline +#endif +#define PHILOX_DEVICE_INLINE PHILOX_DEVICE_FUNC PHILOX_INLINE + +#include <math.h> + +namespace tensorflow { +namespace random { + +// A class that represents an inline array. It can be used on both CPU and GPU, +// and also trivially copyable between CPU and GPU. +// Arguments: +// T: the array element type; +// ElementCount: the fixed size of the array; +template <typename T, int ElementCount> +class Array { + public: + PHILOX_DEVICE_INLINE Array() { + for (int i = 0; i < ElementCount; ++i) { + data_[i] = T(); + } + } + + PHILOX_DEVICE_INLINE const T& operator[](int index) const { + return data_[index]; + } + + PHILOX_DEVICE_INLINE T& operator[](int index) { return data_[index]; } + + size_t size() const { return ElementCount; } + + private: + T data_[ElementCount]; +}; + +// A class that encapsulates all the states for a random number generator using +// the philox_4x32_10 algorithm. Each invocation returns a 128-bit random bits +// in the form of four uint32. +// There are multiple variants of this algorithm, we picked the 4x32_10 version +// that is most suited for our applications. +// Since this class is meant to be copied between CPU to GPU, it maintains a +// value semantics. +// +// For example: To use this class and populate an array of 1024 randoms on CPU +// with two threads, +// +// void Fill(PhiloxRandom rnd, uint32* output, int start, int limit) { +// assert(start % 4 == 0); +// assert(limit % 4 == 0); +// rnd.Skip(start / 4); +// for (int i = start; i < limit; i += 4) { +// auto sample = rnd(); +// ... copy sample[0..3] to output[i..i+3] +// } +// } +// +// PhiloxRandom rng(seed); +// PhiloxRandom rng_copy = rng; +// rng.Skip(1000/4); +// +// ... schedule Fill(rng_copy, output, 0, 512) in thread 1; +// ... schedule Fill(rng_copy, output, 512, 1024) in thread 2; +// ... wait for thread 1 & 2 to finish executing Fill(). +// +// NOTE: +// 1. PhiloxRandom is trivially copyable. +// 2. PhiloxRandom is compilable by gcc and nvcc. +class PhiloxRandom { + public: + typedef Array<uint32, 4> ResultType; + typedef uint32 ResultElementType; + // The number of elements that will be returned. + static const int kResultElementCount = 4; + + PHILOX_DEVICE_INLINE + PhiloxRandom() {} + + PHILOX_DEVICE_INLINE + explicit PhiloxRandom(uint64 seed) { + key_[0] = static_cast<uint32>(seed); + key_[1] = static_cast<uint32>(seed >> 32); + } + + PHILOX_DEVICE_INLINE + explicit PhiloxRandom(uint64 seed_lo, uint64 seed_hi) { + key_[0] = static_cast<uint32>(seed_lo); + key_[1] = static_cast<uint32>(seed_lo >> 32); + counter_[2] = static_cast<uint32>(seed_hi); + counter_[3] = static_cast<uint32>(seed_hi >> 32); + } + + // Skip the specified number of samples of 128-bits in the current stream. + PHILOX_DEVICE_INLINE + void Skip(uint64 count) { + const uint32 count_lo = static_cast<uint32>(count); + uint32 count_hi = static_cast<uint32>(count >> 32); + + counter_[0] += count_lo; + if (counter_[0] < count_lo) { + ++count_hi; + } + + counter_[1] += count_hi; + if (counter_[1] < count_hi) { + if (++counter_[2] == 0) { + ++counter_[3]; + } + } + } + + // Returns a group of four random numbers using the underlying Philox + // algorithm. + PHILOX_DEVICE_INLINE ResultType operator()() { + ResultType counter = counter_; + Key key = key_; + + // Run the single rounds for ten times. Manually unrolling the loop + // for better performance. + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + + SkipOne(); + + return counter; + } + + private: + // The type for the 64-bit key stored in the form of two 32-bit uint + // that are used in the diffusion process. + typedef Array<uint32, 2> Key; + + // We use the same constants as recommended by the original paper. + static const uint32 kPhiloxW32A = 0x9E3779B9; + static const uint32 kPhiloxW32B = 0xBB67AE85; + static const uint32 kPhiloxM4x32A = 0xD2511F53; + static const uint32 kPhiloxM4x32B = 0xCD9E8D57; + + // Helper function to skip the next sample of 128-bits in the current stream. + PHILOX_DEVICE_INLINE void SkipOne() { + if (++counter_[0] == 0) { + if (++counter_[1] == 0) { + if (++counter_[2] == 0) { + ++counter_[3]; + } + } + } + } + + // Helper function to return the lower and higher 32-bits from two 32-bit + // integer multiplications. + PHILOX_DEVICE_INLINE + static void MultiplyHighLow(uint32 a, uint32 b, uint32* result_low, + uint32* result_high) { +#ifndef __GCUDACC__ + const uint64 product = static_cast<uint64>(a) * b; + *result_low = static_cast<uint32>(product); + *result_high = static_cast<uint32>(product >> 32); +#else + *result_low = a * b; + *result_high = __umulhi(a, b); +#endif + } + + // Helper function for a single round of the underlying Philox algorithm. + PHILOX_DEVICE_INLINE static ResultType ComputeSingleRound( + const ResultType& counter, const Key& key) { + uint32 lo0; + uint32 hi0; + MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0); + + uint32 lo1; + uint32 hi1; + MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1); + + ResultType result; + result[0] = hi1 ^ counter[1] ^ key[0]; + result[1] = lo1; + result[2] = hi0 ^ counter[3] ^ key[1]; + result[3] = lo0; + return result; + } + + PHILOX_DEVICE_INLINE void RaiseKey(Key* key) { + (*key)[0] += kPhiloxW32A; + (*key)[1] += kPhiloxW32B; + } + + private: + ResultType counter_; + Key key_; +}; + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_ diff --git a/tensorflow/core/lib/random/philox_random_test.cc b/tensorflow/core/lib/random/philox_random_test.cc new file mode 100644 index 0000000000..997c0263b7 --- /dev/null +++ b/tensorflow/core/lib/random/philox_random_test.cc @@ -0,0 +1,58 @@ +#include "tensorflow/core/lib/random/philox_random.h" + +#include <math.h> +#include <algorithm> +#include <functional> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/random/philox_random_test_utils.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace random { +namespace { + +// A trivial distribution that just returns the PhiloxRandom as a distribution +class TrivialPhiloxDistribution { + public: + // The number of elements that will be returned. + static constexpr int kResultElementCount = PhiloxRandom::kResultElementCount; + typedef PhiloxRandom::ResultType ResultType; + typedef PhiloxRandom::ResultElementType ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(PhiloxRandom* gen) { return (*gen)(); } +}; + +// This test checks that skipping certain number of samples, is equivalent to +// generate the same number of samples without skipping. +TEST(PhiloxRandomTest, SkipMatchTest) { + constexpr int count = 1024; + constexpr int skip_count = 2048; + + uint64 test_seed = GetTestSeed(); + std::vector<uint32> v1(count); + { + PhiloxRandom gen(test_seed); + gen.Skip(skip_count / 4); + FillRandoms<TrivialPhiloxDistribution>(gen, &v1[0], v1.size()); + } + + std::vector<uint32> v2(count + skip_count); + { + PhiloxRandom gen(test_seed); + FillRandoms<TrivialPhiloxDistribution>(gen, &v2[0], v2.size()); + } + + for (int i = 0; i < count; ++i) { + ASSERT_EQ(v1[i], v2[i + skip_count]); + } +} + +} // namespace +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/philox_random_test_utils.h b/tensorflow/core/lib/random/philox_random_test_utils.h new file mode 100644 index 0000000000..d22f6b36e4 --- /dev/null +++ b/tensorflow/core/lib/random/philox_random_test_utils.h @@ -0,0 +1,36 @@ +#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ +#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ + +#include <algorithm> + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace random { + +// Return a random seed. +inline uint64 GetTestSeed() { return New64(); } + +// A utility function to fill the given array with samples from the given +// distribution. +template <class Distribution> +void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p, + int64 size) { + const int granularity = Distribution::kResultElementCount; + + CHECK(size % granularity == 0) << " size: " << size + << " granularity: " << granularity; + + Distribution dist; + for (int i = 0; i < size; i += granularity) { + const auto sample = dist(&gen); + std::copy(&sample[0], &sample[0] + granularity, &p[i]); + } +} + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ diff --git a/tensorflow/core/lib/random/random.cc b/tensorflow/core/lib/random/random.cc new file mode 100644 index 0000000000..2959b05382 --- /dev/null +++ b/tensorflow/core/lib/random/random.cc @@ -0,0 +1,22 @@ +#include "tensorflow/core/lib/random/random.h" + +#include <random> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace random { + +std::mt19937_64* InitRng() { + std::random_device device("/dev/random"); + return new std::mt19937_64(device()); +} + +uint64 New64() { + static std::mt19937_64* rng = InitRng(); + static mutex mu; + mutex_lock l(mu); + return (*rng)(); +} + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/random.h b/tensorflow/core/lib/random/random.h new file mode 100644 index 0000000000..1a20436c4e --- /dev/null +++ b/tensorflow/core/lib/random/random.h @@ -0,0 +1,16 @@ +#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_H_ +#define TENSORFLOW_LIB_RANDOM_RANDOM_H_ + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace random { + +// Return a 64-bit random value. Different sequences are generated +// in different processes. +uint64 New64(); + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_RANDOM_H_ diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h new file mode 100644 index 0000000000..caafcde513 --- /dev/null +++ b/tensorflow/core/lib/random/random_distributions.h @@ -0,0 +1,361 @@ +#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ + +#include <math.h> +#include <string.h> +#include <algorithm> + +#include "tensorflow/core/lib/random/philox_random.h" + +namespace tensorflow { +namespace random { + +// Helper function to convert a 32-bit integer to a float between [0..1). +PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x); +// Helper function to convert two 32-bit integers to a double between [0..1). +PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1); + +// A class that generates uniform distribution random numbers from the +// underlying random integer generator. +// Arguments: +// Generator: a generator type that returns a number of uint32 upon each +// each invocation. It needs to define kResultElementCount for the +// sample count for each invocation, and ResultType for actual +// returned sample type. +// RealType: the data type of the real numberes that will be returned by the +// distribution. This could be either float or double for now. +// This class is meant to be implemented through specialization. The default +// is not defined by design. +template <class Generator, typename RealType> +class UniformDistribution; + +template <class Generator> +class UniformDistribution<Generator, float> { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array<float, kResultElementCount> ResultType; + typedef float ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; ++i) { + result[i] = Uint32ToFloat(sample[i]); + } + return result; + } +}; + +template <class Generator> +class UniformDistribution<Generator, double> { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount / 2; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array<double, kResultElementCount> ResultType; + typedef double ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; ++i) { + result[i] = Uint64ToDouble(sample[2 * i], sample[2 * i + 1]); + } + return result; + } +}; + +// A class that adapts the underlying native multiple samples to return a single +// sample at a time. +template <class Generator> +class SingleSampleAdapter { + public: + // The number of elements that will be returned. + static const int kResultElementCount = 1; + // The number of elements that will be returned by the underlying generator. + static const int kNativeElementCount = Generator::kResultElementCount; + typedef typename Generator::ResultElementType ResultType; + typedef typename Generator::ResultElementType ResultElementType; + + PHILOX_DEVICE_INLINE + explicit SingleSampleAdapter(Generator* gen) + : generator_(gen), used_result_index_(Generator::kResultElementCount) {} + + PHILOX_DEVICE_INLINE + ResultType operator()() { + if (used_result_index_ == Generator::kResultElementCount) { + unused_results_ = (*generator_)(); + used_result_index_ = 0; + } + + return unused_results_[used_result_index_++]; + } + + private: + Generator* generator_; + typename Generator::ResultType unused_results_; + int used_result_index_; +}; + +// A class that generates unit normal distribution random numbers from the +// underlying random integer generator. +// Arguments: +// Generator: a generator type that returns a number of uint32 upon each +// each invocation. It needs to define kResultElementCount for the +// sample count for each invocation, and ResultType for actual +// returned sample type. +// RealType: the data type of the real numberes that will be returned by the +// distribution. This could be either float or double for now. +// This class is meant to be implemented through specialization. The default +// is not defined by design. +template <class Generator, typename RealType> +class NormalDistribution; + +PHILOX_DEVICE_INLINE +void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1); + +PHILOX_DEVICE_INLINE +void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0, + double* d1); + +template <class Generator> +class NormalDistribution<Generator, float> { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array<float, kResultElementCount> ResultType; + typedef float ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; i += 2) { + BoxMullerFloat(sample[i], sample[i + 1], &result[i], &result[i + 1]); + } + return result; + } +}; + +template <class Generator> +class NormalDistribution<Generator, double> { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount / 2; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array<double, kResultElementCount> ResultType; + typedef double ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; i += 2) { + const int i2 = 2 * i; + BoxMullerDouble(sample[i2], sample[i2 + 1], sample[i2 + 2], + sample[i2 + 3], &result[i], &result[i + 1]); + } + return result; + } +}; + +// A class that returns standard normal distribution between +// [-kTruncateValue, kTruncateValue]. +// Arguments: +// Generator: a generator type that returns a number of uint32 upon each +// each invocation. It needs to define kResultElementCount for the +// sample count for each invocation, and ResultType for actual +// returned sample type. +// RealType: the data type of the real numberes that will be returned by the +// distribution. This could be either float or double for now. +// This class is meant to be implemented through specialization. The default +// is not defined by design. +template <class SingleSampleGenerator, typename RealType> +class TruncatedNormalDistribution; + +// Partial specialization for float. +template <class SingleSampleGenerator> +class TruncatedNormalDistribution<SingleSampleGenerator, float> { + public: + // The number of elements that will be returned. + static const int kResultElementCount = + SingleSampleGenerator::kNativeElementCount; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = true; + // The threshold where the normal distribution is truncated. + const float kTruncateValue = 2.0f; + + typedef Array<float, kResultElementCount> ResultType; + typedef float ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(SingleSampleGenerator* gen) { + ResultType results; + int index = 0; + while (true) { + // Repeatedly take samples from the normal distribution, until we have + // the desired number of elements that fall within the pre-defined cutoff + // threshold. + const uint32 x0 = (*gen)(); + const uint32 x1 = (*gen)(); + float f[2]; + BoxMullerFloat(x0, x1, &f[0], &f[1]); + + for (int i = 0; i < 2; ++i) { + if (fabs(f[i]) < kTruncateValue) { + results[index++] = f[i]; + if (index >= kResultElementCount) { + return results; + } + } + } + } + } +}; + +// Partial specialization for double. +template <class SingleSampleGenerator> +class TruncatedNormalDistribution<SingleSampleGenerator, double> { + public: + // The number of elements that will be returned. + static const int kResultElementCount = + (SingleSampleGenerator::kNativeElementCount > 1) + ? SingleSampleGenerator::kNativeElementCount / 2 + : 1; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = true; + typedef Array<double, kResultElementCount> ResultType; + typedef double ResultElementType; + const double kTruncateValue = 2.0; + + PHILOX_DEVICE_INLINE + ResultType operator()(SingleSampleGenerator* gen) { + ResultType results; + int index = 0; + while (1) { + const uint32 x0 = (*gen)(); + const uint32 x1 = (*gen)(); + const uint32 x2 = (*gen)(); + const uint32 x3 = (*gen)(); + double d[2]; + BoxMullerDouble(x0, x1, x2, x3, &d[0], &d[1]); + + for (int i = 0; i < 2; ++i) { + if (fabs(d[i]) < kTruncateValue) { + results[index++] = d[i]; + if (index >= kResultElementCount) { + return results; + } + } + } + } + } +}; + +// Helper function to convert two 32-bit uniform integers to two floats +// under the unit normal distribution. +PHILOX_DEVICE_INLINE +void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1) { + // This function implements the Box-Muller transform: + // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form + // Do not send a really small number to log(). + // We cannot mark "epsilon" as "static const" because NVCC would complain + const float epsilon = 1.0e-7f; + float u1 = Uint32ToFloat(x0); + if (u1 < epsilon) { + u1 = epsilon; + } + const float v1 = 2.0f * M_PI * Uint32ToFloat(x1); + const float u2 = sqrt(-2.0f * log(u1)); +#if defined(__linux) + sincosf(v1, f0, f1); +#else + *f0 = sinf(v1); + *f1 = cosf(v1); +#endif + *f0 *= u2; + *f1 *= u2; +} + +// Helper function to convert four 32-bit uniform integers to two doubles +// under the unit normal distribution. +PHILOX_DEVICE_INLINE +void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0, + double* d1) { + // This function implements the Box-Muller transform: + // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form + // Do not send a really small number to log(). + // We cannot mark "epsilon" as "static const" because NVCC would complain + const double epsilon = 1.0e-7; + double u1 = Uint64ToDouble(x0, x1); + if (u1 < epsilon) { + u1 = epsilon; + } + const double v1 = 2 * M_PI * Uint64ToDouble(x2, x3); + const double u2 = sqrt(-2.0 * log(u1)); +#if defined(__linux) + sincos(v1, d0, d1); +#else + *d0 = sin(v1); + *d1 = cos(v1); +#endif + *d0 *= u2; + *d1 *= u2; +} + +// Helper function to convert an 32-bit integer to a float between [0..1). +PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x) { + // IEEE754 floats are formatted as follows (MSB first): + // sign(1) exponent(8) mantissa(23) + // Conceptually construct the following: + // sign == 0 + // exponent == 127 -- an excess 127 representation of a zero exponent + // mantissa == 23 random bits + const uint32 man = x & 0x7fffffu; // 23 bit mantissa + const uint32 exp = static_cast<uint32>(127); + const uint32 val = (exp << 23) | man; + + // Assumes that endian-ness is same for float and uint32. + float result; + memcpy(&result, &val, sizeof(val)); + return result - 1.0f; +} + +// Helper function to convert two 32-bit integers to a double between [0..1). +PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1) { + // IEEE754 doubles are formatted as follows (MSB first): + // sign(1) exponent(11) mantissa(52) + // Conceptually construct the following: + // sign == 0 + // exponent == 1023 -- an excess 1023 representation of a zero exponent + // mantissa == 52 random bits + const uint32 mhi = x0 & 0xfffffu; // upper 20 bits of mantissa + const uint32 mlo = x1; // lower 32 bits of mantissa + const uint64 man = (static_cast<uint64>(mhi) << 32) | mlo; // mantissa + const uint64 exp = static_cast<uint64>(1023); + const uint64 val = (exp << 52) | man; + // Assumes that endian-ness is same for double and uint64. + double result; + memcpy(&result, &val, sizeof(val)); + return result - 1.0; +} + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc new file mode 100644 index 0000000000..3ce86a907a --- /dev/null +++ b/tensorflow/core/lib/random/random_distributions_test.cc @@ -0,0 +1,270 @@ +#include "tensorflow/core/lib/random/random_distributions.h" + +#include <math.h> +#include <algorithm> +#include <functional> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/philox_random_test_utils.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace random { +namespace { + +// The largest z-value we want to tolerate. Since the z-test approximates a +// unit normal distribution, it should almost definitely never exceed 6. +static constexpr float kZLimit = 6.0; + +// A utility function to fill the given array with samples from the given +// distribution, using the single adatper of the underlying generator +template <class Distribution> +void FillRandomsWithSingles(PhiloxRandom gen, + typename Distribution::ResultElementType* p, + int64 size) { + int granularity = Distribution::kResultElementCount; + + CHECK(size % granularity == 0) << " size: " << size + << " granularity: " << granularity; + + SingleSampleAdapter<PhiloxRandom> single_samples(&gen); + + Distribution dist; + for (int i = 0; i < size; i += granularity) { + auto sample = dist(&single_samples); + std::copy(&sample[0], &sample[0] + granularity, &p[i]); + } +} + +// Check the given array of samples matches the given theoretical moment +// function at different orders. The test is considered passing if the z-tests +// of all statistical moments are all below z_limit. +// typename T in the template argument could be either float or double. +// Arguments: +// samples: an array of samples to be tested for their statistical properties; +// theoretical_moments: a functor that can calculate arbitrary order of +// of the given distribution; +// max_moments: the largest moments of the uniform distribution to be tested; +// stride: the distance between samples to check for statistical properties +// 0 means the n-th moment of each sample +// any other strides tests for spatial correlation between samples; +// z_limit: the maximum z-test we would consider the test to pass; +template <typename T> +bool CheckSamplesMoments(const std::vector<T>& samples, + std::function<double(int)> theoretical_moments, + int max_moments, int stride, T z_limit) { + const T* const samples_data = &samples[0]; + const int samples_size = samples.size(); + std::vector<double> moments(max_moments + 1); + double* const moments_data = &moments[0]; + std::vector<int> moments_sample_count(max_moments + 1); + int* const moments_sample_count_data = &moments_sample_count[0]; + + for (int k = 0; k < samples_size; ++k) { + double moment = 1.; + for (int i = 0; i <= max_moments; ++i) { + int index = k + i * stride; + if (index >= samples_size) { + break; + } + // moments[i] store the i-th order measured moments. + // bypass std::vector::opeartor[] because they are too slow in the debug + // mode, given the large number of samples. + moments_data[i] += moment; + ++moments_sample_count_data[i]; + moment *= samples_data[index]; + } + } + + // normalize the moments + for (int i = 0; i <= max_moments; ++i) { + moments[i] /= moments_sample_count[i]; + } + + bool status = true; + + for (int i = 1; i <= max_moments; ++i) { + // Calculate the theoretical mean and variance + const double moments_i_mean = (stride == 0) + ? theoretical_moments(i) + : std::pow(theoretical_moments(1), i); + const double moments_i_squared = (stride == 0) + ? theoretical_moments(2 * i) + : std::pow(theoretical_moments(2), i); + const double moments_i_var = + moments_i_squared - moments_i_mean * moments_i_mean; + + // assume every operation has a small numerical error. + static const double kNumericalError = 1e-6; + // it takes i multiplications to calculate one i-th moment. + const double error_per_moment = i * kNumericalError; + const double total_variance = + moments_i_var / moments_sample_count[i] + error_per_moment; + // z_test is approximately a unit normal distribution. + const double z_test = + fabs((moments[i] - moments_i_mean) / sqrt(total_variance)); + + if (z_test > z_limit) { + LOG(ERROR) << "failing z_test:" + << " moment: " << i << " stride: " << stride + << " z_test: " << z_test << " z_limit: " << z_limit + << " measured moments: " << moments[i] + << " theoretical mean of the moments: " << moments_i_mean + << " theoretical var of the moments: " << moments_i_var + << " sample count: " << moments_sample_count[i]; + status = false; + } + } + + return status; +} + +// This tests checks that the generated samples match the theoretical moments +// of the uniform distribution. +template <typename T> +void UniformMomentsTest(int count, int max_moments, + const std::vector<int>& strides, T z_limit) { + auto uniform_moments = [](int n) -> double { return 1. / (n + 1); }; + + std::vector<T> v1(count); + uint64 seed = GetTestSeed(); + PhiloxRandom gen(seed); + FillRandoms<UniformDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size()); + for (int stride : strides) { + bool status = CheckSamplesMoments<T>(v1, uniform_moments, max_moments, + stride, z_limit); + ASSERT_TRUE(status) << " UniformMomentsTest failing. seed: " << seed; + } +} + +// This test checks that the generated samples match the theoretical moments +// of the unit normal distribution. +template <typename T> +void NormalMomentsTest(int count, int max_moments, + const std::vector<int>& strides, T z_limit) { + auto normal_moments = [](int n) -> double { + if (n % 2 == 1) { + // For an odd order, the moment of a unit normal distribution is zero. + return 0.; + } else { + // For an even order, the moment of a unit normal distribution is. + // (n-1)!! + double v = 1.; + for (int i = n - 1; i >= 1; i -= 2) { + v *= i; + } + return v; + } + }; + + std::vector<T> v1(count); + uint64 seed = GetTestSeed(); + PhiloxRandom gen(seed); + FillRandoms<NormalDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size()); + + for (int stride : strides) { + bool status = CheckSamplesMoments<T>(v1, normal_moments, max_moments, + stride, z_limit); + ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed; + } +} + +// A functor to calculate the moments for the truncated normal distribution. +// For any odd order, the moment is zero. But for any other n, it can be proven +// that the following recursive relationship for the moments of the truncated +// standard normal: +// m(n) = (n - 1) * m(n - 2) - 2 * v ^ (n - 1) * f(v) / (2 * Phi(v) - 1) +// where v is the cut-off value, f(v) is the p.d.f of the standard +// normal, and Phi(v) is the c.d.f of the standard normal. +class TruncatedNormalMoments { + public: + double operator()(int n) { + if (n == 0) { + return 1; + } + if (n % 2 == 1) { + // For an odd order, the moment is always zero + return 0.; + } + + // Memoization and check the cached results. + auto iter = cached_results_.find(n); + if (iter != cached_results_.end()) { + return iter->second; + } + + // The real computation of the moment. + double bias = 2.0 * std::pow(kV, n - 1) * kFV / (2.0 * kPhiV - 1.0); + double moment_n_minus_2 = (*this)(n - 2); + double moment_n = (n - 1) * moment_n_minus_2 - bias; + + cached_results_[n] = moment_n; + return moment_n; + } + + private: + const double kV = 2.0; + // f(v), where f is the p.d.f of the normal distribution and v=2. + const double kFV = 1.0 / sqrt(2.0 * M_PI) * exp(-kV * kV / 2.0); + // The numerical evaluation of Phi(v), where v is the truncate value. + // v = 2 in the current implementation. + const double kPhiV = 0.977249868051821; + std::unordered_map<int, double> cached_results_; +}; + +// This test checks that the generated samples matche the theoretical moments +// of the truncated normal distribution. +template <typename T> +void RandomParametersMomentsTest(int count, int max_moments, + const std::vector<int>& strides, T z_limit) { + std::vector<T> v1(count); + uint64 seed = GetTestSeed(); + PhiloxRandom gen(seed); + FillRandomsWithSingles< + TruncatedNormalDistribution<SingleSampleAdapter<PhiloxRandom>, T> >( + gen, &v1[0], v1.size()); + + for (int stride : strides) { + bool status = CheckSamplesMoments<T>(v1, TruncatedNormalMoments(), + max_moments, stride, z_limit); + ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed; + } +} + +TEST(PhiloxRandomTest, UniformFloatMomentsTest) { + const std::vector<int> strides = {0, 1, 4, 17}; + UniformMomentsTest<float>(1 << 20, 40, strides, kZLimit); +} + +TEST(PhiloxRandomTest, NormalFloatMomentsTest) { + const std::vector<int> strides = {0, 1, 4, 17}; + NormalMomentsTest<float>(8 << 20, 25, strides, kZLimit); +} + +TEST(PhiloxRandomTest, RandomParametersFloatMomentsTest) { + const std::vector<int> strides = {0, 1, 4, 17}; + RandomParametersMomentsTest<float>(1 << 20, 40, strides, kZLimit); +} + +TEST(PhiloxRandomTest, UniformDoubleMomentsTest) { + const std::vector<int> strides = {0, 1, 4, 17}; + UniformMomentsTest<double>(1 << 20, 40, strides, kZLimit); +} + +TEST(PhiloxRandomTest, NormalDoubleMomentsTest) { + const std::vector<int> strides = {0, 1, 4, 17}; + NormalMomentsTest<double>(8 << 20, 25, strides, kZLimit); +} + +TEST(PhiloxRandomTest, RandomParametersDoubleMomentsTest) { + const std::vector<int> strides = {0, 1, 4, 17}; + RandomParametersMomentsTest<double>(1 << 20, 40, strides, kZLimit); +} + +} // namespace +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/random_test.cc b/tensorflow/core/lib/random/random_test.cc new file mode 100644 index 0000000000..7ed37c8b5e --- /dev/null +++ b/tensorflow/core/lib/random/random_test.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/lib/random/random.h" + +#include <set> +#include "tensorflow/core/platform/port.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace random { +namespace { + +TEST(New64Test, SanityCheck) { + std::set<uint64> values; + for (int i = 0; i < 1000000; i++) { + uint64 x = New64(); + EXPECT_TRUE(values.insert(x).second) << "duplicate " << x; + } +} + +} // namespace +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/simple_philox.cc b/tensorflow/core/lib/random/simple_philox.cc new file mode 100644 index 0000000000..1035e1f017 --- /dev/null +++ b/tensorflow/core/lib/random/simple_philox.cc @@ -0,0 +1,24 @@ +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/random/exact_uniform_int.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace random { + +uint32 SimplePhilox::Uniform(uint32 n) { + return ExactUniformInt<uint32>(n, [this]() { return Rand32(); }); +} + +uint64 SimplePhilox::Uniform64(uint64 n) { + return ExactUniformInt<uint64>(n, [this]() { return Rand64(); }); +} + +uint32 SimplePhilox::Skewed(int max_log) { + CHECK(0 <= max_log && max_log <= 32); + const int shift = Rand32() % (max_log + 1); + const uint32 mask = shift == 32 ? ~static_cast<uint32>(0) : (1 << shift) - 1; + return Rand32() & mask; +} + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/simple_philox.h b/tensorflow/core/lib/random/simple_philox.h new file mode 100644 index 0000000000..12b15d7616 --- /dev/null +++ b/tensorflow/core/lib/random/simple_philox.h @@ -0,0 +1,61 @@ +#ifndef TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_ +#define TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_ + +#include <math.h> +#include <string.h> +#include <algorithm> + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { +namespace random { + +// A simple imperative interface to Philox +class SimplePhilox { + public: + PHILOX_DEVICE_INLINE + explicit SimplePhilox(PhiloxRandom* gen) : single_(gen) {} + + // 32 random bits + PHILOX_DEVICE_INLINE uint32 Rand32() { return single_(); } + + // 64 random bits + PHILOX_DEVICE_INLINE uint64 Rand64() { + const uint32 lo = single_(), hi = single_(); + return lo | static_cast<uint64>(hi) << 32; + } + + // Uniform float in [0, 1) + PHILOX_DEVICE_INLINE float RandFloat() { return Uint32ToFloat(single_()); } + + // Uniform double in [0, 1) + PHILOX_DEVICE_INLINE double RandDouble() { + const uint32 x0 = single_(), x1 = single_(); + return Uint64ToDouble(x0, x1); + } + + // Uniform integer in [0, n). + // Uses rejection sampling, so may need more than one 32-bit sample. + uint32 Uniform(uint32 n); + + // Approximately uniform integer in [0, n). + // Uses rejection sampling, so may need more than one 64-bit sample. + uint64 Uniform64(uint64 n); + + // True with probability 1/n. + bool OneIn(uint32 n) { return Uniform(n) == 0; } + + // Skewed: pick "base" uniformly from range [0,max_log] and then + // return "base" random bits. The effect is to pick a number in the + // range [0,2^max_log-1] with bias towards smaller numbers. + uint32 Skewed(int max_log); + + private: + SingleSampleAdapter<PhiloxRandom> single_; +}; + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_ diff --git a/tensorflow/core/lib/random/simple_philox_test.cc b/tensorflow/core/lib/random/simple_philox_test.cc new file mode 100644 index 0000000000..4246b8b4dd --- /dev/null +++ b/tensorflow/core/lib/random/simple_philox_test.cc @@ -0,0 +1,120 @@ +#include "tensorflow/core/lib/random/simple_philox.h" + +#include <set> +#include <string> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace random { +namespace { + +TEST(SimplePhiloxTest, FloatTest) { + PhiloxRandom philox(7, 7); + SimplePhilox gen(&philox); + static const int kIters = 1000000; + for (int i = 0; i < kIters; ++i) { + float f = gen.RandFloat(); + EXPECT_LE(0.0f, f); + EXPECT_GT(1.0f, f); + } + for (int i = 0; i < kIters; ++i) { + double d = gen.RandDouble(); + EXPECT_LE(0.0, d); + EXPECT_GT(1.0, d); + } +} + +static void DifferenceTest(const char *names, SimplePhilox *gen1, + SimplePhilox *gen2) { + static const int kIters = 100; + bool different = false; + for (int i = 0; i < kIters; ++i) { + if (gen1->Rand32() != gen2->Rand32()) { + different = true; + break; + } + } + CHECK(different) << "different seeds but same output!"; +} + +TEST(SimplePhiloxTest, DifferenceTest) { + PhiloxRandom philox1(1, 1), philox2(17, 17); + SimplePhilox gen1(&philox1), gen2(&philox2); + + DifferenceTest("SimplePhilox: different seeds", &gen1, &gen2); +} + +TEST(SimplePhiloxTest, DifferenceTestCloseSeeds) { + PhiloxRandom philox1(1, 1), philox2(2, 1); + SimplePhilox gen1(&philox1), gen2(&philox2); + + DifferenceTest("SimplePhilox: close seeds", &gen1, &gen2); +} + +TEST(SimplePhiloxTest, Regression_CloseSeedsAreDifferent) { + const int kCount = 1000; + + // Two seeds differ only by the last bit. + PhiloxRandom philox1(0, 1), philox2(1, 1); + SimplePhilox gen1(&philox1), gen2(&philox2); + + std::set<uint32> first; + std::set<uint32> all; + for (int i = 0; i < kCount; ++i) { + uint32 v = gen1.Rand32(); + first.insert(v); + all.insert(v); + all.insert(gen2.Rand32()); + } + + // Broken array initialization implementation (before 2009-08-18) using the + // above seeds return <1000, 1007>, generating output that is >99% similar. + // The fix returns <1000, 2000> for completely disjoint sets. + EXPECT_EQ(kCount, first.size()); + EXPECT_EQ(2 * kCount, all.size()); +} + +TEST(SimplePhiloxTest, TestUniform) { + PhiloxRandom philox(17, 17); + SimplePhilox gen(&philox); + + uint32 range = 3 * (1L << 29); + uint32 threshold = 1L << 30; + + size_t count = 0; + static const int kTrials = 100000; + for (int i = 0; i < kTrials; ++i) { + uint32 rnd = gen.Uniform(range); + if (rnd < threshold) { + ++count; + } + } + + EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005); +} + +TEST(SimplePhiloxTest, TestUniform64) { + PhiloxRandom philox(17, 17); + SimplePhilox gen(&philox); + + uint64 range = 3 * (1LL << 59); + uint64 threshold = 1LL << 60; + + size_t count = 0; + static const int kTrials = 100000; + for (int i = 0; i < kTrials; ++i) { + uint64 rnd = gen.Uniform64(range); + if (rnd < threshold) { + ++count; + } + } + + EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005); +} + +} // namespace +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/weighted_picker.cc b/tensorflow/core/lib/random/weighted_picker.cc new file mode 100644 index 0000000000..f96da578ec --- /dev/null +++ b/tensorflow/core/lib/random/weighted_picker.cc @@ -0,0 +1,203 @@ +#include "tensorflow/core/lib/random/weighted_picker.h" + +#include <string.h> +#include <algorithm> + +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { +namespace random { + +WeightedPicker::WeightedPicker(int N) { + CHECK_GE(N, 0); + N_ = N; + + // Find the number of levels + num_levels_ = 1; + while (LevelSize(num_levels_ - 1) < N) { + num_levels_++; + } + + // Initialize the levels + level_ = new int32*[num_levels_]; + for (int l = 0; l < num_levels_; l++) { + level_[l] = new int32[LevelSize(l)]; + } + + SetAllWeights(1); +} + +WeightedPicker::~WeightedPicker() { + for (int l = 0; l < num_levels_; l++) { + delete[] level_[l]; + } + delete[] level_; +} + +static int32 UnbiasedUniform(SimplePhilox* r, int32 n) { + CHECK_LE(0, n); + const uint32 range = ~static_cast<uint32>(0); + if (n == 0) { + return r->Rand32() * n; + } else if (0 == (n & (n - 1))) { + // N is a power of two, so just mask off the lower bits. + return r->Rand32() & (n - 1); + } else { + // Reject all numbers that skew the distribution towards 0. + + // Rand32's output is uniform in the half-open interval [0, 2^{32}). + // For any interval [m,n), the number of elements in it is n-m. + + uint32 rem = (range % n) + 1; + uint32 rnd; + + // rem = ((2^{32}-1) \bmod n) + 1 + // 1 <= rem <= n + + // NB: rem == n is impossible, since n is not a power of 2 (from + // earlier check). + + do { + rnd = r->Rand32(); // rnd uniform over [0, 2^{32}) + } while (rnd < rem); // reject [0, rem) + // rnd is uniform over [rem, 2^{32}) + // + // The number of elements in the half-open interval is + // + // 2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1 + // = 2^{32}-1 - ((2^{32}-1) \bmod n) + // = n \cdot \lfloor (2^{32}-1)/n \rfloor + // + // therefore n evenly divides the number of integers in the + // interval. + // + // The function v \rightarrow v % n takes values from [bias, + // 2^{32}) to [0, n). Each integer in the range interval [0, n) + // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from + // the domain interval. + // + // Therefore, v % n is uniform over [0, n). QED. + + return rnd % n; + } +} + +int WeightedPicker::Pick(SimplePhilox* rnd) const { + if (total_weight() == 0) return -1; + + // using unbiased uniform distribution to avoid bias + // toward low elements resulting from a possible use + // of big weights. + return PickAt(UnbiasedUniform(rnd, total_weight())); +} + +int WeightedPicker::PickAt(int32 weight_index) const { + if (weight_index < 0 || weight_index >= total_weight()) return -1; + + int32 position = weight_index; + int index = 0; + + for (int l = 1; l < num_levels_; l++) { + // Pick left or right child of "level_[l-1][index]" + const int32 left_weight = level_[l][2 * index]; + if (position < left_weight) { + // Descend to left child + index = 2 * index; + } else { + // Descend to right child + index = 2 * index + 1; + position -= left_weight; + } + } + CHECK_GE(index, 0); + CHECK_LT(index, N_); + CHECK_LE(position, level_[num_levels_ - 1][index]); + return index; +} + +void WeightedPicker::set_weight(int index, int32 weight) { + assert(index >= 0); + assert(index < N_); + + // Adjust the sums all the way up to the root + const int32 delta = weight - get_weight(index); + for (int l = num_levels_ - 1; l >= 0; l--) { + level_[l][index] += delta; + index >>= 1; + } +} + +void WeightedPicker::SetAllWeights(int32 weight) { + // Initialize leaves + int32* leaves = level_[num_levels_ - 1]; + for (int i = 0; i < N_; i++) leaves[i] = weight; + for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0; + + // Now sum up towards the root + RebuildTreeWeights(); +} + +void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) { + Resize(N); + + // Initialize leaves + int32* leaves = level_[num_levels_ - 1]; + for (int i = 0; i < N_; i++) leaves[i] = weights[i]; + for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0; + + // Now sum up towards the root + RebuildTreeWeights(); +} + +void WeightedPicker::RebuildTreeWeights() { + for (int l = num_levels_ - 2; l >= 0; l--) { + int32* level = level_[l]; + int32* children = level_[l + 1]; + for (int i = 0; i < LevelSize(l); i++) { + level[i] = children[2 * i] + children[2 * i + 1]; + } + } +} + +void WeightedPicker::Append(int32 weight) { + Resize(num_elements() + 1); + set_weight(num_elements() - 1, weight); +} + +void WeightedPicker::Resize(int new_size) { + CHECK_GE(new_size, 0); + if (new_size <= LevelSize(num_levels_ - 1)) { + // The new picker fits in the existing levels. + + // First zero out any of the weights that are being dropped so + // that the levels are correct (only needed when shrinking) + for (int i = new_size; i < N_; i++) { + set_weight(i, 0); + } + + // We do not need to set any new weights when enlarging because + // the unneeded entries always have weight zero. + N_ = new_size; + return; + } + + // We follow the simple strategy of just copying the old + // WeightedPicker into a new WeightedPicker. The cost is + // O(N) regardless. + assert(new_size > N_); + WeightedPicker new_picker(new_size); + int32* dst = new_picker.level_[new_picker.num_levels_ - 1]; + int32* src = this->level_[this->num_levels_ - 1]; + memcpy(dst, src, sizeof(dst[0]) * N_); + memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_)); + new_picker.RebuildTreeWeights(); + + // Now swap the two pickers + std::swap(new_picker.N_, this->N_); + std::swap(new_picker.num_levels_, this->num_levels_); + std::swap(new_picker.level_, this->level_); + assert(this->N_ == new_size); +} + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/weighted_picker.h b/tensorflow/core/lib/random/weighted_picker.h new file mode 100644 index 0000000000..3d2c2dbb39 --- /dev/null +++ b/tensorflow/core/lib/random/weighted_picker.h @@ -0,0 +1,118 @@ + +// An abstraction to pick from one of N elements with a specified +// weight per element. +// +// The weight for a given element can be changed in O(lg N) time +// An element can be picked in O(lg N) time. +// +// Uses O(N) bytes of memory. +// +// Alternative: distribution-sampler.h allows O(1) time picking, but no weight +// adjustment after construction. + +#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ +#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ + +#include <assert.h> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace random { + +class SimplePhilox; + +class WeightedPicker { + public: + // REQUIRES N >= 0 + // Initializes the elements with a weight of one per element + explicit WeightedPicker(int N); + + // Releases all resources + ~WeightedPicker(); + + // Pick a random element with probability proportional to its weight. + // If total weight is zero, returns -1. + int Pick(SimplePhilox* rnd) const; + + // Deterministically pick element x whose weight covers the + // specified weight_index. + // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ] + int PickAt(int32 weight_index) const; + + // Get the weight associated with an element + // REQUIRES 0 <= index < N + int32 get_weight(int index) const; + + // Set the weight associated with an element + // REQUIRES weight >= 0.0f + // REQUIRES 0 <= index < N + void set_weight(int index, int32 weight); + + // Get the total combined weight of all elements + int32 total_weight() const; + + // Get the number of elements in the picker + int num_elements() const; + + // Set weight of each element to "weight" + void SetAllWeights(int32 weight); + + // Resizes the picker to N and + // sets the weight of each element i to weight[i]. + // The sum of the weights should not exceed 2^31 - 2 + // Complexity O(N). + void SetWeightsFromArray(int N, const int32* weights); + + // REQUIRES N >= 0 + // + // Resize the weighted picker so that it has "N" elements. + // Any newly added entries have zero weight. + // + // Note: Resizing to a smaller size than num_elements() will + // not reclaim any memory. If you wish to reduce memory usage, + // allocate a new WeightedPicker of the appropriate size. + // + // It is efficient to use repeated calls to Resize(num_elements() + 1) + // to grow the picker to size X (takes total time O(X)). + void Resize(int N); + + // Grow the picker by one and set the weight of the new entry to "weight". + // + // Repeated calls to Append() in order to grow the + // picker to size X takes a total time of O(X lg(X)). + // Consider using SetWeightsFromArray instead. + void Append(int32 weight); + + private: + // We keep a binary tree with N leaves. The "i"th leaf contains + // the weight of the "i"th element. An internal node contains + // the sum of the weights of its children. + int N_; // Number of elements + int num_levels_; // Number of levels in tree (level-0 is root) + int32** level_; // Array that holds nodes per level + + // Size of each level + static int LevelSize(int level) { return 1 << level; } + + // Rebuild the tree weights using the leaf weights + void RebuildTreeWeights(); + + TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker); +}; + +inline int32 WeightedPicker::get_weight(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, N_); + return level_[num_levels_ - 1][index]; +} + +inline int32 WeightedPicker::total_weight() const { return level_[0][0]; } + +inline int WeightedPicker::num_elements() const { return N_; } + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ diff --git a/tensorflow/core/lib/random/weighted_picker_test.cc b/tensorflow/core/lib/random/weighted_picker_test.cc new file mode 100644 index 0000000000..0b27d437d5 --- /dev/null +++ b/tensorflow/core/lib/random/weighted_picker_test.cc @@ -0,0 +1,254 @@ +#include "tensorflow/core/lib/random/weighted_picker.h" + +#include <string.h> +#include <vector> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace random { + +static void TestPicker(SimplePhilox* rnd, int size); +static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials); +static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials); +static void TestPickAt(int items, const int32* weights); + +TEST(WeightedPicker, Simple) { + PhiloxRandom philox(testing::RandomSeed(), 17); + SimplePhilox rnd(&philox); + + { + VLOG(0) << "======= Zero-length picker"; + WeightedPicker picker(0); + EXPECT_EQ(picker.Pick(&rnd), -1); + } + + { + VLOG(0) << "======= Singleton picker"; + WeightedPicker picker(1); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + } + + { + VLOG(0) << "======= Grown picker"; + WeightedPicker picker(0); + for (int i = 0; i < 10; i++) { + picker.Append(1); + } + CheckUniform(&rnd, &picker, 100000); + } + + { + VLOG(0) << "======= Grown picker with zero weights"; + WeightedPicker picker(1); + picker.Resize(10); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + } + + { + VLOG(0) << "======= Shrink picker and check weights"; + WeightedPicker picker(1); + picker.Resize(10); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + for (int i = 0; i < 10; i++) { + picker.set_weight(i, i); + } + EXPECT_EQ(picker.total_weight(), 45); + picker.Resize(5); + EXPECT_EQ(picker.total_weight(), 10); + picker.Resize(2); + EXPECT_EQ(picker.total_weight(), 1); + picker.Resize(1); + EXPECT_EQ(picker.total_weight(), 0); + } +} + +TEST(WeightedPicker, BigWeights) { + PhiloxRandom philox(testing::RandomSeed() + 1, 17); + SimplePhilox rnd(&philox); + VLOG(0) << "======= Check uniform with big weights"; + WeightedPicker picker(2); + picker.SetAllWeights(2147483646L / 3); // (2^31 - 2) / 3 + CheckUniform(&rnd, &picker, 100000); +} + +TEST(WeightedPicker, Deterministic) { + VLOG(0) << "======= Testing deterministic pick"; + static const int32 weights[] = {1, 0, 200, 5, 42}; + TestPickAt(TF_ARRAYSIZE(weights), weights); +} + +TEST(WeightedPicker, Randomized) { + PhiloxRandom philox(testing::RandomSeed() + 10, 17); + SimplePhilox rnd(&philox); + TestPicker(&rnd, 1); + TestPicker(&rnd, 2); + TestPicker(&rnd, 3); + TestPicker(&rnd, 4); + TestPicker(&rnd, 7); + TestPicker(&rnd, 8); + TestPicker(&rnd, 9); + TestPicker(&rnd, 10); + TestPicker(&rnd, 100); +} + +static void TestPicker(SimplePhilox* rnd, int size) { + VLOG(0) << "======= Testing size " << size; + + // Check that empty picker returns -1 + { + WeightedPicker picker(size); + picker.SetAllWeights(0); + for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1); + } + + // Create zero weights array + std::vector<int32> weights(size); + for (int elem = 0; elem < size; elem++) { + weights[elem] = 0; + } + + // Check that singleton picker always returns the same element + for (int elem = 0; elem < size; elem++) { + WeightedPicker picker(size); + picker.SetAllWeights(0); + picker.set_weight(elem, elem + 1); + for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); + weights[elem] = 10; + picker.SetWeightsFromArray(size, &weights[0]); + for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); + weights[elem] = 0; + } + + // Check that uniform picker generates elements roughly uniformly + { + WeightedPicker picker(size); + CheckUniform(rnd, &picker, 100000); + } + + // Check uniform picker that was grown piecemeal + if (size / 3 > 0) { + WeightedPicker picker(size / 3); + while (picker.num_elements() != size) { + picker.Append(1); + } + CheckUniform(rnd, &picker, 100000); + } + + // Check that skewed distribution works + if (size <= 10) { + // When picker grows one element at a time + WeightedPicker picker(size); + int32 weight = 1; + for (int elem = 0; elem < size; elem++) { + picker.set_weight(elem, weight); + weights[elem] = weight; + weight *= 2; + } + CheckSkewed(rnd, &picker, 1000000); + + // When picker is created from an array + WeightedPicker array_picker(0); + array_picker.SetWeightsFromArray(size, &weights[0]); + CheckSkewed(rnd, &array_picker, 1000000); + } +} + +static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, + int trials) { + const int size = picker->num_elements(); + int* count = new int[size]; + memset(count, 0, sizeof(count[0]) * size); + for (int i = 0; i < size * trials; i++) { + const int elem = picker->Pick(rnd); + EXPECT_GE(elem, 0); + EXPECT_LT(elem, size); + count[elem]++; + } + const int expected_min = int(0.9 * trials); + const int expected_max = int(1.1 * trials); + for (int i = 0; i < size; i++) { + EXPECT_GE(count[i], expected_min); + EXPECT_LE(count[i], expected_max); + } + delete[] count; +} + +static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) { + const int size = picker->num_elements(); + int* count = new int[size]; + memset(count, 0, sizeof(count[0]) * size); + for (int i = 0; i < size * trials; i++) { + const int elem = picker->Pick(rnd); + EXPECT_GE(elem, 0); + EXPECT_LT(elem, size); + count[elem]++; + } + + for (int i = 0; i < size - 1; i++) { + LOG(INFO) << i << ": " << count[i]; + const float ratio = float(count[i + 1]) / float(count[i]); + EXPECT_GE(ratio, 1.6f); + EXPECT_LE(ratio, 2.4f); + } + delete[] count; +} + +static void TestPickAt(int items, const int32* weights) { + WeightedPicker picker(items); + picker.SetWeightsFromArray(items, weights); + int weight_index = 0; + for (int i = 0; i < items; ++i) { + for (int j = 0; j < weights[i]; ++j) { + int pick = picker.PickAt(weight_index); + EXPECT_EQ(pick, i); + ++weight_index; + } + } + EXPECT_EQ(weight_index, picker.total_weight()); +} + +static void BM_Create(int iters, int arg) { + while (--iters > 0) { + WeightedPicker p(arg); + } +} +BENCHMARK(BM_Create)->Range(1, 1024); + +static void BM_CreateAndSetWeights(int iters, int arg) { + std::vector<int32> weights(arg); + for (int i = 0; i < arg; i++) { + weights[i] = i * 10; + } + while (--iters > 0) { + WeightedPicker p(arg); + p.SetWeightsFromArray(arg, &weights[0]); + } +} +BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024); + +static void BM_Pick(int iters, int arg) { + PhiloxRandom philox(301, 17); + SimplePhilox rnd(&philox); + WeightedPicker p(arg); + int result = 0; + while (--iters > 0) { + result += p.Pick(&rnd); + } + VLOG(4) << result; // Dummy use +} +BENCHMARK(BM_Pick)->Range(1, 1024); + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc new file mode 100644 index 0000000000..d61129fb3f --- /dev/null +++ b/tensorflow/core/lib/strings/numbers.cc @@ -0,0 +1,260 @@ +#include "tensorflow/core/lib/strings/numbers.h" + +#include <float.h> +#include <stdio.h> +#include <stdlib.h> +#include <algorithm> +#include <cmath> + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace strings { + +char* FastInt32ToBufferLeft(int32 i, char* buffer) { + uint32 u = i; + if (i < 0) { + *buffer++ = '-'; + // We need to do the negation in modular (i.e., "unsigned") + // arithmetic; MSVC++ apprently warns for plain "-u", so + // we write the equivalent expression "0 - u" instead. + u = 0 - u; + } + return FastUInt32ToBufferLeft(u, buffer); +} + +char* FastUInt32ToBufferLeft(uint32 i, char* buffer) { + char* start = buffer; + do { + *buffer++ = ((i % 10) + '0'); + i /= 10; + } while (i > 0); + *buffer = 0; + std::reverse(start, buffer); + return buffer; +} + +char* FastInt64ToBufferLeft(int64 i, char* buffer) { + uint64 u = i; + if (i < 0) { + *buffer++ = '-'; + u = 0 - u; + } + return FastUInt64ToBufferLeft(u, buffer); +} + +char* FastUInt64ToBufferLeft(uint64 i, char* buffer) { + char* start = buffer; + do { + *buffer++ = ((i % 10) + '0'); + i /= 10; + } while (i > 0); + *buffer = 0; + std::reverse(start, buffer); + return buffer; +} + +static const double kDoublePrecisionCheckMax = DBL_MAX / 1.000000000000001; + +char* DoubleToBuffer(double value, char* buffer) { + // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all + // platforms these days. Just in case some system exists where DBL_DIG + // is significantly larger -- and risks overflowing our buffer -- we have + // this assert. + static_assert(DBL_DIG < 20, "DBL_DIG is too big"); + + bool full_precision_needed = true; + if (std::abs(value) <= kDoublePrecisionCheckMax) { + int snprintf_result = + snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG, value); + + // The snprintf should never overflow because the buffer is significantly + // larger than the precision we asked for. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + + full_precision_needed = strtod(buffer, NULL) != value; + } + + if (full_precision_needed) { + int snprintf_result = + snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG + 2, value); + + // Should never overflow; see above. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + } + return buffer; +} + +bool safe_strto64(const char* str, int64* value) { + if (!str) return false; + + // Skip leading space. + while (isspace(*str)) ++str; + + int64 vlimit = kint64max; + int sign = 1; + if (*str == '-') { + sign = -1; + ++str; + // Different limit for positive and negative integers. + vlimit = kint64min; + } + + if (!isdigit(*str)) return false; + + int64 result = 0; + if (sign == 1) { + do { + int digit = *str - '0'; + if ((vlimit - digit) / 10 < result) { + return false; + } + result = result * 10 + digit; + ++str; + } while (isdigit(*str)); + } else { + do { + int digit = *str - '0'; + if ((vlimit + digit) / 10 > result) { + return false; + } + result = result * 10 - digit; + ++str; + } while (isdigit(*str)); + } + + // Skip trailing space. + while (isspace(*str)) ++str; + + if (*str) return false; + + *value = result; + return true; +} + +bool safe_strto32(const char* str, int32* value) { + if (!str) return false; + + // Skip leading space. + while (isspace(*str)) ++str; + + int64 vmax = kint32max; + int sign = 1; + if (*str == '-') { + sign = -1; + ++str; + // Different max for positive and negative integers. + ++vmax; + } + + if (!isdigit(*str)) return false; + + int64 result = 0; + do { + result = result * 10 + *str - '0'; + if (result > vmax) { + return false; + } + ++str; + } while (isdigit(*str)); + + // Skip trailing space. + while (isspace(*str)) ++str; + + if (*str) return false; + + *value = result * sign; + return true; +} + +bool safe_strtof(const char* str, float* value) { + char* endptr; + *value = strtof(str, &endptr); + while (isspace(*endptr)) ++endptr; + // Ignore range errors from strtod/strtof. + // The values it returns on underflow and + // overflow are the right fallback in a + // robust setting. + return *str != '\0' && *endptr == '\0'; +} + +char* FloatToBuffer(float value, char* buffer) { + // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all + // platforms these days. Just in case some system exists where FLT_DIG + // is significantly larger -- and risks overflowing our buffer -- we have + // this assert. + static_assert(FLT_DIG < 10, "FLT_DIG is too big"); + + int snprintf_result = + snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG, value); + + // The snprintf should never overflow because the buffer is significantly + // larger than the precision we asked for. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + + float parsed_value; + if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) { + snprintf_result = + snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 2, value); + + // Should never overflow; see above. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + } + return buffer; +} + +string FpToString(Fprint fp) { + char buf[17]; + snprintf(buf, sizeof(buf), "%016llx", static_cast<uint64>(fp)); + return string(buf); +} + +bool StringToFp(const string& s, Fprint* fp) { + char junk; + uint64 result; + if (sscanf(s.c_str(), "%llx%c", &result, &junk) == 1) { + *fp = result; + return true; + } else { + return false; + } +} + +string HumanReadableNumBytes(int64 num_bytes) { + if (num_bytes == kint64min) { + // Special case for number with not representable negation. + return "-8E"; + } + + const char* neg_str = (num_bytes < 0) ? "-" : ""; + if (num_bytes < 0) { + num_bytes = -num_bytes; + } + + // Special case for bytes. + if (num_bytes < 1024) { + // No fractions for bytes. + char buf[8]; // Longest possible string is '-XXXXB' + snprintf(buf, sizeof(buf), "%s%lldB", neg_str, + static_cast<int64>(num_bytes)); + return string(buf); + } + + static const char units[] = "KMGTPE"; // int64 only goes up to E. + const char* unit = units; + while (num_bytes >= static_cast<int64>(1024) * 1024) { + num_bytes /= 1024; + ++unit; + CHECK(unit < units + TF_ARRAYSIZE(units)); + } + + // We use SI prefixes. + char buf[16]; + snprintf(buf, sizeof(buf), ((*unit == 'K') ? "%s%.1f%ciB" : "%s%.2f%ciB"), + neg_str, num_bytes / 1024.0, *unit); + return string(buf); +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h new file mode 100644 index 0000000000..a30a862279 --- /dev/null +++ b/tensorflow/core/lib/strings/numbers.h @@ -0,0 +1,92 @@ +#ifndef TENSORFLOW_LIB_STRINGS_NUMBERS_H_ +#define TENSORFLOW_LIB_STRINGS_NUMBERS_H_ + +#include <string> + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace strings { + +// ---------------------------------------------------------------------- +// FastIntToBufferLeft() +// These are intended for speed. +// +// All functions take the output buffer as an arg. FastInt() uses +// at most 22 bytes, FastTime() uses exactly 30 bytes. They all +// return a pointer to the beginning of the output, which is the same as +// the beginning of the input buffer. +// +// NOTE: In 64-bit land, sizeof(time_t) is 8, so it is possible +// to pass to FastTimeToBuffer() a time whose year cannot be +// represented in 4 digits. In this case, the output buffer +// will contain the string "Invalid:<value>" +// ---------------------------------------------------------------------- + +// Previously documented minimums -- the buffers provided must be at least this +// long, though these numbers are subject to change: +// Int32, UInt32: 12 bytes +// Int64, UInt64, Int, Uint: 22 bytes +// Time: 30 bytes +// Use kFastToBufferSize rather than hardcoding constants. +static const int kFastToBufferSize = 32; + +// ---------------------------------------------------------------------- +// FastInt32ToBufferLeft() +// FastUInt32ToBufferLeft() +// FastInt64ToBufferLeft() +// FastUInt64ToBufferLeft() +// +// These functions convert their numeric argument to an ASCII +// representation of the numeric value in base 10, with the +// representation being left-aligned in the buffer. The caller is +// responsible for ensuring that the buffer has enough space to hold +// the output. The buffer should typically be at least kFastToBufferSize +// bytes. +// +// Returns a pointer to the end of the string (i.e. the null character +// terminating the string). +// ---------------------------------------------------------------------- + +char* FastInt32ToBufferLeft(int32 i, char* buffer); // at least 12 bytes +char* FastUInt32ToBufferLeft(uint32 i, char* buffer); // at least 12 bytes +char* FastInt64ToBufferLeft(int64 i, char* buffer); // at least 22 bytes +char* FastUInt64ToBufferLeft(uint64 i, char* buffer); // at least 22 bytes + +// Required buffer size for DoubleToBuffer is kFastToBufferSize. +// Required buffer size for FloatToBuffer is kFastToBufferSize. +char* DoubleToBuffer(double i, char* buffer); +char* FloatToBuffer(float i, char* buffer); + +// Convert a 64-bit fingerprint value to an ASCII representation. +string FpToString(Fprint fp); + +// Attempt to parse a fingerprint in the form encoded by FpToString. If +// successsful, stores the fingerprint in *fp and returns true. Otherwise, +// returns false. +bool StringToFp(const string& s, Fprint* fp); + +// Convert strings to 32bit integer values. +// Leading and trailing spaces are allowed. +// Return false with overflow or invalid input. +bool safe_strto32(const char* str, int32* value); + +// Convert strings to 64bit integer values. +// Leading and trailing spaces are allowed. +// Return false with overflow or invalid input. +bool safe_strto64(const char* str, int64* value); + +// Convert strings to floating point values. +// Leading and trailing spaces are allowed. +// Values may be rounded on over- and underflow. +bool safe_strtof(const char* str, float* value); + +// Converts from an int64 representing a number of bytes to a +// human readable string representing the same number. +// e.g. 12345678 -> "11.77MiB". +string HumanReadableNumBytes(int64 num_bytes); + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_NUMBERS_H_ diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc new file mode 100644 index 0000000000..b178e6af53 --- /dev/null +++ b/tensorflow/core/lib/strings/numbers_test.cc @@ -0,0 +1,113 @@ +#include "tensorflow/core/lib/strings/numbers.h" + +#include <string> +#include <gtest/gtest.h> + +namespace tensorflow { +namespace strings { + +// NOTE: most of the routines in numbers.h are tested indirectly through +// strcat_test.cc in this directory. + +// Test StrCat of ints and longs of various sizes and signdedness. +TEST(FpToString, Ints) { + for (int s = 0; s < 64; s++) { + for (int delta = -1; delta <= 1; delta++) { + uint64 fp = (1ull << s) + delta; + string s = FpToString(fp); + uint64 fp2; + EXPECT_TRUE(StringToFp(s, &fp2)); + EXPECT_EQ(fp, fp2); + } + } + Fprint dummy; + EXPECT_FALSE(StringToFp("", &dummy)); + EXPECT_FALSE(StringToFp("xyz", &dummy)); + EXPECT_FALSE(StringToFp("0000000000000000xyz", &dummy)); +} + +TEST(HumanReadableNumBytes, Bytes) { + EXPECT_EQ("0B", HumanReadableNumBytes(0)); + EXPECT_EQ("4B", HumanReadableNumBytes(4)); + EXPECT_EQ("1023B", HumanReadableNumBytes(1023)); + + EXPECT_EQ("1.0KiB", HumanReadableNumBytes(1024)); + EXPECT_EQ("1.0KiB", HumanReadableNumBytes(1025)); + EXPECT_EQ("1.5KiB", HumanReadableNumBytes(1500)); + EXPECT_EQ("1.9KiB", HumanReadableNumBytes(1927)); + + EXPECT_EQ("2.0KiB", HumanReadableNumBytes(2048)); + EXPECT_EQ("1.00MiB", HumanReadableNumBytes(1 << 20)); + EXPECT_EQ("11.77MiB", HumanReadableNumBytes(12345678)); + EXPECT_EQ("1.00GiB", HumanReadableNumBytes(1 << 30)); + + EXPECT_EQ("1.00TiB", HumanReadableNumBytes(1LL << 40)); + EXPECT_EQ("1.00PiB", HumanReadableNumBytes(1LL << 50)); + EXPECT_EQ("1.00EiB", HumanReadableNumBytes(1LL << 60)); + + // Try a few negative numbers + EXPECT_EQ("-1B", HumanReadableNumBytes(-1)); + EXPECT_EQ("-4B", HumanReadableNumBytes(-4)); + EXPECT_EQ("-1000B", HumanReadableNumBytes(-1000)); + EXPECT_EQ("-11.77MiB", HumanReadableNumBytes(-12345678)); + EXPECT_EQ("-8E", HumanReadableNumBytes(kint64min)); +} + +TEST(safe_strto32, Int32s) { + int32 result; + + EXPECT_EQ(true, safe_strto32("1", &result)); + EXPECT_EQ(1, result); + EXPECT_EQ(true, safe_strto32("123", &result)); + EXPECT_EQ(123, result); + EXPECT_EQ(true, safe_strto32(" -123 ", &result)); + EXPECT_EQ(-123, result); + EXPECT_EQ(true, safe_strto32("2147483647", &result)); + EXPECT_EQ(2147483647, result); + EXPECT_EQ(true, safe_strto32("-2147483648", &result)); + EXPECT_EQ(-2147483648, result); + + // Invalid argument + EXPECT_EQ(false, safe_strto32(" 132as ", &result)); + EXPECT_EQ(false, safe_strto32(" 132.2 ", &result)); + EXPECT_EQ(false, safe_strto32(" -", &result)); + EXPECT_EQ(false, safe_strto32("", &result)); + EXPECT_EQ(false, safe_strto32(" ", &result)); + EXPECT_EQ(false, safe_strto32("123 a", &result)); + + // Overflow + EXPECT_EQ(false, safe_strto32("2147483648", &result)); + EXPECT_EQ(false, safe_strto32("-2147483649", &result)); +} + +TEST(safe_strto64, Int64s) { + int64 result; + + EXPECT_EQ(true, safe_strto64("1", &result)); + EXPECT_EQ(1, result); + EXPECT_EQ(true, safe_strto64("123", &result)); + EXPECT_EQ(123, result); + EXPECT_EQ(true, safe_strto64(" -123 ", &result)); + EXPECT_EQ(-123, result); + EXPECT_EQ(true, safe_strto64("9223372036854775807", &result)); + EXPECT_EQ(9223372036854775807, result); + EXPECT_EQ(true, safe_strto64("-9223372036854775808", &result)); + // kint64min == -9223372036854775808 + // Use -9223372036854775808 directly results in out of range error + EXPECT_EQ(kint64min, result); + + // Invalid argument + EXPECT_EQ(false, safe_strto64(" 132as ", &result)); + EXPECT_EQ(false, safe_strto64(" 132.2 ", &result)); + EXPECT_EQ(false, safe_strto64(" -", &result)); + EXPECT_EQ(false, safe_strto64("", &result)); + EXPECT_EQ(false, safe_strto64(" ", &result)); + EXPECT_EQ(false, safe_strto64("123 a", &result)); + + // Overflow + EXPECT_EQ(false, safe_strto64("9223372036854775808", &result)); + EXPECT_EQ(false, safe_strto64("-9223372036854775809", &result)); +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/ordered_code.cc b/tensorflow/core/lib/strings/ordered_code.cc new file mode 100644 index 0000000000..ec67595ebb --- /dev/null +++ b/tensorflow/core/lib/strings/ordered_code.cc @@ -0,0 +1,515 @@ +#include "tensorflow/core/lib/strings/ordered_code.h" + +#include <assert.h> +#include <stddef.h> + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace strings { + +// We encode a string in different ways depending on whether the item +// should be in lexicographically increasing or decreasing order. +// +// +// Lexicographically increasing order +// +// We want a string-to-string mapping F(x) such that for any two strings +// +// x < y => F(x) < F(y) +// +// In addition to the normal characters '\x00' through '\xff', we want to +// encode a few extra symbols in strings: +// +// <sep> Separator between items +// <infinity> Infinite string +// +// Therefore we need an alphabet with at least 258 symbols. Each +// character '\1' through '\xfe' is mapped to itself. The other four are +// encoded into two-letter sequences starting with '\0' and '\xff': +// +// <sep> encoded as => \0\1 +// \0 encoded as => \0\xff +// \xff encoded as => \xff\x00 +// <infinity> encoded as => \xff\xff +// +// The remaining two-letter sequences starting with '\0' and '\xff' are +// currently unused. +// +// F(<infinity>) is defined above. For any finite string x, F(x) is the +// the encodings of x's characters followed by the encoding for <sep>. The +// ordering of two finite strings is the same as the ordering of the +// respective characters at the first position where they differ, which in +// turn is the same as the ordering of the encodings of those two +// characters. Moreover, for every finite string x, F(x) < F(<infinity>). +// +// +// Lexicographically decreasing order +// +// We want a string-to-string mapping G(x) such that for any two strings, +// whether finite or not, +// +// x < y => G(x) > G(y) +// +// To achieve this, define G(x) to be the inversion of F(x): I(F(x)). In +// other words, invert every bit in F(x) to get G(x). For example, +// +// x = \x00\x13\xff +// F(x) = \x00\xff\x13\xff\x00\x00\x01 escape \0, \xff, append F(<sep>) +// G(x) = \xff\x00\xec\x00\xff\xff\xfe invert every bit in F(x) +// +// x = <infinity> +// F(x) = \xff\xff +// G(x) = \x00\x00 +// +// Another example is +// +// x F(x) G(x) = I(F(x)) +// - ---- -------------- +// <infinity> \xff\xff \x00\x00 +// "foo" foo\0\1 \x99\x90\x90\xff\xfe +// "aaa" aaa\0\1 \x9e\x9e\x9e\xff\xfe +// "aa" aa\0\1 \x9e\x9e\xff\xfe +// "" \0\1 \xff\xfe +// +// More generally and rigorously, if for any two strings x and y +// +// F(x) < F(y) => I(F(x)) > I(F(y)) (1) +// +// it would follow that x < y => G(x) > G(y) because +// +// x < y => F(x) < F(y) => G(x) = I(F(x)) > I(F(y)) = G(y) +// +// We now show why (1) is true, in two parts. Notice that for any two +// strings x < y, F(x) is *not* a proper prefix of F(y). Suppose x is a +// proper prefix of y (say, x="abc" < y="abcd"). F(x) and F(y) diverge at +// the F(<sep>) in F(x) (v. F('d') in the example). Suppose x is not a +// proper prefix of y (say, x="abce" < y="abd"), F(x) and F(y) diverge at +// their respective encodings of the characters where x and y diverge +// (F('c') v. F('d')). Finally, if y=<infinity>, we can see that +// F(y)=\xff\xff is not the prefix of F(x) for any finite string x, simply +// by considering all the possible first characters of F(x). +// +// Given that F(x) is not a proper prefix F(y), the order of F(x) and F(y) +// is determined by the byte where F(x) and F(y) diverge. For example, the +// order of F(x)="eefh" and F(y)="eeg" is determined by their third +// characters. I(p) inverts each byte in p, which effectively subtracts +// each byte from 0xff. So, in this example, I('f') > I('g'), and thus +// I(F(x)) > I(F(y)). +// +// +// Implementation +// +// To implement G(x) efficiently, we use C++ template to instantiate two +// versions of the code to produce F(x), one for normal encoding (giving us +// F(x)) and one for inverted encoding (giving us G(x) = I(F(x))). + +static const char kEscape1 = '\000'; +static const char kNullCharacter = '\xff'; // Combined with kEscape1 +static const char kSeparator = '\001'; // Combined with kEscape1 + +static const char kEscape2 = '\xff'; +static const char kInfinity = '\xff'; // Combined with kEscape2 +static const char kFFCharacter = '\000'; // Combined with kEscape2 + +static const char kEscape1_Separator[2] = {kEscape1, kSeparator}; + +// Append to "*dest" the "len" bytes starting from "*src". +inline static void AppendBytes(string* dest, const char* src, int len) { + dest->append(src, len); +} + +inline bool IsSpecialByte(char c) { return ((unsigned char)(c + 1)) < 2; } + +// Return a pointer to the first byte in the range "[start..limit)" +// whose value is 0 or 255 (kEscape1 or kEscape2). If no such byte +// exists in the range, returns "limit". +inline const char* SkipToNextSpecialByte(const char* start, const char* limit) { + // If these constants were ever changed, this routine needs to change + DCHECK_EQ(kEscape1, 0); + DCHECK_EQ(kEscape2 & 0xffu, 255u); + const char* p = start; + while (p < limit && !IsSpecialByte(*p)) { + p++; + } + return p; +} + +// Expose SkipToNextSpecialByte for testing purposes +const char* OrderedCode::TEST_SkipToNextSpecialByte(const char* start, + const char* limit) { + return SkipToNextSpecialByte(start, limit); +} + +// Helper routine to encode "s" and append to "*dest", escaping special +// characters. +inline static void EncodeStringFragment(string* dest, StringPiece s) { + const char* p = s.data(); + const char* limit = p + s.size(); + const char* copy_start = p; + while (true) { + p = SkipToNextSpecialByte(p, limit); + if (p >= limit) break; // No more special characters that need escaping + char c = *(p++); + DCHECK(IsSpecialByte(c)); + if (c == kEscape1) { + AppendBytes(dest, copy_start, p - copy_start - 1); + dest->push_back(kEscape1); + dest->push_back(kNullCharacter); + copy_start = p; + } else { + assert(c == kEscape2); + AppendBytes(dest, copy_start, p - copy_start - 1); + dest->push_back(kEscape2); + dest->push_back(kFFCharacter); + copy_start = p; + } + } + if (p > copy_start) { + AppendBytes(dest, copy_start, p - copy_start); + } +} + +void OrderedCode::WriteString(string* dest, StringPiece s) { + EncodeStringFragment(dest, s); + AppendBytes(dest, kEscape1_Separator, 2); +} + +void OrderedCode::WriteNumIncreasing(string* dest, uint64 val) { + // Values are encoded with a single byte length prefix, followed + // by the actual value in big-endian format with leading 0 bytes + // dropped. + unsigned char buf[9]; // 8 bytes for value plus one byte for length + int len = 0; + while (val > 0) { + len++; + buf[9 - len] = (val & 0xff); + val >>= 8; + } + buf[9 - len - 1] = (unsigned char)len; + len++; + AppendBytes(dest, reinterpret_cast<const char*>(buf + 9 - len), len); +} + +// Parse the encoding of a previously encoded string. +// If parse succeeds, return true, consume encoding from +// "*src", and if result != NULL append the decoded string to "*result". +// Otherwise, return false and leave both undefined. +inline static bool ReadStringInternal(StringPiece* src, string* result) { + const char* start = src->data(); + const char* string_limit = src->data() + src->size(); + + // We only scan up to "limit-2" since a valid string must end with + // a two character terminator: 'kEscape1 kSeparator' + const char* limit = string_limit - 1; + const char* copy_start = start; + while (true) { + start = SkipToNextSpecialByte(start, limit); + if (start >= limit) break; // No terminator sequence found + const char c = *(start++); + // If inversion is required, instead of inverting 'c', we invert the + // character constants to which 'c' is compared. We get the same + // behavior but save the runtime cost of inverting 'c'. + DCHECK(IsSpecialByte(c)); + if (c == kEscape1) { + if (result) { + AppendBytes(result, copy_start, start - copy_start - 1); + } + // kEscape1 kSeparator ends component + // kEscape1 kNullCharacter represents '\0' + const char next = *(start++); + if (next == kSeparator) { + src->remove_prefix(start - src->data()); + return true; + } else if (next == kNullCharacter) { + if (result) { + *result += '\0'; + } + } else { + return false; + } + copy_start = start; + } else { + assert(c == kEscape2); + if (result) { + AppendBytes(result, copy_start, start - copy_start - 1); + } + // kEscape2 kFFCharacter represents '\xff' + // kEscape2 kInfinity is an error + const char next = *(start++); + if (next == kFFCharacter) { + if (result) { + *result += '\xff'; + } + } else { + return false; + } + copy_start = start; + } + } + return false; +} + +bool OrderedCode::ReadString(StringPiece* src, string* result) { + return ReadStringInternal(src, result); +} + +bool OrderedCode::ReadNumIncreasing(StringPiece* src, uint64* result) { + if (src->empty()) { + return false; // Not enough bytes + } + + // Decode length byte + const size_t len = static_cast<unsigned char>((*src)[0]); + + // If len > 0 and src is longer than 1, the first byte of "payload" + // must be non-zero (otherwise the encoding is not minimal). + // In opt mode, we don't enforce that encodings must be minimal. + DCHECK(0 == len || src->size() == 1 || (*src)[1] != '\0') + << "invalid encoding"; + + if (len + 1 > src->size() || len > 8) { + return false; // Not enough bytes or too many bytes + } + + if (result) { + uint64 tmp = 0; + for (size_t i = 0; i < len; i++) { + tmp <<= 8; + tmp |= static_cast<unsigned char>((*src)[1 + i]); + } + *result = tmp; + } + src->remove_prefix(len + 1); + return true; +} + +void OrderedCode::TEST_Corrupt(string* str, int k) { + int seen_seps = 0; + for (size_t i = 0; i + 1 < str->size(); i++) { + if ((*str)[i] == kEscape1 && (*str)[i + 1] == kSeparator) { + seen_seps++; + if (seen_seps == k) { + (*str)[i + 1] = kSeparator + 1; + return; + } + } + } +} + +// Signed number encoding/decoding ///////////////////////////////////// +// +// The format is as follows: +// +// The first bit (the most significant bit of the first byte) +// represents the sign, 0 if the number is negative and +// 1 if the number is >= 0. +// +// Any unbroken sequence of successive bits with the same value as the sign +// bit, up to 9 (the 8th and 9th are the most significant bits of the next +// byte), are size bits that count the number of bytes after the first byte. +// That is, the total length is between 1 and 10 bytes. +// +// The value occupies the bits after the sign bit and the "size bits" +// till the end of the string, in network byte order. If the number +// is negative, the bits are in 2-complement. +// +// +// Example 1: number 0x424242 -> 4 byte big-endian hex string 0xf0424242: +// +// +---------------+---------------+---------------+---------------+ +// 1 1 1 1 0 0 0 0 0 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 1 0 +// +---------------+---------------+---------------+---------------+ +// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ +// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// | | | | payload: the remaining bits after the sign and size bits +// | | | | and the delimiter bit, the value is 0x424242 +// | | | | +// | size bits: 3 successive bits with the same value as the sign bit +// | (followed by a delimiter bit with the opposite value) +// | mean that there are 3 bytes after the first byte, 4 total +// | +// sign bit: 1 means that the number is non-negative +// +// Example 2: negative number -0x800 -> 2 byte big-endian hex string 0x3800: +// +// +---------------+---------------+ +// 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 +// +---------------+---------------+ +// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ +// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// | | payload: the remaining bits after the sign and size bits and the +// | | delimiter bit, 2-complement because of the negative sign, +// | | value is ~0x7ff, represents the value -0x800 +// | | +// | size bits: 1 bit with the same value as the sign bit +// | (followed by a delimiter bit with the opposite value) +// | means that there is 1 byte after the first byte, 2 total +// | +// sign bit: 0 means that the number is negative +// +// +// Compared with the simpler unsigned format used for uint64 numbers, +// this format is more compact for small numbers, namely one byte encodes +// numbers in the range [-64,64), two bytes cover the range [-2^13,2^13), etc. +// In general, n bytes encode numbers in the range [-2^(n*7-1),2^(n*7-1)). +// (The cross-over point for compactness of representation is 8 bytes, +// where this format only covers the range [-2^55,2^55), +// whereas an encoding with sign bit and length in the first byte and +// payload in all following bytes would cover [-2^56,2^56).) + +static const int kMaxSigned64Length = 10; + +// This array maps encoding length to header bits in the first two bytes. +static const char kLengthToHeaderBits[1 + kMaxSigned64Length][2] = { + {0, 0}, {'\x80', 0}, {'\xc0', 0}, {'\xe0', 0}, + {'\xf0', 0}, {'\xf8', 0}, {'\xfc', 0}, {'\xfe', 0}, + {'\xff', 0}, {'\xff', '\x80'}, {'\xff', '\xc0'}}; + +// This array maps encoding lengths to the header bits that overlap with +// the payload and need fixing when reading. +static const uint64 kLengthToMask[1 + kMaxSigned64Length] = { + 0ULL, + 0x80ULL, + 0xc000ULL, + 0xe00000ULL, + 0xf0000000ULL, + 0xf800000000ULL, + 0xfc0000000000ULL, + 0xfe000000000000ULL, + 0xff00000000000000ULL, + 0x8000000000000000ULL, + 0ULL}; + +// This array maps the number of bits in a number to the encoding +// length produced by WriteSignedNumIncreasing. +// For positive numbers, the number of bits is 1 plus the most significant +// bit position (the highest bit position in a positive int64 is 63). +// For a negative number n, we count the bits in ~n. +// That is, length = kBitsToLength[Bits::Log2Floor64(n < 0 ? ~n : n) + 1]. +static const int8 kBitsToLength[1 + 63] = { + 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, + 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 7, + 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 10}; + +#if defined(__GNUC__) +// Returns floor(lg(n)). Returns -1 if n == 0. +static int Log2Floor64(uint64 n) { + return n == 0 ? -1 : 63 ^ __builtin_clzll(n); +} +#else +// Portable slow version +static int Log2Floor32_Portable(uint32 n) { + if (n == 0) return -1; + int log = 0; + uint32 value = n; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32 x = value >> shift; + if (x != 0) { + value = x; + log += shift; + } + } + assert(value == 1); + return log; +} +// Returns floor(lg(n)). Returns -1 if n == 0. +static int Log2Floor64(uint64 n) { + const uint32 topbits = static_cast<uint32>(n >> 32); + if (topbits == 0) { + // Top bits are zero, so scan in bottom bits + return Log2Floor32_Portable(static_cast<uint32>(n)); + } else { + return 32 + Log2Floor32_Portable(topbits); + } +} +#endif + +// Calculates the encoding length in bytes of the signed number n. +static inline int SignedEncodingLength(int64 n) { + return kBitsToLength[Log2Floor64(n < 0 ? ~n : n) + 1]; +} + +static void StoreBigEndian64(char* dst, uint64 v) { + for (int i = 0; i < 8; i++) { + dst[i] = (v >> (56 - 8 * i)) & 0xff; + } +} + +static uint64 LoadBigEndian64(const char* src) { + uint64 result = 0; + for (int i = 0; i < 8; i++) { + unsigned char c = static_cast<unsigned char>(src[i]); + result |= static_cast<uint64>(c) << (56 - 8 * i); + } + return result; +} + +void OrderedCode::WriteSignedNumIncreasing(string* dest, int64 val) { + const uint64 x = val < 0 ? ~val : val; + if (x < 64) { // fast path for encoding length == 1 + *dest += kLengthToHeaderBits[1][0] ^ val; + return; + } + // buf = val in network byte order, sign extended to 10 bytes + const char sign_byte = val < 0 ? '\xff' : '\0'; + char buf[10] = { + sign_byte, sign_byte, + }; + StoreBigEndian64(buf + 2, val); + static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch"); + const int len = SignedEncodingLength(x); + DCHECK_GE(len, 2); + char* const begin = buf + sizeof(buf) - len; + begin[0] ^= kLengthToHeaderBits[len][0]; + begin[1] ^= kLengthToHeaderBits[len][1]; // ok because len >= 2 + dest->append(begin, len); +} + +bool OrderedCode::ReadSignedNumIncreasing(StringPiece* src, int64* result) { + if (src->empty()) return false; + const uint64 xor_mask = (!((*src)[0] & 0x80)) ? ~0ULL : 0ULL; + const unsigned char first_byte = (*src)[0] ^ (xor_mask & 0xff); + + // now calculate and test length, and set x to raw (unmasked) result + int len; + uint64 x; + if (first_byte != 0xff) { + len = 7 - Log2Floor64(first_byte ^ 0xff); + if (src->size() < static_cast<size_t>(len)) return false; + x = xor_mask; // sign extend using xor_mask + for (int i = 0; i < len; ++i) + x = (x << 8) | static_cast<unsigned char>((*src)[i]); + } else { + len = 8; + if (src->size() < static_cast<size_t>(len)) return false; + const unsigned char second_byte = (*src)[1] ^ (xor_mask & 0xff); + if (second_byte >= 0x80) { + if (second_byte < 0xc0) { + len = 9; + } else { + const unsigned char third_byte = (*src)[2] ^ (xor_mask & 0xff); + if (second_byte == 0xc0 && third_byte < 0x80) { + len = 10; + } else { + return false; // either len > 10 or len == 10 and #bits > 63 + } + } + if (src->size() < static_cast<size_t>(len)) return false; + } + x = LoadBigEndian64(src->data() + len - 8); + } + + x ^= kLengthToMask[len]; // remove spurious header bits + + DCHECK_EQ(len, SignedEncodingLength(x)) << "invalid encoding"; + + if (result) *result = x; + src->remove_prefix(len); + return true; +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/ordered_code.h b/tensorflow/core/lib/strings/ordered_code.h new file mode 100644 index 0000000000..39f1df9a94 --- /dev/null +++ b/tensorflow/core/lib/strings/ordered_code.h @@ -0,0 +1,77 @@ +// This module provides routines for encoding a sequence of typed +// entities into a string. The resulting strings can be +// lexicographically compared to yield the same comparison value that +// would have been generated if the encoded items had been compared +// one by one according to their type. +// +// More precisely, suppose: +// 1. string A is generated by encoding the sequence of items [A_1..A_n] +// 2. string B is generated by encoding the sequence of items [B_1..B_n] +// 3. The types match; i.e., for all i: A_i was encoded using +// the same routine as B_i +// Then: +// Comparing A vs. B lexicographically is the same as comparing +// the vectors [A_1..A_n] and [B_1..B_n] lexicographically. +// +// Furthermore, if n < m, the encoding of [A_1..A_n] is a strict prefix of +// [A_1..A_m] (unless m = n+1 and A_m is the empty string encoded with +// WriteTrailingString, in which case the encodings are equal). +// +// This module is often useful when generating multi-part sstable +// keys that have to be ordered in a particular fashion. + +#ifndef TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__ +#define TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__ + +#include <string> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +class StringPiece; + +namespace strings { + +class OrderedCode { + public: + // ------------------------------------------------------------------- + // Encoding routines: each one of the following routines append + // one item to "*dest" in an encoding where larger values are + // ordered lexicographically after smaller values. + static void WriteString(string* dest, StringPiece str); + static void WriteNumIncreasing(string* dest, uint64 num); + static void WriteSignedNumIncreasing(string* dest, int64 num); + + // ------------------------------------------------------------------- + // Decoding routines: these extract an item earlier encoded using + // the corresponding WriteXXX() routines above. The item is read + // from "*src"; "*src" is modified to point past the decoded item; + // and if "result" is non-NULL, "*result" is modified to contain the + // result. In case of string result, the decoded string is appended to + // "*result". Returns true if the next item was read successfully, false + // otherwise. + static bool ReadString(StringPiece* src, string* result); + static bool ReadNumIncreasing(StringPiece* src, uint64* result); + static bool ReadSignedNumIncreasing(StringPiece* src, int64* result); + + // Helper for testing: corrupt "*str" by changing the kth item separator + // in the string. + static void TEST_Corrupt(string* str, int k); + + // Helper for testing. + // SkipToNextSpecialByte is an internal routine defined in the .cc file + // with the following semantics. Return a pointer to the first byte + // in the range "[start..limit)" whose value is 0 or 255. If no such + // byte exists in the range, returns "limit". + static const char* TEST_SkipToNextSpecialByte(const char* start, + const char* limit); + + private: + // This has only static methods, so disallow construction entirely + OrderedCode(); + TF_DISALLOW_COPY_AND_ASSIGN(OrderedCode); +}; + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__ diff --git a/tensorflow/core/lib/strings/ordered_code_test.cc b/tensorflow/core/lib/strings/ordered_code_test.cc new file mode 100644 index 0000000000..d517d14f4a --- /dev/null +++ b/tensorflow/core/lib/strings/ordered_code_test.cc @@ -0,0 +1,1183 @@ +#include "tensorflow/core/lib/strings/ordered_code.h" + +#include <float.h> +#include <stddef.h> +#include <limits> +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace strings { + +static string RandomString(random::SimplePhilox* rnd, int len) { + string x; + for (int i = 0; i < len; i++) { + x += rnd->Uniform(256); + } + return x; +} + +// --------------------------------------------------------------------- +// Utility template functions (they help templatize the tests below) + +// Read/WriteIncreasing are defined for string, uint64, int64 below. +template <typename T> +static void OCWriteIncreasing(string* dest, const T& val); +template <typename T> +static bool OCReadIncreasing(StringPiece* src, T* result); + +// Read/WriteIncreasing<string> +template <> +void OCWriteIncreasing<string>(string* dest, const string& val) { + OrderedCode::WriteString(dest, val); +} +template <> +bool OCReadIncreasing<string>(StringPiece* src, string* result) { + return OrderedCode::ReadString(src, result); +} + +// Read/WriteIncreasing<uint64> +template <> +void OCWriteIncreasing<uint64>(string* dest, const uint64& val) { + OrderedCode::WriteNumIncreasing(dest, val); +} +template <> +bool OCReadIncreasing<uint64>(StringPiece* src, uint64* result) { + return OrderedCode::ReadNumIncreasing(src, result); +} + +// Read/WriteIncreasing<int64> +template <> +void OCWriteIncreasing<int64>(string* dest, const int64& val) { + OrderedCode::WriteSignedNumIncreasing(dest, val); +} +template <> +bool OCReadIncreasing<int64>(StringPiece* src, int64* result) { + return OrderedCode::ReadSignedNumIncreasing(src, result); +} + +template <typename T> +string OCWrite(T val) { + string result; + OCWriteIncreasing<T>(&result, val); + return result; +} + +template <typename T> +void OCWriteToString(string* result, T val) { + OCWriteIncreasing<T>(result, val); +} + +template <typename T> +bool OCRead(StringPiece* s, T* val) { + return OCReadIncreasing<T>(s, val); +} + +// --------------------------------------------------------------------- +// Numbers + +template <typename T> +static T TestRead(const string& a) { + // gracefully reject any proper prefix of an encoding + for (int i = 0; i < a.size() - 1; ++i) { + StringPiece s(a.data(), i); + CHECK(!OCRead<T>(&s, NULL)); + CHECK_EQ(s, a.substr(0, i)); + } + + StringPiece s(a); + T v; + CHECK(OCRead<T>(&s, &v)); + CHECK(s.empty()); + return v; +} + +template <typename T> +static void TestWriteRead(T expected) { + EXPECT_EQ(expected, TestRead<T>(OCWrite<T>(expected))); +} + +// Verifies that the second Write* call appends a non-empty string to its +// output. +template <typename T, typename U> +static void TestWriteAppends(T first, U second) { + string encoded; + OCWriteToString<T>(&encoded, first); + string encoded_first_only = encoded; + OCWriteToString<U>(&encoded, second); + EXPECT_NE(encoded, encoded_first_only); + EXPECT_TRUE(StringPiece(encoded).starts_with(encoded_first_only)); +} + +template <typename T> +static void TestNumbers(T multiplier) { + // first test powers of 2 (and nearby numbers) + for (T x = std::numeric_limits<T>().max(); x != 0; x /= 2) { + TestWriteRead(multiplier * (x - 1)); + TestWriteRead(multiplier * x); + if (x != std::numeric_limits<T>::max()) { + TestWriteRead(multiplier * (x + 1)); + } else if (multiplier < 0 && multiplier == -1) { + TestWriteRead(-x - 1); + } + } + + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int bits = 1; bits <= std::numeric_limits<T>().digits; ++bits) { + // test random non-negative numbers with given number of significant bits + const uint64 mask = (~0ULL) >> (64 - bits); + for (int i = 0; i < 1000; i++) { + T x = rnd.Rand64() & mask; + TestWriteRead(multiplier * x); + T y = rnd.Rand64() & mask; + TestWriteAppends(multiplier * x, multiplier * y); + } + } +} + +// Return true iff 'a' is "before" 'b' +static bool CompareStrings(const string& a, const string& b) { return (a < b); } + +template <typename T> +static void TestNumberOrdering() { + // first the negative numbers (if T is signed, otherwise no-op) + string laststr = OCWrite<T>(std::numeric_limits<T>().min()); + for (T num = std::numeric_limits<T>().min() / 2; num != 0; num /= 2) { + string strminus1 = OCWrite<T>(num - 1); + string str = OCWrite<T>(num); + string strplus1 = OCWrite<T>(num + 1); + + CHECK(CompareStrings(strminus1, str)); + CHECK(CompareStrings(str, strplus1)); + + // Compare 'str' with 'laststr'. When we approach 0, 'laststr' is + // not necessarily before 'strminus1'. + CHECK(CompareStrings(laststr, str)); + laststr = str; + } + + // then the positive numbers + laststr = OCWrite<T>(0); + T num = 1; + while (num < std::numeric_limits<T>().max() / 2) { + num *= 2; + string strminus1 = OCWrite<T>(num - 1); + string str = OCWrite<T>(num); + string strplus1 = OCWrite<T>(num + 1); + + CHECK(CompareStrings(strminus1, str)); + CHECK(CompareStrings(str, strplus1)); + + // Compare 'str' with 'laststr'. + CHECK(CompareStrings(laststr, str)); + laststr = str; + } +} + +// Helper routine for testing TEST_SkipToNextSpecialByte +static int FindSpecial(const string& x) { + const char* p = x.data(); + const char* limit = p + x.size(); + const char* result = OrderedCode::TEST_SkipToNextSpecialByte(p, limit); + return result - p; +} + +TEST(OrderedCode, SkipToNextSpecialByte) { + for (size_t len = 0; len < 256; len++) { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + string x; + while (x.size() < len) { + char c = 1 + rnd.Uniform(254); + ASSERT_NE(c, 0); + ASSERT_NE(c, 255); + x += c; // No 0 bytes, no 255 bytes + } + EXPECT_EQ(FindSpecial(x), x.size()); + for (size_t special_pos = 0; special_pos < len; special_pos++) { + for (size_t special_test = 0; special_test < 2; special_test++) { + const char special_byte = (special_test == 0) ? 0 : 255; + string y = x; + y[special_pos] = special_byte; + EXPECT_EQ(FindSpecial(y), special_pos); + if (special_pos < 16) { + // Add some special bytes after the one at special_pos to make sure + // we still return the earliest special byte in the string + for (size_t rest = special_pos + 1; rest < len; rest++) { + if (rnd.OneIn(3)) { + y[rest] = rnd.OneIn(2) ? 0 : 255; + EXPECT_EQ(FindSpecial(y), special_pos); + } + } + } + } + } + } +} + +TEST(OrderedCode, ExhaustiveFindSpecial) { + char buf[16]; + char* limit = buf + sizeof(buf); + int count = 0; + for (int start_offset = 0; start_offset <= 5; start_offset += 5) { + // We test exhaustively with all combinations of 3 bytes starting + // at offset 0 and offset 5 (so as to test with the bytes at both + // ends of a 64-bit word). + for (size_t i = 0; i < sizeof(buf); i++) { + buf[i] = 'a'; // Not a special byte + } + for (int b0 = 0; b0 < 256; b0++) { + for (int b1 = 0; b1 < 256; b1++) { + for (int b2 = 0; b2 < 256; b2++) { + buf[start_offset + 0] = b0; + buf[start_offset + 1] = b1; + buf[start_offset + 2] = b2; + char* expected; + if (b0 == 0 || b0 == 255) { + expected = &buf[start_offset]; + } else if (b1 == 0 || b1 == 255) { + expected = &buf[start_offset + 1]; + } else if (b2 == 0 || b2 == 255) { + expected = &buf[start_offset + 2]; + } else { + expected = limit; + } + count++; + EXPECT_EQ(expected, + OrderedCode::TEST_SkipToNextSpecialByte(buf, limit)); + } + } + } + } + EXPECT_EQ(count, 256 * 256 * 256 * 2); +} + +TEST(Uint64, EncodeDecode) { TestNumbers<uint64>(1); } + +TEST(Uint64, Ordering) { TestNumberOrdering<uint64>(); } + +TEST(Int64, EncodeDecode) { + TestNumbers<int64>(1); + TestNumbers<int64>(-1); +} + +TEST(Int64, Ordering) { TestNumberOrdering<int64>(); } + +// Returns the bitwise complement of s. +static inline string StrNot(const string& s) { + string result; + for (string::const_iterator it = s.begin(); it != s.end(); ++it) + result.push_back(~*it); + return result; +} + +template <typename T> +static void TestInvalidEncoding(const string& s) { + StringPiece p(s); + EXPECT_FALSE(OCRead<T>(&p, static_cast<T*>(NULL))); + EXPECT_EQ(s, p); +} + +TEST(OrderedCodeInvalidEncodingsTest, Overflow) { + // 1U << 64, increasing and decreasing + const string k2xx64U = "\x09\x01" + string(8, 0); + TestInvalidEncoding<uint64>(k2xx64U); + + // 1 << 63 and ~(1 << 63), increasing and decreasing + const string k2xx63 = "\xff\xc0\x80" + string(7, 0); + TestInvalidEncoding<int64>(k2xx63); + TestInvalidEncoding<int64>(StrNot(k2xx63)); +} + +TEST(OrderedCodeInvalidEncodingsDeathTest, NonCanonical) { + // Test "ambiguous"/"non-canonical" encodings. + // These are non-minimal (but otherwise "valid") encodings that + // differ from the minimal encoding chosen by OrderedCode::WriteXXX + // and thus should be avoided to not mess up the string ordering of + // encodings. + + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + + for (int n = 2; n <= 9; ++n) { + // The zero in non_minimal[1] is "redundant". + string non_minimal = + string(1, n - 1) + string(1, 0) + RandomString(&rnd, n - 2); + EXPECT_EQ(n, non_minimal.length()); + + EXPECT_NE(OCWrite<uint64>(0), non_minimal); +#ifndef NDEBUG + StringPiece s(non_minimal); + EXPECT_DEATH(OrderedCode::ReadNumIncreasing(&s, NULL), "invalid encoding"); +#else + TestRead<uint64>(non_minimal); +#endif + } + + for (int n = 2; n <= 10; ++n) { + // Header with 1 sign bit and n-1 size bits. + string header = string(n / 8, 0xff) + string(1, 0xff << (8 - (n % 8))); + // There are more than 7 zero bits between header bits and "payload". + string non_minimal = header + + string(1, rnd.Uniform(256) & ~*header.rbegin()) + + RandomString(&rnd, n - header.length() - 1); + EXPECT_EQ(n, non_minimal.length()); + + EXPECT_NE(OCWrite<int64>(0), non_minimal); +#ifndef NDEBUG + StringPiece s(non_minimal); + EXPECT_DEATH(OrderedCode::ReadSignedNumIncreasing(&s, NULL), + "invalid encoding") + << n; +#else + TestRead<int64>(non_minimal); +#endif + } +} + +// Returns random number with specified number of bits, +// i.e., in the range [2^(bits-1),2^bits). +static uint64 NextBits(random::SimplePhilox* rnd, int bits) { + return (bits != 0) + ? (rnd->Rand64() % (1LL << (bits - 1))) + (1LL << (bits - 1)) + : 0; +} + +template <typename T> +static void BM_WriteNum(int n, T multiplier) { + static const int kValues = 64; + T values[kValues]; + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + // Use enough distinct values to confuse the branch predictor + for (int i = 0; i < kValues; i++) { + values[i] = NextBits(&rnd, n % 64) * multiplier; + } + string result; + int index = 0; + while (n-- > 0) { + result.clear(); + OCWriteToString<T>(&result, values[index % kValues]); + index++; + } +} + +template <typename T> +static void BM_ReadNum(int n, T multiplier) { + string x; + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + // Use enough distinct values to confuse the branch predictor + static const int kValues = 64; + string values[kValues]; + for (int i = 0; i < kValues; i++) { + T val = NextBits(&rnd, i % 64) * multiplier; + values[i] = OCWrite<T>(val); + } + uint32 index = 0; + while (n-- > 0) { + T val; + StringPiece s = values[index++ % kValues]; + OCRead<T>(&s, &val); + } +} + +#define BENCHMARK_NUM(name, T, multiplier) \ + static void BM_Write##name(int n) { BM_WriteNum<T>(n, multiplier); } \ + BENCHMARK(BM_Write##name); \ + static void BM_Read##name(int n) { BM_ReadNum<T>(n, multiplier); } \ + BENCHMARK(BM_Read##name) + +BENCHMARK_NUM(NumIncreasing, uint64, 1); +BENCHMARK_NUM(SignedNum, int64, 1); +BENCHMARK_NUM(SignedNumNegative, int64, -1); + +#undef BENCHMARK_NUM + +// --------------------------------------------------------------------- +// Strings + +TEST(String, EncodeDecode) { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + + for (int len = 0; len < 256; len++) { + const string a = RandomString(&rnd, len); + TestWriteRead(a); + for (int len2 = 0; len2 < 64; len2++) { + const string b = RandomString(&rnd, len2); + + TestWriteAppends(a, b); + + string out; + OCWriteToString<string>(&out, a); + OCWriteToString<string>(&out, b); + + string a2, b2, dummy; + StringPiece s = out; + StringPiece s2 = out; + CHECK(OCRead<string>(&s, &a2)); + CHECK(OCRead<string>(&s2, NULL)); + CHECK_EQ(s, s2); + + CHECK(OCRead<string>(&s, &b2)); + CHECK(OCRead<string>(&s2, NULL)); + CHECK_EQ(s, s2); + + CHECK(!OCRead<string>(&s, &dummy)); + CHECK(!OCRead<string>(&s2, NULL)); + CHECK_EQ(a, a2); + CHECK_EQ(b, b2); + CHECK(s.empty()); + CHECK(s2.empty()); + } + } +} + +// 'str' is a static C-style string that may contain '\0' +#define STATIC_STR(str) StringPiece((str), sizeof(str) - 1) + +static string EncodeStringIncreasing(StringPiece value) { + string encoded; + OrderedCode::WriteString(&encoded, value); + return encoded; +} + +TEST(String, Increasing) { + // Here are a series of strings in non-decreasing order, including + // consecutive strings such that the second one is equal to, a proper + // prefix of, or has the same length as the first one. Most also contain + // the special escaping characters '\x00' and '\xff'. + ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("")), + EncodeStringIncreasing(STATIC_STR(""))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("")), + EncodeStringIncreasing(STATIC_STR("\x00"))); + + ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("\x00")), + EncodeStringIncreasing(STATIC_STR("\x00"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\x00")), + EncodeStringIncreasing(STATIC_STR("\x01"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\x01")), + EncodeStringIncreasing(STATIC_STR("a"))); + + ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("a")), + EncodeStringIncreasing(STATIC_STR("a"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("a")), + EncodeStringIncreasing(STATIC_STR("aa"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("aa")), + EncodeStringIncreasing(STATIC_STR("\xff"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\xff")), + EncodeStringIncreasing(STATIC_STR("\xff\x00"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\xff\x00")), + EncodeStringIncreasing(STATIC_STR("\xff\x01"))); +} + +TEST(EncodingIsExpected, String) { + std::vector<std::pair<string, string>> data = { + {"", string("\x00\x01", 2)}, + {"foo", string("foo\x00\x01", 5)}, + {"hello", string("hello\x00\x01", 7)}, + {string("\x00\x01\xff", 3), string("\x00\xff\x01\xff\x00\x00\x01", 7)}, + }; + for (const auto& t : data) { + string result; + OrderedCode::WriteString(&result, t.first); + EXPECT_EQ(t.second, result); + + StringPiece in = result; + string decoded; + EXPECT_TRUE(OrderedCode::ReadString(&in, &decoded)); + EXPECT_EQ(t.first, decoded); + EXPECT_EQ("", in); + } +} + +TEST(EncodingIsExpected, Unsigned) { + std::vector<std::pair<uint64, string>> data = { + {0x0ull, string("\000", 1)}, + {0x1ull, string("\001\001", 2)}, + {0x2ull, string("\001\002", 2)}, + {0x1ull, string("\001\001", 2)}, + {0x2ull, string("\001\002", 2)}, + {0x3ull, string("\001\003", 2)}, + {0x3ull, string("\001\003", 2)}, + {0x4ull, string("\001\004", 2)}, + {0x5ull, string("\001\005", 2)}, + {0x7ull, string("\001\007", 2)}, + {0x8ull, string("\001\010", 2)}, + {0x9ull, string("\001\t", 2)}, + {0xfull, string("\001\017", 2)}, + {0x10ull, string("\001\020", 2)}, + {0x11ull, string("\001\021", 2)}, + {0x1full, string("\001\037", 2)}, + {0x20ull, string("\001 ", 2)}, + {0x21ull, string("\001!", 2)}, + {0x3full, string("\001?", 2)}, + {0x40ull, string("\001@", 2)}, + {0x41ull, string("\001A", 2)}, + {0x7full, string("\001\177", 2)}, + {0x80ull, string("\001\200", 2)}, + {0x81ull, string("\001\201", 2)}, + {0xffull, string("\001\377", 2)}, + {0x100ull, string("\002\001\000", 3)}, + {0x101ull, string("\002\001\001", 3)}, + {0x1ffull, string("\002\001\377", 3)}, + {0x200ull, string("\002\002\000", 3)}, + {0x201ull, string("\002\002\001", 3)}, + {0x3ffull, string("\002\003\377", 3)}, + {0x400ull, string("\002\004\000", 3)}, + {0x401ull, string("\002\004\001", 3)}, + {0x7ffull, string("\002\007\377", 3)}, + {0x800ull, string("\002\010\000", 3)}, + {0x801ull, string("\002\010\001", 3)}, + {0xfffull, string("\002\017\377", 3)}, + {0x1000ull, string("\002\020\000", 3)}, + {0x1001ull, string("\002\020\001", 3)}, + {0x1fffull, string("\002\037\377", 3)}, + {0x2000ull, string("\002 \000", 3)}, + {0x2001ull, string("\002 \001", 3)}, + {0x3fffull, string("\002?\377", 3)}, + {0x4000ull, string("\002@\000", 3)}, + {0x4001ull, string("\002@\001", 3)}, + {0x7fffull, string("\002\177\377", 3)}, + {0x8000ull, string("\002\200\000", 3)}, + {0x8001ull, string("\002\200\001", 3)}, + {0xffffull, string("\002\377\377", 3)}, + {0x10000ull, string("\003\001\000\000", 4)}, + {0x10001ull, string("\003\001\000\001", 4)}, + {0x1ffffull, string("\003\001\377\377", 4)}, + {0x20000ull, string("\003\002\000\000", 4)}, + {0x20001ull, string("\003\002\000\001", 4)}, + {0x3ffffull, string("\003\003\377\377", 4)}, + {0x40000ull, string("\003\004\000\000", 4)}, + {0x40001ull, string("\003\004\000\001", 4)}, + {0x7ffffull, string("\003\007\377\377", 4)}, + {0x80000ull, string("\003\010\000\000", 4)}, + {0x80001ull, string("\003\010\000\001", 4)}, + {0xfffffull, string("\003\017\377\377", 4)}, + {0x100000ull, string("\003\020\000\000", 4)}, + {0x100001ull, string("\003\020\000\001", 4)}, + {0x1fffffull, string("\003\037\377\377", 4)}, + {0x200000ull, string("\003 \000\000", 4)}, + {0x200001ull, string("\003 \000\001", 4)}, + {0x3fffffull, string("\003?\377\377", 4)}, + {0x400000ull, string("\003@\000\000", 4)}, + {0x400001ull, string("\003@\000\001", 4)}, + {0x7fffffull, string("\003\177\377\377", 4)}, + {0x800000ull, string("\003\200\000\000", 4)}, + {0x800001ull, string("\003\200\000\001", 4)}, + {0xffffffull, string("\003\377\377\377", 4)}, + {0x1000000ull, string("\004\001\000\000\000", 5)}, + {0x1000001ull, string("\004\001\000\000\001", 5)}, + {0x1ffffffull, string("\004\001\377\377\377", 5)}, + {0x2000000ull, string("\004\002\000\000\000", 5)}, + {0x2000001ull, string("\004\002\000\000\001", 5)}, + {0x3ffffffull, string("\004\003\377\377\377", 5)}, + {0x4000000ull, string("\004\004\000\000\000", 5)}, + {0x4000001ull, string("\004\004\000\000\001", 5)}, + {0x7ffffffull, string("\004\007\377\377\377", 5)}, + {0x8000000ull, string("\004\010\000\000\000", 5)}, + {0x8000001ull, string("\004\010\000\000\001", 5)}, + {0xfffffffull, string("\004\017\377\377\377", 5)}, + {0x10000000ull, string("\004\020\000\000\000", 5)}, + {0x10000001ull, string("\004\020\000\000\001", 5)}, + {0x1fffffffull, string("\004\037\377\377\377", 5)}, + {0x20000000ull, string("\004 \000\000\000", 5)}, + {0x20000001ull, string("\004 \000\000\001", 5)}, + {0x3fffffffull, string("\004?\377\377\377", 5)}, + {0x40000000ull, string("\004@\000\000\000", 5)}, + {0x40000001ull, string("\004@\000\000\001", 5)}, + {0x7fffffffull, string("\004\177\377\377\377", 5)}, + {0x80000000ull, string("\004\200\000\000\000", 5)}, + {0x80000001ull, string("\004\200\000\000\001", 5)}, + {0xffffffffull, string("\004\377\377\377\377", 5)}, + {0x100000000ull, string("\005\001\000\000\000\000", 6)}, + {0x100000001ull, string("\005\001\000\000\000\001", 6)}, + {0x1ffffffffull, string("\005\001\377\377\377\377", 6)}, + {0x200000000ull, string("\005\002\000\000\000\000", 6)}, + {0x200000001ull, string("\005\002\000\000\000\001", 6)}, + {0x3ffffffffull, string("\005\003\377\377\377\377", 6)}, + {0x400000000ull, string("\005\004\000\000\000\000", 6)}, + {0x400000001ull, string("\005\004\000\000\000\001", 6)}, + {0x7ffffffffull, string("\005\007\377\377\377\377", 6)}, + {0x800000000ull, string("\005\010\000\000\000\000", 6)}, + {0x800000001ull, string("\005\010\000\000\000\001", 6)}, + {0xfffffffffull, string("\005\017\377\377\377\377", 6)}, + {0x1000000000ull, string("\005\020\000\000\000\000", 6)}, + {0x1000000001ull, string("\005\020\000\000\000\001", 6)}, + {0x1fffffffffull, string("\005\037\377\377\377\377", 6)}, + {0x2000000000ull, string("\005 \000\000\000\000", 6)}, + {0x2000000001ull, string("\005 \000\000\000\001", 6)}, + {0x3fffffffffull, string("\005?\377\377\377\377", 6)}, + {0x4000000000ull, string("\005@\000\000\000\000", 6)}, + {0x4000000001ull, string("\005@\000\000\000\001", 6)}, + {0x7fffffffffull, string("\005\177\377\377\377\377", 6)}, + {0x8000000000ull, string("\005\200\000\000\000\000", 6)}, + {0x8000000001ull, string("\005\200\000\000\000\001", 6)}, + {0xffffffffffull, string("\005\377\377\377\377\377", 6)}, + {0x10000000000ull, string("\006\001\000\000\000\000\000", 7)}, + {0x10000000001ull, string("\006\001\000\000\000\000\001", 7)}, + {0x1ffffffffffull, string("\006\001\377\377\377\377\377", 7)}, + {0x20000000000ull, string("\006\002\000\000\000\000\000", 7)}, + {0x20000000001ull, string("\006\002\000\000\000\000\001", 7)}, + {0x3ffffffffffull, string("\006\003\377\377\377\377\377", 7)}, + {0x40000000000ull, string("\006\004\000\000\000\000\000", 7)}, + {0x40000000001ull, string("\006\004\000\000\000\000\001", 7)}, + {0x7ffffffffffull, string("\006\007\377\377\377\377\377", 7)}, + {0x80000000000ull, string("\006\010\000\000\000\000\000", 7)}, + {0x80000000001ull, string("\006\010\000\000\000\000\001", 7)}, + {0xfffffffffffull, string("\006\017\377\377\377\377\377", 7)}, + {0x100000000000ull, string("\006\020\000\000\000\000\000", 7)}, + {0x100000000001ull, string("\006\020\000\000\000\000\001", 7)}, + {0x1fffffffffffull, string("\006\037\377\377\377\377\377", 7)}, + {0x200000000000ull, string("\006 \000\000\000\000\000", 7)}, + {0x200000000001ull, string("\006 \000\000\000\000\001", 7)}, + {0x3fffffffffffull, string("\006?\377\377\377\377\377", 7)}, + {0x400000000000ull, string("\006@\000\000\000\000\000", 7)}, + {0x400000000001ull, string("\006@\000\000\000\000\001", 7)}, + {0x7fffffffffffull, string("\006\177\377\377\377\377\377", 7)}, + {0x800000000000ull, string("\006\200\000\000\000\000\000", 7)}, + {0x800000000001ull, string("\006\200\000\000\000\000\001", 7)}, + {0xffffffffffffull, string("\006\377\377\377\377\377\377", 7)}, + {0x1000000000000ull, string("\007\001\000\000\000\000\000\000", 8)}, + {0x1000000000001ull, string("\007\001\000\000\000\000\000\001", 8)}, + {0x1ffffffffffffull, string("\007\001\377\377\377\377\377\377", 8)}, + {0x2000000000000ull, string("\007\002\000\000\000\000\000\000", 8)}, + {0x2000000000001ull, string("\007\002\000\000\000\000\000\001", 8)}, + {0x3ffffffffffffull, string("\007\003\377\377\377\377\377\377", 8)}, + {0x4000000000000ull, string("\007\004\000\000\000\000\000\000", 8)}, + {0x4000000000001ull, string("\007\004\000\000\000\000\000\001", 8)}, + {0x7ffffffffffffull, string("\007\007\377\377\377\377\377\377", 8)}, + {0x8000000000000ull, string("\007\010\000\000\000\000\000\000", 8)}, + {0x8000000000001ull, string("\007\010\000\000\000\000\000\001", 8)}, + {0xfffffffffffffull, string("\007\017\377\377\377\377\377\377", 8)}, + {0x10000000000000ull, string("\007\020\000\000\000\000\000\000", 8)}, + {0x10000000000001ull, string("\007\020\000\000\000\000\000\001", 8)}, + {0x1fffffffffffffull, string("\007\037\377\377\377\377\377\377", 8)}, + {0x20000000000000ull, string("\007 \000\000\000\000\000\000", 8)}, + {0x20000000000001ull, string("\007 \000\000\000\000\000\001", 8)}, + {0x3fffffffffffffull, string("\007?\377\377\377\377\377\377", 8)}, + {0x40000000000000ull, string("\007@\000\000\000\000\000\000", 8)}, + {0x40000000000001ull, string("\007@\000\000\000\000\000\001", 8)}, + {0x7fffffffffffffull, string("\007\177\377\377\377\377\377\377", 8)}, + {0x80000000000000ull, string("\007\200\000\000\000\000\000\000", 8)}, + {0x80000000000001ull, string("\007\200\000\000\000\000\000\001", 8)}, + {0xffffffffffffffull, string("\007\377\377\377\377\377\377\377", 8)}, + {0x100000000000000ull, string("\010\001\000\000\000\000\000\000\000", 9)}, + {0x100000000000001ull, string("\010\001\000\000\000\000\000\000\001", 9)}, + {0x1ffffffffffffffull, string("\010\001\377\377\377\377\377\377\377", 9)}, + {0x200000000000000ull, string("\010\002\000\000\000\000\000\000\000", 9)}, + {0x200000000000001ull, string("\010\002\000\000\000\000\000\000\001", 9)}, + {0x3ffffffffffffffull, string("\010\003\377\377\377\377\377\377\377", 9)}, + {0x400000000000000ull, string("\010\004\000\000\000\000\000\000\000", 9)}, + {0x400000000000001ull, string("\010\004\000\000\000\000\000\000\001", 9)}, + {0x7ffffffffffffffull, string("\010\007\377\377\377\377\377\377\377", 9)}, + {0x800000000000000ull, string("\010\010\000\000\000\000\000\000\000", 9)}, + {0x800000000000001ull, string("\010\010\000\000\000\000\000\000\001", 9)}, + {0xfffffffffffffffull, string("\010\017\377\377\377\377\377\377\377", 9)}, + {0x1000000000000000ull, + string("\010\020\000\000\000\000\000\000\000", 9)}, + {0x1000000000000001ull, + string("\010\020\000\000\000\000\000\000\001", 9)}, + {0x1fffffffffffffffull, + string("\010\037\377\377\377\377\377\377\377", 9)}, + {0x2000000000000000ull, string("\010 \000\000\000\000\000\000\000", 9)}, + {0x2000000000000001ull, string("\010 \000\000\000\000\000\000\001", 9)}, + {0x3fffffffffffffffull, string("\010?\377\377\377\377\377\377\377", 9)}, + {0x4000000000000000ull, string("\010@\000\000\000\000\000\000\000", 9)}, + {0x4000000000000001ull, string("\010@\000\000\000\000\000\000\001", 9)}, + {0x7fffffffffffffffull, + string("\010\177\377\377\377\377\377\377\377", 9)}, + {0x8000000000000000ull, + string("\010\200\000\000\000\000\000\000\000", 9)}, + {0x8000000000000001ull, + string("\010\200\000\000\000\000\000\000\001", 9)}, + }; + for (const auto& t : data) { + uint64 num = t.first; + string result; + OrderedCode::WriteNumIncreasing(&result, num); + EXPECT_EQ(t.second, result) << std::hex << num; + + StringPiece in = result; + uint64 decoded; + EXPECT_TRUE(OrderedCode::ReadNumIncreasing(&in, &decoded)); + EXPECT_EQ(num, decoded); + EXPECT_EQ("", in); + } +} + +TEST(EncodingIsExpected, Signed) { + std::vector<std::pair<int64, string>> data = { + {0ll, string("\200", 1)}, + {1ll, string("\201", 1)}, + {2ll, string("\202", 1)}, + {1ll, string("\201", 1)}, + {2ll, string("\202", 1)}, + {3ll, string("\203", 1)}, + {3ll, string("\203", 1)}, + {4ll, string("\204", 1)}, + {5ll, string("\205", 1)}, + {7ll, string("\207", 1)}, + {8ll, string("\210", 1)}, + {9ll, string("\211", 1)}, + {15ll, string("\217", 1)}, + {16ll, string("\220", 1)}, + {17ll, string("\221", 1)}, + {31ll, string("\237", 1)}, + {32ll, string("\240", 1)}, + {33ll, string("\241", 1)}, + {63ll, string("\277", 1)}, + {64ll, string("\300@", 2)}, + {65ll, string("\300A", 2)}, + {127ll, string("\300\177", 2)}, + {128ll, string("\300\200", 2)}, + {129ll, string("\300\201", 2)}, + {255ll, string("\300\377", 2)}, + {256ll, string("\301\000", 2)}, + {257ll, string("\301\001", 2)}, + {511ll, string("\301\377", 2)}, + {512ll, string("\302\000", 2)}, + {513ll, string("\302\001", 2)}, + {1023ll, string("\303\377", 2)}, + {1024ll, string("\304\000", 2)}, + {1025ll, string("\304\001", 2)}, + {2047ll, string("\307\377", 2)}, + {2048ll, string("\310\000", 2)}, + {2049ll, string("\310\001", 2)}, + {4095ll, string("\317\377", 2)}, + {4096ll, string("\320\000", 2)}, + {4097ll, string("\320\001", 2)}, + {8191ll, string("\337\377", 2)}, + {8192ll, string("\340 \000", 3)}, + {8193ll, string("\340 \001", 3)}, + {16383ll, string("\340?\377", 3)}, + {16384ll, string("\340@\000", 3)}, + {16385ll, string("\340@\001", 3)}, + {32767ll, string("\340\177\377", 3)}, + {32768ll, string("\340\200\000", 3)}, + {32769ll, string("\340\200\001", 3)}, + {65535ll, string("\340\377\377", 3)}, + {65536ll, string("\341\000\000", 3)}, + {65537ll, string("\341\000\001", 3)}, + {131071ll, string("\341\377\377", 3)}, + {131072ll, string("\342\000\000", 3)}, + {131073ll, string("\342\000\001", 3)}, + {262143ll, string("\343\377\377", 3)}, + {262144ll, string("\344\000\000", 3)}, + {262145ll, string("\344\000\001", 3)}, + {524287ll, string("\347\377\377", 3)}, + {524288ll, string("\350\000\000", 3)}, + {524289ll, string("\350\000\001", 3)}, + {1048575ll, string("\357\377\377", 3)}, + {1048576ll, string("\360\020\000\000", 4)}, + {1048577ll, string("\360\020\000\001", 4)}, + {2097151ll, string("\360\037\377\377", 4)}, + {2097152ll, string("\360 \000\000", 4)}, + {2097153ll, string("\360 \000\001", 4)}, + {4194303ll, string("\360?\377\377", 4)}, + {4194304ll, string("\360@\000\000", 4)}, + {4194305ll, string("\360@\000\001", 4)}, + {8388607ll, string("\360\177\377\377", 4)}, + {8388608ll, string("\360\200\000\000", 4)}, + {8388609ll, string("\360\200\000\001", 4)}, + {16777215ll, string("\360\377\377\377", 4)}, + {16777216ll, string("\361\000\000\000", 4)}, + {16777217ll, string("\361\000\000\001", 4)}, + {33554431ll, string("\361\377\377\377", 4)}, + {33554432ll, string("\362\000\000\000", 4)}, + {33554433ll, string("\362\000\000\001", 4)}, + {67108863ll, string("\363\377\377\377", 4)}, + {67108864ll, string("\364\000\000\000", 4)}, + {67108865ll, string("\364\000\000\001", 4)}, + {134217727ll, string("\367\377\377\377", 4)}, + {134217728ll, string("\370\010\000\000\000", 5)}, + {134217729ll, string("\370\010\000\000\001", 5)}, + {268435455ll, string("\370\017\377\377\377", 5)}, + {268435456ll, string("\370\020\000\000\000", 5)}, + {268435457ll, string("\370\020\000\000\001", 5)}, + {536870911ll, string("\370\037\377\377\377", 5)}, + {536870912ll, string("\370 \000\000\000", 5)}, + {536870913ll, string("\370 \000\000\001", 5)}, + {1073741823ll, string("\370?\377\377\377", 5)}, + {1073741824ll, string("\370@\000\000\000", 5)}, + {1073741825ll, string("\370@\000\000\001", 5)}, + {2147483647ll, string("\370\177\377\377\377", 5)}, + {2147483648ll, string("\370\200\000\000\000", 5)}, + {2147483649ll, string("\370\200\000\000\001", 5)}, + {4294967295ll, string("\370\377\377\377\377", 5)}, + {4294967296ll, string("\371\000\000\000\000", 5)}, + {4294967297ll, string("\371\000\000\000\001", 5)}, + {8589934591ll, string("\371\377\377\377\377", 5)}, + {8589934592ll, string("\372\000\000\000\000", 5)}, + {8589934593ll, string("\372\000\000\000\001", 5)}, + {17179869183ll, string("\373\377\377\377\377", 5)}, + {17179869184ll, string("\374\004\000\000\000\000", 6)}, + {17179869185ll, string("\374\004\000\000\000\001", 6)}, + {34359738367ll, string("\374\007\377\377\377\377", 6)}, + {34359738368ll, string("\374\010\000\000\000\000", 6)}, + {34359738369ll, string("\374\010\000\000\000\001", 6)}, + {68719476735ll, string("\374\017\377\377\377\377", 6)}, + {68719476736ll, string("\374\020\000\000\000\000", 6)}, + {68719476737ll, string("\374\020\000\000\000\001", 6)}, + {137438953471ll, string("\374\037\377\377\377\377", 6)}, + {137438953472ll, string("\374 \000\000\000\000", 6)}, + {137438953473ll, string("\374 \000\000\000\001", 6)}, + {274877906943ll, string("\374?\377\377\377\377", 6)}, + {274877906944ll, string("\374@\000\000\000\000", 6)}, + {274877906945ll, string("\374@\000\000\000\001", 6)}, + {549755813887ll, string("\374\177\377\377\377\377", 6)}, + {549755813888ll, string("\374\200\000\000\000\000", 6)}, + {549755813889ll, string("\374\200\000\000\000\001", 6)}, + {1099511627775ll, string("\374\377\377\377\377\377", 6)}, + {1099511627776ll, string("\375\000\000\000\000\000", 6)}, + {1099511627777ll, string("\375\000\000\000\000\001", 6)}, + {2199023255551ll, string("\375\377\377\377\377\377", 6)}, + {2199023255552ll, string("\376\002\000\000\000\000\000", 7)}, + {2199023255553ll, string("\376\002\000\000\000\000\001", 7)}, + {4398046511103ll, string("\376\003\377\377\377\377\377", 7)}, + {4398046511104ll, string("\376\004\000\000\000\000\000", 7)}, + {4398046511105ll, string("\376\004\000\000\000\000\001", 7)}, + {8796093022207ll, string("\376\007\377\377\377\377\377", 7)}, + {8796093022208ll, string("\376\010\000\000\000\000\000", 7)}, + {8796093022209ll, string("\376\010\000\000\000\000\001", 7)}, + {17592186044415ll, string("\376\017\377\377\377\377\377", 7)}, + {17592186044416ll, string("\376\020\000\000\000\000\000", 7)}, + {17592186044417ll, string("\376\020\000\000\000\000\001", 7)}, + {35184372088831ll, string("\376\037\377\377\377\377\377", 7)}, + {35184372088832ll, string("\376 \000\000\000\000\000", 7)}, + {35184372088833ll, string("\376 \000\000\000\000\001", 7)}, + {70368744177663ll, string("\376?\377\377\377\377\377", 7)}, + {70368744177664ll, string("\376@\000\000\000\000\000", 7)}, + {70368744177665ll, string("\376@\000\000\000\000\001", 7)}, + {140737488355327ll, string("\376\177\377\377\377\377\377", 7)}, + {140737488355328ll, string("\376\200\000\000\000\000\000", 7)}, + {140737488355329ll, string("\376\200\000\000\000\000\001", 7)}, + {281474976710655ll, string("\376\377\377\377\377\377\377", 7)}, + {281474976710656ll, string("\377\001\000\000\000\000\000\000", 8)}, + {281474976710657ll, string("\377\001\000\000\000\000\000\001", 8)}, + {562949953421311ll, string("\377\001\377\377\377\377\377\377", 8)}, + {562949953421312ll, string("\377\002\000\000\000\000\000\000", 8)}, + {562949953421313ll, string("\377\002\000\000\000\000\000\001", 8)}, + {1125899906842623ll, string("\377\003\377\377\377\377\377\377", 8)}, + {1125899906842624ll, string("\377\004\000\000\000\000\000\000", 8)}, + {1125899906842625ll, string("\377\004\000\000\000\000\000\001", 8)}, + {2251799813685247ll, string("\377\007\377\377\377\377\377\377", 8)}, + {2251799813685248ll, string("\377\010\000\000\000\000\000\000", 8)}, + {2251799813685249ll, string("\377\010\000\000\000\000\000\001", 8)}, + {4503599627370495ll, string("\377\017\377\377\377\377\377\377", 8)}, + {4503599627370496ll, string("\377\020\000\000\000\000\000\000", 8)}, + {4503599627370497ll, string("\377\020\000\000\000\000\000\001", 8)}, + {9007199254740991ll, string("\377\037\377\377\377\377\377\377", 8)}, + {9007199254740992ll, string("\377 \000\000\000\000\000\000", 8)}, + {9007199254740993ll, string("\377 \000\000\000\000\000\001", 8)}, + {18014398509481983ll, string("\377?\377\377\377\377\377\377", 8)}, + {18014398509481984ll, string("\377@\000\000\000\000\000\000", 8)}, + {18014398509481985ll, string("\377@\000\000\000\000\000\001", 8)}, + {36028797018963967ll, string("\377\177\377\377\377\377\377\377", 8)}, + {36028797018963968ll, string("\377\200\200\000\000\000\000\000\000", 9)}, + {36028797018963969ll, string("\377\200\200\000\000\000\000\000\001", 9)}, + {72057594037927935ll, string("\377\200\377\377\377\377\377\377\377", 9)}, + {72057594037927936ll, string("\377\201\000\000\000\000\000\000\000", 9)}, + {72057594037927937ll, string("\377\201\000\000\000\000\000\000\001", 9)}, + {144115188075855871ll, string("\377\201\377\377\377\377\377\377\377", 9)}, + {144115188075855872ll, string("\377\202\000\000\000\000\000\000\000", 9)}, + {144115188075855873ll, string("\377\202\000\000\000\000\000\000\001", 9)}, + {288230376151711743ll, string("\377\203\377\377\377\377\377\377\377", 9)}, + {288230376151711744ll, string("\377\204\000\000\000\000\000\000\000", 9)}, + {288230376151711745ll, string("\377\204\000\000\000\000\000\000\001", 9)}, + {576460752303423487ll, string("\377\207\377\377\377\377\377\377\377", 9)}, + {576460752303423488ll, string("\377\210\000\000\000\000\000\000\000", 9)}, + {576460752303423489ll, string("\377\210\000\000\000\000\000\000\001", 9)}, + {1152921504606846975ll, + string("\377\217\377\377\377\377\377\377\377", 9)}, + {1152921504606846976ll, + string("\377\220\000\000\000\000\000\000\000", 9)}, + {1152921504606846977ll, + string("\377\220\000\000\000\000\000\000\001", 9)}, + {2305843009213693951ll, + string("\377\237\377\377\377\377\377\377\377", 9)}, + {2305843009213693952ll, + string("\377\240\000\000\000\000\000\000\000", 9)}, + {2305843009213693953ll, + string("\377\240\000\000\000\000\000\000\001", 9)}, + {4611686018427387903ll, + string("\377\277\377\377\377\377\377\377\377", 9)}, + {4611686018427387904ll, + string("\377\300@\000\000\000\000\000\000\000", 10)}, + {4611686018427387905ll, + string("\377\300@\000\000\000\000\000\000\001", 10)}, + {9223372036854775807ll, + string("\377\300\177\377\377\377\377\377\377\377", 10)}, + {-9223372036854775807ll, + string("\000?\200\000\000\000\000\000\000\001", 10)}, + {0ll, string("\200", 1)}, + {-1ll, string("\177", 1)}, + {-2ll, string("~", 1)}, + {-1ll, string("\177", 1)}, + {-2ll, string("~", 1)}, + {-3ll, string("}", 1)}, + {-3ll, string("}", 1)}, + {-4ll, string("|", 1)}, + {-5ll, string("{", 1)}, + {-7ll, string("y", 1)}, + {-8ll, string("x", 1)}, + {-9ll, string("w", 1)}, + {-15ll, string("q", 1)}, + {-16ll, string("p", 1)}, + {-17ll, string("o", 1)}, + {-31ll, string("a", 1)}, + {-32ll, string("`", 1)}, + {-33ll, string("_", 1)}, + {-63ll, string("A", 1)}, + {-64ll, string("@", 1)}, + {-65ll, string("?\277", 2)}, + {-127ll, string("?\201", 2)}, + {-128ll, string("?\200", 2)}, + {-129ll, string("?\177", 2)}, + {-255ll, string("?\001", 2)}, + {-256ll, string("?\000", 2)}, + {-257ll, string(">\377", 2)}, + {-511ll, string(">\001", 2)}, + {-512ll, string(">\000", 2)}, + {-513ll, string("=\377", 2)}, + {-1023ll, string("<\001", 2)}, + {-1024ll, string("<\000", 2)}, + {-1025ll, string(";\377", 2)}, + {-2047ll, string("8\001", 2)}, + {-2048ll, string("8\000", 2)}, + {-2049ll, string("7\377", 2)}, + {-4095ll, string("0\001", 2)}, + {-4096ll, string("0\000", 2)}, + {-4097ll, string("/\377", 2)}, + {-8191ll, string(" \001", 2)}, + {-8192ll, string(" \000", 2)}, + {-8193ll, string("\037\337\377", 3)}, + {-16383ll, string("\037\300\001", 3)}, + {-16384ll, string("\037\300\000", 3)}, + {-16385ll, string("\037\277\377", 3)}, + {-32767ll, string("\037\200\001", 3)}, + {-32768ll, string("\037\200\000", 3)}, + {-32769ll, string("\037\177\377", 3)}, + {-65535ll, string("\037\000\001", 3)}, + {-65536ll, string("\037\000\000", 3)}, + {-65537ll, string("\036\377\377", 3)}, + {-131071ll, string("\036\000\001", 3)}, + {-131072ll, string("\036\000\000", 3)}, + {-131073ll, string("\035\377\377", 3)}, + {-262143ll, string("\034\000\001", 3)}, + {-262144ll, string("\034\000\000", 3)}, + {-262145ll, string("\033\377\377", 3)}, + {-524287ll, string("\030\000\001", 3)}, + {-524288ll, string("\030\000\000", 3)}, + {-524289ll, string("\027\377\377", 3)}, + {-1048575ll, string("\020\000\001", 3)}, + {-1048576ll, string("\020\000\000", 3)}, + {-1048577ll, string("\017\357\377\377", 4)}, + {-2097151ll, string("\017\340\000\001", 4)}, + {-2097152ll, string("\017\340\000\000", 4)}, + {-2097153ll, string("\017\337\377\377", 4)}, + {-4194303ll, string("\017\300\000\001", 4)}, + {-4194304ll, string("\017\300\000\000", 4)}, + {-4194305ll, string("\017\277\377\377", 4)}, + {-8388607ll, string("\017\200\000\001", 4)}, + {-8388608ll, string("\017\200\000\000", 4)}, + {-8388609ll, string("\017\177\377\377", 4)}, + {-16777215ll, string("\017\000\000\001", 4)}, + {-16777216ll, string("\017\000\000\000", 4)}, + {-16777217ll, string("\016\377\377\377", 4)}, + {-33554431ll, string("\016\000\000\001", 4)}, + {-33554432ll, string("\016\000\000\000", 4)}, + {-33554433ll, string("\r\377\377\377", 4)}, + {-67108863ll, string("\014\000\000\001", 4)}, + {-67108864ll, string("\014\000\000\000", 4)}, + {-67108865ll, string("\013\377\377\377", 4)}, + {-134217727ll, string("\010\000\000\001", 4)}, + {-134217728ll, string("\010\000\000\000", 4)}, + {-134217729ll, string("\007\367\377\377\377", 5)}, + {-268435455ll, string("\007\360\000\000\001", 5)}, + {-268435456ll, string("\007\360\000\000\000", 5)}, + {-268435457ll, string("\007\357\377\377\377", 5)}, + {-536870911ll, string("\007\340\000\000\001", 5)}, + {-536870912ll, string("\007\340\000\000\000", 5)}, + {-536870913ll, string("\007\337\377\377\377", 5)}, + {-1073741823ll, string("\007\300\000\000\001", 5)}, + {-1073741824ll, string("\007\300\000\000\000", 5)}, + {-1073741825ll, string("\007\277\377\377\377", 5)}, + {-2147483647ll, string("\007\200\000\000\001", 5)}, + {-2147483648ll, string("\007\200\000\000\000", 5)}, + {-2147483649ll, string("\007\177\377\377\377", 5)}, + {-4294967295ll, string("\007\000\000\000\001", 5)}, + {-4294967296ll, string("\007\000\000\000\000", 5)}, + {-4294967297ll, string("\006\377\377\377\377", 5)}, + {-8589934591ll, string("\006\000\000\000\001", 5)}, + {-8589934592ll, string("\006\000\000\000\000", 5)}, + {-8589934593ll, string("\005\377\377\377\377", 5)}, + {-17179869183ll, string("\004\000\000\000\001", 5)}, + {-17179869184ll, string("\004\000\000\000\000", 5)}, + {-17179869185ll, string("\003\373\377\377\377\377", 6)}, + {-34359738367ll, string("\003\370\000\000\000\001", 6)}, + {-34359738368ll, string("\003\370\000\000\000\000", 6)}, + {-34359738369ll, string("\003\367\377\377\377\377", 6)}, + {-68719476735ll, string("\003\360\000\000\000\001", 6)}, + {-68719476736ll, string("\003\360\000\000\000\000", 6)}, + {-68719476737ll, string("\003\357\377\377\377\377", 6)}, + {-137438953471ll, string("\003\340\000\000\000\001", 6)}, + {-137438953472ll, string("\003\340\000\000\000\000", 6)}, + {-137438953473ll, string("\003\337\377\377\377\377", 6)}, + {-274877906943ll, string("\003\300\000\000\000\001", 6)}, + {-274877906944ll, string("\003\300\000\000\000\000", 6)}, + {-274877906945ll, string("\003\277\377\377\377\377", 6)}, + {-549755813887ll, string("\003\200\000\000\000\001", 6)}, + {-549755813888ll, string("\003\200\000\000\000\000", 6)}, + {-549755813889ll, string("\003\177\377\377\377\377", 6)}, + {-1099511627775ll, string("\003\000\000\000\000\001", 6)}, + {-1099511627776ll, string("\003\000\000\000\000\000", 6)}, + {-1099511627777ll, string("\002\377\377\377\377\377", 6)}, + {-2199023255551ll, string("\002\000\000\000\000\001", 6)}, + {-2199023255552ll, string("\002\000\000\000\000\000", 6)}, + {-2199023255553ll, string("\001\375\377\377\377\377\377", 7)}, + {-4398046511103ll, string("\001\374\000\000\000\000\001", 7)}, + {-4398046511104ll, string("\001\374\000\000\000\000\000", 7)}, + {-4398046511105ll, string("\001\373\377\377\377\377\377", 7)}, + {-8796093022207ll, string("\001\370\000\000\000\000\001", 7)}, + {-8796093022208ll, string("\001\370\000\000\000\000\000", 7)}, + {-8796093022209ll, string("\001\367\377\377\377\377\377", 7)}, + {-17592186044415ll, string("\001\360\000\000\000\000\001", 7)}, + {-17592186044416ll, string("\001\360\000\000\000\000\000", 7)}, + {-17592186044417ll, string("\001\357\377\377\377\377\377", 7)}, + {-35184372088831ll, string("\001\340\000\000\000\000\001", 7)}, + {-35184372088832ll, string("\001\340\000\000\000\000\000", 7)}, + {-35184372088833ll, string("\001\337\377\377\377\377\377", 7)}, + {-70368744177663ll, string("\001\300\000\000\000\000\001", 7)}, + {-70368744177664ll, string("\001\300\000\000\000\000\000", 7)}, + {-70368744177665ll, string("\001\277\377\377\377\377\377", 7)}, + {-140737488355327ll, string("\001\200\000\000\000\000\001", 7)}, + {-140737488355328ll, string("\001\200\000\000\000\000\000", 7)}, + {-140737488355329ll, string("\001\177\377\377\377\377\377", 7)}, + {-281474976710655ll, string("\001\000\000\000\000\000\001", 7)}, + {-281474976710656ll, string("\001\000\000\000\000\000\000", 7)}, + {-281474976710657ll, string("\000\376\377\377\377\377\377\377", 8)}, + {-562949953421311ll, string("\000\376\000\000\000\000\000\001", 8)}, + {-562949953421312ll, string("\000\376\000\000\000\000\000\000", 8)}, + {-562949953421313ll, string("\000\375\377\377\377\377\377\377", 8)}, + {-1125899906842623ll, string("\000\374\000\000\000\000\000\001", 8)}, + {-1125899906842624ll, string("\000\374\000\000\000\000\000\000", 8)}, + {-1125899906842625ll, string("\000\373\377\377\377\377\377\377", 8)}, + {-2251799813685247ll, string("\000\370\000\000\000\000\000\001", 8)}, + {-2251799813685248ll, string("\000\370\000\000\000\000\000\000", 8)}, + {-2251799813685249ll, string("\000\367\377\377\377\377\377\377", 8)}, + {-4503599627370495ll, string("\000\360\000\000\000\000\000\001", 8)}, + {-4503599627370496ll, string("\000\360\000\000\000\000\000\000", 8)}, + {-4503599627370497ll, string("\000\357\377\377\377\377\377\377", 8)}, + {-9007199254740991ll, string("\000\340\000\000\000\000\000\001", 8)}, + {-9007199254740992ll, string("\000\340\000\000\000\000\000\000", 8)}, + {-9007199254740993ll, string("\000\337\377\377\377\377\377\377", 8)}, + {-18014398509481983ll, string("\000\300\000\000\000\000\000\001", 8)}, + {-18014398509481984ll, string("\000\300\000\000\000\000\000\000", 8)}, + {-18014398509481985ll, string("\000\277\377\377\377\377\377\377", 8)}, + {-36028797018963967ll, string("\000\200\000\000\000\000\000\001", 8)}, + {-36028797018963968ll, string("\000\200\000\000\000\000\000\000", 8)}, + {-36028797018963969ll, string("\000\177\177\377\377\377\377\377\377", 9)}, + {-72057594037927935ll, string("\000\177\000\000\000\000\000\000\001", 9)}, + {-72057594037927936ll, string("\000\177\000\000\000\000\000\000\000", 9)}, + {-72057594037927937ll, string("\000~\377\377\377\377\377\377\377", 9)}, + {-144115188075855871ll, string("\000~\000\000\000\000\000\000\001", 9)}, + {-144115188075855872ll, string("\000~\000\000\000\000\000\000\000", 9)}, + {-144115188075855873ll, string("\000}\377\377\377\377\377\377\377", 9)}, + {-288230376151711743ll, string("\000|\000\000\000\000\000\000\001", 9)}, + {-288230376151711744ll, string("\000|\000\000\000\000\000\000\000", 9)}, + {-288230376151711745ll, string("\000{\377\377\377\377\377\377\377", 9)}, + {-576460752303423487ll, string("\000x\000\000\000\000\000\000\001", 9)}, + {-576460752303423488ll, string("\000x\000\000\000\000\000\000\000", 9)}, + {-576460752303423489ll, string("\000w\377\377\377\377\377\377\377", 9)}, + {-1152921504606846975ll, string("\000p\000\000\000\000\000\000\001", 9)}, + {-1152921504606846976ll, string("\000p\000\000\000\000\000\000\000", 9)}, + {-1152921504606846977ll, string("\000o\377\377\377\377\377\377\377", 9)}, + {-2305843009213693951ll, string("\000`\000\000\000\000\000\000\001", 9)}, + {-2305843009213693952ll, string("\000`\000\000\000\000\000\000\000", 9)}, + {-2305843009213693953ll, string("\000_\377\377\377\377\377\377\377", 9)}, + {-4611686018427387903ll, string("\000@\000\000\000\000\000\000\001", 9)}, + {-4611686018427387904ll, string("\000@\000\000\000\000\000\000\000", 9)}, + {-4611686018427387905ll, + string("\000?\277\377\377\377\377\377\377\377", 10)}, + {-9223372036854775807ll, + string("\000?\200\000\000\000\000\000\000\001", 10)}, + {9223372036854775807ll, + string("\377\300\177\377\377\377\377\377\377\377", 10)}, + }; + for (const auto& t : data) { + int64 num = t.first; + string result; + OrderedCode::WriteSignedNumIncreasing(&result, num); + EXPECT_EQ(t.second, result) << std::hex << num; + + StringPiece in = result; + int64 decoded; + EXPECT_TRUE(OrderedCode::ReadSignedNumIncreasing(&in, &decoded)); + EXPECT_EQ(num, decoded); + EXPECT_EQ("", in); + } +} + +static void BM_WriteString(int n, int len) { + testing::StopTiming(); + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + string x; + for (int i = 0; i < len; i++) { + x += rnd.Uniform(256); + } + string y; + + testing::BytesProcessed(n * len); + testing::StartTiming(); + while (n-- > 0) { + y.clear(); + OCWriteToString<string>(&y, x); + } +} + +static void BM_ReadString(int n, int len) { + testing::StopTiming(); + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + string x; + for (int i = 0; i < len; i++) { + x += rnd.Uniform(256); + } + string data; + OCWriteToString<string>(&data, x); + string result; + + testing::BytesProcessed(n * len); + testing::StartTiming(); + while (n-- > 0) { + result.clear(); + StringPiece s = data; + OCRead<string>(&s, &result); + } +} + +static void BM_WriteStringIncreasing(int n, int len) { BM_WriteString(n, len); } +static void BM_ReadStringIncreasing(int n, int len) { BM_ReadString(n, len); } + +BENCHMARK(BM_WriteStringIncreasing)->Range(0, 1024); +BENCHMARK(BM_ReadStringIncreasing)->Range(0, 1024); + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc new file mode 100644 index 0000000000..cccd50c7ff --- /dev/null +++ b/tensorflow/core/lib/strings/str_util.cc @@ -0,0 +1,312 @@ +#include "tensorflow/core/lib/strings/str_util.h" +#include <ctype.h> + +namespace tensorflow { +namespace str_util { + +static char hex_char[] = "0123456789abcdef"; + +string CEscape(const string& src) { + string dest; + + for (unsigned char c : src) { + switch (c) { + case '\n': + dest.append("\\n"); + break; + case '\r': + dest.append("\\r"); + break; + case '\t': + dest.append("\\t"); + break; + case '\"': + dest.append("\\\""); + break; + case '\'': + dest.append("\\'"); + break; + case '\\': + dest.append("\\\\"); + break; + default: + // Note that if we emit \xNN and the src character after that is a hex + // digit then that digit must be escaped too to prevent it being + // interpreted as part of the character code by C. + if ((c >= 0x80) || !isprint(c)) { + dest.append("\\"); + dest.push_back(hex_char[c / 64]); + dest.push_back(hex_char[(c % 64) / 8]); + dest.push_back(hex_char[c % 8]); + } else { + dest.push_back(c); + break; + } + } + } + + return dest; +} + +namespace { // Private helpers for CUnescape(). + +inline bool is_octal_digit(unsigned char c) { return c >= '0' && c <= '7'; } + +inline bool ascii_isxdigit(unsigned char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || + (c >= 'A' && c <= 'F'); +} + +inline int hex_digit_to_int(char c) { + int x = static_cast<unsigned char>(c); + if (x > '9') { + x += 9; + } + return x & 0xf; +} + +bool CUnescapeInternal(StringPiece source, char* dest, int* dest_len, + string* error) { + char* d = dest; + const char* p = source.data(); + const char* end = source.end(); + const char* last_byte = end - 1; + + // Small optimization for case where source = dest and there's no escaping + while (p == d && p < end && *p != '\\') p++, d++; + + while (p < end) { + if (*p != '\\') { + *d++ = *p++; + } else { + if (++p > last_byte) { // skip past the '\\' + if (error) *error = "String cannot end with \\"; + return false; + } + switch (*p) { + case 'a': + *d++ = '\a'; + break; + case 'b': + *d++ = '\b'; + break; + case 'f': + *d++ = '\f'; + break; + case 'n': + *d++ = '\n'; + break; + case 'r': + *d++ = '\r'; + break; + case 't': + *d++ = '\t'; + break; + case 'v': + *d++ = '\v'; + break; + case '\\': + *d++ = '\\'; + break; + case '?': + *d++ = '\?'; + break; // \? Who knew? + case '\'': + *d++ = '\''; + break; + case '"': + *d++ = '\"'; + break; + case '0': + case '1': + case '2': + case '3': // octal digit: 1 to 3 digits + case '4': + case '5': + case '6': + case '7': { + const char* octal_start = p; + unsigned int ch = *p - '0'; + if (p < last_byte && is_octal_digit(p[1])) ch = ch * 8 + *++p - '0'; + if (p < last_byte && is_octal_digit(p[1])) + ch = ch * 8 + *++p - '0'; // now points at last digit + if (ch > 0xff) { + if (error) { + *error = "Value of \\" + + string(octal_start, p + 1 - octal_start) + + " exceeds 0xff"; + } + return false; + } + *d++ = ch; + break; + } + case 'x': + case 'X': { + if (p >= last_byte) { + if (error) *error = "String cannot end with \\x"; + return false; + } else if (!ascii_isxdigit(p[1])) { + if (error) *error = "\\x cannot be followed by a non-hex digit"; + return false; + } + unsigned int ch = 0; + const char* hex_start = p; + while (p < last_byte && ascii_isxdigit(p[1])) + // Arbitrarily many hex digits + ch = (ch << 4) + hex_digit_to_int(*++p); + if (ch > 0xFF) { + if (error) { + *error = "Value of \\" + string(hex_start, p + 1 - hex_start) + + " exceeds 0xff"; + } + return false; + } + *d++ = ch; + break; + } + default: { + if (error) *error = string("Unknown escape sequence: \\") + *p; + return false; + } + } + p++; // read past letter we escaped + } + } + *dest_len = d - dest; + return true; +} + +} // namespace + +bool CUnescape(StringPiece source, string* dest, string* error) { + dest->resize(source.size()); + int dest_size; + if (!CUnescapeInternal(source, const_cast<char*>(dest->data()), &dest_size, + error)) { + return false; + } + dest->erase(dest_size); + return true; +} + +bool NumericParse32(const string& text, int32* val) { + // Slow, but this code is not performance critical, and this + // doesn't bring in any new dependencies + char junk; + if (sscanf(text.c_str(), "%d%c", val, &junk) == 1) { + return true; + } else { + return false; + } +} + +void StripTrailingWhitespace(string* s) { + string::size_type i; + for (i = s->size(); i > 0 && isspace((*s)[i - 1]); --i) { + } + s->resize(i); +} + +// Return lower-cased version of s. +string Lowercase(StringPiece s) { + string result(s.data(), s.size()); + for (char& c : result) { + c = tolower(c); + } + return result; +} + +// Return upper-cased version of s. +string Uppercase(StringPiece s) { + string result(s.data(), s.size()); + for (char& c : result) { + c = toupper(c); + } + return result; +} + +void TitlecaseString(string* s, StringPiece delimiters) { + bool upper = true; + for (string::iterator ss = s->begin(); ss != s->end(); ++ss) { + if (upper) { + *ss = toupper(*ss); + } + upper = (delimiters.find(*ss) != StringPiece::npos); + } +} + +size_t RemoveLeadingWhitespace(StringPiece* text) { + size_t count = 0; + const char* ptr = text->data(); + while (count < text->size() && isspace(*ptr)) { + count++; + ptr++; + } + text->remove_prefix(count); + return count; +} + +size_t RemoveTrailingWhitespace(StringPiece* text) { + size_t count = 0; + const char* ptr = text->data() + text->size() - 1; + while (count < text->size() && isspace(*ptr)) { + ++count; + --ptr; + } + text->remove_suffix(count); + return count; +} + +size_t RemoveWhitespaceContext(StringPiece* text) { + // use RemoveLeadingWhitespace() and RemoveTrailingWhitespace() to do the job + return (RemoveLeadingWhitespace(text) + RemoveTrailingWhitespace(text)); +} + +bool ConsumePrefix(StringPiece* s, StringPiece expected) { + if (s->starts_with(expected)) { + s->remove_prefix(expected.size()); + return true; + } + return false; +} + +bool ConsumeLeadingDigits(StringPiece* s, uint64* val) { + const char* p = s->data(); + const char* limit = p + s->size(); + uint64 v = 0; + while (p < limit) { + const char c = *p; + if (c < '0' || c > '9') break; + uint64 new_v = (v * 10) + (c - '0'); + if (new_v < v) { + // Overflow occurred + return false; + } + v = new_v; + p++; + } + if (p > s->data()) { + // Consume some digits + s->remove_prefix(p - s->data()); + *val = v; + return true; + } else { + return false; + } +} + +bool SplitAndParseAsInts(StringPiece text, char delim, + std::vector<int32>* result) { + result->clear(); + std::vector<string> num_strings = Split(text, delim); + for (const auto& s : num_strings) { + int32 num; + if (!NumericParse32(s, &num)) return false; + result->push_back(num); + } + return true; +} + +} // namespace str_util +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h new file mode 100644 index 0000000000..34ea462b2d --- /dev/null +++ b/tensorflow/core/lib/strings/str_util.h @@ -0,0 +1,149 @@ +#ifndef TENSORFLOW_LIB_STRINGS_STR_UTIL_H_ +#define TENSORFLOW_LIB_STRINGS_STR_UTIL_H_ + +#include <string> +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" + +// Basic string utility routines +namespace tensorflow { +namespace str_util { + +// Returns a version of 'src' where unprintable characters have been +// escaped using C-style escape sequences. +string CEscape(const string& src); + +// Copies "source" to "dest", rewriting C-style escape sequences -- +// '\n', '\r', '\\', '\ooo', etc -- to their ASCII equivalents. +// +// Errors: Sets the description of the first encountered error in +// 'error'. To disable error reporting, set 'error' to NULL. +// +// NOTE: Does not support \u or \U! +bool CUnescape(StringPiece source, string* dest, string* error); + +// If "text" can be successfully parsed as the ASCII representation of +// an integer, sets "*val" to the value and returns true. Otherwise, +// returns false. +bool NumericParse32(const string& text, int32* val); + +// Removes any trailing whitespace from "*s". +void StripTrailingWhitespace(string* s); + +// Removes leading ascii_isspace() characters. +// Returns number of characters removed. +size_t RemoveLeadingWhitespace(StringPiece* text); + +// Removes trailing ascii_isspace() characters. +// Returns number of characters removed. +size_t RemoveTrailingWhitespace(StringPiece* text); + +// Removes leading and trailing ascii_isspace() chars. +// Returns number of chars removed. +size_t RemoveWhitespaceContext(StringPiece* text); + +// Consume a leading positive integer value. If any digits were +// found, store the value of the leading unsigned number in "*val", +// advance "*s" past the consumed number, and return true. If +// overflow occurred, returns false. Otherwise, returns false. +bool ConsumeLeadingDigits(StringPiece* s, uint64* val); + +// If "*s" starts with "expected", consume it and return true. +// Otherwise, return false. +bool ConsumePrefix(StringPiece* s, StringPiece expected); + +// Return lower-cased version of s. +string Lowercase(StringPiece s); + +// Return upper-cased version of s. +string Uppercase(StringPiece s); + +// Capitalize first character of each word in "*s". "delimiters" is a +// set of characters that can be used as word boundaries. +void TitlecaseString(string* s, StringPiece delimiters); + +// Join functionality +template <typename T> +string Join(const std::vector<T>& s, const char* sep); +template <typename T> +string Join(const gtl::ArraySlice<T>& s, const char* sep); + +struct AllowEmpty { + bool operator()(StringPiece sp) const { return true; } +}; +struct SkipEmpty { + bool operator()(StringPiece sp) const { return !sp.empty(); } +}; +struct SkipWhitespace { + bool operator()(StringPiece sp) const { + RemoveTrailingWhitespace(&sp); + return !sp.empty(); + } +}; + +std::vector<string> Split(StringPiece text, char delim); +template <typename Predicate> +std::vector<string> Split(StringPiece text, char delim, Predicate p); + +// Split "text" at "delim" characters, and parse each component as +// an integer. If successful, adds the individual numbers in order +// to "*result" and returns true. Otherwise returns false. +bool SplitAndParseAsInts(StringPiece text, char delim, + std::vector<int32>* result); + +// ------------------------------------------------------------------ +// Implementation details below +namespace internal { +template <typename T> +string JoinHelper(typename gtl::ArraySlice<T>::const_iterator begin, + typename gtl::ArraySlice<T>::const_iterator end, + const char* sep) { + string result; + bool first = true; + for (typename gtl::ArraySlice<T>::const_iterator it = begin; it != end; + ++it) { + tensorflow::strings::StrAppend(&result, (first ? "" : sep), *it); + first = false; + } + return result; +} +} // namespace internal + +template <typename T> +string Join(const std::vector<T>& s, const char* sep) { + return Join<T>(gtl::ArraySlice<T>(s), sep); +} + +template <typename T> +string Join(const gtl::ArraySlice<T>& s, const char* sep) { + return internal::JoinHelper<T>(s.begin(), s.end(), sep); +} + +inline std::vector<string> Split(StringPiece text, char delim) { + return Split(text, delim, AllowEmpty()); +} + +template <typename Predicate> +std::vector<string> Split(StringPiece text, char delim, Predicate p) { + std::vector<string> result; + int token_start = 0; + if (!text.empty()) { + for (int i = 0; i < text.size() + 1; i++) { + if ((i == text.size()) || (text[i] == delim)) { + StringPiece token(text.data() + token_start, i - token_start); + if (p(token)) { + result.push_back(token.ToString()); + } + token_start = i + 1; + } + } + } + return result; +} + +} // namespace str_util +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_STR_UTIL_H_ diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc new file mode 100644 index 0000000000..f71cc6c609 --- /dev/null +++ b/tensorflow/core/lib/strings/str_util_test.cc @@ -0,0 +1,258 @@ +#include "tensorflow/core/lib/strings/str_util.h" + +#include <gtest/gtest.h> + +namespace tensorflow { + +TEST(CEscape, Basic) { + EXPECT_EQ(str_util::CEscape("hello"), "hello"); + EXPECT_EQ(str_util::CEscape("hello\n"), "hello\\n"); + EXPECT_EQ(str_util::CEscape("hello\r"), "hello\\r"); + EXPECT_EQ(str_util::CEscape("\t\r\"'"), "\\t\\r\\\"\\'"); + EXPECT_EQ(str_util::CEscape("\320hi\200"), "\\320hi\\200"); +} + +string ExpectCUnescapeSuccess(StringPiece source) { + string dest; + string error; + EXPECT_TRUE(str_util::CUnescape(source, &dest, &error)) << error; + return dest; +} + +TEST(CUnescape, Basic) { + EXPECT_EQ("hello", ExpectCUnescapeSuccess("hello")); + EXPECT_EQ("hello\n", ExpectCUnescapeSuccess("hello\\n")); + EXPECT_EQ("hello\r", ExpectCUnescapeSuccess("hello\\r")); + EXPECT_EQ("\t\r\"'", ExpectCUnescapeSuccess("\\t\\r\\\"\\'")); + EXPECT_EQ("\320hi\200", ExpectCUnescapeSuccess("\\320hi\\200")); +} + +TEST(NumericParse32, Basic) { + int32 val = -1234; + EXPECT_TRUE(str_util::NumericParse32("0", &val) && val == 0); + EXPECT_TRUE(str_util::NumericParse32("123", &val) && val == 123); + EXPECT_TRUE(str_util::NumericParse32("-375", &val) && val == -375); + EXPECT_FALSE(str_util::NumericParse32("123hello", &val)); + EXPECT_FALSE(str_util::NumericParse32("hello123", &val)); +} + +TEST(StripTrailingWhitespace, Basic) { + string test; + test = "hello"; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, "hello"); + + test = "foo "; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, "foo"); + + test = " "; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, ""); + + test = ""; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, ""); + + test = " abc\t"; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, " abc"); +} + +TEST(RemoveLeadingWhitespace, Basic) { + string text = " \t \n \r Quick\t"; + StringPiece data(text); + // check that all whitespace is removed + EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 11); + EXPECT_EQ(data, StringPiece("Quick\t")); + // check that non-whitespace is not removed + EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0); + EXPECT_EQ(data, StringPiece("Quick\t")); +} + +TEST(RemoveLeadingWhitespace, TerminationHandling) { + // check termination handling + string text = "\t"; + StringPiece data(text); + EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 1); + EXPECT_EQ(data, StringPiece("")); + + // check termination handling again + EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0); + EXPECT_EQ(data, StringPiece("")); +} + +TEST(RemoveTrailingWhitespace, Basic) { + string text = " \t \n \r Quick \t"; + StringPiece data(text); + // check that all whitespace is removed + EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 2); + EXPECT_EQ(data, StringPiece(" \t \n \r Quick")); + // check that non-whitespace is not removed + EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0); + EXPECT_EQ(data, StringPiece(" \t \n \r Quick")); +} + +TEST(RemoveTrailingWhitespace, TerminationHandling) { + // check termination handling + string text = "\t"; + StringPiece data(text); + EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 1); + EXPECT_EQ(data, StringPiece("")); + + // check termination handling again + EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0); + EXPECT_EQ(data, StringPiece("")); +} + +TEST(RemoveWhitespaceContext, Basic) { + string text = " \t \n \r Quick \t"; + StringPiece data(text); + // check that all whitespace is removed + EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 13); + EXPECT_EQ(data, StringPiece("Quick")); + // check that non-whitespace is not removed + EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0); + EXPECT_EQ(data, StringPiece("Quick")); + + // Test empty string + text = ""; + data = text; + EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0); + EXPECT_EQ(data, StringPiece("")); +} + +void TestConsumeLeadingDigits(StringPiece s, int64 expected, + StringPiece remaining) { + uint64 v; + StringPiece input(s); + if (str_util::ConsumeLeadingDigits(&input, &v)) { + EXPECT_EQ(v, static_cast<uint64>(expected)); + EXPECT_EQ(input, remaining); + } else { + EXPECT_LT(expected, 0); + EXPECT_EQ(input, remaining); + } +} + +TEST(ConsumeLeadingDigits, Basic) { + TestConsumeLeadingDigits("123", 123, ""); + TestConsumeLeadingDigits("a123", -1, "a123"); + TestConsumeLeadingDigits("9_", 9, "_"); + TestConsumeLeadingDigits("11111111111xyz", 11111111111ll, "xyz"); + + // Overflow case + TestConsumeLeadingDigits("1111111111111111111111111111111xyz", -1, + "1111111111111111111111111111111xyz"); + + // 2^64 + TestConsumeLeadingDigits("18446744073709551616xyz", -1, + "18446744073709551616xyz"); + // 2^64-1 + TestConsumeLeadingDigits("18446744073709551615xyz", 18446744073709551615ull, + "xyz"); +} + +TEST(ConsumePrefix, Basic) { + string s("abcdef"); + StringPiece input(s); + EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdefg")); + EXPECT_EQ(input, "abcdef"); + + EXPECT_FALSE(str_util::ConsumePrefix(&input, "abce")); + EXPECT_EQ(input, "abcdef"); + + EXPECT_TRUE(str_util::ConsumePrefix(&input, "")); + EXPECT_EQ(input, "abcdef"); + + EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdeg")); + EXPECT_EQ(input, "abcdef"); + + EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcdef")); + EXPECT_EQ(input, ""); + + input = s; + EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcde")); + EXPECT_EQ(input, "f"); +} + +TEST(JoinStrings, Basic) { + std::vector<string> s; + s = {"hi"}; + EXPECT_EQ(str_util::Join(s, " "), "hi"); + s = {"hi", "there", "strings"}; + EXPECT_EQ(str_util::Join(s, " "), "hi there strings"); + + std::vector<StringPiece> sp; + sp = {"hi"}; + EXPECT_EQ(str_util::Join(sp, ",,"), "hi"); + sp = {"hi", "there", "strings"}; + EXPECT_EQ(str_util::Join(sp, "--"), "hi--there--strings"); +} + +TEST(Split, Basic) { + EXPECT_TRUE(str_util::Split("", ',').empty()); + EXPECT_EQ(str_util::Join(str_util::Split("a", ','), "|"), "a"); + EXPECT_EQ(str_util::Join(str_util::Split(",", ','), "|"), "|"); + EXPECT_EQ(str_util::Join(str_util::Split("a,b,c", ','), "|"), "a|b|c"); + EXPECT_EQ(str_util::Join(str_util::Split("a,,,b,,c,", ','), "|"), + "a|||b||c|"); + EXPECT_EQ(str_util::Join( + str_util::Split("a,,,b,,c,", ',', str_util::SkipEmpty()), "|"), + "a|b|c"); + EXPECT_EQ( + str_util::Join( + str_util::Split("a, ,b,,c,", ',', str_util::SkipWhitespace()), "|"), + "a|b|c"); +} + +TEST(SplitAndParseAsInts, Basic) { + std::vector<int32> nums; + EXPECT_TRUE(str_util::SplitAndParseAsInts("", ',', &nums)); + EXPECT_EQ(nums.size(), 0); + + EXPECT_TRUE(str_util::SplitAndParseAsInts("134", ',', &nums)); + EXPECT_EQ(nums.size(), 1); + EXPECT_EQ(nums[0], 134); + + EXPECT_TRUE(str_util::SplitAndParseAsInts("134,2,13,-5", ',', &nums)); + EXPECT_EQ(nums.size(), 4); + EXPECT_EQ(nums[0], 134); + EXPECT_EQ(nums[1], 2); + EXPECT_EQ(nums[2], 13); + EXPECT_EQ(nums[3], -5); + + EXPECT_FALSE(str_util::SplitAndParseAsInts("abc", ',', &nums)); + + EXPECT_FALSE(str_util::SplitAndParseAsInts("-13,abc", ',', &nums)); + + EXPECT_FALSE(str_util::SplitAndParseAsInts("13,abc,5", ',', &nums)); +} + +TEST(Lowercase, Basic) { + EXPECT_EQ("", str_util::Lowercase("")); + EXPECT_EQ("hello", str_util::Lowercase("hello")); + EXPECT_EQ("hello world", str_util::Lowercase("Hello World")); +} + +TEST(Uppercase, Basic) { + EXPECT_EQ("", str_util::Uppercase("")); + EXPECT_EQ("HELLO", str_util::Uppercase("hello")); + EXPECT_EQ("HELLO WORLD", str_util::Uppercase("Hello World")); +} + +TEST(TitlecaseString, Basic) { + string s = "sparse_lookup"; + str_util::TitlecaseString(&s, "_"); + ASSERT_EQ(s, "Sparse_Lookup"); + + s = "sparse_lookup"; + str_util::TitlecaseString(&s, " "); + ASSERT_EQ(s, "Sparse_lookup"); + + s = "dense"; + str_util::TitlecaseString(&s, " "); + ASSERT_EQ(s, "Dense"); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc new file mode 100644 index 0000000000..e564b9eb73 --- /dev/null +++ b/tensorflow/core/lib/strings/strcat.cc @@ -0,0 +1,194 @@ +#include "tensorflow/core/lib/strings/strcat.h" + +#include <stdarg.h> +#include <stdint.h> +#include <stdio.h> +#include <string.h> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/gtl/stl_util.h" + +namespace tensorflow { +namespace strings { + +AlphaNum gEmptyAlphaNum(""); + +AlphaNum::AlphaNum(Hex hex) { + char *const end = &digits_[kFastToBufferSize]; + char *writer = end; + uint64 value = hex.value; + uint64 width = hex.spec; + // We accomplish minimum width by OR'ing in 0x10000 to the user's value, + // where 0x10000 is the smallest hex number that is as wide as the user + // asked for. + uint64 mask = ((static_cast<uint64>(1) << (width - 1) * 4)) | value; + static const char hexdigits[] = "0123456789abcdef"; + do { + *--writer = hexdigits[value & 0xF]; + value >>= 4; + mask >>= 4; + } while (mask != 0); + piece_.set(writer, end - writer); +} + +// ---------------------------------------------------------------------- +// StrCat() +// This merges the given strings or integers, with no delimiter. This +// is designed to be the fastest possible way to construct a string out +// of a mix of raw C strings, StringPieces, strings, and integer values. +// ---------------------------------------------------------------------- + +// Append is merely a version of memcpy that returns the address of the byte +// after the area just overwritten. It comes in multiple flavors to minimize +// call overhead. +static char *Append1(char *out, const AlphaNum &x) { + memcpy(out, x.data(), x.size()); + return out + x.size(); +} + +static char *Append2(char *out, const AlphaNum &x1, const AlphaNum &x2) { + memcpy(out, x1.data(), x1.size()); + out += x1.size(); + + memcpy(out, x2.data(), x2.size()); + return out + x2.size(); +} + +static char *Append4(char *out, const AlphaNum &x1, const AlphaNum &x2, + const AlphaNum &x3, const AlphaNum &x4) { + memcpy(out, x1.data(), x1.size()); + out += x1.size(); + + memcpy(out, x2.data(), x2.size()); + out += x2.size(); + + memcpy(out, x3.data(), x3.size()); + out += x3.size(); + + memcpy(out, x4.data(), x4.size()); + return out + x4.size(); +} + +string StrCat(const AlphaNum &a, const AlphaNum &b) { + string result; + gtl::STLStringResizeUninitialized(&result, a.size() + b.size()); + char *const begin = &*result.begin(); + char *out = Append2(begin, a, b); + DCHECK_EQ(out, begin + result.size()); + return result; +} + +string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c) { + string result; + gtl::STLStringResizeUninitialized(&result, a.size() + b.size() + c.size()); + char *const begin = &*result.begin(); + char *out = Append2(begin, a, b); + out = Append1(out, c); + DCHECK_EQ(out, begin + result.size()); + return result; +} + +string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d) { + string result; + gtl::STLStringResizeUninitialized(&result, + a.size() + b.size() + c.size() + d.size()); + char *const begin = &*result.begin(); + char *out = Append4(begin, a, b, c, d); + DCHECK_EQ(out, begin + result.size()); + return result; +} + +namespace internal { + +// Do not call directly - these are not part of the public API. +string CatPieces(std::initializer_list<StringPiece> pieces) { + string result; + size_t total_size = 0; + for (const StringPiece piece : pieces) total_size += piece.size(); + gtl::STLStringResizeUninitialized(&result, total_size); + + char *const begin = &*result.begin(); + char *out = begin; + for (const StringPiece piece : pieces) { + const size_t this_size = piece.size(); + memcpy(out, piece.data(), this_size); + out += this_size; + } + DCHECK_EQ(out, begin + result.size()); + return result; +} + +// It's possible to call StrAppend with a StringPiece that is itself a fragment +// of the string we're appending to. However the results of this are random. +// Therefore, check for this in debug mode. Use unsigned math so we only have +// to do one comparison. +#define DCHECK_NO_OVERLAP(dest, src) \ + DCHECK_GE(uintptr_t((src).data() - (dest).data()), uintptr_t((dest).size())) + +void AppendPieces(string *result, std::initializer_list<StringPiece> pieces) { + size_t old_size = result->size(); + size_t total_size = old_size; + for (const StringPiece piece : pieces) { + DCHECK_NO_OVERLAP(*result, piece); + total_size += piece.size(); + } + gtl::STLStringResizeUninitialized(result, total_size); + + char *const begin = &*result->begin(); + char *out = begin + old_size; + for (const StringPiece piece : pieces) { + const size_t this_size = piece.size(); + memcpy(out, piece.data(), this_size); + out += this_size; + } + DCHECK_EQ(out, begin + result->size()); +} + +} // namespace internal + +void StrAppend(string *result, const AlphaNum &a) { + DCHECK_NO_OVERLAP(*result, a); + result->append(a.data(), a.size()); +} + +void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b) { + DCHECK_NO_OVERLAP(*result, a); + DCHECK_NO_OVERLAP(*result, b); + string::size_type old_size = result->size(); + gtl::STLStringResizeUninitialized(result, old_size + a.size() + b.size()); + char *const begin = &*result->begin(); + char *out = Append2(begin + old_size, a, b); + DCHECK_EQ(out, begin + result->size()); +} + +void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c) { + DCHECK_NO_OVERLAP(*result, a); + DCHECK_NO_OVERLAP(*result, b); + DCHECK_NO_OVERLAP(*result, c); + string::size_type old_size = result->size(); + gtl::STLStringResizeUninitialized(result, + old_size + a.size() + b.size() + c.size()); + char *const begin = &*result->begin(); + char *out = Append2(begin + old_size, a, b); + out = Append1(out, c); + DCHECK_EQ(out, begin + result->size()); +} + +void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c, const AlphaNum &d) { + DCHECK_NO_OVERLAP(*result, a); + DCHECK_NO_OVERLAP(*result, b); + DCHECK_NO_OVERLAP(*result, c); + DCHECK_NO_OVERLAP(*result, d); + string::size_type old_size = result->size(); + gtl::STLStringResizeUninitialized( + result, old_size + a.size() + b.size() + c.size() + d.size()); + char *const begin = &*result->begin(); + char *out = Append4(begin + old_size, a, b, c, d); + DCHECK_EQ(out, begin + result->size()); +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h new file mode 100644 index 0000000000..763ad8368a --- /dev/null +++ b/tensorflow/core/lib/strings/strcat.h @@ -0,0 +1,229 @@ +// #status: RECOMMENDED +// #category: operations on strings +// #summary: Merges strings or numbers with no delimiter. +// +#ifndef TENSORFLOW_LIB_STRINGS_STRCAT_H_ +#define TENSORFLOW_LIB_STRINGS_STRCAT_H_ + +#include <string> + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/port.h" + +// The AlphaNum type was designed to be used as the parameter type for StrCat(). +// Any routine accepting either a string or a number may accept it. +// The basic idea is that by accepting a "const AlphaNum &" as an argument +// to your function, your callers will automagically convert bools, integers, +// and floating point values to strings for you. +// +// NOTE: Use of AlphaNum outside of the //strings package is unsupported except +// for the specific case of function parameters of type "AlphaNum" or "const +// AlphaNum &". In particular, instantiating AlphaNum directly as a stack +// variable is not supported. +// +// Conversion from 8-bit values is not accepted because if it were, then an +// attempt to pass ':' instead of ":" might result in a 58 ending up in your +// result. +// +// Bools convert to "0" or "1". +// +// Floating point values are converted to a string which, if passed to strtod(), +// would produce the exact same original double (except in case of NaN; all NaNs +// are considered the same value). We try to keep the string short but it's not +// guaranteed to be as short as possible. +// +// You can convert to Hexadecimal output rather than Decimal output using Hex. +// To do this, pass strings::Hex(my_int) as a parameter to StrCat. You may +// specify a minimum field width using a separate parameter, so the equivalent +// of Printf("%04x", my_int) is StrCat(Hex(my_int, strings::ZERO_PAD_4)) +// +// This class has implicit constructors. +namespace tensorflow { +namespace strings { + +enum PadSpec { + NO_PAD = 1, + ZERO_PAD_2, + ZERO_PAD_3, + ZERO_PAD_4, + ZERO_PAD_5, + ZERO_PAD_6, + ZERO_PAD_7, + ZERO_PAD_8, + ZERO_PAD_9, + ZERO_PAD_10, + ZERO_PAD_11, + ZERO_PAD_12, + ZERO_PAD_13, + ZERO_PAD_14, + ZERO_PAD_15, + ZERO_PAD_16, +}; + +struct Hex { + uint64 value; + enum PadSpec spec; + template <class Int> + explicit Hex(Int v, PadSpec s = NO_PAD) + : spec(s) { + // Prevent sign-extension by casting integers to + // their unsigned counterparts. + static_assert( + sizeof(v) == 1 || sizeof(v) == 2 || sizeof(v) == 4 || sizeof(v) == 8, + "Unknown integer type"); + value = sizeof(v) == 1 + ? static_cast<uint8>(v) + : sizeof(v) == 2 ? static_cast<uint16>(v) + : sizeof(v) == 4 ? static_cast<uint32>(v) + : static_cast<uint64>(v); + } +}; + +class AlphaNum { + public: + // No bool ctor -- bools convert to an integral type. + // A bool ctor would also convert incoming pointers (bletch). + + AlphaNum(int i32) // NOLINT(runtime/explicit) + : piece_(digits_, FastInt32ToBufferLeft(i32, digits_) - &digits_[0]) {} + AlphaNum(unsigned int u32) // NOLINT(runtime/explicit) + : piece_(digits_, FastUInt32ToBufferLeft(u32, digits_) - &digits_[0]) {} + AlphaNum(long x) // NOLINT(runtime/explicit) + : piece_(digits_, FastInt64ToBufferLeft(x, digits_) - &digits_[0]) {} + AlphaNum(unsigned long x) // NOLINT(runtime/explicit) + : piece_(digits_, FastUInt64ToBufferLeft(x, digits_) - &digits_[0]) {} + AlphaNum(long long int i64) // NOLINT(runtime/explicit) + : piece_(digits_, FastInt64ToBufferLeft(i64, digits_) - &digits_[0]) {} + AlphaNum(unsigned long long int u64) // NOLINT(runtime/explicit) + : piece_(digits_, FastUInt64ToBufferLeft(u64, digits_) - &digits_[0]) {} + + AlphaNum(float f) // NOLINT(runtime/explicit) + : piece_(digits_, strlen(FloatToBuffer(f, digits_))) {} + AlphaNum(double f) // NOLINT(runtime/explicit) + : piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {} + + AlphaNum(Hex hex); // NOLINT(runtime/explicit) + + AlphaNum(const char *c_str) : piece_(c_str) {} // NOLINT(runtime/explicit) + AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit) + AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit) + : piece_(str) {} + + StringPiece::size_type size() const { return piece_.size(); } + const char *data() const { return piece_.data(); } + StringPiece Piece() const { return piece_; } + + private: + StringPiece piece_; + char digits_[kFastToBufferSize]; + + // Use ":" not ':' + AlphaNum(char c); // NOLINT(runtime/explicit) + + TF_DISALLOW_COPY_AND_ASSIGN(AlphaNum); +}; + +extern AlphaNum gEmptyAlphaNum; + +using strings::AlphaNum; +using strings::gEmptyAlphaNum; + +// ---------------------------------------------------------------------- +// StrCat() +// This merges the given strings or numbers, with no delimiter. This +// is designed to be the fastest possible way to construct a string out +// of a mix of raw C strings, StringPieces, strings, bool values, +// and numeric values. +// +// Don't use this for user-visible strings. The localization process +// works poorly on strings built up out of fragments. +// +// For clarity and performance, don't use StrCat when appending to a +// string. In particular, avoid using any of these (anti-)patterns: +// str.append(StrCat(...)) +// str += StrCat(...) +// str = StrCat(str, ...) +// where the last is the worse, with the potential to change a loop +// from a linear time operation with O(1) dynamic allocations into a +// quadratic time operation with O(n) dynamic allocations. StrAppend +// is a better choice than any of the above, subject to the restriction +// of StrAppend(&str, a, b, c, ...) that none of the a, b, c, ... may +// be a reference into str. +// ---------------------------------------------------------------------- + +// For performance reasons, we have specializations for <= 4 args. +string StrCat(const AlphaNum &a) TF_MUST_USE_RESULT; +string StrCat(const AlphaNum &a, const AlphaNum &b) TF_MUST_USE_RESULT; +string StrCat(const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c) TF_MUST_USE_RESULT; +string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d) TF_MUST_USE_RESULT; + +// inline definitions must be duplicated due to TF_MUST_USE_RESULT +inline string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); } + +namespace internal { + +// Do not call directly - this is not part of the public API. +string CatPieces(std::initializer_list<StringPiece> pieces); +void AppendPieces(string *dest, std::initializer_list<StringPiece> pieces); + +} // namespace internal + +// Support 5 or more arguments +template <typename... AV> +string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d, const AlphaNum &e, + const AV &... args) TF_MUST_USE_RESULT; + +template <typename... AV> +inline string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d, const AlphaNum &e, const AV &... args) { + return internal::CatPieces({a.Piece(), b.Piece(), c.Piece(), d.Piece(), + e.Piece(), + static_cast<const AlphaNum &>(args).Piece()...}); +} + +// ---------------------------------------------------------------------- +// StrAppend() +// Same as above, but adds the output to the given string. +// WARNING: For speed, StrAppend does not try to check each of its input +// arguments to be sure that they are not a subset of the string being +// appended to. That is, while this will work: +// +// string s = "foo"; +// s += s; +// +// This will not (necessarily) work: +// +// string s = "foo"; +// StrAppend(&s, s); +// +// Note: while StrCat supports appending up to 26 arguments, StrAppend +// is currently limited to 9. That's rarely an issue except when +// automatically transforming StrCat to StrAppend, and can easily be +// worked around as consecutive calls to StrAppend are quite efficient. +// ---------------------------------------------------------------------- + +void StrAppend(string *dest, const AlphaNum &a); +void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b); +void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c); +void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c, const AlphaNum &d); + +// Support 5 or more arguments +template <typename... AV> +inline void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c, const AlphaNum &d, const AlphaNum &e, + const AV &... args) { + internal::AppendPieces(dest, + {a.Piece(), b.Piece(), c.Piece(), d.Piece(), e.Piece(), + static_cast<const AlphaNum &>(args).Piece()...}); +} + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_STRCAT_H_ diff --git a/tensorflow/core/lib/strings/strcat_test.cc b/tensorflow/core/lib/strings/strcat_test.cc new file mode 100644 index 0000000000..9ff7d81af9 --- /dev/null +++ b/tensorflow/core/lib/strings/strcat_test.cc @@ -0,0 +1,324 @@ +#include "tensorflow/core/lib/strings/strcat.h" + +#include <string> + +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/port.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace strings { + +// Test StrCat of ints and longs of various sizes and signdedness. +TEST(StrCat, Ints) { + const int16 s = -1; + const uint16 us = 2; + const int i = -3; + const unsigned int ui = 4; + const int32 l = -5; + const uint32 ul = 6; + const int64 ll = -7; + const uint64 ull = 8; + const ptrdiff_t ptrdiff = -9; + const size_t size = 10; + const ssize_t ssize = -11; + const intptr_t intptr = -12; + const uintptr_t uintptr = 13; + string answer; + answer = StrCat(s, us); + EXPECT_EQ(answer, "-12"); + answer = StrCat(i, ui); + EXPECT_EQ(answer, "-34"); + answer = StrCat(l, ul); + EXPECT_EQ(answer, "-56"); + answer = StrCat(ll, ull); + EXPECT_EQ(answer, "-78"); + answer = StrCat(ptrdiff, size); + EXPECT_EQ(answer, "-910"); + answer = StrCat(ssize, intptr); + EXPECT_EQ(answer, "-11-12"); + answer = StrCat(uintptr, 0); + EXPECT_EQ(answer, "130"); +} + +TEST(StrCat, Basics) { + string result; + + string strs[] = {"Hello", "Cruel", "World"}; + + StringPiece pieces[] = {"Hello", "Cruel", "World"}; + + const char *c_strs[] = {"Hello", "Cruel", "World"}; + + int32 i32s[] = {'H', 'C', 'W'}; + uint64 ui64s[] = {12345678910LL, 10987654321LL}; + + result = StrCat(false, true, 2, 3); + EXPECT_EQ(result, "0123"); + + result = StrCat(-1); + EXPECT_EQ(result, "-1"); + + result = StrCat(0.5); + EXPECT_EQ(result, "0.5"); + + result = StrCat(strs[1], pieces[2]); + EXPECT_EQ(result, "CruelWorld"); + + result = StrCat(strs[0], ", ", pieces[2]); + EXPECT_EQ(result, "Hello, World"); + + result = StrCat(strs[0], ", ", strs[1], " ", strs[2], "!"); + EXPECT_EQ(result, "Hello, Cruel World!"); + + result = StrCat(pieces[0], ", ", pieces[1], " ", pieces[2]); + EXPECT_EQ(result, "Hello, Cruel World"); + + result = StrCat(c_strs[0], ", ", c_strs[1], " ", c_strs[2]); + EXPECT_EQ(result, "Hello, Cruel World"); + + result = StrCat("ASCII ", i32s[0], ", ", i32s[1], " ", i32s[2], "!"); + EXPECT_EQ(result, "ASCII 72, 67 87!"); + + result = StrCat(ui64s[0], ", ", ui64s[1], "!"); + EXPECT_EQ(result, "12345678910, 10987654321!"); + + string one = "1"; // Actually, it's the size of this string that we want; a + // 64-bit build distinguishes between size_t and uint64, + // even though they're both unsigned 64-bit values. + result = StrCat("And a ", one.size(), " and a ", &result[2] - &result[0], + " and a ", one, " 2 3 4", "!"); + EXPECT_EQ(result, "And a 1 and a 2 and a 1 2 3 4!"); + + // result = StrCat("Single chars won't compile", '!'); + // result = StrCat("Neither will NULLs", NULL); + result = StrCat("To output a char by ASCII/numeric value, use +: ", '!' + 0); + EXPECT_EQ(result, "To output a char by ASCII/numeric value, use +: 33"); + + float f = 100000.5; + result = StrCat("A hundred K and a half is ", f); + EXPECT_EQ(result, "A hundred K and a half is 100000.5"); + + double d = f; + d *= d; + result = StrCat("A hundred K and a half squared is ", d); + EXPECT_EQ(result, "A hundred K and a half squared is 10000100000.25"); + + result = StrCat(1, 2, 333, 4444, 55555, 666666, 7777777, 88888888, 999999999); + EXPECT_EQ(result, "12333444455555666666777777788888888999999999"); +} + +TEST(StrCat, MaxArgs) { + string result; + // Test 10 up to 26 arguments, the current maximum + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a"); + EXPECT_EQ(result, "123456789a"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b"); + EXPECT_EQ(result, "123456789ab"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c"); + EXPECT_EQ(result, "123456789abc"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d"); + EXPECT_EQ(result, "123456789abcd"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e"); + EXPECT_EQ(result, "123456789abcde"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f"); + EXPECT_EQ(result, "123456789abcdef"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g"); + EXPECT_EQ(result, "123456789abcdefg"); + result = + StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", "h"); + EXPECT_EQ(result, "123456789abcdefgh"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i"); + EXPECT_EQ(result, "123456789abcdefghi"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j"); + EXPECT_EQ(result, "123456789abcdefghij"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k"); + EXPECT_EQ(result, "123456789abcdefghijk"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l"); + EXPECT_EQ(result, "123456789abcdefghijkl"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m"); + EXPECT_EQ(result, "123456789abcdefghijklm"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m", "n"); + EXPECT_EQ(result, "123456789abcdefghijklmn"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m", "n", "o"); + EXPECT_EQ(result, "123456789abcdefghijklmno"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m", "n", "o", "p"); + EXPECT_EQ(result, "123456789abcdefghijklmnop"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m", "n", "o", "p", "q"); + EXPECT_EQ(result, "123456789abcdefghijklmnopq"); + // No limit thanks to C++11's variadic templates + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "a", "b", "c", "d", "e", "f", + "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", + "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D", + "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", + "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"); + EXPECT_EQ(result, + "12345678910abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"); +} + +TEST(StrAppend, Basics) { + string result = "existing text"; + + string strs[] = {"Hello", "Cruel", "World"}; + + StringPiece pieces[] = {"Hello", "Cruel", "World"}; + + const char *c_strs[] = {"Hello", "Cruel", "World"}; + + int32 i32s[] = {'H', 'C', 'W'}; + uint64 ui64s[] = {12345678910LL, 10987654321LL}; + + string::size_type old_size = result.size(); + StrAppend(&result, strs[0]); + EXPECT_EQ(result.substr(old_size), "Hello"); + + old_size = result.size(); + StrAppend(&result, strs[1], pieces[2]); + EXPECT_EQ(result.substr(old_size), "CruelWorld"); + + old_size = result.size(); + StrAppend(&result, strs[0], ", ", pieces[2]); + EXPECT_EQ(result.substr(old_size), "Hello, World"); + + old_size = result.size(); + StrAppend(&result, strs[0], ", ", strs[1], " ", strs[2], "!"); + EXPECT_EQ(result.substr(old_size), "Hello, Cruel World!"); + + old_size = result.size(); + StrAppend(&result, pieces[0], ", ", pieces[1], " ", pieces[2]); + EXPECT_EQ(result.substr(old_size), "Hello, Cruel World"); + + old_size = result.size(); + StrAppend(&result, c_strs[0], ", ", c_strs[1], " ", c_strs[2]); + EXPECT_EQ(result.substr(old_size), "Hello, Cruel World"); + + old_size = result.size(); + StrAppend(&result, "ASCII ", i32s[0], ", ", i32s[1], " ", i32s[2], "!"); + EXPECT_EQ(result.substr(old_size), "ASCII 72, 67 87!"); + + old_size = result.size(); + StrAppend(&result, ui64s[0], ", ", ui64s[1], "!"); + EXPECT_EQ(result.substr(old_size), "12345678910, 10987654321!"); + + string one = "1"; // Actually, it's the size of this string that we want; a + // 64-bit build distinguishes between size_t and uint64, + // even though they're both unsigned 64-bit values. + old_size = result.size(); + StrAppend(&result, "And a ", one.size(), " and a ", &result[2] - &result[0], + " and a ", one, " 2 3 4", "!"); + EXPECT_EQ(result.substr(old_size), "And a 1 and a 2 and a 1 2 3 4!"); + + // result = StrCat("Single chars won't compile", '!'); + // result = StrCat("Neither will NULLs", NULL); + old_size = result.size(); + StrAppend(&result, "To output a char by ASCII/numeric value, use +: ", + '!' + 0); + EXPECT_EQ(result.substr(old_size), + "To output a char by ASCII/numeric value, use +: 33"); + + float f = 100000.5; + old_size = result.size(); + StrAppend(&result, "A hundred K and a half is ", f); + EXPECT_EQ(result.substr(old_size), "A hundred K and a half is 100000.5"); + + double d = f; + d *= d; + old_size = result.size(); + StrAppend(&result, "A hundred K and a half squared is ", d); + EXPECT_EQ(result.substr(old_size), + "A hundred K and a half squared is 10000100000.25"); + + // Test 9 arguments, the old maximum + old_size = result.size(); + StrAppend(&result, 1, 22, 333, 4444, 55555, 666666, 7777777, 88888888, 9); + EXPECT_EQ(result.substr(old_size), "1223334444555556666667777777888888889"); + + // No limit thanks to C++11's variadic templates + old_size = result.size(); + StrAppend(&result, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "a", "b", "c", "d", "e", + "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", + "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D", "E", + "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", + "S", "T", "U", "V", "W", "X", "Y", "Z", + "No limit thanks to C++11's variadic templates"); + EXPECT_EQ(result.substr(old_size), + "12345678910abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + "No limit thanks to C++11's variadic templates"); +} + +TEST(StrAppend, Death) { + string s = "self"; + EXPECT_DEBUG_DEATH(StrAppend(&s, s.c_str() + 1), "Check failed:"); + EXPECT_DEBUG_DEATH(StrAppend(&s, s), "Check failed:"); +} + +static void CheckHex64(uint64 v) { + using tensorflow::strings::Hex; + string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_16)); + string expected = Printf("%016llx", static_cast<unsigned long long>(v)); + EXPECT_EQ(expected, actual) << " decimal value " << v; + + actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8)); + expected = Printf("%08llx", static_cast<unsigned long long>(v)); + EXPECT_EQ(expected, actual) << " decimal value " << v; + + actual = StrCat(Hex(v)); + expected = Printf("%llx", static_cast<unsigned long long>(v)); + EXPECT_EQ(expected, actual) << " decimal value " << v; +} + +static void CheckHex32(uint32 v) { + using tensorflow::strings::Hex; + string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8)); + string expected = Printf("%08x", v); + EXPECT_EQ(expected, actual) << " decimal value " << v; + + actual = StrCat(Hex(v)); + expected = Printf("%x", v); + EXPECT_EQ(expected, actual) << " decimal value " << v; +} + +static void CheckHexSigned32(int32 v) { + using tensorflow::strings::Hex; + string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8)); + string expected = Printf("%08x", v); + EXPECT_EQ(expected, actual) << " decimal value " << v; + + actual = StrCat(Hex(v)); + expected = Printf("%x", v); + EXPECT_EQ(expected, actual) << " decimal value " << v; +} + +static void TestFastPrints() { + using tensorflow::strings::Hex; + + // Test min int to make sure that works + for (int i = 0; i < 10000; i++) { + CheckHex64(i); + CheckHex32(i); + CheckHexSigned32(i); + CheckHexSigned32(-i); + } + CheckHex64(0x123456789abcdef0ull); + CheckHex32(0x12345678); + + int8 minus_one_8bit = -1; + EXPECT_EQ("ff", StrCat(Hex(minus_one_8bit))); + + int16 minus_one_16bit = -1; + EXPECT_EQ("ffff", StrCat(Hex(minus_one_16bit))); +} + +TEST(Numbers, TestFunctionsMovedOverFromNumbersMain) { TestFastPrints(); } + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/stringprintf.cc b/tensorflow/core/lib/strings/stringprintf.cc new file mode 100644 index 0000000000..b354706cbd --- /dev/null +++ b/tensorflow/core/lib/strings/stringprintf.cc @@ -0,0 +1,85 @@ +#include "tensorflow/core/lib/strings/stringprintf.h" + +#include <errno.h> +#include <stdarg.h> // For va_list and related operations +#include <stdio.h> // MSVC requires this for _vsnprintf +#include <vector> + +namespace tensorflow { +namespace strings { + +#ifdef COMPILER_MSVC +enum { IS_COMPILER_MSVC = 1 }; +#else +enum { IS_COMPILER_MSVC = 0 }; +#endif + +void Appendv(string* dst, const char* format, va_list ap) { + // First try with a small fixed size buffer + static const int kSpaceLength = 1024; + char space[kSpaceLength]; + + // It's possible for methods that use a va_list to invalidate + // the data in it upon use. The fix is to make a copy + // of the structure before using it and use that copy instead. + va_list backup_ap; + va_copy(backup_ap, ap); + int result = vsnprintf(space, kSpaceLength, format, backup_ap); + va_end(backup_ap); + + if (result < kSpaceLength) { + if (result >= 0) { + // Normal case -- everything fit. + dst->append(space, result); + return; + } + + if (IS_COMPILER_MSVC) { + // Error or MSVC running out of space. MSVC 8.0 and higher + // can be asked about space needed with the special idiom below: + va_copy(backup_ap, ap); + result = vsnprintf(NULL, 0, format, backup_ap); + va_end(backup_ap); + } + + if (result < 0) { + // Just an error. + return; + } + } + + // Increase the buffer size to the size requested by vsnprintf, + // plus one for the closing \0. + int length = result + 1; + char* buf = new char[length]; + + // Restore the va_list before we use it again + va_copy(backup_ap, ap); + result = vsnprintf(buf, length, format, backup_ap); + va_end(backup_ap); + + if (result >= 0 && result < length) { + // It fit + dst->append(buf, result); + } + delete[] buf; +} + +string Printf(const char* format, ...) { + va_list ap; + va_start(ap, format); + string result; + Appendv(&result, format, ap); + va_end(ap); + return result; +} + +void Appendf(string* dst, const char* format, ...) { + va_list ap; + va_start(ap, format); + Appendv(dst, format, ap); + va_end(ap); +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/stringprintf.h b/tensorflow/core/lib/strings/stringprintf.h new file mode 100644 index 0000000000..23ca2583ca --- /dev/null +++ b/tensorflow/core/lib/strings/stringprintf.h @@ -0,0 +1,37 @@ +// Printf variants that place their output in a C++ string. +// +// Usage: +// string result = strings::Printf("%d %s\n", 10, "hello"); +// strings::SPrintf(&result, "%d %s\n", 10, "hello"); +// strings::Appendf(&result, "%d %s\n", 20, "there"); + +#ifndef TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_ +#define TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_ + +#include <stdarg.h> +#include <string> +#include <vector> + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace strings { + +// Return a C++ string +extern string Printf(const char* format, ...) + // Tell the compiler to do printf format string checking. + TF_PRINTF_ATTRIBUTE(1, 2); + +// Append result to a supplied string +extern void Appendf(string* dst, const char* format, ...) + // Tell the compiler to do printf format string checking. + TF_PRINTF_ATTRIBUTE(2, 3); + +// Lower-level routine that takes a va_list and appends to a specified +// string. All other routines are just convenience wrappers around it. +extern void Appendv(string* dst, const char* format, va_list ap); + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_ diff --git a/tensorflow/core/lib/strings/stringprintf_test.cc b/tensorflow/core/lib/strings/stringprintf_test.cc new file mode 100644 index 0000000000..737ed5c0e0 --- /dev/null +++ b/tensorflow/core/lib/strings/stringprintf_test.cc @@ -0,0 +1,113 @@ +#include "tensorflow/core/lib/strings/stringprintf.h" + +#include <string> + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace strings { +namespace { + +TEST(PrintfTest, Empty) { + EXPECT_EQ("", Printf("%s", string().c_str())); + EXPECT_EQ("", Printf("%s", "")); +} + +TEST(PrintfTest, Misc) { +// MSVC does not support $ format specifier. +#if !defined(COMPILER_MSVC) + EXPECT_EQ("123hello w", Printf("%3$d%2$s %1$c", 'w', "hello", 123)); +#endif // !COMPILER_MSVC +} + +TEST(AppendfTest, Empty) { + string value("Hello"); + const char* empty = ""; + Appendf(&value, "%s", empty); + EXPECT_EQ("Hello", value); +} + +TEST(AppendfTest, EmptyString) { + string value("Hello"); + Appendf(&value, "%s", ""); + EXPECT_EQ("Hello", value); +} + +TEST(AppendfTest, String) { + string value("Hello"); + Appendf(&value, " %s", "World"); + EXPECT_EQ("Hello World", value); +} + +TEST(AppendfTest, Int) { + string value("Hello"); + Appendf(&value, " %d", 123); + EXPECT_EQ("Hello 123", value); +} + +TEST(PrintfTest, Multibyte) { + // If we are in multibyte mode and feed invalid multibyte sequence, + // Printf should return an empty string instead of running + // out of memory while trying to determine destination buffer size. + // see b/4194543. + + char* old_locale = setlocale(LC_CTYPE, NULL); + // Push locale with multibyte mode + setlocale(LC_CTYPE, "en_US.utf8"); + + const char kInvalidCodePoint[] = "\375\067s"; + string value = Printf("%.*s", 3, kInvalidCodePoint); + + // In some versions of glibc (e.g. eglibc-2.11.1, aka GRTEv2), snprintf + // returns error given an invalid codepoint. Other versions + // (e.g. eglibc-2.15, aka pre-GRTEv3) emit the codepoint verbatim. + // We test that the output is one of the above. + EXPECT_TRUE(value.empty() || value == kInvalidCodePoint); + + // Repeat with longer string, to make sure that the dynamically + // allocated path in StringAppendV is handled correctly. + int n = 2048; + char* buf = new char[n + 1]; + memset(buf, ' ', n - 3); + memcpy(buf + n - 3, kInvalidCodePoint, 4); + value = Printf("%.*s", n, buf); + // See GRTEv2 vs. GRTEv3 comment above. + EXPECT_TRUE(value.empty() || value == buf); + delete[] buf; + + setlocale(LC_CTYPE, old_locale); +} + +TEST(PrintfTest, NoMultibyte) { + // No multibyte handling, but the string contains funny chars. + char* old_locale = setlocale(LC_CTYPE, NULL); + setlocale(LC_CTYPE, "POSIX"); + string value = Printf("%.*s", 3, "\375\067s"); + setlocale(LC_CTYPE, old_locale); + EXPECT_EQ("\375\067s", value); +} + +TEST(PrintfTest, DontOverwriteErrno) { + // Check that errno isn't overwritten unless we're printing + // something significantly larger than what people are normally + // printing in their badly written PLOG() statements. + errno = ECHILD; + string value = Printf("Hello, %s!", "World"); + EXPECT_EQ(ECHILD, errno); +} + +TEST(PrintfTest, LargeBuf) { + // Check that the large buffer is handled correctly. + int n = 2048; + char* buf = new char[n + 1]; + memset(buf, ' ', n); + buf[n] = 0; + string value = Printf("%s", buf); + EXPECT_EQ(buf, value); + delete[] buf; +} + +} // namespace + +} // namespace strings +} // namespace tensorflow |