diff options
-rw-r--r-- | tensorflow/core/util/tensor_bundle/tensor_bundle.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc | 50 |
2 files changed, 60 insertions, 6 deletions
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index a52e4f940c..ebaee909c8 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -779,10 +779,15 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key, // hard for the caller of the tensor bundle module to allocate these // precisely-shaped scratch storage. - // Optimization for the common case: stored slice == to-restore slice. - // TODO(zongheng): also include the case where "slice_spec" is full ("-"), - // and "stored_slice" is logically full but contains actual extents. - if (stored_slice == slice_spec) { + // Optimization for the common case: the stored slice can be directly + // copied to the destination without additional slicing. This is true when + // either the slices are equal or when they are both full slices having the + // same shape. + TensorShape stored_slice_shape(stored_slice_entry.shape()); + if (stored_slice == slice_spec || + (stored_slice_shape == val->shape() && + IsFullSlice(stored_slice, stored_slice_shape) && + IsFullSlice(slice_spec, stored_slice_shape))) { VLOG(1) << "Optimized for common case: directly copying into " "pre-allocated buffer; spec: " << slice_spec.DebugString(); @@ -790,8 +795,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key, return status_; } - Tensor stored_slice_tensor(stored_slice_entry.dtype(), - TensorShape(stored_slice_entry.shape())); + Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape); status_ = GetValue(stored_slice_entry, &stored_slice_tensor); if (!status_.ok()) return status_; diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index 3b775e4000..3d1ae10816 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -306,6 +306,56 @@ TEST(TensorBundleTest, PartitionedVariables) { } } +TEST(TensorBundleTest, EquivalentSliceTest) { + const TensorShape kFullShape({5, 10}); + const Tensor kExpected(Constant<float>(1., kFullShape)); + { + BundleWriter writer(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape, + TensorSlice::ParseOrDie("-:-"), kExpected)); + TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape, + TensorSlice::ParseOrDie("0,5:0,10"), + kExpected)); + TF_ASSERT_OK(writer.Finish()); + } + // Slices match exactly and are fully abbreviated. + { + BundleReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + const TensorSlice slice = TensorSlice::ParseOrDie("-:-"); + Tensor val(DT_FLOAT, TensorShape(kFullShape)); + TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val)); + test::ExpectTensorEqual<float>(val, kExpected); + } + // Slice match exactly and are fully specified. + { + BundleReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10"); + Tensor val(DT_FLOAT, TensorShape(kFullShape)); + TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val)); + test::ExpectTensorEqual<float>(val, kExpected); + } + // Stored slice has no extents, spec has extents. + { + BundleReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10"); + Tensor val(DT_FLOAT, TensorShape(kFullShape)); + TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val)); + test::ExpectTensorEqual<float>(val, kExpected); + } + // Stored slice has both extents, spec has no extents. + { + BundleReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + const TensorSlice slice = TensorSlice::ParseOrDie("-:-"); + Tensor val(DT_FLOAT, TensorShape(kFullShape)); + TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val)); + test::ExpectTensorEqual<float>(val, kExpected); + } +} + TEST(TensorBundleTest, NonStandardShapes) { TestNonStandardShapes<float>(); TestNonStandardShapes<double>(); |