diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-19 10:02:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-19 10:05:35 -0700 |
commit | 7f449920f8910561a4e57cc35b96fb7faf08ef98 (patch) | |
tree | 036ee28b10d59da6bb38a1e59625a5620b94924f /tensorflow/contrib | |
parent | 5fc2bdd2d5f624a6bad9e83b992029e3799ab64e (diff) |
Refresh allocations in the presence of dynamic tensors
PiperOrigin-RevId: 201193941
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/lite/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 10 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.h | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter_test.cc | 59 |
4 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 9c804d2785..8c17c65fcc 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -184,6 +184,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/schema:schema_fbs", diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 3287f9c4fd..57b2c0f32b 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -605,9 +605,17 @@ TfLiteStatus Interpreter::Invoke() { } EnsureTensorsVectorCapacity(); + tensor_resized_since_op_invoke_ = false; if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } + + // Force execution prep for downstream ops if the latest op triggered the + // resize of a dynamic tensor. + if (tensor_resized_since_op_invoke_ && + HasDynamicTensor(context_, node.outputs)) { + next_execution_plan_index_to_prepare_ = execution_plan_index + 1; + } } if (!allow_buffer_handle_output_) { @@ -783,6 +791,8 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, if (tensor->allocation_type == kTfLiteArenaRw || tensor->allocation_type == kTfLiteDynamic || tensor->allocation_type == kTfLiteArenaRwPersistent) { + tensor_resized_since_op_invoke_ |= + TfLiteIntArrayEqual(tensor->dims, new_size) == 0; if (tensor->type != kTfLiteString) { size_t bytesRequired; TfLiteStatus status = BytesRequired(tensor->type, new_size->data, diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 37961cd1dc..436c1007af 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -589,6 +589,11 @@ class Interpreter { bool allow_buffer_handle_output_ = false; + // Tracking bit for whether a tensor was resized in the course of an op + // invocation. This is a useful hint to ensure that dynamic tensor outputs + // trigger downstream reallocation after op invocation. + bool tensor_resized_since_op_invoke_ = false; + // Profiler for this interpreter instance. profiling::Profiler* profiler_; }; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index b977cb089c..21cdf87d1e 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -23,6 +23,12 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { +namespace ops { +namespace builtin { +TfLiteRegistration* Register_PADV2(); +TfLiteRegistration* Register_NEG(); +} // namespace builtin +} // namespace ops namespace { // Make an interpreter that has no tensors and no nodes @@ -615,6 +621,59 @@ TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError); } +TEST(BasicInterpreter, DynamicTensorsResizeDescendants) { + // Assemble a graph with a node that has dynamically sized output (via the + // pad op), followed by a node with a standard element-wise op (negate). + Interpreter interpreter; + interpreter.AddTensors(4); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({3}); + TfLiteQuantizationParams quant; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {2, 2, 1, 1}, + quant); + interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "", {4, 2}, quant); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {}, quant); + interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {}, quant); + + TfLiteRegistration* pad_op = tflite::ops::builtin::Register_PADV2(); + TfLiteRegistration* neg_op = tflite::ops::builtin::Register_NEG(); + interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, pad_op); + interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, neg_op); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // Configure [[2,2],[4,4]] padding and execute the graph. + interpreter.typed_tensor<int>(1)[0] = 2; + interpreter.typed_tensor<int>(1)[1] = 2; + interpreter.typed_tensor<int>(1)[2] = 2; + interpreter.typed_tensor<int>(1)[3] = 2; + interpreter.typed_tensor<int>(1)[4] = 0; + interpreter.typed_tensor<int>(1)[5] = 0; + interpreter.typed_tensor<int>(1)[6] = 0; + interpreter.typed_tensor<int>(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Both the output and intermediate tensor sizes should reflect the output + // from the dynamic pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 6 * 6); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 6 * 6); + + // Now configure [[4,4],[6,6]] padding and execute the graph. + interpreter.typed_tensor<int>(1)[0] = 4; + interpreter.typed_tensor<int>(1)[1] = 4; + interpreter.typed_tensor<int>(1)[2] = 6; + interpreter.typed_tensor<int>(1)[3] = 6; + interpreter.typed_tensor<int>(1)[4] = 0; + interpreter.typed_tensor<int>(1)[5] = 0; + interpreter.typed_tensor<int>(1)[6] = 0; + interpreter.typed_tensor<int>(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Again, the output and intermediate tensor sizes should reflect the *new* + // resize from the latest pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 10 * 14); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 10 * 14); +} + TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity), |