aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-04 07:45:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-04 07:50:20 -0700
commit3fbec92fa4997fc834807b57b229fa9ada179f6c (patch)
tree34a7c15af0428e5ecf2ee2deb16e7c26b3d4320b
parent71b19430e8484b136e0b872f6a543aff8a242587 (diff)
Turn Cast into a proper builtin operator.
PiperOrigin-RevId: 191590230
-rw-r--r--tensorflow/contrib/lite/BUILD1
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h7
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc9
-rw-r--r--tensorflow/contrib/lite/model.cc70
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs2
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h38
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc26
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc2
8 files changed, 117 insertions, 38 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index ac269d540a..9c4533079c 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -89,6 +89,7 @@ cc_library(
hdrs = [
"builtin_op_data.h",
],
+ deps = [":context"],
)
cc_library(
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 5fc8954743..2b6c24768c 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -17,6 +17,8 @@ limitations under the License.
#include <stdint.h>
+#include "tensorflow/contrib/lite/context.h"
+
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
@@ -174,6 +176,11 @@ typedef struct {
int block_size;
} TfLiteSpaceToDepthParams;
+typedef struct {
+ TfLiteType in_data_type;
+ TfLiteType out_data_type;
+} TfLiteCastParams;
+
typedef enum {
kTfLiteCombinerTypeSum = 0,
kTfLiteCombinerTypeMean = 1,
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 19942de7bc..17ef2c572e 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -34,6 +34,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // TODO(ahentz): these two checks would make the new implementation
+ // incompatible with some existing models, where params is not specified. It
+ // is OK not to have them because toco would have set input and output types
+ // to match the parameters.
+ // auto* params = reinterpret_cast<TfLiteCastParams*>(node->builtin_data);
+ // TF_LITE_ENSURE_EQ(context, input->type, params->in_data_type);
+ // TF_LITE_ENSURE_EQ(context, output->type, params->out_data_type);
+
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 791d1378f3..606f4a5635 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -32,6 +32,32 @@ namespace tflite {
const char* kEmptyTensorName = "";
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter) {
+ switch (tensor_type) {
+ case TensorType_FLOAT32:
+ *type = kTfLiteFloat32;
+ break;
+ case TensorType_INT32:
+ *type = kTfLiteInt32;
+ break;
+ case TensorType_UINT8:
+ *type = kTfLiteUInt8;
+ break;
+ case TensorType_INT64:
+ *type = kTfLiteInt64;
+ break;
+ case TensorType_STRING:
+ *type = kTfLiteString;
+ break;
+ default:
+ error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
+ EnumNameTensorType(tensor_type), tensor_type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
@@ -307,10 +333,25 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_EXP:
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_LOG_SOFTMAX:
- case BuiltinOperator_CAST:
case BuiltinOperator_DEQUANTIZE:
case BuiltinOperator_PRELU:
break;
+ case BuiltinOperator_CAST: {
+ TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+ if (auto* schema_params = op->builtin_options_as_CastOptions()) {
+ auto in_status =
+ ConvertTensorType(schema_params->in_data_type(),
+ &params->in_data_type, error_reporter);
+ auto out_status =
+ ConvertTensorType(schema_params->out_data_type(),
+ &params->out_data_type, error_reporter);
+ if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
+ break;
+ }
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_LSH_PROJECTION: {
TfLiteLSHProjectionParams* params =
MallocPOD<TfLiteLSHProjectionParams>();
@@ -707,29 +748,10 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
}
TfLiteType type;
- switch (tensor->type()) {
- case TensorType_FLOAT32:
- type = kTfLiteFloat32;
- break;
- case TensorType_INT32:
- type = kTfLiteInt32;
- break;
- case TensorType_UINT8:
- type = kTfLiteUInt8;
- break;
- case TensorType_INT64:
- type = kTfLiteInt64;
- break;
- case TensorType_STRING:
- type = kTfLiteString;
- break;
- default:
- // tensorType = ArrayType::NONE;
- error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n",
- EnumNameTensorType(tensor->type()),
- tensor->type());
- status = kTfLiteError;
- continue;
+ if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
+ kTfLiteOk) {
+ status = kTfLiteError;
+ continue;
}
auto get_readonly_data = [&](const char** buffer_data,
size_t* buffer_size) {
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 7d2e00fe32..c63bfb28cc 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -381,6 +381,8 @@ table LogSoftmaxOptions {
}
table CastOptions {
+ in_data_type: TensorType;
+ out_data_type: TensorType;
}
table DequantizeOptions {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 66a97a1460..0735be5c8f 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -3702,14 +3702,30 @@ flatbuffers::Offset<LogSoftmaxOptions> CreateLogSoftmaxOptions(flatbuffers::Flat
struct CastOptionsT : public flatbuffers::NativeTable {
typedef CastOptions TableType;
- CastOptionsT() {
+ TensorType in_data_type;
+ TensorType out_data_type;
+ CastOptionsT()
+ : in_data_type(TensorType_FLOAT32),
+ out_data_type(TensorType_FLOAT32) {
}
};
struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef CastOptionsT NativeTableType;
+ enum {
+ VT_IN_DATA_TYPE = 4,
+ VT_OUT_DATA_TYPE = 6
+ };
+ TensorType in_data_type() const {
+ return static_cast<TensorType>(GetField<int8_t>(VT_IN_DATA_TYPE, 0));
+ }
+ TensorType out_data_type() const {
+ return static_cast<TensorType>(GetField<int8_t>(VT_OUT_DATA_TYPE, 0));
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_IN_DATA_TYPE) &&
+ VerifyField<int8_t>(verifier, VT_OUT_DATA_TYPE) &&
verifier.EndTable();
}
CastOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -3720,6 +3736,12 @@ struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
struct CastOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
+ void add_in_data_type(TensorType in_data_type) {
+ fbb_.AddElement<int8_t>(CastOptions::VT_IN_DATA_TYPE, static_cast<int8_t>(in_data_type), 0);
+ }
+ void add_out_data_type(TensorType out_data_type) {
+ fbb_.AddElement<int8_t>(CastOptions::VT_OUT_DATA_TYPE, static_cast<int8_t>(out_data_type), 0);
+ }
explicit CastOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -3733,8 +3755,12 @@ struct CastOptionsBuilder {
};
inline flatbuffers::Offset<CastOptions> CreateCastOptions(
- flatbuffers::FlatBufferBuilder &_fbb) {
+ flatbuffers::FlatBufferBuilder &_fbb,
+ TensorType in_data_type = TensorType_FLOAT32,
+ TensorType out_data_type = TensorType_FLOAT32) {
CastOptionsBuilder builder_(_fbb);
+ builder_.add_out_data_type(out_data_type);
+ builder_.add_in_data_type(in_data_type);
return builder_.Finish();
}
@@ -5727,6 +5753,8 @@ inline CastOptionsT *CastOptions::UnPack(const flatbuffers::resolver_function_t
inline void CastOptions::UnPackTo(CastOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
+ { auto _e = in_data_type(); _o->in_data_type = _e; };
+ { auto _e = out_data_type(); _o->out_data_type = _e; };
}
inline flatbuffers::Offset<CastOptions> CastOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -5737,8 +5765,12 @@ inline flatbuffers::Offset<CastOptions> CreateCastOptions(flatbuffers::FlatBuffe
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CastOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _in_data_type = _o->in_data_type;
+ auto _out_data_type = _o->out_data_type;
return tflite::CreateCastOptions(
- _fbb);
+ _fbb,
+ _in_data_type,
+ _out_data_type);
}
inline DequantizeOptionsT *DequantizeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 0cb348bda5..f991529569 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -204,17 +204,22 @@ class BatchToSpaceND
TocoOperator* op) const override {}
};
-class Cast : public CustomOperator<CastOperator> {
+class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
+ ::tflite::BuiltinOptions_CastOptions> {
public:
- using CustomOperator::CustomOperator;
- void WriteOptions(const TocoOperator& op,
- flexbuffers::Builder* fbb) const override {
- fbb->Int("src_data_type", DataType::Serialize(op.src_data_type));
- fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type));
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateCastOptions(*builder,
+ DataType::Serialize(op.src_data_type),
+ DataType::Serialize(op.dst_data_type));
}
- void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
- op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64());
- op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64());
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->src_data_type = DataType::Deserialize(options.in_data_type());
+ op->dst_data_type = DataType::Deserialize(options.out_data_type());
}
};
@@ -827,9 +832,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2));
ops.emplace_back(
new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell));
+ ops.emplace_back(
+ new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
// Custom Operators.
- ops.emplace_back(new Cast("CAST", OperatorType::kCast));
ops.emplace_back(
new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index f7a213ecfc..4783843b7f 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -131,7 +131,7 @@ TEST_F(OperatorTest, BuiltinMean) {
EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
}
-TEST_F(OperatorTest, CustomCast) {
+TEST_F(OperatorTest, BuiltinCast) {
CastOperator op;
op.src_data_type = ArrayDataType::kFloat;
op.dst_data_type = ArrayDataType::kUint8;