diff options
author | 2017-12-13 08:03:03 -0800 | |
---|---|---|
committer | 2017-12-13 08:06:33 -0800 | |
commit | 2eae1ac21ce28f3b2cafe9e12a25b3bddc475847 (patch) | |
tree | 3abd865aabb97677a3087e896187e34963fa0612 /tensorflow/contrib/lite/toco/export_tensorflow.cc | |
parent | 8f19188a14b62f2612783f3ebba0cd1c9d08aba8 (diff) |
Standardize attribute naming for operators specifying a dimension to "axis". This mirrors TensorFlow's attribute naming.
PiperOrigin-RevId: 178903728
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/export_tensorflow.cc | 25 |
1 files changed, 11 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index e18cf46c69..bddb83206b 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -780,13 +780,12 @@ void ConvertConcatenationOperator(const Model& model, auto* dc_op = tensorflow_graph->add_node(); dc_op->set_op("ConcatV2"); dc_op->set_name(src_op.outputs[0]); - const string dummy_concat_dim = src_op.outputs[0] + "/concat_dim"; - CreateDummyConcatDimTensorConst(dummy_concat_dim, src_op.concat_dim, - tensorflow_graph); + const string dummy_axis = src_op.outputs[0] + "/axis"; + CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph); for (const auto& input : src_op.inputs) { *dc_op->add_input() = input; } - *dc_op->add_input() = dummy_concat_dim; + *dc_op->add_input() = dummy_axis; (*dc_op->mutable_attr())["T"].set_type(DT_FLOAT); (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32); (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size()); @@ -993,22 +992,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, const string concat_output = base + "basic_lstm_cell/concat"; // Op names have been chosen to match the tf.slim LSTM naming // as closely as possible. - const int concat_dim = + const int axis = model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]) ->shape() .dimensions_count() - 1; // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat // works the same since the tensor has the same underlying data layout. - const string concat_dim_output = concat_output + "/concat_dim"; - CreateDummyConcatDimTensorConst(concat_dim_output, concat_dim, - tensorflow_graph); + const string axis_output = concat_output + "/axis"; + CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph); auto* concat_op = tensorflow_graph->add_node(); concat_op->set_op("ConcatV2"); concat_op->set_name(concat_output); *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT]; *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]; - *concat_op->add_input() = concat_dim_output; + *concat_op->add_input() = axis_output; (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT); (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32); (*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs @@ -1069,8 +1067,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Split string split_dim_output = base + "split/split_dim"; // The dimension is the same as the concatenation dimension - CreateDummyConcatDimTensorConst(split_dim_output, concat_dim, - tensorflow_graph); + CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph); string split_output = base + "split"; auto* split_op = tensorflow_graph->add_node(); split_op->set_op("Split"); @@ -1298,11 +1295,11 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor(); tensor->set_dtype(DT_INT32); - for (int i = 0; i < src_op.reduction_indices.size(); ++i) { - tensor->add_int_val(src_op.reduction_indices[i]); + for (int i = 0; i < src_op.axis.size(); ++i) { + tensor->add_int_val(src_op.axis[i]); } auto* shape = tensor->mutable_tensor_shape(); - shape->add_dim()->set_size(src_op.reduction_indices.size()); + shape->add_dim()->set_size(src_op.axis.size()); } void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, |