aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-05 16:38:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 16:43:49 -0700
commit7dfc0756439aede05ec471193780a4de9f61874e (patch)
tree89bd16a8f8a36b664f98877c863cb868fba2f521
parent2c8bc1587e9480a44c10146d0e9472c1d6f9c7d7 (diff)
Propagate eager output tensor types in TFLite
PiperOrigin-RevId: 211721354
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc20
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.cc43
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.h28
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.cc36
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h13
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util_test.cc38
7 files changed, 145 insertions, 35 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index eb47f46c0b..984f8bbc98 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -72,6 +72,26 @@ TEST_F(DelegateTest, FullGraph) {
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+ ASSERT_EQ(GetType(8), kTfLiteFloat32);
+}
+
+TEST_F(DelegateTest, NonFloatTypeInference) {
+ AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
+
+ AddTfOp(testing::kAdd, {0, 1}, {2});
+
+ ConfigureDelegate();
+
+ SetShape(0, {2, 2});
+ SetTypedValues<int>(0, {1, 2, 3, 4});
+ SetShape(1, {2, 2});
+ SetTypedValues<int>(1, {4, 3, 2, 1});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
+ ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
+ ASSERT_EQ(GetType(2), kTfLiteInt32);
}
TEST_F(DelegateTest, MixedGraph) {
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
index f8467c7cb2..0ee4db1ffb 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -278,7 +278,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* tensor = &context->tensors[tensor_index];
TF_LITE_ENSURE_OK(
context,
- CopyShape(context, buffer_map->GetTensor(tensor_index), tensor));
+ CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor));
tensor->buffer_handle = tensor_index;
tensor->data_is_stale = true;
}
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc
index b8c9e2652a..8584999ace 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc
@@ -25,19 +25,6 @@ namespace testing {
bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
-void EagerModelTest::SetValues(int tensor_index,
- const std::vector<float>& values) {
- float* v = interpreter_->typed_tensor<float>(tensor_index);
- for (float f : values) {
- *v++ = f;
- }
-}
-
-std::vector<float> EagerModelTest::GetValues(int tensor_index) {
- TfLiteTensor* o = interpreter_->tensor(tensor_index);
- return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float));
-}
-
void EagerModelTest::SetShape(int tensor_index,
const std::vector<int>& values) {
ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
@@ -54,13 +41,21 @@ std::vector<int> EagerModelTest::GetShape(int tensor_index) {
return result;
}
+TfLiteType EagerModelTest::GetType(int tensor_index) {
+ return interpreter_->tensor(tensor_index)->type;
+}
+
void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
const std::vector<int>& outputs,
- const TfLiteType& type,
- const std::vector<int>& dims) {
+ TfLiteType type, const std::vector<int>& dims) {
interpreter_->AddTensors(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteQuantizationParams quant;
+ // Suppress explicit output type specification to ensure type inference
+ // works properly.
+ if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
+ type = kTfLiteFloat32;
+ }
CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
/*name=*/"",
/*dims=*/dims, quant),
@@ -101,18 +96,26 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
return " attr{ key: '" + key + "' value {" + value + "}}";
};
+ // Crude type attribution, will need fleshing out as more tests are added.
+ // TODO(b/113613439): Use nodedef string utilities to properly handle
+ // all types.
+ string type_attribute = attr("T", "type: DT_FLOAT");
+ if (interpreter_->tensor(inputs[0])->type == kTfLiteInt32) {
+ type_attribute = attr("T", "type: DT_INT32");
+ }
+
if (op == kUnpack) {
- string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") +
- attr("axis", "i: 0");
+ string attributes =
+ type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs);
} else if (op == kIdentity) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs);
} else if (op == kAdd) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerAdd", "Add", attributes, inputs, outputs);
} else if (op == kMul) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerMul", "Mul", attributes, inputs, outputs);
} else if (op == kNonExistent) {
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h
index 0eab9e1135..816db41931 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.h
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.h
@@ -44,11 +44,30 @@ class EagerModelTest : public ::testing::Test {
bool Invoke();
+ // Sets the (typed) tensor's values at the given index.
+ template <typename T>
+ void SetTypedValues(int tensor_index, const std::vector<T>& values) {
+ memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(),
+ values.size() * sizeof(T));
+ }
+
+ // Returns the (typed) tensor's values at the given index.
+ template <typename T>
+ std::vector<T> GetTypedValues(int tensor_index) {
+ const TfLiteTensor* t = interpreter_->tensor(tensor_index);
+ const T* tdata = interpreter_->typed_tensor<T>(tensor_index);
+ return std::vector<T>(tdata, tdata + t->bytes / sizeof(T));
+ }
+
// Sets the tensor's values at the given index.
- void SetValues(int tensor_index, const std::vector<float>& values);
+ void SetValues(int tensor_index, const std::vector<float>& values) {
+ SetTypedValues<float>(tensor_index, values);
+ }
// Returns the tensor's values at the given index.
- std::vector<float> GetValues(int tensor_index);
+ std::vector<float> GetValues(int tensor_index) {
+ return GetTypedValues<float>(tensor_index);
+ }
// Sets the tensor's shape at the given index.
void SetShape(int tensor_index, const std::vector<int>& values);
@@ -56,13 +75,16 @@ class EagerModelTest : public ::testing::Test {
// Returns the tensor's shape at the given index.
std::vector<int> GetShape(int tensor_index);
+ // Returns the tensor's type at the given index.
+ TfLiteType GetType(int tensor_index);
+
const TestErrorReporter& error_reporter() const { return error_reporter_; }
// Adds `num_tensor` tensors to the model. `inputs` contains the indices of
// the input tensors and `outputs` contains the indices of the output
// tensors. All tensors are set to have `type` and `dims`.
void AddTensors(int num_tensors, const std::vector<int>& inputs,
- const std::vector<int>& outputs, const TfLiteType& type,
+ const std::vector<int>& outputs, TfLiteType type,
const std::vector<int>& dims);
// Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
index 4426c653e6..051246bf86 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -26,8 +26,17 @@ TfLiteStatus ConvertStatus(TfLiteContext* context,
return kTfLiteOk;
}
-TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
- TfLiteTensor* tensor) {
+TfLiteStatus CopyShapeAndType(TfLiteContext* context,
+ const tensorflow::Tensor& src,
+ TfLiteTensor* tensor) {
+ tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype()));
+ if (tensor->type == kTfLiteNoType) {
+ context->ReportError(context,
+ "TF Lite does not support TensorFlow data type: %s",
+ DataTypeString(src.dtype()).c_str());
+ return kTfLiteError;
+ }
+
int num_dims = src.dims();
TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims);
for (int j = 0; j < num_dims; ++j) {
@@ -68,5 +77,28 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
}
}
+TfLiteType GetTensorFlowLiteType(TF_DataType type) {
+ switch (type) {
+ case TF_FLOAT:
+ return kTfLiteFloat32;
+ case TF_INT16:
+ return kTfLiteInt16;
+ case TF_INT32:
+ return kTfLiteInt32;
+ case TF_UINT8:
+ return kTfLiteUInt8;
+ case TF_INT64:
+ return kTfLiteInt64;
+ case TF_COMPLEX64:
+ return kTfLiteComplex64;
+ case TF_STRING:
+ return kTfLiteString;
+ case TF_BOOL:
+ return kTfLiteBool;
+ default:
+ return kTfLiteNoType;
+ }
+}
+
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index a9407be071..ff500d18f3 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -28,14 +28,19 @@ namespace eager {
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status);
-// Copies the given shape of the given 'src' into a TF Lite 'tensor'. Logs an
-// error and returns kTfLiteError if the shape can't be converted.
-TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
- TfLiteTensor* tensor);
+// Copies the given shape and type of the TensorFlow 'src' tensor into a TF Lite
+// 'tensor'. Logs an error and returns kTfLiteError if the shape or type can't
+// be converted.
+TfLiteStatus CopyShapeAndType(TfLiteContext* context,
+ const tensorflow::Tensor& src,
+ TfLiteTensor* tensor);
// Returns the TF C API Data type that corresponds to the given TfLiteType.
TF_DataType GetTensorFlowDataType(TfLiteType type);
+// Returns the TfLiteType that corresponds to the given TF C API Data type.
+TfLiteType GetTensorFlowLiteType(TF_DataType);
+
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc
index 53378a1eaf..aebc91149c 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc
@@ -26,6 +26,7 @@ namespace eager {
namespace {
using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
using tensorflow::Tensor;
using ::testing::ElementsAre;
@@ -71,27 +72,41 @@ TEST(UtilTest, ConvertStatus) {
EXPECT_TRUE(context.error.empty());
}
-TEST(UtilTest, CopyShape) {
+TEST(UtilTest, CopyShapeAndType) {
TestContext context;
context.ReportError = ReportError;
context.ResizeTensor = ResizeTensor;
TfLiteTensor dst;
- EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk);
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(), &dst), kTfLiteOk);
EXPECT_THAT(context.new_size, ElementsAre(0));
+ EXPECT_EQ(dst.type, kTfLiteFloat32);
- EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1, 2}), &dst), kTfLiteOk);
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1, 2}), &dst),
+ kTfLiteOk);
EXPECT_THAT(context.new_size, ElementsAre(1, 2));
+ EXPECT_EQ(dst.type, kTfLiteFloat32);
- EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst),
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_INT32, {1, 2}), &dst),
+ kTfLiteOk);
+ EXPECT_THAT(context.new_size, ElementsAre(1, 2));
+ EXPECT_EQ(dst.type, kTfLiteInt32);
+
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst),
kTfLiteError);
EXPECT_EQ(context.error,
"Dimension value in TensorFlow shape is larger than supported by "
"TF Lite");
+
+ EXPECT_EQ(
+ CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst),
+ kTfLiteError);
+ EXPECT_EQ(context.error,
+ "TF Lite does not support TensorFlow data type: half");
}
-TEST(UtilTest, TypeConversions) {
+TEST(UtilTest, TypeConversionsFromTFLite) {
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType));
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32));
EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16));
@@ -103,6 +118,19 @@ TEST(UtilTest, TypeConversions) {
EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
}
+TEST(UtilTest, TypeConversionsFromTensorFlow) {
+ EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT));
+ EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16));
+ EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32));
+ EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8));
+ EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64));
+ EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64));
+ EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));
+ EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL));
+ EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE));
+ EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT));
+}
+
} // namespace
} // namespace eager
} // namespace tflite