aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc16
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc50
2 files changed, 60 insertions, 6 deletions
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index a52e4f940c..ebaee909c8 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -779,10 +779,15 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
// hard for the caller of the tensor bundle module to allocate these
// precisely-shaped scratch storage.
- // Optimization for the common case: stored slice == to-restore slice.
- // TODO(zongheng): also include the case where "slice_spec" is full ("-"),
- // and "stored_slice" is logically full but contains actual extents.
- if (stored_slice == slice_spec) {
+ // Optimization for the common case: the stored slice can be directly
+ // copied to the destination without additional slicing. This is true when
+ // either the slices are equal or when they are both full slices having the
+ // same shape.
+ TensorShape stored_slice_shape(stored_slice_entry.shape());
+ if (stored_slice == slice_spec ||
+ (stored_slice_shape == val->shape() &&
+ IsFullSlice(stored_slice, stored_slice_shape) &&
+ IsFullSlice(slice_spec, stored_slice_shape))) {
VLOG(1) << "Optimized for common case: directly copying into "
"pre-allocated buffer; spec: "
<< slice_spec.DebugString();
@@ -790,8 +795,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
return status_;
}
- Tensor stored_slice_tensor(stored_slice_entry.dtype(),
- TensorShape(stored_slice_entry.shape()));
+ Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
if (!status_.ok()) return status_;
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 3b775e4000..3d1ae10816 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -306,6 +306,56 @@ TEST(TensorBundleTest, PartitionedVariables) {
}
}
+TEST(TensorBundleTest, EquivalentSliceTest) {
+ const TensorShape kFullShape({5, 10});
+ const Tensor kExpected(Constant<float>(1., kFullShape));
+ {
+ BundleWriter writer(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape,
+ TensorSlice::ParseOrDie("-:-"), kExpected));
+ TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape,
+ TensorSlice::ParseOrDie("0,5:0,10"),
+ kExpected));
+ TF_ASSERT_OK(writer.Finish());
+ }
+ // Slices match exactly and are fully abbreviated.
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
+ Tensor val(DT_FLOAT, TensorShape(kFullShape));
+ TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
+ test::ExpectTensorEqual<float>(val, kExpected);
+ }
+ // Slice match exactly and are fully specified.
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
+ Tensor val(DT_FLOAT, TensorShape(kFullShape));
+ TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
+ test::ExpectTensorEqual<float>(val, kExpected);
+ }
+ // Stored slice has no extents, spec has extents.
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
+ Tensor val(DT_FLOAT, TensorShape(kFullShape));
+ TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
+ test::ExpectTensorEqual<float>(val, kExpected);
+ }
+ // Stored slice has both extents, spec has no extents.
+ {
+ BundleReader reader(Env::Default(), Prefix("foo"));
+ TF_ASSERT_OK(reader.status());
+ const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
+ Tensor val(DT_FLOAT, TensorShape(kFullShape));
+ TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
+ test::ExpectTensorEqual<float>(val, kExpected);
+ }
+}
+
TEST(TensorBundleTest, NonStandardShapes) {
TestNonStandardShapes<float>();
TestNonStandardShapes<double>();