diff options
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/util.cc')
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/util.cc | 36 |
1 files changed, 34 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc index 4426c653e6..051246bf86 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.cc +++ b/tensorflow/contrib/lite/delegates/eager/util.cc @@ -26,8 +26,17 @@ TfLiteStatus ConvertStatus(TfLiteContext* context, return kTfLiteOk; } -TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, - TfLiteTensor* tensor) { +TfLiteStatus CopyShapeAndType(TfLiteContext* context, + const tensorflow::Tensor& src, + TfLiteTensor* tensor) { + tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype())); + if (tensor->type == kTfLiteNoType) { + context->ReportError(context, + "TF Lite does not support TensorFlow data type: %s", + DataTypeString(src.dtype()).c_str()); + return kTfLiteError; + } + int num_dims = src.dims(); TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims); for (int j = 0; j < num_dims; ++j) { @@ -68,5 +77,28 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { } } +TfLiteType GetTensorFlowLiteType(TF_DataType type) { + switch (type) { + case TF_FLOAT: + return kTfLiteFloat32; + case TF_INT16: + return kTfLiteInt16; + case TF_INT32: + return kTfLiteInt32; + case TF_UINT8: + return kTfLiteUInt8; + case TF_INT64: + return kTfLiteInt64; + case TF_COMPLEX64: + return kTfLiteComplex64; + case TF_STRING: + return kTfLiteString; + case TF_BOOL: + return kTfLiteBool; + default: + return kTfLiteNoType; + } +} + } // namespace eager } // namespace tflite |