diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-06 14:56:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 15:07:58 -0700 |
commit | 3b44d4bbfccce918ea9155e33c3da55c770b781f (patch) | |
tree | c4ad72f378e1be32f6f3ff02ef72f4dd7bc83914 /tensorflow/contrib/lite/toco | |
parent | e0a8285d9563122a75d94a54352f5c94f287e810 (diff) |
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 211874785
Diffstat (limited to 'tensorflow/contrib/lite/toco')
4 files changed, 25 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index a75553db84..bea90f1ce8 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -372,6 +372,7 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", + "//tensorflow/contrib/lite/kernels/internal:types", "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index c25be078ff..f103bb94ae 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1314,12 +1314,16 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { // Compute output shape for (int axis = 0; axis < num_input_axes; ++axis) { + const auto strided_slice_params = + tflite::strided_slice::BuildStridedSliceParams( + op->begin_mask, op->end_mask, op->shrink_axis_mask, + op->start_indices, op->stop_indices, op->strides); int start_index = tflite::strided_slice::StartForAxis( - op->begin_mask, op->start_indices, op->strides, - input_array.shape().dims().data(), axis); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis); int stop_index = tflite::strided_slice::StopForAxis( - op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides, - input_array.shape().dims().data(), axis, start_index); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis, + start_index); + int dim_size = ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 9d8bd4fc39..8853ed87e6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>(); std::vector<int> src_coord(num_input_axes); std::vector<int> stop_for_axis(num_input_axes); + const auto strided_slice_params = + tflite::strided_slice::BuildStridedSliceParams( + op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices, + op.stop_indices, op.strides); + for (int axis = 0; axis < num_input_axes; axis++) { - int start = tflite::strided_slice::StartForAxis( - op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(), - axis); - src_coord[axis] = start; + int start_index = tflite::strided_slice::StartForAxis( + strided_slice_params, ToRuntimeShape(input_array.shape()), axis); + src_coord[axis] = start_index; stop_for_axis[axis] = tflite::strided_slice::StopForAxis( - op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides, - input_shape.dims().data(), axis, start); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis, + start_index); } // In order to handle any number (N) of dimensions, we copy elements one by @@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) { // Reset axis and set carry src_coord[axis] = tflite::strided_slice::StartForAxis( - op.begin_mask, op.start_indices, op.strides, - input_shape.dims().data(), axis); + strided_slice_params, ToRuntimeShape(input_shape), axis); carry = true; } else { carry = false; diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index bdeb203024..5f4b8cb66a 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -28,6 +28,7 @@ limitations under the License. #if TOCO_SUPPORT_PORTABLE_PROTOS #include "third_party/protobuf/include/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS +#include "tensorflow/contrib/lite/kernels/internal/types.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" @@ -139,6 +140,10 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1); // - For the remaining indices [0..i0), d0[i0] == 1. bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1); +inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) { + return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data()); +} + bool IsArrayFullyConnectedWeights(const Model& model, const string& name); // If there is a wildcard dimension (-1), this may return a negative value. |