aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-08-07 13:33:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 13:37:35 -0700
commit452f995e2c23cbd67c14b15b678bb3a352212633 (patch)
tree0871a8d91a74376611c4554c4c130e6cc2abf596 /tensorflow/contrib/lite/toco/import_tensorflow.cc
parent2e2486adedb5164b82b0c2fcb8b1d27f987c1428 (diff)
Fix toco compilation on Windows
The toco utility should now build successfully on Windows. However, there are a few lingering issues with execution that need to be resolved before the utility is fully functional. PiperOrigin-RevId: 207770586
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc24
1 files changed, 17 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index d8d331f3d4..b7fffbce22 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1223,11 +1223,10 @@ tensorflow::Status ConvertGatherOperator(
return tensorflow::Status::OK();
}
-template <typename Op, const char* op_name>
+template <typename Op>
tensorflow::Status ConvertArgMinMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), op_name);
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto axis_data_type =
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
@@ -1245,6 +1244,20 @@ tensorflow::Status ConvertArgMinMaxOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertArgMaxOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "ArgMax");
+ return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags, model);
+}
+
+tensorflow::Status ConvertArgMinOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "ArgMin");
+ return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags, model);
+}
+
tensorflow::Status ConvertResizeBilinearOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1899,17 +1912,14 @@ using ConverterType = tensorflow::Status (*)(
Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
-constexpr char kArgMax[] = "ArgMax";
-constexpr char kArgMin[] = "ArgMin";
-
ConverterMapType GetTensorFlowNodeConverterMap() {
return std::unordered_map<std::string, ConverterType>({
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
{"Any", ConvertAnyOperator},
- {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>},
- {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>},
+ {"ArgMax", ConvertArgMaxOperator},
+ {"ArgMin", ConvertArgMinOperator},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
{"AvgPool", ConvertAvgPoolOperator},
{"BatchMatMul", ConvertBatchMatMulOperator},