diff options
author | 2018-01-11 12:42:35 -0800 | |
---|---|---|
committer | 2018-01-11 12:46:30 -0800 | |
commit | d4de943eb47c6d096c910074a545128bb16b224d (patch) | |
tree | c9ad034e86a7eb0391a2707e2bf9fe423d206876 | |
parent | ec1ca2419ef33af62f5e7865d901981a96dbf6c9 (diff) |
Add option to specify output alignment for BundleWriter.
PiperOrigin-RevId: 181647536
-rw-r--r-- | tensorflow/core/util/tensor_bundle/tensor_bundle.cc | 20 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_bundle/tensor_bundle.h | 12 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc | 88 |
3 files changed, 118 insertions, 2 deletions
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index f030983c02..5e1a640472 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -345,10 +345,27 @@ table::Options TableBuilderOptions() { return o; } +// Writes zeros to output buffer to align the next write to the requested +// alignment. "size" is the current size of the buffer and is updated to the +// new size. +Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) { + int bytes_over = *size % alignment; + if (bytes_over == 0) { + return Status::OK(); + } + int bytes_to_write = alignment - bytes_over; + Status status = out->Append(string(bytes_to_write, '\0')); + if (status.ok()) { + *size += bytes_to_write; + } + return status; +} + } // namespace -BundleWriter::BundleWriter(Env* env, StringPiece prefix) +BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options) : env_(env), + options_(options), prefix_(prefix.ToString()), tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate", random::New64())), @@ -402,6 +419,7 @@ Status BundleWriter::Add(StringPiece key, const Tensor& val) { entry->set_size(data_bytes_written); entry->set_crc32c(crc32c::Mask(crc32c)); size_ += data_bytes_written; + status_ = PadAlignment(out_.get(), options_.data_alignment, &size_); } return status_; } diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index 129646cb69..02db764025 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -107,7 +107,14 @@ extern const char* const kHeaderEntryKey; // All threads accessing the same BundleWriter must synchronize. class BundleWriter { public: - BundleWriter(Env* env, StringPiece prefix); + struct Options { + Options() {} + // Alignment, in bytes, for tensor data. + // Must be >= 1. The default size of 1 densely packs tensors. + int data_alignment{1}; + }; + BundleWriter(Env* env, StringPiece prefix, + const Options& options = Options()); // Adds the tensor "val" under key "key". // Across calls "key" must be unique but can be added in any order. @@ -140,6 +147,7 @@ class BundleWriter { private: Env* const env_; // Not owned. + const Options options_; const string prefix_; const string tmp_metadata_path_; const string tmp_data_path_; @@ -297,6 +305,8 @@ class BundleReader { // TODO(b/64763924): Remove after Jan 1st 2018. bool lenient_names_; + friend class TensorBundleAlignmentTest; // For testing data alignment. + TF_DISALLOW_COPY_AND_ASSIGN(BundleReader); }; diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index 341aae36f4..08f1aa7125 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/io/table_builder.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { @@ -770,4 +771,91 @@ TEST(TensorBundleTest, VersionTest) { } } +class TensorBundleAlignmentTest : public ::testing::Test { + protected: + template <typename T> + void ExpectAlignment(BundleReader* reader, const string& key, int alignment) { + BundleEntryProto full_tensor_entry; + TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry)); + EXPECT_EQ(0, full_tensor_entry.offset() % alignment); + } +}; + +TEST_F(TensorBundleAlignmentTest, AlignmentTest) { + { + BundleWriter::Options opts; + opts.data_alignment = 42; + BundleWriter writer(Env::Default(), Prefix("foo"), opts); + TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3))); + TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0))); + TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2))); + TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1))); + TF_ASSERT_OK(writer.Finish()); + } + { + BundleReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ( + AllTensorKeys(&reader), + std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"})); + Expect<float>(&reader, "foo_000", Constant_2x3<float>(0)); + Expect<float>(&reader, "foo_001", Constant_2x3<float>(1)); + Expect<float>(&reader, "foo_002", Constant_2x3<float>(2)); + Expect<float>(&reader, "foo_003", Constant_2x3<float>(3)); + } + { + BundleReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + ExpectNext<float>(&reader, Constant_2x3<float>(0)); + ExpectNext<float>(&reader, Constant_2x3<float>(1)); + ExpectNext<float>(&reader, Constant_2x3<float>(2)); + ExpectNext<float>(&reader, Constant_2x3<float>(3)); + EXPECT_TRUE(reader.Valid()); + reader.Next(); + EXPECT_FALSE(reader.Valid()); + } + { + BundleReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + ExpectAlignment<float>(&reader, "foo_000", 42); + ExpectAlignment<float>(&reader, "foo_001", 42); + ExpectAlignment<float>(&reader, "foo_002", 42); + ExpectAlignment<float>(&reader, "foo_003", 42); + } +} + +static void BM_BundleAlignmentByteOff(int iters, int alignment, + int tensor_size) { + testing::StopTiming(); + { + BundleWriter::Options opts; + opts.data_alignment = alignment; + BundleWriter writer(Env::Default(), Prefix("foo"), opts); + TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1})))); + TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size})))); + TF_CHECK_OK(writer.Finish()); + } + BundleReader reader(Env::Default(), Prefix("foo")); + TF_CHECK_OK(reader.status()); + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + Tensor t; + TF_CHECK_OK(reader.Lookup("big", &t)); + } + testing::StopTiming(); +} + +#define BM_BundleAlignment(ALIGN, SIZE) \ + static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \ + BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \ + } \ + BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE) + +BM_BundleAlignment(1, 512); +BM_BundleAlignment(1, 4096); +BM_BundleAlignment(1, 1048576); +BM_BundleAlignment(4096, 512); +BM_BundleAlignment(4096, 4096); +BM_BundleAlignment(4096, 1048576); + } // namespace tensorflow |