aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-11 12:42:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-11 12:46:30 -0800
commitd4de943eb47c6d096c910074a545128bb16b224d (patch)
treec9ad034e86a7eb0391a2707e2bf9fe423d206876
parentec1ca2419ef33af62f5e7865d901981a96dbf6c9 (diff)
Add option to specify output alignment for BundleWriter.
PiperOrigin-RevId: 181647536
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc20
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.h12
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc88
3 files changed, 118 insertions, 2 deletions
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index f030983c02..5e1a640472 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -345,10 +345,27 @@ table::Options TableBuilderOptions() {
return o;
}
+// Writes zeros to output buffer to align the next write to the requested
+// alignment. "size" is the current size of the buffer and is updated to the
+// new size.
+Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
+ int bytes_over = *size % alignment;
+ if (bytes_over == 0) {
+ return Status::OK();
+ }
+ int bytes_to_write = alignment - bytes_over;
+ Status status = out->Append(string(bytes_to_write, '\0'));
+ if (status.ok()) {
+ *size += bytes_to_write;
+ }
+ return status;
+}
+
} // namespace
-BundleWriter::BundleWriter(Env* env, StringPiece prefix)
+BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
: env_(env),
+ options_(options),
prefix_(prefix.ToString()),
tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate",
random::New64())),
@@ -402,6 +419,7 @@ Status BundleWriter::Add(StringPiece key, const Tensor& val) {
entry->set_size(data_bytes_written);
entry->set_crc32c(crc32c::Mask(crc32c));
size_ += data_bytes_written;
+ status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
}
return status_;
}
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
index 129646cb69..02db764025 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
@@ -107,7 +107,14 @@ extern const char* const kHeaderEntryKey;
// All threads accessing the same BundleWriter must synchronize.
class BundleWriter {
public:
- BundleWriter(Env* env, StringPiece prefix);
+ struct Options {
+ Options() {}
+ // Alignment, in bytes, for tensor data.
+ // Must be >= 1. The default size of 1 densely packs tensors.
+ int data_alignment{1};
+ };
+ BundleWriter(Env* env, StringPiece prefix,
+ const Options& options = Options());
// Adds the tensor "val" under key "key".
// Across calls "key" must be unique but can be added in any order.
@@ -140,6 +147,7 @@ class BundleWriter {
private:
Env* const env_; // Not owned.
+ const Options options_;
const string prefix_;
const string tmp_metadata_path_;
const string tmp_data_path_;
@@ -297,6 +305,8 @@ class BundleReader {
// TODO(b/64763924): Remove after Jan 1st 2018.
bool lenient_names_;
+ friend class TensorBundleAlignmentTest; // For testing data alignment.
+
TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
};
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 341aae36f4..08f1aa7125 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/table_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
@@ -770,4 +771,91 @@ TEST(TensorBundleTest, VersionTest) {
}
}
+class TensorBundleAlignmentTest : public ::testing::Test {
+ protected:
+ template <typename T>
+ void ExpectAlignment(BundleReader* reader, const string& key, int alignment) {
+ BundleEntryProto full_tensor_entry;
+ TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry));
+ EXPECT_EQ(0, full_tensor_entry.offset() % alignment);
+ }
+};
+
+TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
+ {
+ BundleWriter::Options opts;
+ opts.data_alignment = 42;
+ BundleWriter writer(Env::Default(), Prefix("foo"), opts);
+ TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3)));
+ TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0)));
+ TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2)));
+ TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1)));
+ TF_ASSERT_OK(writer.Finish());
+ }
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ EXPECT_EQ(
+ AllTensorKeys(&reader),
+ std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
+ Expect<float>(&reader, "foo_000", Constant_2x3<float>(0));
+ Expect<float>(&reader, "foo_001", Constant_2x3<float>(1));
+ Expect<float>(&reader, "foo_002", Constant_2x3<float>(2));
+ Expect<float>(&reader, "foo_003", Constant_2x3<float>(3));
+ }
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ ExpectNext<float>(&reader, Constant_2x3<float>(0));
+ ExpectNext<float>(&reader, Constant_2x3<float>(1));
+ ExpectNext<float>(&reader, Constant_2x3<float>(2));
+ ExpectNext<float>(&reader, Constant_2x3<float>(3));
+ EXPECT_TRUE(reader.Valid());
+ reader.Next();
+ EXPECT_FALSE(reader.Valid());
+ }
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ ExpectAlignment<float>(&reader, "foo_000", 42);
+ ExpectAlignment<float>(&reader, "foo_001", 42);
+ ExpectAlignment<float>(&reader, "foo_002", 42);
+ ExpectAlignment<float>(&reader, "foo_003", 42);
+ }
+}
+
+static void BM_BundleAlignmentByteOff(int iters, int alignment,
+ int tensor_size) {
+ testing::StopTiming();
+ {
+ BundleWriter::Options opts;
+ opts.data_alignment = alignment;
+ BundleWriter writer(Env::Default(), Prefix("foo"), opts);
+ TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1}))));
+ TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size}))));
+ TF_CHECK_OK(writer.Finish());
+ }
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_CHECK_OK(reader.status());
+ testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ Tensor t;
+ TF_CHECK_OK(reader.Lookup("big", &t));
+ }
+ testing::StopTiming();
+}
+
+#define BM_BundleAlignment(ALIGN, SIZE) \
+ static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \
+ BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \
+ } \
+ BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
+
+BM_BundleAlignment(1, 512);
+BM_BundleAlignment(1, 4096);
+BM_BundleAlignment(1, 1048576);
+BM_BundleAlignment(4096, 512);
+BM_BundleAlignment(4096, 4096);
+BM_BundleAlignment(4096, 1048576);
+
} // namespace tensorflow