diff options
author | 2018-09-05 16:38:33 -0700 | |
---|---|---|
committer | 2018-09-05 16:43:49 -0700 | |
commit | 7dfc0756439aede05ec471193780a4de9f61874e (patch) | |
tree | 89bd16a8f8a36b664f98877c863cb868fba2f521 | |
parent | 2c8bc1587e9480a44c10146d0e9472c1d6f9c7d7 (diff) |
Propagate eager output tensor types in TFLite
PiperOrigin-RevId: 211721354
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/delegate_test.cc | 20 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/kernel.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/test_util.cc | 43 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/test_util.h | 28 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/util.cc | 36 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/util.h | 13 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/util_test.cc | 38 |
7 files changed, 145 insertions, 35 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc index eb47f46c0b..984f8bbc98 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc @@ -72,6 +72,26 @@ TEST_F(DelegateTest, FullGraph) { ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + ASSERT_EQ(GetType(8), kTfLiteFloat32); +} + +TEST_F(DelegateTest, NonFloatTypeInference) { + AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2}); + + AddTfOp(testing::kAdd, {0, 1}, {2}); + + ConfigureDelegate(); + + SetShape(0, {2, 2}); + SetTypedValues<int>(0, {1, 2, 3, 4}); + SetShape(1, {2, 2}); + SetTypedValues<int>(1, {4, 3, 2, 1}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(2), ElementsAre(2, 2)); + ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5)); + ASSERT_EQ(GetType(2), kTfLiteInt32); } TEST_F(DelegateTest, MixedGraph) { diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc index f8467c7cb2..0ee4db1ffb 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.cc +++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc @@ -278,7 +278,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* tensor = &context->tensors[tensor_index]; TF_LITE_ENSURE_OK( context, - CopyShape(context, buffer_map->GetTensor(tensor_index), tensor)); + CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor)); tensor->buffer_handle = tensor_index; tensor->data_is_stale = true; } diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc index b8c9e2652a..8584999ace 100644 --- a/tensorflow/contrib/lite/delegates/eager/test_util.cc +++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc @@ -25,19 +25,6 @@ namespace testing { bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; } -void EagerModelTest::SetValues(int tensor_index, - const std::vector<float>& values) { - float* v = interpreter_->typed_tensor<float>(tensor_index); - for (float f : values) { - *v++ = f; - } -} - -std::vector<float> EagerModelTest::GetValues(int tensor_index) { - TfLiteTensor* o = interpreter_->tensor(tensor_index); - return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float)); -} - void EagerModelTest::SetShape(int tensor_index, const std::vector<int>& values) { ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk); @@ -54,13 +41,21 @@ std::vector<int> EagerModelTest::GetShape(int tensor_index) { return result; } +TfLiteType EagerModelTest::GetType(int tensor_index) { + return interpreter_->tensor(tensor_index)->type; +} + void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs, const std::vector<int>& outputs, - const TfLiteType& type, - const std::vector<int>& dims) { + TfLiteType type, const std::vector<int>& dims) { interpreter_->AddTensors(num_tensors); for (int i = 0; i < num_tensors; ++i) { TfLiteQuantizationParams quant; + // Suppress explicit output type specification to ensure type inference + // works properly. + if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) { + type = kTfLiteFloat32; + } CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type, /*name=*/"", /*dims=*/dims, quant), @@ -101,18 +96,26 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs, return " attr{ key: '" + key + "' value {" + value + "}}"; }; + // Crude type attribution, will need fleshing out as more tests are added. + // TODO(b/113613439): Use nodedef string utilities to properly handle + // all types. + string type_attribute = attr("T", "type: DT_FLOAT"); + if (interpreter_->tensor(inputs[0])->type == kTfLiteInt32) { + type_attribute = attr("T", "type: DT_INT32"); + } + if (op == kUnpack) { - string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") + - attr("axis", "i: 0"); + string attributes = + type_attribute + attr("num", "i: 2") + attr("axis", "i: 0"); AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs); } else if (op == kIdentity) { - string attributes = attr("T", "type: DT_FLOAT"); + string attributes = type_attribute; AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs); } else if (op == kAdd) { - string attributes = attr("T", "type: DT_FLOAT"); + string attributes = type_attribute; AddTfOp("EagerAdd", "Add", attributes, inputs, outputs); } else if (op == kMul) { - string attributes = attr("T", "type: DT_FLOAT"); + string attributes = type_attribute; AddTfOp("EagerMul", "Mul", attributes, inputs, outputs); } else if (op == kNonExistent) { AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h index 0eab9e1135..816db41931 100644 --- a/tensorflow/contrib/lite/delegates/eager/test_util.h +++ b/tensorflow/contrib/lite/delegates/eager/test_util.h @@ -44,11 +44,30 @@ class EagerModelTest : public ::testing::Test { bool Invoke(); + // Sets the (typed) tensor's values at the given index. + template <typename T> + void SetTypedValues(int tensor_index, const std::vector<T>& values) { + memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(), + values.size() * sizeof(T)); + } + + // Returns the (typed) tensor's values at the given index. + template <typename T> + std::vector<T> GetTypedValues(int tensor_index) { + const TfLiteTensor* t = interpreter_->tensor(tensor_index); + const T* tdata = interpreter_->typed_tensor<T>(tensor_index); + return std::vector<T>(tdata, tdata + t->bytes / sizeof(T)); + } + // Sets the tensor's values at the given index. - void SetValues(int tensor_index, const std::vector<float>& values); + void SetValues(int tensor_index, const std::vector<float>& values) { + SetTypedValues<float>(tensor_index, values); + } // Returns the tensor's values at the given index. - std::vector<float> GetValues(int tensor_index); + std::vector<float> GetValues(int tensor_index) { + return GetTypedValues<float>(tensor_index); + } // Sets the tensor's shape at the given index. void SetShape(int tensor_index, const std::vector<int>& values); @@ -56,13 +75,16 @@ class EagerModelTest : public ::testing::Test { // Returns the tensor's shape at the given index. std::vector<int> GetShape(int tensor_index); + // Returns the tensor's type at the given index. + TfLiteType GetType(int tensor_index); + const TestErrorReporter& error_reporter() const { return error_reporter_; } // Adds `num_tensor` tensors to the model. `inputs` contains the indices of // the input tensors and `outputs` contains the indices of the output // tensors. All tensors are set to have `type` and `dims`. void AddTensors(int num_tensors, const std::vector<int>& inputs, - const std::vector<int>& outputs, const TfLiteType& type, + const std::vector<int>& outputs, TfLiteType type, const std::vector<int>& dims); // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc index 4426c653e6..051246bf86 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.cc +++ b/tensorflow/contrib/lite/delegates/eager/util.cc @@ -26,8 +26,17 @@ TfLiteStatus ConvertStatus(TfLiteContext* context, return kTfLiteOk; } -TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, - TfLiteTensor* tensor) { +TfLiteStatus CopyShapeAndType(TfLiteContext* context, + const tensorflow::Tensor& src, + TfLiteTensor* tensor) { + tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype())); + if (tensor->type == kTfLiteNoType) { + context->ReportError(context, + "TF Lite does not support TensorFlow data type: %s", + DataTypeString(src.dtype()).c_str()); + return kTfLiteError; + } + int num_dims = src.dims(); TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims); for (int j = 0; j < num_dims; ++j) { @@ -68,5 +77,28 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { } } +TfLiteType GetTensorFlowLiteType(TF_DataType type) { + switch (type) { + case TF_FLOAT: + return kTfLiteFloat32; + case TF_INT16: + return kTfLiteInt16; + case TF_INT32: + return kTfLiteInt32; + case TF_UINT8: + return kTfLiteUInt8; + case TF_INT64: + return kTfLiteInt64; + case TF_COMPLEX64: + return kTfLiteComplex64; + case TF_STRING: + return kTfLiteString; + case TF_BOOL: + return kTfLiteBool; + default: + return kTfLiteNoType; + } +} + } // namespace eager } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h index a9407be071..ff500d18f3 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.h +++ b/tensorflow/contrib/lite/delegates/eager/util.h @@ -28,14 +28,19 @@ namespace eager { TfLiteStatus ConvertStatus(TfLiteContext* context, const tensorflow::Status& status); -// Copies the given shape of the given 'src' into a TF Lite 'tensor'. Logs an -// error and returns kTfLiteError if the shape can't be converted. -TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, - TfLiteTensor* tensor); +// Copies the given shape and type of the TensorFlow 'src' tensor into a TF Lite +// 'tensor'. Logs an error and returns kTfLiteError if the shape or type can't +// be converted. +TfLiteStatus CopyShapeAndType(TfLiteContext* context, + const tensorflow::Tensor& src, + TfLiteTensor* tensor); // Returns the TF C API Data type that corresponds to the given TfLiteType. TF_DataType GetTensorFlowDataType(TfLiteType type); +// Returns the TfLiteType that corresponds to the given TF C API Data type. +TfLiteType GetTensorFlowLiteType(TF_DataType); + } // namespace eager } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc index 53378a1eaf..aebc91149c 100644 --- a/tensorflow/contrib/lite/delegates/eager/util_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc @@ -26,6 +26,7 @@ namespace eager { namespace { using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; using tensorflow::Tensor; using ::testing::ElementsAre; @@ -71,27 +72,41 @@ TEST(UtilTest, ConvertStatus) { EXPECT_TRUE(context.error.empty()); } -TEST(UtilTest, CopyShape) { +TEST(UtilTest, CopyShapeAndType) { TestContext context; context.ReportError = ReportError; context.ResizeTensor = ResizeTensor; TfLiteTensor dst; - EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk); + EXPECT_EQ(CopyShapeAndType(&context, Tensor(), &dst), kTfLiteOk); EXPECT_THAT(context.new_size, ElementsAre(0)); + EXPECT_EQ(dst.type, kTfLiteFloat32); - EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1, 2}), &dst), kTfLiteOk); + EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1, 2}), &dst), + kTfLiteOk); EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + EXPECT_EQ(dst.type, kTfLiteFloat32); - EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst), + EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_INT32, {1, 2}), &dst), + kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + EXPECT_EQ(dst.type, kTfLiteInt32); + + EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst), kTfLiteError); EXPECT_EQ(context.error, "Dimension value in TensorFlow shape is larger than supported by " "TF Lite"); + + EXPECT_EQ( + CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst), + kTfLiteError); + EXPECT_EQ(context.error, + "TF Lite does not support TensorFlow data type: half"); } -TEST(UtilTest, TypeConversions) { +TEST(UtilTest, TypeConversionsFromTFLite) { EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType)); EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32)); EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16)); @@ -103,6 +118,19 @@ TEST(UtilTest, TypeConversions) { EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool)); } +TEST(UtilTest, TypeConversionsFromTensorFlow) { + EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT)); + EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16)); + EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32)); + EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8)); + EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64)); + EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64)); + EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING)); + EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL)); + EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE)); + EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT)); +} + } // namespace } // namespace eager } // namespace tflite |