aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates/eager/util.cc
blob: 4426c653e6ff80aac52b50e06a3005173490433d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/delegates/eager/util.h"

namespace tflite {
namespace eager {

TfLiteStatus ConvertStatus(TfLiteContext* context,
                           const tensorflow::Status& status) {
  if (!status.ok()) {
    context->ReportError(context, "%s", status.error_message().c_str());
    return kTfLiteError;
  }
  return kTfLiteOk;
}

TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
                       TfLiteTensor* tensor) {
  int num_dims = src.dims();
  TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims);
  for (int j = 0; j < num_dims; ++j) {
    // We need to cast from TensorFlow's int64 to TF Lite's int32. Let's
    // make sure there's no overflow.
    if (src.dim_size(j) >= std::numeric_limits<int>::max()) {
      context->ReportError(context,
                           "Dimension value in TensorFlow shape is larger than "
                           "supported by TF Lite");
      TfLiteIntArrayFree(shape);
      return kTfLiteError;
    }
    shape->data[j] = static_cast<int>(src.dim_size(j));
  }
  return context->ResizeTensor(context, tensor, shape);
}

TF_DataType GetTensorFlowDataType(TfLiteType type) {
  switch (type) {
    case kTfLiteNoType:
      return TF_FLOAT;
    case kTfLiteFloat32:
      return TF_FLOAT;
    case kTfLiteInt16:
      return TF_INT16;
    case kTfLiteInt32:
      return TF_INT32;
    case kTfLiteUInt8:
      return TF_UINT8;
    case kTfLiteInt64:
      return TF_INT64;
    case kTfLiteComplex64:
      return TF_COMPLEX64;
    case kTfLiteString:
      return TF_STRING;
    case kTfLiteBool:
      return TF_BOOL;
  }
}

}  // namespace eager
}  // namespace tflite