diff options
author | 2017-12-21 11:15:26 -0800 | |
---|---|---|
committer | 2017-12-21 11:20:15 -0800 | |
commit | 741a94013c4c9319b30534e3c40bdee3d71bd0bd (patch) | |
tree | eea6b706f90c61945131cd0b7e39a6b660142470 /tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc | |
parent | bdf5a88b8f0a8b06b5fa95c1d5540651db889406 (diff) |
Adding support for to resolve constant FloorDiv, FloorMod, StridedSlice, Stack, Rank and Range.
PiperOrigin-RevId: 179836027
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc | 32 |
1 files changed, 17 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc index 97946182ef..dbe69adcbd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -12,11 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <memory> -#include <string> -#include <unordered_map> -#include <vector> - #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" @@ -30,19 +25,14 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { if (slice_op->type != OperatorType::kStridedSlice) return false; auto* op = static_cast<StridedSliceOperator*>(slice_op); - if (!op->start_indices.empty()) return false; + if (!op->start_indices.empty()) { + // We have already resolved these attributes + return false; + } CHECK_EQ(op->inputs.size(), 4); - if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - if (!IsConstantParameterArray(*model, op->inputs[2])) return false; - if (!IsConstantParameterArray(*model, op->inputs[3])) return false; - const auto& start_array = *model->arrays[op->inputs[1]]; if (!start_array.has_shape()) return false; - if (toco::RequiredBufferSizeForShape(start_array.shape()) != 4) { - // Only 4D arrays are supported for now. - return false; - } const auto& stop_array = *model->arrays[op->inputs[2]]; if (!stop_array.has_shape()) return false; @@ -50,12 +40,24 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { const auto& stride_array = *model->arrays[op->inputs[3]]; if (!stride_array.has_shape()) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + if (!IsConstantParameterArray(*model, op->inputs[3])) return false; + op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data; op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data; op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data; - // TODO(dkalenichenko): Delete the extra inputs? + CHECK_GE(op->start_indices.size(), 1); + CHECK_LE(op->start_indices.size(), 4); + CHECK_EQ(op->stop_indices.size(), op->start_indices.size()); + CHECK_EQ(op->strides.size(), op->stop_indices.size()); + // Ideally, we would remove the input arrays after they have been resolved. + // However, we must then reconstitute these input arrays for all supported + // export formats. For now, leave the arrays so we don't have to modify our + // exporters. Ideally, we wouldn't have op attributes, and would work directly + // with the input arrays. return true; } } // namespace toco |