aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-21 11:15:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-21 11:20:15 -0800
commit741a94013c4c9319b30534e3c40bdee3d71bd0bd (patch)
treeeea6b706f90c61945131cd0b7e39a6b660142470 /tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
parentbdf5a88b8f0a8b06b5fa95c1d5540651db889406 (diff)
Adding support for to resolve constant FloorDiv, FloorMod, StridedSlice, Stack, Rank and Range.
PiperOrigin-RevId: 179836027
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc32
1 files changed, 17 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
index 97946182ef..dbe69adcbd 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -12,11 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
@@ -30,19 +25,14 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
if (slice_op->type != OperatorType::kStridedSlice) return false;
auto* op = static_cast<StridedSliceOperator*>(slice_op);
- if (!op->start_indices.empty()) return false;
+ if (!op->start_indices.empty()) {
+ // We have already resolved these attributes
+ return false;
+ }
CHECK_EQ(op->inputs.size(), 4);
- if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
-
const auto& start_array = *model->arrays[op->inputs[1]];
if (!start_array.has_shape()) return false;
- if (toco::RequiredBufferSizeForShape(start_array.shape()) != 4) {
- // Only 4D arrays are supported for now.
- return false;
- }
const auto& stop_array = *model->arrays[op->inputs[2]];
if (!stop_array.has_shape()) return false;
@@ -50,12 +40,24 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
const auto& stride_array = *model->arrays[op->inputs[3]];
if (!stride_array.has_shape()) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
+
op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data;
op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
- // TODO(dkalenichenko): Delete the extra inputs?
+ CHECK_GE(op->start_indices.size(), 1);
+ CHECK_LE(op->start_indices.size(), 4);
+ CHECK_EQ(op->stop_indices.size(), op->start_indices.size());
+ CHECK_EQ(op->strides.size(), op->stop_indices.size());
+ // Ideally, we would remove the input arrays after they have been resolved.
+ // However, we must then reconstitute these input arrays for all supported
+ // export formats. For now, leave the arrays so we don't have to modify our
+ // exporters. Ideally, we wouldn't have op attributes, and would work directly
+ // with the input arrays.
return true;
}
} // namespace toco