aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates/eager/util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/util.cc')
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.cc36
1 files changed, 34 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
index 4426c653e6..051246bf86 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -26,8 +26,17 @@ TfLiteStatus ConvertStatus(TfLiteContext* context,
return kTfLiteOk;
}
-TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
- TfLiteTensor* tensor) {
+TfLiteStatus CopyShapeAndType(TfLiteContext* context,
+ const tensorflow::Tensor& src,
+ TfLiteTensor* tensor) {
+ tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype()));
+ if (tensor->type == kTfLiteNoType) {
+ context->ReportError(context,
+ "TF Lite does not support TensorFlow data type: %s",
+ DataTypeString(src.dtype()).c_str());
+ return kTfLiteError;
+ }
+
int num_dims = src.dims();
TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims);
for (int j = 0; j < num_dims; ++j) {
@@ -68,5 +77,28 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
}
}
+TfLiteType GetTensorFlowLiteType(TF_DataType type) {
+ switch (type) {
+ case TF_FLOAT:
+ return kTfLiteFloat32;
+ case TF_INT16:
+ return kTfLiteInt16;
+ case TF_INT32:
+ return kTfLiteInt32;
+ case TF_UINT8:
+ return kTfLiteUInt8;
+ case TF_INT64:
+ return kTfLiteInt64;
+ case TF_COMPLEX64:
+ return kTfLiteComplex64;
+ case TF_STRING:
+ return kTfLiteString;
+ case TF_BOOL:
+ return kTfLiteBool;
+ default:
+ return kTfLiteNoType;
+ }
+}
+
} // namespace eager
} // namespace tflite