diff options
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 11df3c43c7..e540ecfa8d 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -394,6 +394,28 @@ output: A `Tensor` with the concatenation of values stacked along the in `concat_dim` where it has the sum of the sizes. )doc"); +// TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops +// are not to be made user-accessible. +#ifdef INTEL_MKL +REGISTER_OP("_MklConcatV2") + .Input("values: N * T") + .Input("axis: Tidx") + .Input("mkl_values: N * uint8") + .Input("mkl_axis: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("N: int >= 2") + .Attr("T: type") + .Attr("Tidx: {int32, int64} = DT_INT32") + .SetShapeFn(shape_inference::ConcatV2Shape) + .Doc(R"doc( +MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); +#endif + REGISTER_OP("ConcatOffset") .Input("concat_dim: int32") .Input("shape: N * int32") @@ -1638,6 +1660,21 @@ reshape(t, []) ==> 7 shape: Defines the shape of the output tensor. )Doc"); +#ifdef INTEL_MKL +REGISTER_OP("_MklReshape") + .Input("tensor: T") + .Input("shape: Tshape") + .Input("mkl_tensor: uint8") + .Input("mkl_shape: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: type") + .Attr("Tshape: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); }) + .Doc(R"Doc( MKL implementation of ReshapeOp. +)Doc"); +#endif // INTEL_MKL + // -------------------------------------------------------------------------- REGISTER_OP("InvertPermutation") .Input("x: T") @@ -4965,6 +5002,27 @@ backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`: `sum_per_d(gradients * (inputs > max))`. )doc"); +#ifdef INTEL_MKL +REGISTER_OP("_MklConcat") + .Input("concat_dim: int32") + .Input("values: N * T") + .Input("mkl_concat_dim: uint8") + .Input("mkl_values: N * uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("N: int >= 2") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + return shape_inference::ConcatShape(c, c->num_inputs() - 3); + }) + .Doc(R"doc( +MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); +#endif + // Deprecated op registrations: // The following can be deleted after 10mar2017. |