aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 14:56:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 15:07:58 -0700
commit3b44d4bbfccce918ea9155e33c3da55c770b781f (patch)
treec4ad72f378e1be32f6f3ff02ef72f4dd7bc83914 /tensorflow/contrib/lite/toco
parente0a8285d9563122a75d94a54352f5c94f287e810 (diff)
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 211874785
Diffstat (limited to 'tensorflow/contrib/lite/toco')
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc19
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h5
4 files changed, 25 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index a75553db84..bea90f1ce8 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -372,6 +372,7 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
+ "//tensorflow/contrib/lite/kernels/internal:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index c25be078ff..f103bb94ae 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1314,12 +1314,16 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
// Compute output shape
for (int axis = 0; axis < num_input_axes; ++axis) {
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op->begin_mask, op->end_mask, op->shrink_axis_mask,
+ op->start_indices, op->stop_indices, op->strides);
int start_index = tflite::strided_slice::StartForAxis(
- op->begin_mask, op->start_indices, op->strides,
- input_array.shape().dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
int stop_index = tflite::strided_slice::StopForAxis(
- op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides,
- input_array.shape().dims().data(), axis, start_index);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
+
int dim_size =
ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 9d8bd4fc39..8853ed87e6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
std::vector<int> src_coord(num_input_axes);
std::vector<int> stop_for_axis(num_input_axes);
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices,
+ op.stop_indices, op.strides);
+
for (int axis = 0; axis < num_input_axes; axis++) {
- int start = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
- axis);
- src_coord[axis] = start;
+ int start_index = tflite::strided_slice::StartForAxis(
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
+ src_coord[axis] = start_index;
stop_for_axis[axis] = tflite::strided_slice::StopForAxis(
- op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides,
- input_shape.dims().data(), axis, start);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
}
// In order to handle any number (N) of dimensions, we copy elements one by
@@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
// Reset axis and set carry
src_coord[axis] = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides,
- input_shape.dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_shape), axis);
carry = true;
} else {
carry = false;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index bdeb203024..5f4b8cb66a 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -28,6 +28,7 @@ limitations under the License.
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/include/google/protobuf/text_format.h"
#endif // TOCO_SUPPORT_PORTABLE_PROTOS
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
@@ -139,6 +140,10 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1);
// - For the remaining indices [0..i0), d0[i0] == 1.
bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
+inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) {
+ return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data());
+}
+
bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
// If there is a wildcard dimension (-1), this may return a negative value.