aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/export_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-13 08:03:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-13 08:06:33 -0800
commit2eae1ac21ce28f3b2cafe9e12a25b3bddc475847 (patch)
tree3abd865aabb97677a3087e896187e34963fa0612 /tensorflow/contrib/lite/toco/export_tensorflow.cc
parent8f19188a14b62f2612783f3ebba0cd1c9d08aba8 (diff)
Standardize attribute naming for operators specifying a dimension to "axis". This mirrors TensorFlow's attribute naming.
PiperOrigin-RevId: 178903728
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc25
1 files changed, 11 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index e18cf46c69..bddb83206b 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -780,13 +780,12 @@ void ConvertConcatenationOperator(const Model& model,
auto* dc_op = tensorflow_graph->add_node();
dc_op->set_op("ConcatV2");
dc_op->set_name(src_op.outputs[0]);
- const string dummy_concat_dim = src_op.outputs[0] + "/concat_dim";
- CreateDummyConcatDimTensorConst(dummy_concat_dim, src_op.concat_dim,
- tensorflow_graph);
+ const string dummy_axis = src_op.outputs[0] + "/axis";
+ CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph);
for (const auto& input : src_op.inputs) {
*dc_op->add_input() = input;
}
- *dc_op->add_input() = dummy_concat_dim;
+ *dc_op->add_input() = dummy_axis;
(*dc_op->mutable_attr())["T"].set_type(DT_FLOAT);
(*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
(*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
@@ -993,22 +992,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
const string concat_output = base + "basic_lstm_cell/concat";
// Op names have been chosen to match the tf.slim LSTM naming
// as closely as possible.
- const int concat_dim =
+ const int axis =
model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
->shape()
.dimensions_count() -
1;
// Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
// works the same since the tensor has the same underlying data layout.
- const string concat_dim_output = concat_output + "/concat_dim";
- CreateDummyConcatDimTensorConst(concat_dim_output, concat_dim,
- tensorflow_graph);
+ const string axis_output = concat_output + "/axis";
+ CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph);
auto* concat_op = tensorflow_graph->add_node();
concat_op->set_op("ConcatV2");
concat_op->set_name(concat_output);
*concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
*concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
- *concat_op->add_input() = concat_dim_output;
+ *concat_op->add_input() = axis_output;
(*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
(*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
(*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs
@@ -1069,8 +1067,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Split
string split_dim_output = base + "split/split_dim";
// The dimension is the same as the concatenation dimension
- CreateDummyConcatDimTensorConst(split_dim_output, concat_dim,
- tensorflow_graph);
+ CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph);
string split_output = base + "split";
auto* split_op = tensorflow_graph->add_node();
split_op->set_op("Split");
@@ -1298,11 +1295,11 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op,
auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
tensor->set_dtype(DT_INT32);
- for (int i = 0; i < src_op.reduction_indices.size(); ++i) {
- tensor->add_int_val(src_op.reduction_indices[i]);
+ for (int i = 0; i < src_op.axis.size(); ++i) {
+ tensor->add_int_val(src_op.axis[i]);
}
auto* shape = tensor->mutable_tensor_shape();
- shape->add_dim()->set_size(src_op.reduction_indices.size());
+ shape->add_dim()->set_size(src_op.axis.size());
}
void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,