aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-10-09 11:38:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:48:46 -0700
commit12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch)
treed2f0b6ba463baff8e3607575f41d3655762f3d14 /tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
parent931353c5f79c2d419afb3a5ecac59184c5558351 (diff)
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
index 310a88484c..8a945ac435 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
@@ -25,10 +25,13 @@ limitations under the License.
namespace toco {
-bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertExpandDimsToReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto expand_it = model->operators.begin() + op_index;
if (expand_it->get()->type != OperatorType::kExpandDims) {
- return false;
+ return ::tensorflow::Status::OK();
}
ExpandDimsOperator* expand_op =
static_cast<ExpandDimsOperator*>(expand_it->get());
@@ -38,18 +41,18 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
const auto& input_array = model->GetArray(expand_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& axis_array = model->GetArray(expand_op->inputs[1]);
if (!axis_array.has_shape()) {
// Yield until input axis array shape has been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1);
if (!axis_array.buffer) {
// Yield until the input axis array is constant
- return false;
+ return ::tensorflow::Status::OK();
}
int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
std::vector<int> reshape_dims(input_array.shape().dims());
@@ -90,7 +93,8 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(expand_it->get(), expand_op);
model->operators.erase(expand_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco