diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-16 15:25:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-16 16:44:52 -0700 |
commit | c092f31cb54f4964da6ea0476df926902a8fe4f9 (patch) | |
tree | eed70ff5d5371b265817bd868ddddc67b942f13a | |
parent | 433c8c89d2daadc7c19bbb3dcabe9d8afcfd03df (diff) |
[Tensorflow] Expose API to lookup TensorSlice.
Change: 150384503
-rw-r--r-- | tensorflow/core/util/tensor_bundle/tensor_bundle.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_bundle/tensor_bundle.h | 6 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc | 21 |
3 files changed, 34 insertions, 5 deletions
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 6e46ca6f75..a62b32fcfd 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -712,6 +712,18 @@ Status BundleReader::Lookup(StringPiece key, Tensor* val) { } } +Status BundleReader::LookupTensorSlices(StringPiece key, + std::vector<TensorSlice>* slices) { + slices->clear(); + BundleEntryProto entry; + TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); + slices->reserve(entry.slices_size()); + for (const auto& slice : entry.slices()) { + slices->emplace_back(slice); + } + return Status::OK(); +} + Status BundleReader::LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec, Tensor* val) { BundleEntryProto entry; diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index 6a8104ade5..bca3910f59 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -207,6 +207,12 @@ class BundleReader { // REQUIRES: status().ok() Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT; + // Looks up the slices of the tensor keyed by "key". On OK, "slices" + // is non-empty if and only if the tensor is a partitioned tensor. + // REQUIRES: status().ok() + Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices) + TF_MUST_USE_RESULT; + // Looks up a specific slice of a partitioned tensor. // It is only required that the stored slices cover the requested slice, // namely "slice_spec" is a subset of the union of the stored slices. diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index 3d1ae10816..de8576b55a 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -245,15 +245,14 @@ TEST(TensorBundleTest, PartitionedVariables) { // Adds two slices. // First slice: column 0, all zeros. // Second slice: column 1 to rest, all ones. + TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1"); + TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9"); { BundleWriter writer(Env::Default(), Prefix("foo")); - TensorSlice slice = TensorSlice::ParseOrDie("-:0,1"); - TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, - TensorSlice::ParseOrDie("-:0,1"), + TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1, Constant<float>(0., TensorShape({5, 1})))); - TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, - TensorSlice::ParseOrDie("-:1,9"), + TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2, Constant<float>(1., TensorShape({5, 9})))); TF_ASSERT_OK(writer.Finish()); } @@ -274,6 +273,18 @@ TEST(TensorBundleTest, PartitionedVariables) { TF_ASSERT_OK(reader.Lookup("foo", &val)); test::ExpectTensorEqual<float>(val, expected_val); } + // Reads all slices. + { + BundleReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + + std::vector<TensorSlice> slices; + TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices)); + + EXPECT_EQ(2, slices.size()); + EXPECT_EQ(slice1.DebugString(), slices[0].DebugString()); + EXPECT_EQ(slice2.DebugString(), slices[1].DebugString()); + } // Reads a slice consisting of first two columns, "cutting" both slices. { BundleReader reader(Env::Default(), Prefix("foo")); |