aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-16 15:25:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-16 16:44:52 -0700
commitc092f31cb54f4964da6ea0476df926902a8fe4f9 (patch)
treeeed70ff5d5371b265817bd868ddddc67b942f13a
parent433c8c89d2daadc7c19bbb3dcabe9d8afcfd03df (diff)
[Tensorflow] Expose API to lookup TensorSlice.
Change: 150384503
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc12
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.h6
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc21
3 files changed, 34 insertions, 5 deletions
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index 6e46ca6f75..a62b32fcfd 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -712,6 +712,18 @@ Status BundleReader::Lookup(StringPiece key, Tensor* val) {
}
}
+Status BundleReader::LookupTensorSlices(StringPiece key,
+ std::vector<TensorSlice>* slices) {
+ slices->clear();
+ BundleEntryProto entry;
+ TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
+ slices->reserve(entry.slices_size());
+ for (const auto& slice : entry.slices()) {
+ slices->emplace_back(slice);
+ }
+ return Status::OK();
+}
+
Status BundleReader::LookupSlice(StringPiece full_tensor_key,
const TensorSlice& slice_spec, Tensor* val) {
BundleEntryProto entry;
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
index 6a8104ade5..bca3910f59 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
@@ -207,6 +207,12 @@ class BundleReader {
// REQUIRES: status().ok()
Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT;
+ // Looks up the slices of the tensor keyed by "key". On OK, "slices"
+ // is non-empty if and only if the tensor is a partitioned tensor.
+ // REQUIRES: status().ok()
+ Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices)
+ TF_MUST_USE_RESULT;
+
// Looks up a specific slice of a partitioned tensor.
// It is only required that the stored slices cover the requested slice,
// namely "slice_spec" is a subset of the union of the stored slices.
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 3d1ae10816..de8576b55a 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -245,15 +245,14 @@ TEST(TensorBundleTest, PartitionedVariables) {
// Adds two slices.
// First slice: column 0, all zeros.
// Second slice: column 1 to rest, all ones.
+ TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1");
+ TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9");
{
BundleWriter writer(Env::Default(), Prefix("foo"));
- TensorSlice slice = TensorSlice::ParseOrDie("-:0,1");
- TF_ASSERT_OK(writer.AddSlice("foo", kFullShape,
- TensorSlice::ParseOrDie("-:0,1"),
+ TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1,
Constant<float>(0., TensorShape({5, 1}))));
- TF_ASSERT_OK(writer.AddSlice("foo", kFullShape,
- TensorSlice::ParseOrDie("-:1,9"),
+ TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2,
Constant<float>(1., TensorShape({5, 9}))));
TF_ASSERT_OK(writer.Finish());
}
@@ -274,6 +273,18 @@ TEST(TensorBundleTest, PartitionedVariables) {
TF_ASSERT_OK(reader.Lookup("foo", &val));
test::ExpectTensorEqual<float>(val, expected_val);
}
+ // Reads all slices.
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+
+ std::vector<TensorSlice> slices;
+ TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices));
+
+ EXPECT_EQ(2, slices.size());
+ EXPECT_EQ(slice1.DebugString(), slices[0].DebugString());
+ EXPECT_EQ(slice2.DebugString(), slices[1].DebugString());
+ }
// Reads a slice consisting of first two columns, "cutting" both slices.
{
BundleReader reader(Env::Default(), Prefix("foo"));