aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/mkl_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/mkl_util.h')
-rw-r--r--tensorflow/core/util/mkl_util.h44
1 files changed, 40 insertions, 4 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 897b174eff..6a37256ea9 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -542,8 +542,8 @@ inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
return mkl_shape.dim_size(index);
}
-inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
- int idx_out) {
+inline void CopyMklTensorInToOut(OpKernelContext* context,
+ int idx_in, int idx_out) {
int num_inputs = context->num_inputs();
int num_outputs = context->num_outputs();
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
@@ -563,8 +563,9 @@ inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
context->set_output(idx_meta_out, meta_output);
}
-inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in,
- int idx_out, const TensorShape& shape) {
+inline void CopyTfTensorInToOutWithShape(OpKernelContext* context,
+ int idx_in, int idx_out,
+ const TensorShape& shape) {
int num_inputs = context->num_inputs();
int num_outputs = context->num_outputs();
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
@@ -580,6 +581,41 @@ inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in,
context->set_output(idx_data_out, output);
}
+inline void FowardTfTensorInToOut(OpKernelContext* context,
+ int idx_in, int idx_out) {
+ int num_inputs = context->num_inputs();
+ int num_outputs = context->num_outputs();
+ int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
+ int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
+
+ MklShape mkl_shape_output;
+ mkl_shape_output.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
+ if (IsRefType(context->input_dtype(idx_data_in))) {
+ context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
+ } else {
+ context->set_output(idx_data_out, context->input(idx_data_in));
+ }
+}
+
+inline void ForwarMklTensorInToOut(OpKernelContext* context,
+ int idx_in, int idx_out) {
+ int num_inputs = context->num_inputs();
+ int num_outputs = context->num_outputs();
+ int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
+ int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
+ int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
+ int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
+
+ if (IsRefType(context->input_dtype(idx_data_in))) {
+ context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
+ context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
+ } else {
+ context->set_output(idx_data_out, context->input(idx_data_in));
+ context->set_output(idx_meta_out, context->input(idx_meta_in));
+ }
+}
+
namespace mkl_op_registry {
static const char* kMklOpLabel = "MklOp";
static const char* kMklOpLabelPattern = "label='MklOp'";