diff options
Diffstat (limited to 'tensorflow/core/util/mkl_util.h')
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 44 |
1 files changed, 40 insertions, 4 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 897b174eff..6a37256ea9 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -542,8 +542,8 @@ inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) { return mkl_shape.dim_size(index); } -inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in, - int idx_out) { +inline void CopyMklTensorInToOut(OpKernelContext* context, + int idx_in, int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -563,8 +563,9 @@ inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in, context->set_output(idx_meta_out, meta_output); } -inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in, - int idx_out, const TensorShape& shape) { +inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, + int idx_in, int idx_out, + const TensorShape& shape) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -580,6 +581,41 @@ inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in, context->set_output(idx_data_out, output); } +inline void FowardTfTensorInToOut(OpKernelContext* context, + int idx_in, int idx_out) { + int num_inputs = context->num_inputs(); + int num_outputs = context->num_outputs(); + int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); + int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); + + MklShape mkl_shape_output; + mkl_shape_output.SetMklTensor(false); + AllocateOutputSetMklShape(context, idx_out, mkl_shape_output); + if (IsRefType(context->input_dtype(idx_data_in))) { + context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); + } else { + context->set_output(idx_data_out, context->input(idx_data_in)); + } +} + +inline void ForwarMklTensorInToOut(OpKernelContext* context, + int idx_in, int idx_out) { + int num_inputs = context->num_inputs(); + int num_outputs = context->num_outputs(); + int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); + int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs); + int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); + int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs); + + if (IsRefType(context->input_dtype(idx_data_in))) { + context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); + context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out); + } else { + context->set_output(idx_data_out, context->input(idx_data_in)); + context->set_output(idx_meta_out, context->input(idx_meta_in)); + } +} + namespace mkl_op_registry { static const char* kMklOpLabel = "MklOp"; static const char* kMklOpLabelPattern = "label='MklOp'"; |