aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r--tensorflow/core/ops/array_ops.cc58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 11df3c43c7..e540ecfa8d 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -394,6 +394,28 @@ output: A `Tensor` with the concatenation of values stacked along the
in `concat_dim` where it has the sum of the sizes.
)doc");
+// TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops
+// are not to be made user-accessible.
+#ifdef INTEL_MKL
+REGISTER_OP("_MklConcatV2")
+ .Input("values: N * T")
+ .Input("axis: Tidx")
+ .Input("mkl_values: N * uint8")
+ .Input("mkl_axis: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .SetShapeFn(shape_inference::ConcatV2Shape)
+ .Doc(R"doc(
+MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+#endif
+
REGISTER_OP("ConcatOffset")
.Input("concat_dim: int32")
.Input("shape: N * int32")
@@ -1638,6 +1660,21 @@ reshape(t, []) ==> 7
shape: Defines the shape of the output tensor.
)Doc");
+#ifdef INTEL_MKL
+REGISTER_OP("_MklReshape")
+ .Input("tensor: T")
+ .Input("shape: Tshape")
+ .Input("mkl_tensor: uint8")
+ .Input("mkl_shape: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: type")
+ .Attr("Tshape: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
+ .Doc(R"Doc( MKL implementation of ReshapeOp.
+)Doc");
+#endif // INTEL_MKL
+
// --------------------------------------------------------------------------
REGISTER_OP("InvertPermutation")
.Input("x: T")
@@ -4965,6 +5002,27 @@ backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`:
`sum_per_d(gradients * (inputs > max))`.
)doc");
+#ifdef INTEL_MKL
+REGISTER_OP("_MklConcat")
+ .Input("concat_dim: int32")
+ .Input("values: N * T")
+ .Input("mkl_concat_dim: uint8")
+ .Input("mkl_values: N * uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ return shape_inference::ConcatShape(c, c->num_inputs() - 3);
+ })
+ .Doc(R"doc(
+MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+#endif
+
// Deprecated op registrations:
// The following can be deleted after 10mar2017.