diff options
Diffstat (limited to 'tensorflow/core/lib/core')
29 files changed, 2612 insertions, 0 deletions
diff --git a/tensorflow/core/lib/core/arena.cc b/tensorflow/core/lib/core/arena.cc new file mode 100644 index 0000000000..ceb1001af0 --- /dev/null +++ b/tensorflow/core/lib/core/arena.cc @@ -0,0 +1,246 @@ +// This approach to arenas overcomes many of the limitations described +// in the "Specialized allocators" section of +// http://www.pdos.lcs.mit.edu/~dm/c++-new.html +// +// A somewhat similar approach to Gladiator, but for heap-detection, was +// suggested by Ron van der Wal and Scott Meyers at +// http://www.aristeia.com/BookErrata/M27Comments_frames.html + +#include "tensorflow/core/lib/core/arena.h" + +#include <assert.h> +#include <unistd.h> + +#include <vector> + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { +namespace core { + +static const int kPageSize = getpagesize(); + +// ---------------------------------------------------------------------- +// Arena::Arena() +// Arena::~Arena() +// Destroying the arena automatically calls Reset() +// ---------------------------------------------------------------------- + +Arena::Arena(const size_t block_size) + : remaining_(0), + block_size_(block_size), + freestart_(NULL), // set for real in Reset() + blocks_alloced_(1), + overflow_blocks_(NULL) { + assert(block_size > kDefaultAlignment); + + first_blocks_[0].mem = reinterpret_cast<char*>(malloc(block_size_)); + + first_blocks_[0].size = block_size_; + + Reset(); +} + +Arena::~Arena() { + FreeBlocks(); + assert(overflow_blocks_ == NULL); // FreeBlocks() should do that + // The first X blocks stay allocated always by default. Delete them now. + for (size_t i = 0; i < blocks_alloced_; ++i) free(first_blocks_[i].mem); +} + +// Returns true iff it advances freestart_ to the first position +// satisfying alignment without exhausting the current block. +bool Arena::SatisfyAlignment(size_t alignment) { + const size_t overage = reinterpret_cast<size_t>(freestart_) & (alignment - 1); + if (overage > 0) { + const size_t waste = alignment - overage; + if (waste >= remaining_) { + return false; + } + freestart_ += waste; + remaining_ -= waste; + } + DCHECK_EQ(0, reinterpret_cast<size_t>(freestart_) & (alignment - 1)); + return true; +} + +// ---------------------------------------------------------------------- +// Arena::Reset() +// Clears all the memory an arena is using. +// ---------------------------------------------------------------------- + +void Arena::Reset() { + FreeBlocks(); + freestart_ = first_blocks_[0].mem; + remaining_ = first_blocks_[0].size; + + // There is no guarantee the first block is properly aligned, so + // enforce that now. + CHECK(SatisfyAlignment(kDefaultAlignment)); + + freestart_when_empty_ = freestart_; +} + +// ---------------------------------------------------------------------- +// Arena::MakeNewBlock() +// Our sbrk() equivalent. We always make blocks of the same size +// (though GetMemory() can also make a new block for really big +// data. +// ---------------------------------------------------------------------- + +void Arena::MakeNewBlock(const uint32 alignment) { + AllocatedBlock* block = AllocNewBlock(block_size_, alignment); + freestart_ = block->mem; + remaining_ = block->size; + CHECK(SatisfyAlignment(alignment)); +} + +// The following simple numeric routines also exist in util/math/mathutil.h +// but we don't want to depend on that library. + +// Euclid's algorithm for Greatest Common Denominator. +static uint32 GCD(uint32 x, uint32 y) { + while (y != 0) { + uint32 r = x % y; + x = y; + y = r; + } + return x; +} + +static uint32 LeastCommonMultiple(uint32 a, uint32 b) { + if (a > b) { + return (a / GCD(a, b)) * b; + } else if (a < b) { + return (b / GCD(b, a)) * a; + } else { + return a; + } +} + +// ------------------------------------------------------------- +// Arena::AllocNewBlock() +// Adds and returns an AllocatedBlock. +// The returned AllocatedBlock* is valid until the next call +// to AllocNewBlock or Reset. (i.e. anything that might +// affect overflow_blocks_). +// ------------------------------------------------------------- + +Arena::AllocatedBlock* Arena::AllocNewBlock(const size_t block_size, + const uint32 alignment) { + AllocatedBlock* block; + // Find the next block. + if (blocks_alloced_ < TF_ARRAYSIZE(first_blocks_)) { + // Use one of the pre-allocated blocks + block = &first_blocks_[blocks_alloced_++]; + } else { // oops, out of space, move to the vector + if (overflow_blocks_ == NULL) + overflow_blocks_ = new std::vector<AllocatedBlock>; + // Adds another block to the vector. + overflow_blocks_->resize(overflow_blocks_->size() + 1); + // block points to the last block of the vector. + block = &overflow_blocks_->back(); + } + + // NOTE(tucker): this utility is made slightly more complex by + // not disallowing the case where alignment > block_size. + // Can we, without breaking existing code? + + // Must be a multiple of kDefaultAlignment, unless requested + // alignment is 1, in which case we don't care at all. + const uint32 adjusted_alignment = + (alignment > 1 ? LeastCommonMultiple(alignment, kDefaultAlignment) : 1); + + CHECK_LE(adjusted_alignment, 1 << 20) + << "Alignment on boundaries greater than 1MB not supported."; + + // If block_size > alignment we force block_size to be a multiple + // of alignment; if block_size < alignment we make no adjustment. + size_t adjusted_block_size = block_size; + if (adjusted_alignment > 1) { + if (adjusted_block_size > adjusted_alignment) { + const uint32 excess = adjusted_block_size % adjusted_alignment; + adjusted_block_size += (excess > 0 ? adjusted_alignment - excess : 0); + } + block->mem = reinterpret_cast<char*>( + port::aligned_malloc(adjusted_block_size, adjusted_alignment)); + } else { + block->mem = reinterpret_cast<char*>(malloc(adjusted_block_size)); + } + block->size = adjusted_block_size; + CHECK(NULL != block->mem) << "block_size=" << block_size + << " adjusted_block_size=" << adjusted_block_size + << " alignment=" << alignment + << " adjusted_alignment=" << adjusted_alignment; + + return block; +} + +// ---------------------------------------------------------------------- +// Arena::GetMemoryFallback() +// We take memory out of our pool, aligned on the byte boundary +// requested. If we don't have space in our current pool, we +// allocate a new block (wasting the remaining space in the +// current block) and give you that. If your memory needs are +// too big for a single block, we make a special your-memory-only +// allocation -- this is equivalent to not using the arena at all. +// ---------------------------------------------------------------------- + +void* Arena::GetMemoryFallback(const size_t size, const int alignment) { + if (0 == size) { + return NULL; // stl/stl_alloc.h says this is okay + } + + // alignment must be a positive power of 2. + CHECK(alignment > 0 && 0 == (alignment & (alignment - 1))); + + // If the object is more than a quarter of the block size, allocate + // it separately to avoid wasting too much space in leftover bytes. + if (block_size_ == 0 || size > block_size_ / 4) { + return AllocNewBlock(size, alignment)->mem; + } + + // Enforce alignment on freestart_ then check for adequate space, + // which may require starting a new block. + if (!SatisfyAlignment(alignment) || size > remaining_) { + MakeNewBlock(alignment); + } + CHECK_LE(size, remaining_); + + remaining_ -= size; + void* result = freestart_; + freestart_ += size; + + return result; +} + +// ---------------------------------------------------------------------- +// Arena::ReturnMemoryFallback() +// Arena::FreeBlocks() +// Unlike GetMemory(), which does actual work, ReturnMemory() is a +// no-op: we don't "free" memory until Reset() is called. We do +// update some stats, though. Note we do no checking that the +// pointer you pass in was actually allocated by us, or that it +// was allocated for the size you say, so be careful here! +// FreeBlocks() does the work for Reset(), actually freeing all +// memory allocated in one fell swoop. +// ---------------------------------------------------------------------- + +void Arena::FreeBlocks() { + for (size_t i = 1; i < blocks_alloced_; ++i) { // keep first block alloced + free(first_blocks_[i].mem); + first_blocks_[i].mem = NULL; + first_blocks_[i].size = 0; + } + blocks_alloced_ = 1; + if (overflow_blocks_ != NULL) { + std::vector<AllocatedBlock>::iterator it; + for (it = overflow_blocks_->begin(); it != overflow_blocks_->end(); ++it) { + free(it->mem); + } + delete overflow_blocks_; // These should be used very rarely + overflow_blocks_ = NULL; + } +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/arena.h b/tensorflow/core/lib/core/arena.h new file mode 100644 index 0000000000..59896803bb --- /dev/null +++ b/tensorflow/core/lib/core/arena.h @@ -0,0 +1,90 @@ +// TODO(vrv): Switch this to an open-sourced version of Arena. + +#ifndef TENSORFLOW_LIB_CORE_ARENA_H_ +#define TENSORFLOW_LIB_CORE_ARENA_H_ + +#include <assert.h> + +#include <vector> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace core { + +// This class is "thread-compatible": different threads can access the +// arena at the same time without locking, as long as they use only +// const methods. +class Arena { + public: + // Allocates a thread-compatible arena with the specified block size. + explicit Arena(const size_t block_size); + ~Arena(); + + char* Alloc(const size_t size) { + return reinterpret_cast<char*>(GetMemory(size, 1)); + } + + void Reset(); + +// This should be the worst-case alignment for any type. This is +// good for IA-32, SPARC version 7 (the last one I know), and +// supposedly Alpha. i386 would be more time-efficient with a +// default alignment of 8, but ::operator new() uses alignment of 4, +// and an assertion will fail below after the call to MakeNewBlock() +// if you try to use a larger alignment. +#ifdef __i386__ + static const int kDefaultAlignment = 4; +#else + static const int kDefaultAlignment = 8; +#endif + + protected: + bool SatisfyAlignment(const size_t alignment); + void MakeNewBlock(const uint32 alignment); + void* GetMemoryFallback(const size_t size, const int align); + void* GetMemory(const size_t size, const int align) { + assert(remaining_ <= block_size_); // an invariant + if (size > 0 && size < remaining_ && align == 1) { // common case + void* result = freestart_; + freestart_ += size; + remaining_ -= size; + return result; + } + return GetMemoryFallback(size, align); + } + + size_t remaining_; + + private: + struct AllocatedBlock { + char* mem; + size_t size; + }; + + // Allocate new new block of at least block_size, with the specified + // alignment. + // The returned AllocatedBlock* is valid until the next call to AllocNewBlock + // or Reset (i.e. anything that might affect overflow_blocks_). + AllocatedBlock* AllocNewBlock(const size_t block_size, + const uint32 alignment); + + const size_t block_size_; + char* freestart_; // beginning of the free space in most recent block + char* freestart_when_empty_; // beginning of the free space when we're empty + // STL vector isn't as efficient as it could be, so we use an array at first + size_t blocks_alloced_; // how many of the first_blocks_ have been alloced + AllocatedBlock first_blocks_[16]; // the length of this array is arbitrary + // if the first_blocks_ aren't enough, expand into overflow_blocks_. + std::vector<AllocatedBlock>* overflow_blocks_; + + void FreeBlocks(); // Frees all except first block + + TF_DISALLOW_COPY_AND_ASSIGN(Arena); +}; + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_ARENA_H_ diff --git a/tensorflow/core/lib/core/arena_test.cc b/tensorflow/core/lib/core/arena_test.cc new file mode 100644 index 0000000000..fa147c3014 --- /dev/null +++ b/tensorflow/core/lib/core/arena_test.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/lib/core/arena.h" + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace core { +namespace { + +// Write random data to allocated memory +static void TestMemory(void* mem, int size) { + // Check that we can memset the entire memory + memset(mem, 0xaa, size); + + // Do some memory allocation to check that the arena doesn't mess up + // the internal memory allocator + char* tmp[100]; + for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) { + tmp[i] = new char[i * i + 1]; + } + + memset(mem, 0xcc, size); + + // Free up the allocated memory; + for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) { + delete[] tmp[i]; + } + + // Check that we can memset the entire memory + memset(mem, 0xee, size); +} + +TEST(ArenaTest, TestBasicArena) { + Arena a(1024); + char* memory = a.Alloc(100); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 100); + + // Allocate again + memory = a.Alloc(100); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 100); +} + +TEST(ArenaTest, TestVariousArenaSizes) { + { + Arena a(1024); + + // Allocate blocksize + char* memory = a.Alloc(1024); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 1024); + + // Allocate another blocksize + char* memory2 = a.Alloc(1024); + ASSERT_NE(memory2, nullptr); + TestMemory(memory2, 1024); + } + + // Allocate an arena and allocate two blocks + // that together exceed a block size + { + Arena a(1024); + + // + char* memory = a.Alloc(768); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 768); + + // Allocate another blocksize + char* memory2 = a.Alloc(768); + ASSERT_NE(memory2, nullptr); + TestMemory(memory2, 768); + } + + // Allocate larger than a blocksize + { + Arena a(1024); + + char* memory = a.Alloc(10240); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 10240); + + // Allocate another blocksize + char* memory2 = a.Alloc(1234); + ASSERT_NE(memory2, nullptr); + TestMemory(memory2, 1234); + } +} + +} // namespace +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/bit_cast_test.cc b/tensorflow/core/lib/core/bit_cast_test.cc new file mode 100644 index 0000000000..0ea583e96f --- /dev/null +++ b/tensorflow/core/lib/core/bit_cast_test.cc @@ -0,0 +1,95 @@ +// Unit test for bit_cast template. + +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +// Marshall and unmarshall. +// ISO spec C++ section 3.9 promises this will work. + +template <int N> +struct marshall { + char buf[N]; +}; + +template <class T> +void TestMarshall(const T values[], int num_values) { + for (int i = 0; i < num_values; ++i) { + T t0 = values[i]; + marshall<sizeof(T)> m0 = bit_cast<marshall<sizeof(T)> >(t0); + T t1 = bit_cast<T>(m0); + marshall<sizeof(T)> m1 = bit_cast<marshall<sizeof(T)> >(t1); + ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T))); + ASSERT_EQ(0, memcmp(&m0, &m1, sizeof(T))); + } +} + +// Convert back and forth to an integral type. The C++ standard does +// not guarantee this will work. +// +// There are implicit assumptions about sizeof(float) and +// sizeof(double). These assumptions are quite extant everywhere. + +template <class T, class I> +void TestIntegral(const T values[], int num_values) { + for (int i = 0; i < num_values; ++i) { + T t0 = values[i]; + I i0 = bit_cast<I>(t0); + T t1 = bit_cast<T>(i0); + I i1 = bit_cast<I>(t1); + ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T))); + ASSERT_EQ(i0, i1); + } +} + +TEST(BitCast, Bool) { + LOG(INFO) << "Test bool"; + static const bool bool_list[] = {false, true}; + TestMarshall<bool>(bool_list, TF_ARRAYSIZE(bool_list)); +} + +TEST(BitCast, Int32) { + static const int32 int_list[] = {0, 1, 100, 2147483647, + -1, -100, -2147483647, -2147483647 - 1}; + TestMarshall<int32>(int_list, TF_ARRAYSIZE(int_list)); +} + +TEST(BitCast, Int64) { + static const int64 int64_list[] = {0, 1, 1LL << 40, -1, -(1LL << 40)}; + TestMarshall<int64>(int64_list, TF_ARRAYSIZE(int64_list)); +} + +TEST(BitCast, Uint64) { + static const uint64 uint64_list[] = {0, 1, 1LLU << 40, 1LLU << 63}; + TestMarshall<uint64>(uint64_list, TF_ARRAYSIZE(uint64_list)); +} + +TEST(BitCast, Float) { + static const float float_list[] = {0.0, 1.0, -1.0, 10.0, -10.0, 1e10, + 1e20, 1e-10, 1e-20, 2.71828, 3.14159}; + TestMarshall<float>(float_list, TF_ARRAYSIZE(float_list)); + TestIntegral<float, int32>(float_list, TF_ARRAYSIZE(float_list)); + TestIntegral<float, uint32>(float_list, TF_ARRAYSIZE(float_list)); +} + +TEST(BitCast, Double) { + static const double double_list[] = { + 0.0, + 1.0, + -1.0, + 10.0, + -10.0, + 1e10, + 1e100, + 1e-10, + 1e-100, + 2.718281828459045, + 3.141592653589793238462643383279502884197169399375105820974944}; + TestMarshall<double>(double_list, TF_ARRAYSIZE(double_list)); + TestIntegral<double, int64>(double_list, TF_ARRAYSIZE(double_list)); + TestIntegral<double, uint64>(double_list, TF_ARRAYSIZE(double_list)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/bits.h b/tensorflow/core/lib/core/bits.h new file mode 100644 index 0000000000..5456a63168 --- /dev/null +++ b/tensorflow/core/lib/core/bits.h @@ -0,0 +1,84 @@ +#ifndef TENSORFLOW_LIB_CORE_BITS_H_ +#define TENSORFLOW_LIB_CORE_BITS_H_ + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +int Log2Floor(uint32 n); +int Log2Floor64(uint64 n); + +// Return ceiling(log2(n)) for positive integer n. Returns -1 iff n == 0. +int Log2Ceiling(uint32 n); +int Log2Ceiling64(uint64 n); + +// ------------------------------------------------------------------------ +// Implementation details follow +// ------------------------------------------------------------------------ + +#if defined(__GNUC__) + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +inline int Log2Floor(uint32 n) { + return n == 0 ? -1 : 31 ^ __builtin_clz(n); +} + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +inline int Log2Floor64(uint64 n) { + return n == 0 ? -1 : 63 ^ __builtin_clzll(n); +} + +#else + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +inline int Log2Floor(uint32 n) { + if (n == 0) + return -1; + int log = 0; + uint32 value = n; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32 x = value >> shift; + if (x != 0) { + value = x; + log += shift; + } + } + assert(value == 1); + return log; +} + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +// Log2Floor64() is defined in terms of Log2Floor32() +inline int Log2Floor64(uint64 n) { + const uint32 topbits = static_cast<uint32>(n >> 32); + if (topbits == 0) { + // Top bits are zero, so scan in bottom bits + return Log2Floor(static_cast<uint32>(n)); + } else { + return 32 + Log2Floor(topbits); + } +} + +#endif + +inline int Log2Ceiling(uint32 n) { + int floor = Log2Floor(n); + if (n == (n & ~(n - 1))) // zero or a power of two + return floor; + else + return floor + 1; +} + +inline int Log2Ceiling64(uint64 n) { + int floor = Log2Floor64(n); + if (n == (n & ~(n - 1))) // zero or a power of two + return floor; + else + return floor + 1; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_BITS_H_ diff --git a/tensorflow/core/lib/core/blocking_counter.h b/tensorflow/core/lib/core/blocking_counter.h new file mode 100644 index 0000000000..f141be2c76 --- /dev/null +++ b/tensorflow/core/lib/core/blocking_counter.h @@ -0,0 +1,41 @@ +#ifndef TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ +#define TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class BlockingCounter { + public: + BlockingCounter(int initial_count) : count_(initial_count) { + CHECK_GE(count_, 0); + } + + ~BlockingCounter() {} + + inline void DecrementCount() { + mutex_lock l(mu_); + --count_; + CHECK(count_ >= 0); + if (count_ == 0) { + cond_var_.notify_all(); + } + } + + inline void Wait() { + mutex_lock l(mu_); + while (count_ > 0) { + cond_var_.wait(l); + } + } + + private: + int count_; + mutex mu_; + condition_variable cond_var_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ diff --git a/tensorflow/core/lib/core/blocking_counter_test.cc b/tensorflow/core/lib/core/blocking_counter_test.cc new file mode 100644 index 0000000000..feb0342086 --- /dev/null +++ b/tensorflow/core/lib/core/blocking_counter_test.cc @@ -0,0 +1,36 @@ +#include <gtest/gtest.h> + +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { +namespace { + +TEST(BlockingCounterTest, TestZero) { + BlockingCounter bc(0); + bc.Wait(); +} + +TEST(BlockingCounterTest, TestSingleThread) { + BlockingCounter bc(2); + bc.DecrementCount(); + bc.DecrementCount(); + bc.Wait(); +} + +TEST(BlockingCounterTest, TestMultipleThread) { + int N = 3; + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "test", N); + + BlockingCounter bc(N); + for (int i = 0; i < N; ++i) { + thread_pool->Schedule([&bc] { bc.DecrementCount(); }); + } + + bc.Wait(); + delete thread_pool; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/casts.h b/tensorflow/core/lib/core/casts.h new file mode 100644 index 0000000000..5b72048ac5 --- /dev/null +++ b/tensorflow/core/lib/core/casts.h @@ -0,0 +1,85 @@ +// Various Google-specific casting templates. +// +// This code is compiled directly on many platforms, including client +// platforms like Windows, Mac, and embedded systems. Before making +// any changes here, make sure that you're not breaking any platforms. +// + +#ifndef TENSORFLOW_LIB_CORE_CASTS_H_ +#define TENSORFLOW_LIB_CORE_CASTS_H_ + +#include <string.h> // for memcpy + +namespace tensorflow { + +// bit_cast<Dest,Source> is a template function that implements the +// equivalent of "*reinterpret_cast<Dest*>(&source)". We need this in +// very low-level functions like the protobuf library and fast math +// support. +// +// float f = 3.14159265358979; +// int i = bit_cast<int32>(f); +// // i = 0x40490fdb +// +// The classical address-casting method is: +// +// // WRONG +// float f = 3.14159265358979; // WRONG +// int i = * reinterpret_cast<int*>(&f); // WRONG +// +// The address-casting method actually produces undefined behavior +// according to ISO C++ specification section 3.10 -15 -. Roughly, this +// section says: if an object in memory has one type, and a program +// accesses it with a different type, then the result is undefined +// behavior for most values of "different type". +// +// This is true for any cast syntax, either *(int*)&f or +// *reinterpret_cast<int*>(&f). And it is particularly true for +// conversions between integral lvalues and floating-point lvalues. +// +// The purpose of 3.10 -15- is to allow optimizing compilers to assume +// that expressions with different types refer to different memory. gcc +// 4.0.1 has an optimizer that takes advantage of this. So a +// non-conforming program quietly produces wildly incorrect output. +// +// The problem is not the use of reinterpret_cast. The problem is type +// punning: holding an object in memory of one type and reading its bits +// back using a different type. +// +// The C++ standard is more subtle and complex than this, but that +// is the basic idea. +// +// Anyways ... +// +// bit_cast<> calls memcpy() which is blessed by the standard, +// especially by the example in section 3.9 . Also, of course, +// bit_cast<> wraps up the nasty logic in one place. +// +// Fortunately memcpy() is very fast. In optimized mode, with a +// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline +// code with the minimal amount of data movement. On a 32-bit system, +// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8) +// compiles to two loads and two stores. +// +// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1. +// +// WARNING: if Dest or Source is a non-POD type, the result of the memcpy +// is likely to surprise you. +// +// Props to Bill Gibbons for the compile time assertion technique and +// Art Komninos and Igor Tandetnik for the msvc experiments. +// +// -- mec 2005-10-17 + +template <class Dest, class Source> +inline Dest bit_cast(const Source& source) { + static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match"); + + Dest dest; + memcpy(&dest, &source, sizeof(dest)); + return dest; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_CASTS_H_ diff --git a/tensorflow/core/lib/core/coding.cc b/tensorflow/core/lib/core/coding.cc new file mode 100644 index 0000000000..efff554742 --- /dev/null +++ b/tensorflow/core/lib/core/coding.cc @@ -0,0 +1,164 @@ +#include "tensorflow/core/lib/core/coding.h" + +namespace tensorflow { +namespace core { + +void EncodeFixed32(char* buf, uint32 value) { + if (port::kLittleEndian) { + memcpy(buf, &value, sizeof(value)); + } else { + buf[0] = value & 0xff; + buf[1] = (value >> 8) & 0xff; + buf[2] = (value >> 16) & 0xff; + buf[3] = (value >> 24) & 0xff; + } +} + +void EncodeFixed64(char* buf, uint64 value) { + if (port::kLittleEndian) { + memcpy(buf, &value, sizeof(value)); + } else { + buf[0] = value & 0xff; + buf[1] = (value >> 8) & 0xff; + buf[2] = (value >> 16) & 0xff; + buf[3] = (value >> 24) & 0xff; + buf[4] = (value >> 32) & 0xff; + buf[5] = (value >> 40) & 0xff; + buf[6] = (value >> 48) & 0xff; + buf[7] = (value >> 56) & 0xff; + } +} + +void PutFixed32(string* dst, uint32 value) { + char buf[sizeof(value)]; + EncodeFixed32(buf, value); + dst->append(buf, sizeof(buf)); +} + +void PutFixed64(string* dst, uint64 value) { + char buf[sizeof(value)]; + EncodeFixed64(buf, value); + dst->append(buf, sizeof(buf)); +} + +char* EncodeVarint32(char* dst, uint32 v) { + // Operate on characters as unsigneds + unsigned char* ptr = reinterpret_cast<unsigned char*>(dst); + static const int B = 128; + if (v < (1 << 7)) { + *(ptr++) = v; + } else if (v < (1 << 14)) { + *(ptr++) = v | B; + *(ptr++) = v >> 7; + } else if (v < (1 << 21)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = v >> 14; + } else if (v < (1 << 28)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = v >> 21; + } else { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = (v >> 21) | B; + *(ptr++) = v >> 28; + } + return reinterpret_cast<char*>(ptr); +} + +void PutVarint32(string* dst, uint32 v) { + char buf[5]; + char* ptr = EncodeVarint32(buf, v); + dst->append(buf, ptr - buf); +} + +char* EncodeVarint64(char* dst, uint64 v) { + static const int B = 128; + unsigned char* ptr = reinterpret_cast<unsigned char*>(dst); + while (v >= B) { + *(ptr++) = (v & (B - 1)) | B; + v >>= 7; + } + *(ptr++) = static_cast<unsigned char>(v); + return reinterpret_cast<char*>(ptr); +} + +void PutVarint64(string* dst, uint64 v) { + char buf[10]; + char* ptr = EncodeVarint64(buf, v); + dst->append(buf, ptr - buf); +} + +int VarintLength(uint64_t v) { + int len = 1; + while (v >= 128) { + v >>= 7; + len++; + } + return len; +} + +const char* GetVarint32PtrFallback(const char* p, const char* limit, + uint32* value) { + uint32 result = 0; + for (uint32 shift = 0; shift <= 28 && p < limit; shift += 7) { + uint32 byte = *(reinterpret_cast<const unsigned char*>(p)); + p++; + if (byte & 128) { + // More bytes are present + result |= ((byte & 127) << shift); + } else { + result |= (byte << shift); + *value = result; + return reinterpret_cast<const char*>(p); + } + } + return NULL; +} + +bool GetVarint32(StringPiece* input, uint32* value) { + const char* p = input->data(); + const char* limit = p + input->size(); + const char* q = GetVarint32Ptr(p, limit, value); + if (q == NULL) { + return false; + } else { + *input = StringPiece(q, limit - q); + return true; + } +} + +const char* GetVarint64Ptr(const char* p, const char* limit, uint64* value) { + uint64 result = 0; + for (uint32 shift = 0; shift <= 63 && p < limit; shift += 7) { + uint64 byte = *(reinterpret_cast<const unsigned char*>(p)); + p++; + if (byte & 128) { + // More bytes are present + result |= ((byte & 127) << shift); + } else { + result |= (byte << shift); + *value = result; + return reinterpret_cast<const char*>(p); + } + } + return NULL; +} + +bool GetVarint64(StringPiece* input, uint64* value) { + const char* p = input->data(); + const char* limit = p + input->size(); + const char* q = GetVarint64Ptr(p, limit, value); + if (q == NULL) { + return false; + } else { + *input = StringPiece(q, limit - q); + return true; + } +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/coding.h b/tensorflow/core/lib/core/coding.h new file mode 100644 index 0000000000..0c14bf1bbf --- /dev/null +++ b/tensorflow/core/lib/core/coding.h @@ -0,0 +1,55 @@ +// Endian-neutral encoding: +// * Fixed-length numbers are encoded with least-significant byte first +// * In addition we support variable length "varint" encoding +// * Strings are encoded prefixed by their length in varint format + +#ifndef TENSORFLOW_LIB_CORE_CODING_H_ +#define TENSORFLOW_LIB_CORE_CODING_H_ + +#include "tensorflow/core/lib/core/raw_coding.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace core { + +// Lower-level versions of Put... that write directly into a character buffer +// REQUIRES: dst has enough space for the value being written +extern void EncodeFixed32(char* dst, uint32 value); +extern void EncodeFixed64(char* dst, uint64 value); +extern void PutFixed32(string* dst, uint32 value); +extern void PutFixed64(string* dst, uint64 value); + +extern void PutVarint32(string* dst, uint32 value); +extern void PutVarint64(string* dst, uint64 value); + +extern bool GetVarint32(StringPiece* input, uint32* value); +extern bool GetVarint64(StringPiece* input, uint64* value); + +extern const char* GetVarint32Ptr(const char* p, const char* limit, uint32* v); +extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v); + +// Internal routine for use by fallback path of GetVarint32Ptr +extern const char* GetVarint32PtrFallback(const char* p, const char* limit, + uint32* value); +inline const char* GetVarint32Ptr(const char* p, const char* limit, + uint32* value) { + if (p < limit) { + uint32 result = *(reinterpret_cast<const unsigned char*>(p)); + if ((result & 128) == 0) { + *value = result; + return p + 1; + } + } + return GetVarint32PtrFallback(p, limit, value); +} + +extern char* EncodeVarint64(char* dst, uint64 v); + +// Returns the length of the varint32 or varint64 encoding of "v" +extern int VarintLength(uint64_t v); + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_CODING_H_ diff --git a/tensorflow/core/lib/core/coding_test.cc b/tensorflow/core/lib/core/coding_test.cc new file mode 100644 index 0000000000..5e9e2c5e96 --- /dev/null +++ b/tensorflow/core/lib/core/coding_test.cc @@ -0,0 +1,168 @@ +#include "tensorflow/core/lib/core/coding.h" + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace core { + +TEST(Coding, Fixed32) { + static const int N = 100000; + + string s; + for (uint32 v = 0; v < N; v++) { + char buf[sizeof(uint32)]; + EncodeFixed32(buf, v); + s.append(buf, sizeof(buf)); + } + + const char* p = s.data(); + for (uint32 v = 0; v < N; v++) { + uint32 actual = DecodeFixed32(p); + ASSERT_EQ(v, actual); + p += sizeof(uint32); + } +} + +TEST(Coding, Fixed64) { + string s; + for (int power = 0; power <= 63; power++) { + uint64 v = static_cast<uint64>(1) << power; + char buf[sizeof(uint64)]; + EncodeFixed64(buf, v - 1); + s.append(buf, sizeof(buf)); + EncodeFixed64(buf, v + 0); + s.append(buf, sizeof(buf)); + EncodeFixed64(buf, v + 1); + s.append(buf, sizeof(buf)); + } + + const char* p = s.data(); + for (int power = 0; power <= 63; power++) { + uint64 v = static_cast<uint64>(1) << power; + uint64 actual; + actual = DecodeFixed64(p); + ASSERT_EQ(v - 1, actual); + p += sizeof(uint64); + + actual = DecodeFixed64(p); + ASSERT_EQ(v + 0, actual); + p += sizeof(uint64); + + actual = DecodeFixed64(p); + ASSERT_EQ(v + 1, actual); + p += sizeof(uint64); + } +} + +// Test that encoding routines generate little-endian encodings +TEST(Coding, EncodingOutput) { + char dst[8]; + EncodeFixed32(dst, 0x04030201); + ASSERT_EQ(0x01, static_cast<int>(dst[0])); + ASSERT_EQ(0x02, static_cast<int>(dst[1])); + ASSERT_EQ(0x03, static_cast<int>(dst[2])); + ASSERT_EQ(0x04, static_cast<int>(dst[3])); + + EncodeFixed64(dst, 0x0807060504030201ull); + ASSERT_EQ(0x01, static_cast<int>(dst[0])); + ASSERT_EQ(0x02, static_cast<int>(dst[1])); + ASSERT_EQ(0x03, static_cast<int>(dst[2])); + ASSERT_EQ(0x04, static_cast<int>(dst[3])); + ASSERT_EQ(0x05, static_cast<int>(dst[4])); + ASSERT_EQ(0x06, static_cast<int>(dst[5])); + ASSERT_EQ(0x07, static_cast<int>(dst[6])); + ASSERT_EQ(0x08, static_cast<int>(dst[7])); +} + +TEST(Coding, Varint32) { + string s; + for (uint32 i = 0; i < (32 * 32); i++) { + uint32 v = (i / 32) << (i % 32); + PutVarint32(&s, v); + } + + const char* p = s.data(); + const char* limit = p + s.size(); + for (uint32 i = 0; i < (32 * 32); i++) { + uint32 expected = (i / 32) << (i % 32); + uint32 actual; + p = GetVarint32Ptr(p, limit, &actual); + ASSERT_TRUE(p != NULL); + ASSERT_EQ(expected, actual); + } + ASSERT_EQ(p, s.data() + s.size()); +} + +TEST(Coding, Varint64) { + // Construct the list of values to check + std::vector<uint64> values; + // Some special values + values.push_back(0); + values.push_back(100); + values.push_back(~static_cast<uint64>(0)); + values.push_back(~static_cast<uint64>(0) - 1); + for (uint32 k = 0; k < 64; k++) { + // Test values near powers of two + const uint64 power = 1ull << k; + values.push_back(power); + values.push_back(power - 1); + values.push_back(power + 1); + } + + string s; + for (size_t i = 0; i < values.size(); i++) { + PutVarint64(&s, values[i]); + } + + const char* p = s.data(); + const char* limit = p + s.size(); + for (size_t i = 0; i < values.size(); i++) { + ASSERT_TRUE(p < limit); + uint64 actual; + p = GetVarint64Ptr(p, limit, &actual); + ASSERT_TRUE(p != NULL); + ASSERT_EQ(values[i], actual); + } + ASSERT_EQ(p, limit); +} + +TEST(Coding, Varint32Overflow) { + uint32 result; + string input("\x81\x82\x83\x84\x85\x11"); + ASSERT_TRUE(GetVarint32Ptr(input.data(), input.data() + input.size(), + &result) == NULL); +} + +TEST(Coding, Varint32Truncation) { + uint32 large_value = (1u << 31) + 100; + string s; + PutVarint32(&s, large_value); + uint32 result; + for (size_t len = 0; len < s.size() - 1; len++) { + ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + len, &result) == NULL); + } + ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + s.size(), &result) != NULL); + ASSERT_EQ(large_value, result); +} + +TEST(Coding, Varint64Overflow) { + uint64 result; + string input("\x81\x82\x83\x84\x85\x81\x82\x83\x84\x85\x11"); + ASSERT_TRUE(GetVarint64Ptr(input.data(), input.data() + input.size(), + &result) == NULL); +} + +TEST(Coding, Varint64Truncation) { + uint64 large_value = (1ull << 63) + 100ull; + string s; + PutVarint64(&s, large_value); + uint64 result; + for (size_t len = 0; len < s.size() - 1; len++) { + ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + len, &result) == NULL); + } + ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + s.size(), &result) != NULL); + ASSERT_EQ(large_value, result); +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/command_line_flags.cc b/tensorflow/core/lib/core/command_line_flags.cc new file mode 100644 index 0000000000..0f1072ffaa --- /dev/null +++ b/tensorflow/core/lib/core/command_line_flags.cc @@ -0,0 +1,94 @@ +#include "tensorflow/core/lib/core/command_line_flags.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { +namespace { + +// Templated function to convert a string to target values. +// Return true if the conversion is successful. Otherwise, return false. +template <typename T> +bool StringToValue(const string& content, T* value); + +template <> +bool StringToValue<int32>(const string& content, int* value) { + return str_util::NumericParse32(content, value); +} + +// Parse a single argument by linearly searching through the command table. +// The input format is: --argument=value. +// Return OK if the argument is used. It store the extracted value into the +// matching flag. +// Return NOT_FOUND if the argument is not recognized. +// Retrun INVALID_ARGUMENT if the command is recognized, but fails to extract +// its value. +template <typename T> +Status ParseArgument(const string& argument) { + for (auto& command : + internal::CommandLineFlagRegistry<int>::Instance()->commands) { + string prefix = strings::StrCat("--", command.name, "="); + if (tensorflow::StringPiece(argument).starts_with(prefix)) { + string content = argument.substr(prefix.length()); + if (StringToValue<T>(content, command.value)) { + return Status::OK(); + } + return Status(error::INVALID_ARGUMENT, + strings::StrCat("Cannot parse integer in: ", argument)); + } + } + return Status(error::NOT_FOUND, + strings::StrCat("Unknown command: ", argument)); +} + +// A specialization for booleans. The input format is: +// "--argument" or "--noargument". +// Parse a single argument by linearly searching through the command table. +// Return OK if the argument is used. The value is stored in the matching flag. +// Return NOT_FOUND if the argument is not recognized. +template <> +Status ParseArgument<bool>(const string& argument) { + for (auto& command : + internal::CommandLineFlagRegistry<bool>::Instance()->commands) { + if (argument == strings::StrCat("--", command.name)) { + *command.value = true; + return Status::OK(); + } else if (argument == strings::StrCat("--no", command.name)) { + *command.value = false; + return Status::OK(); + } + } + return Status(error::NOT_FOUND, + strings::StrCat("Unknown command: ", argument)); +} +} // namespace + +Status ParseCommandLineFlags(int* argc, char* argv[]) { + int unused_argc = 1; + for (int index = 1; index < *argc; ++index) { + Status s; + // Search bool commands. + s = ParseArgument<bool>(argv[index]); + if (s.ok()) { + continue; + } + if (s.code() != error::NOT_FOUND) { + return s; + } + // Search int32 commands. + s = ParseArgument<int32>(argv[index]); + if (s.ok()) { + continue; + } + if (s.code() != error::NOT_FOUND) { + return s; + } + // Pointer swap the unused argument to the front. + std::swap(argv[unused_argc++], argv[index]); + } + *argc = unused_argc; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/command_line_flags.h b/tensorflow/core/lib/core/command_line_flags.h new file mode 100644 index 0000000000..f1a94c11f9 --- /dev/null +++ b/tensorflow/core/lib/core/command_line_flags.h @@ -0,0 +1,60 @@ +#ifndef TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace internal { + +template <typename T> +struct CommandLineFlagRegistry { + static CommandLineFlagRegistry* Instance() { + static CommandLineFlagRegistry instance_; + return &instance_; + } + struct Command { + string name; + T* value; + string text; + }; + std::vector<Command> commands; + + private: + CommandLineFlagRegistry() {} + TF_DISALLOW_COPY_AND_ASSIGN(CommandLineFlagRegistry); +}; + +template <typename T> +struct CommandLineFlagRegister { + CommandLineFlagRegister(const string& name, T* val, const string& text) { + CommandLineFlagRegistry<T>::Instance()->commands.push_back( + {name, val, text}); + } +}; + +#define TF_DEFINE_variable(type, name, default_value, text) \ + type FLAGS_##name = default_value; \ + namespace TF_flags_internal { \ + tensorflow::internal::CommandLineFlagRegister<type> \ + TF_flags_internal_var_##name(#name, &FLAGS_##name, text); \ + } // namespace TF_flags_internal + +} // namespace internal + +#define TF_DEFINE_int32(name, default_value, text) \ + TF_DEFINE_variable(int32, name, default_value, text); + +#define TF_DEFINE_bool(name, default_value, text) \ + TF_DEFINE_variable(bool, name, default_value, text); + +// Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv. +// Returned the number of unused arguments in *argc. +// Return error Status if the parsing encounters errors. +// TODO(opensource): switch to a command line argument parser that can be +// shared with other tests. +Status ParseCommandLineFlags(int* argc, char* argv[]); + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/core/lib/core/error_codes.proto b/tensorflow/core/lib/core/error_codes.proto new file mode 100644 index 0000000000..6735fd8f88 --- /dev/null +++ b/tensorflow/core/lib/core/error_codes.proto @@ -0,0 +1,145 @@ +syntax = "proto3"; + +package tensorflow.error; +// option cc_enable_arenas = true; + +// The canonical error codes for TensorFlow APIs. +// +// Warnings: +// +// - Do not change any numeric assignments. +// - Changes to this list should only be made if there is a compelling +// need that can't be satisfied in another way. Such changes +// must be approved by at least two OWNERS. +// +// Sometimes multiple error codes may apply. Services should return +// the most specific error code that applies. For example, prefer +// OUT_OF_RANGE over FAILED_PRECONDITION if both codes apply. +// Similarly prefer NOT_FOUND or ALREADY_EXISTS over FAILED_PRECONDITION. +enum Code { + // Not an error; returned on success + OK = 0; + + // The operation was cancelled (typically by the caller). + CANCELLED = 1; + + // Unknown error. An example of where this error may be returned is + // if a Status value received from another address space belongs to + // an error-space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + UNKNOWN = 2; + + // Client specified an invalid argument. Note that this differs + // from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + INVALID_ARGUMENT = 3; + + // Deadline expired before operation could complete. For operations + // that change the state of the system, this error may be returned + // even if the operation has completed successfully. For example, a + // successful response from a server could have been delayed long + // enough for the deadline to expire. + DEADLINE_EXCEEDED = 4; + + // Some requested entity (e.g., file or directory) was not found. + // For privacy reasons, this code *may* be returned when the client + // does not have the access right to the entity. + NOT_FOUND = 5; + + // Some entity that we attempted to create (e.g., file or directory) + // already exists. + ALREADY_EXISTS = 6; + + // The caller does not have permission to execute the specified + // operation. PERMISSION_DENIED must not be used for rejections + // caused by exhausting some resource (use RESOURCE_EXHAUSTED + // instead for those errors). PERMISSION_DENIED must not be + // used if the caller can not be identified (use UNAUTHENTICATED + // instead for those errors). + PERMISSION_DENIED = 7; + + // The request does not have valid authentication credentials for the + // operation. + UNAUTHENTICATED = 16; + + // Some resource has been exhausted, perhaps a per-user quota, or + // perhaps the entire file system is out of space. + RESOURCE_EXHAUSTED = 8; + + // Operation was rejected because the system is not in a state + // required for the operation's execution. For example, directory + // to be deleted may be non-empty, an rmdir operation is applied to + // a non-directory, etc. + // + // A litmus test that may help a service implementor in deciding + // between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE: + // (a) Use UNAVAILABLE if the client can retry just the failing call. + // (b) Use ABORTED if the client should retry at a higher-level + // (e.g., restarting a read-modify-write sequence). + // (c) Use FAILED_PRECONDITION if the client should not retry until + // the system state has been explicitly fixed. E.g., if an "rmdir" + // fails because the directory is non-empty, FAILED_PRECONDITION + // should be returned since the client should not retry unless + // they have first fixed up the directory by deleting files from it. + // (d) Use FAILED_PRECONDITION if the client performs conditional + // REST Get/Update/Delete on a resource and the resource on the + // server does not match the condition. E.g., conflicting + // read-modify-write on the same resource. + FAILED_PRECONDITION = 9; + + // The operation was aborted, typically due to a concurrency issue + // like sequencer check failures, transaction aborts, etc. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + ABORTED = 10; + + // Operation was attempted past the valid range. E.g., seeking or + // reading past end of file. + // + // Unlike INVALID_ARGUMENT, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate INVALID_ARGUMENT if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // OUT_OF_RANGE if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between FAILED_PRECONDITION and + // OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an OUT_OF_RANGE error to detect when + // they are done. + OUT_OF_RANGE = 11; + + // Operation is not implemented or not supported/enabled in this service. + UNIMPLEMENTED = 12; + + // Internal errors. Means some invariants expected by underlying + // system has been broken. If you see one of these errors, + // something is very broken. + INTERNAL = 13; + + // The service is currently unavailable. This is a most likely a + // transient condition and may be corrected by retrying with + // a backoff. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + UNAVAILABLE = 14; + + // Unrecoverable data loss or corruption. + DATA_LOSS = 15; + + // An extra enum entry to prevent people from writing code that + // fails to compile when a new code is added. + // + // Nobody should ever reference this enumeration entry. In particular, + // if you write C++ code that switches on this enumeration, add a default: + // case instead of a case that mentions this enumeration entry. + // + // Nobody should rely on the value (currently 20) listed here. It + // may change in the future. + DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ = 20; +} diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h new file mode 100644 index 0000000000..b0badd8c4d --- /dev/null +++ b/tensorflow/core/lib/core/errors.h @@ -0,0 +1,131 @@ +#ifndef TENSORFLOW_LIB_CORE_ERRORS_H_ +#define TENSORFLOW_LIB_CORE_ERRORS_H_ + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace errors { + +typedef ::tensorflow::error::Code Code; + +// Append some context to an error message. Each time we append +// context put it on a new line, since it is possible for there +// to be several layers of additional context. +template <typename... Args> +void AppendToMessage(::tensorflow::Status* status, Args... args) { + *status = ::tensorflow::Status( + status->code(), + strings::StrCat(status->error_message(), "\n\t", args...)); +} + +// For propagating errors when calling a function. +#define TF_RETURN_IF_ERROR(expr) \ + do { \ + const ::tensorflow::Status _status = (expr); \ + if (TF_PREDICT_FALSE(!_status.ok())) return _status; \ + } while (0) + +#define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \ + do { \ + ::tensorflow::Status _status = (expr); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + ::tensorflow::errors::AppendToMessage(&_status, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// Convenience functions for generating and using error status. +// Example usage: +// status.Update(errors::InvalidArgument("The ", foo, " isn't right.")); +// if (errors::IsInvalidArgument(status)) { ... } +// switch (status.code()) { case error::INVALID_ARGUMENT: ... } + +#define DECLARE_ERROR(FUNC, CONST) \ + template <typename... Args> \ + inline ::tensorflow::Status FUNC(Args... args) { \ + return ::tensorflow::Status(::tensorflow::error::CONST, \ + strings::StrCat(args...)); \ + } \ + inline bool Is##FUNC(const ::tensorflow::Status& status) { \ + return status.code() == ::tensorflow::error::CONST; \ + } + +DECLARE_ERROR(Cancelled, CANCELLED) +DECLARE_ERROR(InvalidArgument, INVALID_ARGUMENT) +DECLARE_ERROR(NotFound, NOT_FOUND) +DECLARE_ERROR(AlreadyExists, ALREADY_EXISTS) +DECLARE_ERROR(ResourceExhausted, RESOURCE_EXHAUSTED) +DECLARE_ERROR(Unavailable, UNAVAILABLE) +DECLARE_ERROR(FailedPrecondition, FAILED_PRECONDITION) +DECLARE_ERROR(OutOfRange, OUT_OF_RANGE) +DECLARE_ERROR(Unimplemented, UNIMPLEMENTED) +DECLARE_ERROR(Internal, INTERNAL) +DECLARE_ERROR(Aborted, ABORTED) +DECLARE_ERROR(DeadlineExceeded, DEADLINE_EXCEEDED) +DECLARE_ERROR(DataLoss, DATA_LOSS) +DECLARE_ERROR(Unknown, UNKNOWN) +DECLARE_ERROR(PermissionDenied, PERMISSION_DENIED) +DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED) + +#undef DECLARE_ERROR + +// The CanonicalCode() for non-errors. +using ::tensorflow::error::OK; + +// Convenience macros for asserting and handling exceptional conditions. +// Analogous to the CHECK* macros provided by logging.h. +// +// Example use: +// void Compute(OperationContext* context) { +// OP_REQUIRES(context, context->num_inputs() == 2, +// errors::InvalidArgument("FooOp requires 2 arguments")); +// ... +// Status status = SomeUncertainMethod(); +// OP_REQUIRES_OK(context, status); +// ... +// } + +#define OP_REQUIRES(CTX, EXP, STATUS) \ + if (!(EXP)) { \ + ::tensorflow::Status _s(STATUS); \ + VLOG(1) << _s; \ + (CTX)->SetStatus(_s); \ + return; \ + } + +#define OP_REQUIRES_OK(CTX, STATUS) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!_s.ok()) { \ + LOG(WARNING) << _s; \ + (CTX)->SetStatus(_s); \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ + if (!(EXP)) { \ + ::tensorflow::Status _s(STATUS); \ + VLOG(1) << _s; \ + (CTX)->SetStatus(_s); \ + (CALLBACK)(); \ + return; \ + } + +#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!_s.ok()) { \ + LOG(WARNING) << _s; \ + (CTX)->SetStatus(_s); \ + (CALLBACK)(); \ + return; \ + } \ + } while (0) + +} // namespace errors +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_ERRORS_H_ diff --git a/tensorflow/core/lib/core/notification.h b/tensorflow/core/lib/core/notification.h new file mode 100644 index 0000000000..071e24285a --- /dev/null +++ b/tensorflow/core/lib/core/notification.h @@ -0,0 +1,42 @@ +#ifndef TENSORFLOW_UTIL_NOTIFICATION_H_ +#define TENSORFLOW_UTIL_NOTIFICATION_H_ + +#include <assert.h> + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class Notification { + public: + Notification() : notified_(false) {} + ~Notification() {} + + void Notify() { + mutex_lock l(mu_); + assert(!notified_); + notified_ = true; + cv_.notify_all(); + } + + bool HasBeenNotified() { + mutex_lock l(mu_); + return notified_; + } + + void WaitForNotification() { + mutex_lock l(mu_); + while (!notified_) { + cv_.wait(l); + } + } + + private: + mutex mu_; + condition_variable cv_; + bool notified_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_NOTIFICATION_H_ diff --git a/tensorflow/core/lib/core/notification_test.cc b/tensorflow/core/lib/core/notification_test.cc new file mode 100644 index 0000000000..a9e8942f05 --- /dev/null +++ b/tensorflow/core/lib/core/notification_test.cc @@ -0,0 +1,64 @@ +#include <gtest/gtest.h> + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace { + +TEST(NotificationTest, TestSingleNotification) { + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "test", 1); + + int counter = 0; + Notification start; + Notification proceed; + thread_pool->Schedule([&start, &proceed, &counter] { + start.Notify(); + proceed.WaitForNotification(); + ++counter; + }); + + // Wait for the thread to start + start.WaitForNotification(); + + // The thread should be waiting for the 'proceed' notification. + EXPECT_EQ(0, counter); + + // Unblock the thread + proceed.Notify(); + + delete thread_pool; // Wait for closure to finish. + + // Verify the counter has been incremented + EXPECT_EQ(1, counter); +} + +TEST(NotificationTest, TestMultipleThreadsWaitingOnNotification) { + const int num_closures = 4; + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "test", num_closures); + + mutex lock; + int counter = 0; + Notification n; + + for (int i = 0; i < num_closures; ++i) { + thread_pool->Schedule([&n, &lock, &counter] { + n.WaitForNotification(); + mutex_lock l(lock); + ++counter; + }); + } + sleep(1); + + EXPECT_EQ(0, counter); + + n.Notify(); + delete thread_pool; // Wait for all closures to finish. + EXPECT_EQ(4, counter); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/raw_coding.h b/tensorflow/core/lib/core/raw_coding.h new file mode 100644 index 0000000000..1fe49b75bb --- /dev/null +++ b/tensorflow/core/lib/core/raw_coding.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_LIB_CORE_RAW_CODING_H_ +#define TENSORFLOW_LIB_CORE_RAW_CODING_H_ + +#include <string.h> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace core { + +// Lower-level versions of Get... that read directly from a character buffer +// without any bounds checking. + +inline uint32 DecodeFixed32(const char* ptr) { + if (port::kLittleEndian) { + // Load the raw bytes + uint32 result; + memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load + return result; + } else { + return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) | + (static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) | + (static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) | + (static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24)); + } +} + +inline uint64 DecodeFixed64(const char* ptr) { + if (port::kLittleEndian) { + // Load the raw bytes + uint64 result; + memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load + return result; + } else { + uint64 lo = DecodeFixed32(ptr); + uint64 hi = DecodeFixed32(ptr + 4); + return (hi << 32) | lo; + } +} + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_RAW_CODING_H_ diff --git a/tensorflow/core/lib/core/refcount.cc b/tensorflow/core/lib/core/refcount.cc new file mode 100644 index 0000000000..3ed8c58eb8 --- /dev/null +++ b/tensorflow/core/lib/core/refcount.cc @@ -0,0 +1,35 @@ +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace core { + +RefCounted::RefCounted() : ref_(1) {} + +RefCounted::~RefCounted() { DCHECK_EQ(ref_.load(), 0); } + +void RefCounted::Ref() const { + DCHECK_GE(ref_.load(), 1); + ref_.fetch_add(1, std::memory_order_relaxed); +} + +bool RefCounted::Unref() const { + DCHECK_GT(ref_.load(), 0); + // If ref_==1, this object is owned only by the caller. Bypass a locked op + // in that case. + if (ref_.load(std::memory_order_acquire) == 1 || ref_.fetch_sub(1) == 1) { + // Make DCHECK in ~RefCounted happy + DCHECK((ref_.store(0), true)); + delete this; + return true; + } else { + return false; + } +} + +bool RefCounted::RefCountIsOne() const { + return (ref_.load(std::memory_order_acquire) == 1); +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/refcount.h b/tensorflow/core/lib/core/refcount.h new file mode 100644 index 0000000000..f727750f9e --- /dev/null +++ b/tensorflow/core/lib/core/refcount.h @@ -0,0 +1,63 @@ +#ifndef TENSORFLOW_LIB_CORE_REFCOUNT_H_ +#define TENSORFLOW_LIB_CORE_REFCOUNT_H_ + +#include <atomic> + +namespace tensorflow { +namespace core { + +class RefCounted { + public: + // Initial reference count is one. + RefCounted(); + + // Increments reference count by one. + void Ref() const; + + // Decrements reference count by one. If the count remains + // positive, returns false. When the count reaches zero, returns + // true and deletes this, in which case the caller must not access + // the object afterward. + bool Unref() const; + + // Return whether the reference count is one. + // If the reference count is used in the conventional way, a + // reference count of 1 implies that the current thread owns the + // reference and no other thread shares it. + // This call performs the test for a reference count of one, and + // performs the memory barrier needed for the owning thread + // to act on the object, knowing that it has exclusive access to the + // object. + bool RefCountIsOne() const; + + protected: + // Make destructor protected so that RefCounted objects cannot + // be instantiated directly. Only subclasses can be instantiated. + virtual ~RefCounted(); + + private: + mutable std::atomic_int_fast32_t ref_; + + RefCounted(const RefCounted&) = delete; + void operator=(const RefCounted&) = delete; +}; + +// Helper class to unref an object when out-of-scope. +class ScopedUnref { + public: + explicit ScopedUnref(RefCounted* o) : obj_(o) {} + ~ScopedUnref() { + if (obj_) obj_->Unref(); + } + + private: + RefCounted* obj_; + + ScopedUnref(const ScopedUnref&) = delete; + void operator=(const ScopedUnref&) = delete; +}; + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_REFCOUNT_H_ diff --git a/tensorflow/core/lib/core/refcount_test.cc b/tensorflow/core/lib/core/refcount_test.cc new file mode 100644 index 0000000000..c042be2d61 --- /dev/null +++ b/tensorflow/core/lib/core/refcount_test.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/lib/core/refcount.h" + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace core { +namespace { + +static int constructed = 0; +static int destroyed = 0; + +class MyRef : public RefCounted { + public: + MyRef() { constructed++; } + ~MyRef() override { destroyed++; } +}; + +class RefTest : public testing::Test { + public: + RefTest() { + constructed = 0; + destroyed = 0; + } +}; + +TEST_F(RefTest, New) { + MyRef* ref = new MyRef; + ASSERT_EQ(1, constructed); + ASSERT_EQ(0, destroyed); + ref->Unref(); + ASSERT_EQ(1, constructed); + ASSERT_EQ(1, destroyed); +} + +TEST_F(RefTest, RefUnref) { + MyRef* ref = new MyRef; + ASSERT_EQ(1, constructed); + ASSERT_EQ(0, destroyed); + ref->Ref(); + ASSERT_EQ(0, destroyed); + ref->Unref(); + ASSERT_EQ(0, destroyed); + ref->Unref(); + ASSERT_EQ(1, destroyed); +} + +TEST_F(RefTest, RefCountOne) { + MyRef* ref = new MyRef; + ASSERT_TRUE(ref->RefCountIsOne()); + ref->Unref(); +} + +TEST_F(RefTest, RefCountNotOne) { + MyRef* ref = new MyRef; + ref->Ref(); + ASSERT_FALSE(ref->RefCountIsOne()); + ref->Unref(); + ref->Unref(); +} + +TEST_F(RefTest, ConstRefUnref) { + const MyRef* cref = new MyRef; + ASSERT_EQ(1, constructed); + ASSERT_EQ(0, destroyed); + cref->Ref(); + ASSERT_EQ(0, destroyed); + cref->Unref(); + ASSERT_EQ(0, destroyed); + cref->Unref(); + ASSERT_EQ(1, destroyed); +} + +TEST_F(RefTest, ReturnOfUnref) { + MyRef* ref = new MyRef; + ref->Ref(); + EXPECT_FALSE(ref->Unref()); + EXPECT_TRUE(ref->Unref()); +} + +TEST_F(RefTest, ScopedUnref) { + { ScopedUnref unref(new MyRef); } + EXPECT_EQ(destroyed, 1); +} + +TEST_F(RefTest, ScopedUnref_Nullptr) { + { ScopedUnref unref(nullptr); } + EXPECT_EQ(destroyed, 0); +} + +} // namespace +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/status.cc b/tensorflow/core/lib/core/status.cc new file mode 100644 index 0000000000..24ce842560 --- /dev/null +++ b/tensorflow/core/lib/core/status.cc @@ -0,0 +1,107 @@ +#include "tensorflow/core/public/status.h" +#include <stdio.h> + +namespace tensorflow { + +Status::Status(tensorflow::error::Code code, StringPiece msg) { + assert(code != tensorflow::error::OK); + state_ = new State; + state_->code = code; + state_->msg = msg.ToString(); +} +Status::~Status() { delete state_; } + +void Status::Update(const Status& new_status) { + if (ok()) { + *this = new_status; + } +} + +void Status::SlowCopyFrom(const State* src) { + delete state_; + if (src == nullptr) { + state_ = nullptr; + } else { + state_ = new State(*src); + } +} + +const string& Status::empty_string() { + static string* empty = new string; + return *empty; +} + +string Status::ToString() const { + if (state_ == NULL) { + return "OK"; + } else { + char tmp[30]; + const char* type; + switch (code()) { + case tensorflow::error::CANCELLED: + type = "Cancelled"; + break; + case tensorflow::error::UNKNOWN: + type = "Unknown"; + break; + case tensorflow::error::INVALID_ARGUMENT: + type = "Invalid argument"; + break; + case tensorflow::error::DEADLINE_EXCEEDED: + type = "Deadline exceeded"; + break; + case tensorflow::error::NOT_FOUND: + type = "Not found"; + break; + case tensorflow::error::ALREADY_EXISTS: + type = "Already exists"; + break; + case tensorflow::error::PERMISSION_DENIED: + type = "Permission denied"; + break; + case tensorflow::error::UNAUTHENTICATED: + type = "Unauthenticated"; + break; + case tensorflow::error::RESOURCE_EXHAUSTED: + type = "Resource exhausted"; + break; + case tensorflow::error::FAILED_PRECONDITION: + type = "Failed precondition"; + break; + case tensorflow::error::ABORTED: + type = "Aborted"; + break; + case tensorflow::error::OUT_OF_RANGE: + type = "Out of range"; + break; + case tensorflow::error::UNIMPLEMENTED: + type = "Unimplemented"; + break; + case tensorflow::error::INTERNAL: + type = "Internal"; + break; + case tensorflow::error::UNAVAILABLE: + type = "Unavailable"; + break; + case tensorflow::error::DATA_LOSS: + type = "Data loss"; + break; + default: + snprintf(tmp, sizeof(tmp), "Unknown code(%d)", + static_cast<int>(code())); + type = tmp; + break; + } + string result(type); + result += ": "; + result += state_->msg; + return result; + } +} + +std::ostream& operator<<(std::ostream& os, const Status& x) { + os << x.ToString(); + return os; +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/status_test.cc b/tensorflow/core/lib/core/status_test.cc new file mode 100644 index 0000000000..3ef6b3302a --- /dev/null +++ b/tensorflow/core/lib/core/status_test.cc @@ -0,0 +1,84 @@ +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +TEST(Status, OK) { + EXPECT_EQ(Status::OK().code(), error::OK); + EXPECT_EQ(Status::OK().error_message(), ""); + EXPECT_OK(Status::OK()); + ASSERT_OK(Status::OK()); + EXPECT_EQ(Status::OK(), Status()); + Status s; + EXPECT_TRUE(s.ok()); +} + +TEST(DeathStatus, CheckOK) { + Status status(errors::InvalidArgument("Invalid")); + ASSERT_DEATH(TF_CHECK_OK(status), "Invalid"); +} + +TEST(Status, Set) { + Status status; + status = Status(error::CANCELLED, "Error message"); + EXPECT_EQ(status.code(), error::CANCELLED); + EXPECT_EQ(status.error_message(), "Error message"); +} + +TEST(Status, Copy) { + Status a(errors::InvalidArgument("Invalid")); + Status b(a); + ASSERT_EQ(a.ToString(), b.ToString()); +} + +TEST(Status, Assign) { + Status a(errors::InvalidArgument("Invalid")); + Status b; + b = a; + ASSERT_EQ(a.ToString(), b.ToString()); +} + +TEST(Status, Update) { + Status s; + s.Update(Status::OK()); + ASSERT_TRUE(s.ok()); + Status a(errors::InvalidArgument("Invalid")); + s.Update(a); + ASSERT_EQ(s.ToString(), a.ToString()); + Status b(errors::Internal("Internal")); + s.Update(b); + ASSERT_EQ(s.ToString(), a.ToString()); + s.Update(Status::OK()); + ASSERT_EQ(s.ToString(), a.ToString()); + ASSERT_FALSE(s.ok()); +} + +TEST(Status, EqualsOK) { ASSERT_EQ(Status::OK(), Status()); } + +TEST(Status, EqualsSame) { + Status a(errors::InvalidArgument("Invalid")); + Status b(errors::InvalidArgument("Invalid")); + ASSERT_EQ(a, b); +} + +TEST(Status, EqualsCopy) { + const Status a(errors::InvalidArgument("Invalid")); + const Status b = a; + ASSERT_EQ(a, b); +} + +TEST(Status, EqualsDifferentCode) { + const Status a(errors::InvalidArgument("message")); + const Status b(errors::Internal("message")); + ASSERT_NE(a, b); +} + +TEST(Status, EqualsDifferentMessage) { + const Status a(errors::InvalidArgument("message")); + const Status b(errors::InvalidArgument("another")); + ASSERT_NE(a, b); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h new file mode 100644 index 0000000000..b3b4db429f --- /dev/null +++ b/tensorflow/core/lib/core/status_test_util.h @@ -0,0 +1,20 @@ +#ifndef TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_ +#define TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_ + +#include <gtest/gtest.h> +#include "tensorflow/core/public/status.h" + +// Macros for testing the results of functions that return util::Status. + +#define EXPECT_OK(statement) EXPECT_EQ(::tensorflow::Status::OK(), (statement)) +#define ASSERT_OK(statement) ASSERT_EQ(::tensorflow::Status::OK(), (statement)) + +// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not +// provide much value (when they fail, they would just print the OK status +// which conveys no more information than EXPECT_FALSE(status.ok()); +// If you want to check for particular errors, better alternatives are: +// EXPECT_EQ(::util::Status(...expected error...), status.StripMessage()); +// EXPECT_THAT(status.ToString(), HasSubstr("expected error")); +// Also, see testing/lib/util/status_util.h. + +#endif // TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_ diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc new file mode 100644 index 0000000000..57c5139f47 --- /dev/null +++ b/tensorflow/core/lib/core/stringpiece.cc @@ -0,0 +1,57 @@ +#include "tensorflow/core/lib/core/stringpiece.h" + +#include <iostream> +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +size_t StringPiece::Hasher::operator()(StringPiece s) const { + return Hash64(s.data(), s.size()); +} + +std::ostream& operator<<(std::ostream& o, StringPiece piece) { + o.write(piece.data(), piece.size()); + return o; +} + +bool StringPiece::contains(StringPiece s) const { + return memmem(data_, size_, s.data_, s.size_) != nullptr; +} + +size_t StringPiece::find(char c, size_t pos) const { + if (pos >= size_) { + return npos; + } + const char* result = + reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos)); + return result != NULL ? result - data_ : npos; +} + +// Search range is [0..pos] inclusive. If pos == npos, search everything. +size_t StringPiece::rfind(char c, size_t pos) const { + if (size_ == 0) return npos; + for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) { + if (*p == c) { + return p - data_; + } + } + return npos; +} + +bool StringPiece::Consume(StringPiece x) { + if (starts_with(x)) { + remove_prefix(x.size_); + return true; + } + return false; +} + +StringPiece StringPiece::substr(size_t pos, size_t n) const { + if (pos > size_) pos = size_; + if (n > size_ - pos) n = size_ - pos; + return StringPiece(data_ + pos, n); +} + +const StringPiece::size_type StringPiece::npos = size_type(-1); + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h new file mode 100644 index 0000000000..17d4b294e9 --- /dev/null +++ b/tensorflow/core/lib/core/stringpiece.h @@ -0,0 +1,159 @@ +// StringPiece is a simple structure containing a pointer into some external +// storage and a size. The user of a StringPiece must ensure that the slice +// is not used after the corresponding external storage has been +// deallocated. +// +// Multiple threads can invoke const methods on a StringPiece without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same StringPiece must use +// external synchronization. + +#ifndef TENSORFLOW_LIB_CORE_STRINGPIECE_H_ +#define TENSORFLOW_LIB_CORE_STRINGPIECE_H_ + +#include <assert.h> +#include <stddef.h> +#include <string.h> +#include <iosfwd> +#include <string> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class StringPiece { + public: + typedef size_t size_type; + + // Create an empty slice. + StringPiece() : data_(""), size_(0) {} + + // Create a slice that refers to d[0,n-1]. + StringPiece(const char* d, size_t n) : data_(d), size_(n) {} + + // Create a slice that refers to the contents of "s" + StringPiece(const string& s) : data_(s.data()), size_(s.size()) {} + + // Create a slice that refers to s[0,strlen(s)-1] + StringPiece(const char* s) : data_(s), size_(strlen(s)) {} + + void set(const void* data, size_t len) { + data_ = reinterpret_cast<const char*>(data); + size_ = len; + } + + // Return a pointer to the beginning of the referenced data + const char* data() const { return data_; } + + // Return the length (in bytes) of the referenced data + size_t size() const { return size_; } + + // Return true iff the length of the referenced data is zero + bool empty() const { return size_ == 0; } + + typedef const char* const_iterator; + typedef const char* iterator; + iterator begin() const { return data_; } + iterator end() const { return data_ + size_; } + + static const size_t npos; + + // Return the ith byte in the referenced data. + // REQUIRES: n < size() + char operator[](size_t n) const { + assert(n < size()); + return data_[n]; + } + + // Change this slice to refer to an empty array + void clear() { + data_ = ""; + size_ = 0; + } + + // Drop the first "n" bytes from this slice. + void remove_prefix(size_t n) { + assert(n <= size()); + data_ += n; + size_ -= n; + } + + void remove_suffix(size_t n) { + assert(size_ >= n); + size_ -= n; + } + + size_t find(char c, size_t pos = 0) const; + size_t rfind(char c, size_t pos = npos) const; + bool contains(StringPiece s) const; + + // Checks whether StringPiece starts with x and if so advances the beginning + // of it to past the match. It's basically a shortcut for starts_with + // followed by remove_prefix. + bool Consume(StringPiece x); + + StringPiece substr(size_t pos, size_t n = npos) const; + + struct Hasher { + size_t operator()(StringPiece arg) const; + }; + + // Return a string that contains the copy of the referenced data. + std::string ToString() const { return std::string(data_, size_); } + + // Three-way comparison. Returns value: + // < 0 iff "*this" < "b", + // == 0 iff "*this" == "b", + // > 0 iff "*this" > "b" + int compare(StringPiece b) const; + + // Return true iff "x" is a prefix of "*this" + bool starts_with(StringPiece x) const { + return ((size_ >= x.size_) && (memcmp(data_, x.data_, x.size_) == 0)); + } + // Return true iff "x" is a suffix of "*this" + bool ends_with(StringPiece x) const { + return ((size_ >= x.size_) && + (memcmp(data_ + (size_ - x.size_), x.data_, x.size_) == 0)); + } + + private: + const char* data_; + size_t size_; + + // Intentionally copyable +}; + +inline bool operator==(StringPiece x, StringPiece y) { + return ((x.size() == y.size()) && + (memcmp(x.data(), y.data(), x.size()) == 0)); +} + +inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } + +inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; } +inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; } +inline bool operator<=(StringPiece x, StringPiece y) { + return x.compare(y) <= 0; +} +inline bool operator>=(StringPiece x, StringPiece y) { + return x.compare(y) >= 0; +} + +inline int StringPiece::compare(StringPiece b) const { + const size_t min_len = (size_ < b.size_) ? size_ : b.size_; + int r = memcmp(data_, b.data_, min_len); + if (r == 0) { + if (size_ < b.size_) + r = -1; + else if (size_ > b.size_) + r = +1; + } + return r; +} + +// allow StringPiece to be logged +extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece); + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_STRINGPIECE_H_ diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc new file mode 100644 index 0000000000..e9b84d3102 --- /dev/null +++ b/tensorflow/core/lib/core/threadpool.cc @@ -0,0 +1,108 @@ +#include "tensorflow/core/lib/core/threadpool.h" + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" + +namespace tensorflow { +namespace thread { + +struct ThreadPool::Waiter { + condition_variable cv; + bool ready; +}; + +ThreadPool::ThreadPool(Env* env, const string& name, int num_threads) + : ThreadPool(env, ThreadOptions(), name, num_threads) {} + +ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads) + : name_(name) { + CHECK_GE(num_threads, 1); + string name_prefix = "tf_" + name_; + for (int i = 0; i < num_threads; i++) { + threads_.push_back(env->StartThread(thread_options, name_prefix, + [this]() { WorkerLoop(); })); + } +} + +ThreadPool::~ThreadPool() { + { + // Wait for all work to get done. + mutex_lock l(mu_); + + // Inform every thread to exit. + for (size_t i = 0; i < threads_.size(); ++i) { + pending_.push_back({nullptr, 0}); + } + + // Wakeup all waiters. + for (auto w : waiters_) { + w->ready = true; + w->cv.notify_one(); + } + } + + // Wait for threads to finish. + for (auto t : threads_) { + delete t; + } +} + +bool ThreadPool::HasPendingClosures() const { + mutex_lock l(mu_); + return pending_.size() != 0; +} + +void ThreadPool::Schedule(std::function<void()> fn) { + CHECK(fn != nullptr); + uint64 id = 0; + if (port::Tracing::IsActive()) { + id = port::Tracing::UniqueId(); + port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure, + id); + } + + mutex_lock l(mu_); + pending_.push_back({fn, id}); + if (!waiters_.empty()) { + Waiter* w = waiters_.back(); + waiters_.pop_back(); + w->ready = true; + w->cv.notify_one(); + } +} + +void ThreadPool::WorkerLoop() { + port::Tracing::RegisterCurrentThread(name_.c_str()); + mutex_lock l(mu_); + Waiter w; + while (true) { + while (pending_.empty()) { + // Wait for work to be assigned to me + w.ready = false; + waiters_.push_back(&w); + while (!w.ready) { + w.cv.wait(l); + } + } + // Pick up pending work + Item item = pending_.front(); + pending_.pop_front(); + if (item.fn == nullptr) { + break; + } + mu_.unlock(); + if (item.id != 0) { + port::Tracing::ScopedActivity region( + port::Tracing::EventCategory::kRunClosure, item.id); + item.fn(); + } else { + item.fn(); + } + mu_.lock(); + } +} + +} // namespace thread +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h new file mode 100644 index 0000000000..5cf780fa86 --- /dev/null +++ b/tensorflow/core/lib/core/threadpool.h @@ -0,0 +1,59 @@ +#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_ +#define TENSORFLOW_LIB_CORE_THREADPOOL_H_ + +#include <deque> +#include <functional> +#include <thread> +#include <vector> +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace thread { + +class ThreadPool { + public: + // Construct a pool that contains "num_threads" threads with specified "name". + // env->StartThread() is used to create individual threads. + // + // REQUIRES: num_threads > 0 + ThreadPool(Env* env, const string& name, int num_threads); + + // Construct a pool that contains "num_threads" threads with specified "name". + // env->StartThread() is used to create individual threads. + // + // REQUIRES: num_threads > 0 + ThreadPool(Env* env, const ThreadOptions& thread_options, const string& name, + int num_threads); + + // Wait until all scheduled work has finished and then destroy the + // set of threads. + virtual ~ThreadPool(); + + // Schedule fn() for execution in the pool of threads. + virtual void Schedule(std::function<void()> fn); + + virtual bool HasPendingClosures() const; + + private: + struct Waiter; + struct Item { + std::function<void()> fn; + uint64 id; + }; + + void WorkerLoop(); + + const string name_; + mutable mutex mu_; + std::vector<Thread*> threads_; // All threads + std::vector<Waiter*> waiters_; // Stack of waiting threads. + std::deque<Item> pending_; // Queue of pending work + + TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool); +}; + +} // namespace thread +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_THREADPOOL_H_ diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc new file mode 100644 index 0000000000..f4909c445c --- /dev/null +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -0,0 +1,93 @@ +#include "tensorflow/core/lib/core/threadpool.h" + +#include <atomic> + +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/env.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace thread { + +static const int kNumThreads = 30; + +TEST(ThreadPool, Empty) { + for (int num_threads = 1; num_threads < kNumThreads; num_threads++) { + fprintf(stderr, "Testing with %d threads\n", num_threads); + ThreadPool pool(Env::Default(), "test", num_threads); + } +} + +TEST(ThreadPool, DoWork) { + for (int num_threads = 1; num_threads < kNumThreads; num_threads++) { + fprintf(stderr, "Testing with %d threads\n", num_threads); + const int kWorkItems = 15; + bool work[kWorkItems]; + for (int i = 0; i < kWorkItems; i++) { + work[i] = false; + } + { + ThreadPool pool(Env::Default(), "test", num_threads); + for (int i = 0; i < kWorkItems; i++) { + pool.Schedule([&work, i]() { + ASSERT_FALSE(work[i]); + work[i] = true; + }); + } + } + for (int i = 0; i < kWorkItems; i++) { + ASSERT_TRUE(work[i]); + } + } +} + +static void BM_Sequential(int iters) { + ThreadPool pool(Env::Default(), "test", kNumThreads); + // Decrement count sequentially until 0. + int count = iters; + mutex done_lock; + condition_variable done; + bool done_flag = false; + std::function<void()> work = [&pool, &count, &done_lock, &done, &done_flag, + &work]() { + if (count--) { + pool.Schedule(work); + } else { + mutex_lock l(done_lock); + done_flag = true; + done.notify_all(); + } + }; + work(); + mutex_lock l(done_lock); + if (!done_flag) { + done.wait(l); + } +} +BENCHMARK(BM_Sequential); + +static void BM_Parallel(int iters) { + ThreadPool pool(Env::Default(), "test", kNumThreads); + // Decrement count concurrently until 0. + std::atomic_int_fast32_t count(iters); + mutex done_lock; + condition_variable done; + bool done_flag = false; + for (int i = 0; i < iters; ++i) { + pool.Schedule([&count, &done_lock, &done, &done_flag]() { + if (count.fetch_sub(1) == 1) { + mutex_lock l(done_lock); + done_flag = true; + done.notify_all(); + } + }); + } + mutex_lock l(done_lock); + if (!done_flag) { + done.wait(l); + } +} +BENCHMARK(BM_Parallel); + +} // namespace thread +} // namespace tensorflow |