aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-16 09:02:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-16 09:08:53 -0800
commit02ff8b8d15807c27e9752d03093e0f38c5180c2a (patch)
treec088b98b2a03c1b6de6ad9df12b66fa5ba2d5e3c
parent09f6ebea267d6685d1f967fb32a8d6d38358acaa (diff)
Expanding the optimization that avoids making an extra copy when restoring slices to include the case when the source and destination have the same shape and are full slices with respect to their shape.
Change: 147726529
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc16
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc50
2 files changed, 60 insertions, 6 deletions
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index a52e4f940c..ebaee909c8 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -779,10 +779,15 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
// hard for the caller of the tensor bundle module to allocate these
// precisely-shaped scratch storage.
- // Optimization for the common case: stored slice == to-restore slice.
- // TODO(zongheng): also include the case where "slice_spec" is full ("-"),
- // and "stored_slice" is logically full but contains actual extents.
- if (stored_slice == slice_spec) {
+ // Optimization for the common case: the stored slice can be directly
+ // copied to the destination without additional slicing. This is true when
+ // either the slices are equal or when they are both full slices having the
+ // same shape.
+ TensorShape stored_slice_shape(stored_slice_entry.shape());
+ if (stored_slice == slice_spec ||
+ (stored_slice_shape == val->shape() &&
+ IsFullSlice(stored_slice, stored_slice_shape) &&
+ IsFullSlice(slice_spec, stored_slice_shape))) {
VLOG(1) << "Optimized for common case: directly copying into "
"pre-allocated buffer; spec: "
<< slice_spec.DebugString();
@@ -790,8 +795,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
return status_;
}
- Tensor stored_slice_tensor(stored_slice_entry.dtype(),
- TensorShape(stored_slice_entry.shape()));
+ Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
if (!status_.ok()) return status_;
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 3b775e4000..3d1ae10816 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -306,6 +306,56 @@ TEST(TensorBundleTest, PartitionedVariables) {
}
}
+TEST(TensorBundleTest, EquivalentSliceTest) {
+ const TensorShape kFullShape({5, 10});
+ const Tensor kExpected(Constant<float>(1., kFullShape));
+ {
+ BundleWriter writer(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape,
+ TensorSlice::ParseOrDie("-:-"), kExpected));
+ TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape,
+ TensorSlice::ParseOrDie("0,5:0,10"),
+ kExpected));
+ TF_ASSERT_OK(writer.Finish());
+ }
+ // Slices match exactly and are fully abbreviated.
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
+ Tensor val(DT_FLOAT, TensorShape(kFullShape));
+ TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
+ test::ExpectTensorEqual<float>(val, kExpected);
+ }
+ // Slice match exactly and are fully specified.
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
+ Tensor val(DT_FLOAT, TensorShape(kFullShape));
+ TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
+ test::ExpectTensorEqual<float>(val, kExpected);
+ }
+ // Stored slice has no extents, spec has extents.
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
+ Tensor val(DT_FLOAT, TensorShape(kFullShape));
+ TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
+ test::ExpectTensorEqual<float>(val, kExpected);
+ }
+ // Stored slice has both extents, spec has no extents.
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
+ Tensor val(DT_FLOAT, TensorShape(kFullShape));
+ TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
+ test::ExpectTensorEqual<float>(val, kExpected);
+ }
+}
+
TEST(TensorBundleTest, NonStandardShapes) {
TestNonStandardShapes<float>();
TestNonStandardShapes<double>();