aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 14:25:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 14:33:36 -0700
commitc551a7dbd08685160c233ccecd444f774666f98e (patch)
tree6dab4b4519a38caf7a0ff7e44e802685055dc52b /tensorflow/contrib/lite/kernels
parentdd37be0e66934369bb7f5e4b5a88b982351fbff0 (diff)
Kernel signature reworking, update kernel DepthConcatenation.
PiperOrigin-RevId: 214668695
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h27
1 files changed, 25 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index cd9e1b255d..f3f1595035 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -1991,12 +1991,35 @@ void PackWithScaling(const PackParams& params,
}
}
+template <typename Scalar>
+void DepthConcatenation(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data,
+ const RuntimeShape& output_shape, Scalar* output_data) {
+ auto params_copy = params;
+ params_copy.axis = 3;
+ Concatenation(params_copy, input_shapes, input_data, output_shape,
+ output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
template <FusedActivationFunctionType Ac, typename Scalar>
void DepthConcatenation(const Scalar* const* input_data,
const Dims<4>* const* input_dims, int inputs_count,
Scalar* output_data, const Dims<4>& output_dims) {
- Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
- output_data, output_dims);
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.inputs_count = inputs_count;
+
+ DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
}
inline void LstmCell(