diff options
author | 2018-09-26 14:25:34 -0700 | |
---|---|---|
committer | 2018-09-26 14:33:36 -0700 | |
commit | c551a7dbd08685160c233ccecd444f774666f98e (patch) | |
tree | 6dab4b4519a38caf7a0ff7e44e802685055dc52b /tensorflow/contrib/lite/kernels | |
parent | dd37be0e66934369bb7f5e4b5a88b982351fbff0 (diff) |
Kernel signature reworking, update kernel DepthConcatenation.
PiperOrigin-RevId: 214668695
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index cd9e1b255d..f3f1595035 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -1991,12 +1991,35 @@ void PackWithScaling(const PackParams& params, } } +template <typename Scalar> +void DepthConcatenation(const ConcatenationParams& params, + const RuntimeShape* const* input_shapes, + const Scalar* const* input_data, + const RuntimeShape& output_shape, Scalar* output_data) { + auto params_copy = params; + params_copy.axis = 3; + Concatenation(params_copy, input_shapes, input_data, output_shape, + output_data); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. template <FusedActivationFunctionType Ac, typename Scalar> void DepthConcatenation(const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, Scalar* output_data, const Dims<4>& output_dims) { - Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count, - output_data, output_dims); + // For now we don't have a model with a Concatenation with fused activation. + TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone); + std::vector<RuntimeShape> input_shapes(inputs_count); + std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count); + for (int i = 0; i < inputs_count; ++i) { + ShapeFromDims(*input_dims[i], &input_shapes[i]); + input_shapes_indirect[i] = &input_shapes[i]; + } + tflite::ConcatenationParams op_params; + op_params.inputs_count = inputs_count; + + DepthConcatenation(op_params, input_shapes_indirect.data(), input_data, + DimsToShape(output_dims), output_data); } inline void LstmCell( |