aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 12:54:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 12:57:46 -0700
commitf3c89936e97c99dead1ca3310246691c1b221adf (patch)
tree3c99b66936ed59028b32609115a239f52798907d /tensorflow/core
parent0b9b09a8531004b44b133a52c3fcc67bc6759bd8 (diff)
Merge changes from github.
END_PUBLIC Note: this CL will break builds. cl/159887762 to follow to fix all the breakages. --- Commit 2336cdf7f authored by Maxwell Paul Brickner<mbrickn@users.noreply.github.com> Committed by gunan<gunan@google.com>: Updated link to use HTTPS (#10998) Howdy! I just updated a link to use https instead of http. Thanks! --- Commit ad0892df1 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes run_metadata_test for SYCL This test is designed to test CUDA specific behavior --- Commit 6b37a0725 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update comments --- Commit 1699d904a authored by John Lawson<john@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes CUDA specific test run on SYCL (#56) The testBadParentValuesOnGPU should only be run on CUDA devices, as the test checks for particular CUDA behaviour. We don't actually provide a SYCL kernel for GatherTree and so it's not a problem that the tests don't target SYCL. --- Commit 3c1946230 authored by myPrecious<Moriadry@users.noreply.github.com> Committed by Shanqing Cai<cais@google.com>: Java API to get the size of specified input list of operations. (#10865) * Java API to get the size of specified input list of operations * remove unnecessary explain to avoid bring a new term to users. --- Commit e911c7480 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] REGISTER -> REGISTER6 --- Commit fbf6c4cec authored by superryanguo<superryanguo@gmail.com> Committed by superryanguo<superryanguo@gmail.com>: Simplify the Quickstart section with the weblink is better --- Commit 72e2918cc authored by Taehoon Lee<taehoonlee@snu.ac.kr> Committed by Taehoon Lee<taehoonlee@snu.ac.kr>: Fix typos --- Commit 90c4406b7 authored by Rishabh Patel<patelrishabh@users.noreply.github.com> Committed by GitHub<noreply@github.com>: Correct the learning rate as per the code snippet --- Commit 03da61134 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update ir_array.cc --- Commit 2df6cd3ac authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Another try --- Commit af0cbace1 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Transpose to go through Eigen (#10321) --- Commit fc7361081 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers RGBToHSV and HSVToRGB (#91) (#10848) * [OpenCL] Added RGBToHSV and HSVToRGB * Aligning '\' --- Commit 832894ef8 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers AdjustContrastv2 (#10949) * [OpenCL] Registers AdjustContrastv2 (#93) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL (#96) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL * simplified to #ifndef * Changed to "#if GOOGLE_CUDA" * Update adjust_contrast_op_benchmark_test.cc * Added comments --- Commit cb4c2f8d1 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make TransferBufferToInFeed not virual so it compiles. --- Commit e89f04d80 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix calling Literal member functions. --- Commit 15a8df724 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix mac build clone from meheff's change: [XLA] Change return type of DeviceAssignment::Deserialize to fix build breakage on mac. The mac build had the following error: error: incomplete type 'xla::DeviceAssignment' used in type trait expression This was due to a static method returning a StatusOr<DeviceAssignment> inside of the definition of DeviceAssignment. --- Commit a54d43fa4 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Replace LiteralUtil to Literal in compiler/plugin/executor --- Commit 88a6bb80c authored by Guenther Schmuelling<guschmue@microsoft.com> Committed by Guenther Schmuelling<guschmue@microsoft.com>: expand inline for debug builds to limit number of symbols --- Commit 62fb49d31 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix visibility error for contrib/remote_fused_graph/pylib/BUILD. --- Commit 4c75252f2 authored by Mark Neumann<markn@allenai.org> Committed by Mark Neumann<markn@allenai.org>: fix initial test values to avoid numerical instability --- Commit b58d98353 authored by sj6077<epik03sj@gmail.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: Fixes of AutoParallel bug (#10368) * Fix the bug that auto_parallel could replicate variable snapshot name * Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item * remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel --- Commit a286b7db8 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make debug_test slice integer. --- Commit 97fcfdfa6 authored by Toby Boyd<tobyboyd@google.com> Committed by GitHub<noreply@github.com>: Fixed path to seq2seq.py and minor formatting --- Commit 63c1befb8 authored by Anish Shah<shah.anish07@gmail.com> Committed by Anish Shah<shah.anish07@gmail.com>: Improve docs for tf.nn.depthwise_conv2d_native --- Commit 8d42202b2 authored by Yong Tang<yong.tang.github@outlook.com> Committed by Yong Tang<yong.tang.github@outlook.com>: Fix mismatched delete in mkl_tfconv_op.cc This fix fixes mismatched new[]-delete in mkl_tfconv_op.cc (the file went through clang-format so there are some additional changes) Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit 26301bd55 authored by Danny Goodman<goodman.danny@gmail.com> Committed by Danny Goodman<goodman.danny@gmail.com>: fix error format --- Commit b3f33ad46 authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible). PiperOrigin-RevId: 159649743 --- Commit a4a469832 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Add tests for select ops and while loops that produce tuples that contain predicates. PiperOrigin-RevId: 159645900 --- Commit 980d3f2be authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use C API to implement Operation.name property This name property is used in many existing tests including those that already run with C API enabled (math_ops_test, framework_ops_test, session_test, session_partial_run_test, math_ops_test_gpu, etc). PiperOrigin-RevId: 159645767 --- Commit 26239c706 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Previously we didn't have an implementation of BatchNormInference and BatchNormTraining, which gives a linker error if anyone ever tries to call that. A dummy implementation is friendlier than a linker error. PiperOrigin-RevId: 159645612 --- Commit f671c5caa authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 159570549 PiperOrigin-RevId: 160182040
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc4
-rw-r--r--tensorflow/core/debug/grpc_session_debug_test.cc5
-rw-r--r--tensorflow/core/framework/graph_def_util.h2
-rw-r--r--tensorflow/core/framework/op.h1
-rw-r--r--tensorflow/core/framework/tensor.h2
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc277
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc252
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc2
-rw-r--r--tensorflow/core/grappler/grappler_item.h1
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel.h1
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel_test.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc2
-rw-r--r--tensorflow/core/kernels/BUILD8
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op.cc24
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc5
-rw-r--r--tensorflow/core/kernels/colorspace_op.cc15
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc136
-rw-r--r--tensorflow/core/kernels/cwise_op_add_2.cc7
-rw-r--r--tensorflow/core/kernels/cwise_op_cosh.cc37
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_add.cu.cc3
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_cosh.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sinh.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_invert.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_sinh.cc37
-rw-r--r--tensorflow/core/kernels/cwise_ops.h6
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc27
-rw-r--r--tensorflow/core/kernels/map_stage_op.cc7
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc69
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc51
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc112
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc30
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.cc54
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc36
-rw-r--r--tensorflow/core/kernels/priority_queue.cc2
-rw-r--r--tensorflow/core/kernels/shape_ops.cc56
-rw-r--r--tensorflow/core/kernels/slice_op.cc258
-rw-r--r--tensorflow/core/kernels/sparse_reduce_op.cc341
-rw-r--r--tensorflow/core/kernels/stack_ops.cc68
-rw-r--r--tensorflow/core/kernels/topk_op.cc2
-rw-r--r--tensorflow/core/kernels/transpose_functor.h7
-rw-r--r--tensorflow/core/kernels/transpose_op.cc5
-rw-r--r--tensorflow/core/kernels/typed_conditional_accumulator_base.h2
-rw-r--r--tensorflow/core/lib/gtl/optional.h2
-rw-r--r--tensorflow/core/ops/math_grad.cc20
-rw-r--r--tensorflow/core/ops/math_grad_test.cc20
-rw-r--r--tensorflow/core/ops/math_ops.cc8
-rw-r--r--tensorflow/core/ops/nn_ops.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt75
-rw-r--r--tensorflow/core/ops/sparse_ops.cc69
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils.cc2
-rw-r--r--tensorflow/core/protobuf/worker.proto2
-rw-r--r--tensorflow/core/public/version.h2
-rw-r--r--tensorflow/core/util/mkl_util.h38
54 files changed, 1897 insertions, 398 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 21a20bcc4d..4b951691fb 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -654,11 +654,11 @@ Status DirectSession::Run(const RunOptions& run_options,
// If requested via RunOptions, output the partition graphs.
if (run_options.output_partition_graphs()) {
- protobuf::RepeatedPtrField<GraphDef>* parition_graph_defs =
+ protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
run_metadata->mutable_partition_graphs();
for (const PerPartitionExecutorsAndLib& exec_and_lib :
executors_and_keys->items) {
- GraphDef* partition_graph_def = parition_graph_defs->Add();
+ GraphDef* partition_graph_def = partition_graph_defs->Add();
exec_and_lib.graph->ToGraphDef(partition_graph_def);
}
}
diff --git a/tensorflow/core/debug/grpc_session_debug_test.cc b/tensorflow/core/debug/grpc_session_debug_test.cc
index 3827596a67..d6f35fe24c 100644
--- a/tensorflow/core/debug/grpc_session_debug_test.cc
+++ b/tensorflow/core/debug/grpc_session_debug_test.cc
@@ -279,9 +279,12 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) {
DeleteDumpDir();
} else {
+ // CUDA and SYCL devices do not have an Identity op for strings
LOG(ERROR) << "Error: " << s;
ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
- (b_dev.device_type() == DEVICE_GPU));
+ (a_dev.device_type() == DEVICE_SYCL) ||
+ (b_dev.device_type() == DEVICE_GPU) ||
+ (b_dev.device_type() == DEVICE_SYCL));
ASSERT_FALSE(s.ok());
}
}
diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h
index 56355eaf36..950737c39a 100644
--- a/tensorflow/core/framework/graph_def_util.h
+++ b/tensorflow/core/framework/graph_def_util.h
@@ -62,7 +62,7 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
// attr with a default was added). Note that this will not affect
// attrs with non-default values, so you must run a
// ValidateGraphDef...() function to see if the result is in fact
-// compatible. If not nulllptr, the op/attr pairs that were removed
+// compatible. If not nullptr, the op/attr pairs that were removed
// are added to '*op_attr_removed'.
//
// Expected usage, for a producer that wants to prepare a graph for
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index c5a0983a54..a4dd06de45 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -205,7 +205,6 @@ class OpDefBuilderWrapper;
template <>
class OpDefBuilderWrapper<true> {
public:
- typedef OpDefBuilderWrapper<true> WrapperType;
OpDefBuilderWrapper(const char name[]) : builder_(name) {}
OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
builder_.Attr(spec);
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 49eecc0b08..a164fe61b5 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -307,7 +307,7 @@ class Tensor {
/// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the
/// first 'begin' Tensor dimensions into the first dimension of the result and
/// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last
- /// dimension of the result. If 'begin' < 0 then the the |'begin'| leading
+ /// dimension of the result. If 'begin' < 0 then the |'begin'| leading
/// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then
/// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added.
template <typename T, size_t NDIMS = 3>
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 94741a11ff..625780e7c9 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -247,16 +247,10 @@ namespace tensorflow {
//
// P = Conv2DWithBiasBackpropBias(O, O_m)
//
-// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is
-// the context matching depth. If _MklConv2DWithBias is not within the context
-// matching depth, then we do not rewrite BiasAddGrad.
-
-// How many hops do we search for matching node in the backward dataflow graph?
-// We use maxhop of 10 based on empirical observations. Also, these are
-// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D)
-// directly goes to backward nodes, we do not expect the hop-distance
-// would be more than few nodes.
-static size_t kNodeMergeContextMaxDepth = 10;
+// Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending
+// on the matching 'context'. The term context is loosely related to which
+// forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then
+// we consider it Conv2D context; if it is MatMul, then it is MatMul context.
class MklLayoutRewritePass : public GraphOptimizationPass {
public:
@@ -280,6 +274,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.max_pool = "MaxPool";
csinfo_.max_pool_grad = "MaxPoolGrad";
csinfo_.mkl_conv2d = "_MklConv2D";
+ csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
+ csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_with_bias_backprop_bias =
"_MklConv2DWithBiasBackpropBias";
@@ -360,16 +356,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
csinfo_.mkl_conv2d_with_bias});
- // We use maxhop of 10 based on empirical observations. Also, these are
- // maxhops in backward data-flow graph. Since input of forward nodes
- // (Conv2D) directly goes to backward nodes, we do not expect the
- // hop-distance would be more than few nodes.
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
- kNodeMergeContextMaxDepth};
+ IsBiasAddGradInMatMulContext};
biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
csinfo_.mkl_conv2d_with_bias,
- kNodeMergeContextMaxDepth};
+ IsBiasAddGradInConv2DWithBiasContext};
cinfo_.push_back(&biasaddgrad_matmul_context_);
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
@@ -392,9 +384,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string node; // Name of the node to be rewritten
string fwd; // Name of the node in the forward pass that this node
// corresponds to
- size_t max_hop; // Maximum number of hops the fwd is located
- // from this node. If the fwd is farther than max_hop
- // then we do not rewrite the node.
+ std::function<bool(const Node*, const Node**, void* c)> context_match_fn;
} ContextInfo;
/// Structure to specify the name of an original node, its new name after
@@ -438,7 +428,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
/// Structure to store all constant strings
/// NOTE: names are alphabetically sorted.
- struct {
+ typedef struct {
string avg_pool;
string avg_pool_grad;
string bias_add;
@@ -457,13 +447,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string max_pool;
string max_pool_grad;
string mkl_conv2d;
+ string mkl_conv2d_grad_input;
+ string mkl_conv2d_grad_filter;
string mkl_conv2d_with_bias;
string mkl_conv2d_with_bias_backprop_bias;
string relu;
string relu_grad;
string reshape;
string split;
- } csinfo_;
+ } ConstStringsInfo;
private:
/// Maintain info about nodes to rewrite
@@ -478,6 +470,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
/// Maintain info about nodes to rewrite
static std::vector<ContextInfo*> cinfo_;
+ /// Maintain structure of constant strings
+ static ConstStringsInfo csinfo_;
+
/// Context variables used in referencing rules
static ContextInfo biasaddgrad_matmul_context_;
static ContextInfo biasaddgrad_conv2dwithbias_context_;
@@ -629,6 +624,173 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}
+ // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node
+ // specified in contextinfo 'ci'. Function updates fwd_node to point
+ // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias.
+ //
+ // Association checks for one of the following graphs:
+ //
+ // Graph A:
+ //
+ // _ = Conv2DWithBias(F, I, _)
+ // ..
+ // _ = Conv2DBackpropFilter(F, _, G)
+ // _ = Conv2DBackpropInput(_, I, G)
+ // _ = BiasAddGrad(G)
+ //
+ // OR
+ //
+ // Graph B:
+ //
+ // _ = Conv2DWithBias(F, _, _)
+ // ..
+ // _ = Conv2DBackpropFilter(F, _, G)
+ // _ = BiasAddGrad(G)
+ //
+ // Here F, G, and I are graph nodes; _ represents graph nodes that we
+ // don't care here.
+ //
+ // @return - true (if BiasAddGrad is associated with Conv2DWithBias);
+ // false otherwise.
+ static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n,
+ const Node** fwd_node,
+ void* ci) {
+ CHECK_NOTNULL(n);
+ CHECK_NOTNULL(fwd_node);
+ CHECK_NOTNULL(ci);
+ *fwd_node = nullptr;
+
+ CHECK_EQ(n->type_string(), csinfo_.bias_add_grad);
+
+ // Get the only 1 input of BiasAddGrad.
+ CHECK_EQ(n->num_inputs(), 1);
+ const Node* bias_add_grad_inp = nullptr;
+ TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp));
+ CHECK_NOTNULL(bias_add_grad_inp);
+
+ // Check if this input also goes to BackpropFilter and BackpropInput
+ // as 3rd input.
+ bool found_backprop_input = false;
+ bool found_backprop_filter = false;
+ Node* backprop_filter_node = nullptr;
+ Node* backprop_input_node = nullptr;
+
+ for (const Edge* e : bias_add_grad_inp->out_edges()) {
+ Node* third_input = nullptr;
+ if (e->dst()->type_string() == csinfo_.conv2d_grad_input ||
+ e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) {
+ // Third input (index 2) of BackpropInput
+ TF_CHECK_OK(e->dst()->input_node(2, &third_input));
+ // Third input (index 2) of BackpropInput must be same as the input
+ // of BiasAddGrad.
+ if (third_input == bias_add_grad_inp) {
+ found_backprop_input = true;
+ backprop_input_node = e->dst();
+ }
+ }
+
+ if (e->dst()->type_string() == csinfo_.conv2d_grad_filter ||
+ e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) {
+ // Third input (index 2) of BackpropFilter
+ TF_CHECK_OK(e->dst()->input_node(2, &third_input));
+ // Third input (index 2) of BackpropFilter must be same as the input
+ // of BiasAddGrad.
+ if (third_input == bias_add_grad_inp) {
+ found_backprop_filter = true;
+ backprop_filter_node = e->dst();
+ }
+ }
+
+ // If we found both the nodes, then we can stop the search.
+ if (found_backprop_input && found_backprop_filter) {
+ break;
+ }
+ }
+
+ // If BackpropFilter node is not found, then this is not
+ // Conv2DWithBias context. For 2nd graph in the example above, only
+ // BackpropFilter would be present.
+ if (!found_backprop_filter) {
+ return false;
+ }
+
+ // Otherwise, we found the nodes.
+ CHECK_NOTNULL(backprop_filter_node);
+ if (found_backprop_input) {
+ CHECK_NOTNULL(backprop_input_node);
+ }
+
+ // Now that we confirmed that this is Conv2DWithBias context, we need to
+ // get access to the forward node (Conv2DWithBias). 2nd input of
+ // Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st
+ // input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter
+ // (This comes from definition of gradient computation for Conv2D).
+ if (found_backprop_input) {
+ // Graph A in the example.
+ Node* second_inp_of_input = nullptr;
+ Node* first_inp_of_filter = nullptr;
+ TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input));
+ TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
+ CHECK_NOTNULL(second_inp_of_input);
+ CHECK_NOTNULL(first_inp_of_filter);
+
+ // Now we need to find out Conv2DWithBias node from these input nodes.
+ // Conv2DWithBias node is the node that accepts both the nodes
+ // second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots.
+ for (const Edge* fe : first_inp_of_filter->out_edges()) {
+ if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
+ fe->dst_input() == 0) {
+ for (const Edge* ie : second_inp_of_input->out_edges()) {
+ if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
+ ie->dst_input() == 1 && fe->dst() == ie->dst()) {
+ VLOG(1) << "MklLayoutRewritePass: found "
+ << fe->dst()->DebugString()
+ << " as the forward node for matching context, backward"
+ << " node is: " << n->DebugString();
+ *fwd_node = fe->dst();
+ return true;
+ }
+ }
+ }
+ }
+ } else {
+ // We did not find BackpropInput, so we work with BackpropFilter only.
+ // Graph B in the example.
+ Node* first_inp_of_filter = nullptr;
+ TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
+ CHECK_NOTNULL(first_inp_of_filter);
+
+ // Now we need to find out Conv2DWithBias node from first input of
+ // BackpropFIlter. Conv2DWithBias node is the node that accepts
+ // first_inp_of_filter in 1st input slot.
+ for (const Edge* fe : first_inp_of_filter->out_edges()) {
+ if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
+ fe->dst_input() == 0) {
+ VLOG(1) << "MklLayoutRewritePass: found "
+ << fe->dst()->DebugString()
+ << " as the forward node for matching context, backward"
+ << " node is: " << n->DebugString();
+ *fwd_node = fe->dst();
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ // Is BiasAddGrad node in 'n' is associated with MatMul node
+ // specified in contextinfo 'ci'. Function does not update fwd_node.
+ //
+ // @return - true (if BiasAddGrad is associated with MatMul);
+ // false otherwise.
+ static bool IsBiasAddGradInMatMulContext(const Node* n,
+ const Node** fwd_node,
+ void* ci) {
+ return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci));
+ }
+
+
// Rewrite rule that uses context-information for matching,
// used in scenario 2.
//
@@ -639,8 +801,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
// Helper function that searches the matching contextinfo for the node.
- // Implements depth-first search in the data dependence graph for the
- // gradient op in the backward direction.
//
// @input n - Node (gradient op) whose contextinfo is to be searched,
// fwd_node - pointer to node from the forward pass that this node
@@ -788,6 +948,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
Node* orig_node);
};
+MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
MklLayoutRewritePass::ContextInfo
MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
MklLayoutRewritePass::ContextInfo
@@ -1667,12 +1828,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
const ContextInfo* ci = nullptr;
bool is_context_based_rewrite = false;
if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) {
- CHECK_NOTNULL(fwd_node);
is_context_based_rewrite = true;
// Sanity checks for context-based rewrite (if any)
if (orig_node->type_string() == csinfo_.bias_add_grad &&
ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
+ CHECK_NOTNULL(fwd_node);
DataType orig_T, ctx_T;
string orig_data_format, ctx_data_format;
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
@@ -1784,69 +1945,17 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
CHECK_NOTNULL(fwd_node);
*fwd_node = nullptr;
- // Search for matching contextinfo based on node name.
- // There could be more than one matching contextinfos.
- bool is_matching_cinfo_found = false;
- std::vector<const ContextInfo*> mci;
+ // Search for matching contextinfo based on node name and call
+ // callback function using matching contextinfo.
+ // There could be more than one matching contextinfos but whichever
+ // matches first is returned.
for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
- if (n->type_string() == (*ci)->node) {
- mci.push_back(*ci);
- is_matching_cinfo_found = true;
+ if (n->type_string() == (*ci)->node &&
+ (*ci)->context_match_fn(n, fwd_node, *ci)) {
+ VLOG(1) << "Found context as matching: " << (*ci)->fwd;
+ return *ci;
}
}
- // If no matching contextinfo is found, return immediately.
- if (!is_matching_cinfo_found) {
- return nullptr;
- }
-
- VLOG(1) << "MklLayoutRewritePass: Searching graph for: " << n->type_string()
- << " in backwards.";
-
- // Now we will check for forward op name for context info in data
- // flow graph. Get the max hops we should search for the fwd node.
- // We are now going to search (breadth-first) backwards in data
- // dependence graph (for up to max hops) from n for the node
- // specified in fwd.
- // queue to maintain nodes to be visited and depth info for
- // breadth-first search
- std::queue<std::pair<const Node*, int>> nqueue;
- const Node* curr_node = n;
- size_t curr_depth = 0;
- nqueue.push(std::make_pair(curr_node, curr_depth));
-
- while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) {
- std::pair<const Node*, int> curr_pair = nqueue.front();
- nqueue.pop();
-
- std::set<const Node*> visited_nodes;
- curr_node = curr_pair.first;
- curr_depth = curr_pair.second;
- CHECK_NOTNULL(curr_node);
-
- VLOG(1) << "MklLayoutRewritePass: Visiting node: "
- << curr_node->type_string() << " at depth: " << curr_depth
- << " for node: " << n->type_string();
-
- // If we find a match, we return immediately.
- for (const ContextInfo* ci : mci) {
- if (curr_node->type_string() == ci->fwd) {
- *fwd_node = curr_node;
- return ci;
- }
- }
-
- // Else we explore backward edges from current node.
- // Add the source nodes of all incoming edges of the node to the queue.
- for (const Edge* e : curr_node->in_edges()) {
- // We do not visit already visited node.
- if (visited_nodes.find(e->src()) == visited_nodes.end()) {
- // Depth of these nodes is 1 more than the depth of current node.
- nqueue.push(std::make_pair(e->src(), curr_depth + 1));
- visited_nodes.insert(e->src());
- }
- }
- } /* while */
-
return nullptr;
}
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 3c4a5263af..efbe2134e0 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -345,7 +345,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
// rewrite tests
-// D=_MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E)
+// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
+// and BackpropInput
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
@@ -364,16 +365,255 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
"node { name: 'E' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
- "node { name: 'F' op: 'BiasAddGrad'"
+ "node { name: 'F' op: 'Int32Input'}"
+ "node { name: 'G' op: '_MklConv2DBackpropFilter'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
+ "node { name: 'H' op: 'Int32Input'}"
+ "node { name: 'I' op: '_MklConv2DBackpropInput'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['H', 'B', 'E', 'M', 'N', 'O']}"
+ "node { name: 'J' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
- "E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);"
- "N(_MklInput);O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;"
- "DMT/_0->F:1;E->F;E:control->DMT/_0:control;M->D:3;N->D:4;"
- "O->D:5");
+ "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
+ "I(_MklConv2DBackpropInput);J(_MklConv2DWithBiasBackpropBias);"
+ "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G;B->D:1;"
+ "B->I:1;C->D:2;D->E;DMT/_0->J:1;E->G:2;E->I:2;E->J;"
+ "E:control->DMT/_0:control;F->G:1;H->I;M->D:3;M->G:3;M->I:3;"
+ "N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
+}
+
+// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
+// and BackpropInput. But nodes do not match criteria for rewrite. So
+// rewrite should not happen.
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'O' op: '_MklInput'}"
+ "node { name: 'D' op: '_MklConv2DWithBias'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
+ "node { name: 'E' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'A']}"
+ "node { name: 'F' op: 'Int32Input'}"
+ "node { name: 'G' op: '_MklConv2DBackpropFilter'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
+ "node { name: 'H' op: 'Int32Input'}"
+ "node { name: 'I' op: '_MklConv2DBackpropInput'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['H', 'B', 'E', 'M', 'N', 'O']}"
+ "node { name: 'J' op: 'BiasAddGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['E'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
+ "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
+ "I(_MklConv2DBackpropInput);J(BiasAddGrad);"
+ "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
+ "B->I:1;C->D:2;D->E;E->G;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
+ "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
+}
+
+// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
+// and BackpropInput. But nodes do not match criteria for rewrite. So
+// rewrite should not happen.
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'O' op: '_MklInput'}"
+ "node { name: 'D' op: '_MklConv2DWithBias'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['B', 'A', 'C', 'M', 'N', 'O']}"
+ "node { name: 'E' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'A']}"
+ "node { name: 'F' op: 'Int32Input'}"
+ "node { name: 'G' op: '_MklConv2DBackpropFilter'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
+ "node { name: 'H' op: 'Int32Input'}"
+ "node { name: 'I' op: '_MklConv2DBackpropInput'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['H', 'B', 'E', 'M', 'N', 'O']}"
+ "node { name: 'J' op: 'BiasAddGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['E'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
+ "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
+ "I(_MklConv2DBackpropInput);J(BiasAddGrad);"
+ "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
+ "B->I:1;C->D:2;D->E;E->G:2;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
+ "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
+}
+
+
+// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'O' op: '_MklInput'}"
+ "node { name: 'D' op: '_MklConv2DWithBias'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
+ "node { name: 'E' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'A']}"
+ "node { name: 'F' op: 'Int32Input'}"
+ "node { name: 'G' op: '_MklConv2DBackpropFilter'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
+ "node { name: 'H' op: 'BiasAddGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['E'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
+ "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);"
+ "H(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
+ "O(_MklInput)|A->D;A->E:1;A->G;B->D:1;C->D:2;D->E;DMT/_0->H:1;"
+ "E->G:2;E->H;E:control->DMT/_0:control;F->G:1;M->D:3;M->G:3;"
+ "N->D:4;N->G:4;O->D:5;O->G:5");
+}
+
+// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
+// But BackpropFilter node inputs do not satisfy criteria for rewrite.
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'O' op: '_MklInput'}"
+ "node { name: 'D' op: '_MklConv2DWithBias'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
+ "node { name: 'E' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'A']}"
+ "node { name: 'F' op: 'Int32Input'}"
+ "node { name: 'G' op: '_MklConv2DBackpropFilter'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
+ "node { name: 'H' op: 'BiasAddGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['E'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
+ "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
+ "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
+ "C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
+ "O->G:5");
+}
+
+// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
+// But BackpropFilter node inputs do not satisfy criteria for rewrite.
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'O' op: '_MklInput'}"
+ "node { name: 'D' op: '_MklConv2DWithBias'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['B', 'A', 'C', 'M', 'N', 'O']}"
+ "node { name: 'E' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'A']}"
+ "node { name: 'F' op: 'Int32Input'}"
+ "node { name: 'G' op: '_MklConv2DBackpropFilter'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
+ "node { name: 'H' op: 'BiasAddGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['E'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
+ "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
+ "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
+ "C->D:2;D->E;E->G:2;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
+ "O->G:5");
}
// No _MklConv2DWithBias in context, but _MklConv2D in context.
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index e959eab54e..f6595fcbb3 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -728,7 +728,7 @@ TEST_F(VirtualSchedulerTest, ComplexDependency) {
1 /* control dependency */);
EXPECT_EQ(expected_size, cpu_state.memory_usage);
- // Nodes currrently in memory: bn's port -1, 0, and 2, and x's port 0.
+ // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
std::set<std::pair<string, int>> nodes_in_memory;
std::transform(
cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h
index 84a7681782..1e7a9dfaf5 100644
--- a/tensorflow/core/grappler/grappler_item.h
+++ b/tensorflow/core/grappler/grappler_item.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variable.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc
index 42f2f1850f..d46b849ad4 100644
--- a/tensorflow/core/grappler/optimizers/auto_parallel.cc
+++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc
@@ -168,6 +168,11 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
for (const auto& variable : item.MainVariables()) {
dont_replicate_nodes.insert(variable->name());
}
+
+ for (const auto& init : item.init_ops) {
+ dont_replicate_nodes.insert(NodeName(init));
+ }
+
// Don't replicate all input nodes, except the dequeue node.
for (const auto& input_node : input_nodes) {
if (input_node->name() != dequeue_node->name()) {
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.h b/tensorflow/core/grappler/optimizers/auto_parallel.h
index ad90bbe028..c5d2d47782 100644
--- a/tensorflow/core/grappler/optimizers/auto_parallel.h
+++ b/tensorflow/core/grappler/optimizers/auto_parallel.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/framework/variable.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
index 3d1b4a34bf..9a41b5e0b5 100644
--- a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
+++ b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
@@ -33,6 +33,7 @@ TEST_F(AutoParallelTest, SimpleParallel) {
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
+ Output identity = ops::Identity(s.WithOpName("identity"), {var});
Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT});
auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue},
{constant_b}, {DT_FLOAT});
@@ -44,13 +45,14 @@ TEST_F(AutoParallelTest, SimpleParallel) {
GrapplerItem item;
item.init_ops.push_back("assign");
item.fetch.push_back("apply_gradient");
+ item.init_ops.push_back("assign");
TF_CHECK_OK(s.ToGraphDef(&item.graph));
AutoParallel parallel(2);
GraphDef output;
Status status = parallel.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(20, output.node_size());
+ EXPECT_EQ(21, output.node_size());
const NodeDef& node_assign = output.node(0);
EXPECT_EQ("assign", node_assign.name());
@@ -62,60 +64,64 @@ TEST_F(AutoParallelTest, SimpleParallel) {
const NodeDef& node_fifo_queue = output.node(2);
EXPECT_EQ("fifo_queue", node_fifo_queue.name());
- const NodeDef& node_var = output.node(3);
+ const NodeDef& node_identity = output.node(3);
+ EXPECT_EQ("identity", node_identity.name());
+ EXPECT_EQ("var", node_identity.input(0));
+
+ const NodeDef& node_var = output.node(4);
EXPECT_EQ("var", node_var.name());
- const NodeDef& node_div_const0 = output.node(4);
+ const NodeDef& node_div_const0 = output.node(5);
EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-Const",
node_div_const0.name());
- const NodeDef& node_div0 = output.node(5);
+ const NodeDef& node_div0 = output.node(6);
EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-apply_gradient",
node_div0.name());
- const NodeDef& node_add0 = output.node(6);
+ const NodeDef& node_add0 = output.node(7);
EXPECT_EQ("AutoParallel-Replica-0/add", node_add0.name());
- const NodeDef& node_gradient0 = output.node(7);
+ const NodeDef& node_gradient0 = output.node(8);
EXPECT_EQ("AutoParallel-Replica-0/apply_gradient", node_gradient0.name());
- const NodeDef& node_constant_a0 = output.node(8);
+ const NodeDef& node_constant_a0 = output.node(9);
EXPECT_EQ("AutoParallel-Replica-0/constant_a", node_constant_a0.name());
- const NodeDef& node_dequeue0 = output.node(9);
+ const NodeDef& node_dequeue0 = output.node(10);
EXPECT_EQ("AutoParallel-Replica-0/dequeue", node_dequeue0.name());
- const NodeDef& node_learning_rate0 = output.node(10);
+ const NodeDef& node_learning_rate0 = output.node(11);
EXPECT_EQ("AutoParallel-Replica-0/learning_rate", node_learning_rate0.name());
- const NodeDef& node_div_const1 = output.node(11);
+ const NodeDef& node_div_const1 = output.node(12);
EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-Const",
node_div_const1.name());
- const NodeDef& node_div1 = output.node(12);
+ const NodeDef& node_div1 = output.node(13);
EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-apply_gradient",
node_div1.name());
- const NodeDef& node_add1 = output.node(13);
+ const NodeDef& node_add1 = output.node(14);
EXPECT_EQ("AutoParallel-Replica-1/add", node_add1.name());
- const NodeDef& node_gradient1 = output.node(14);
+ const NodeDef& node_gradient1 = output.node(15);
EXPECT_EQ("AutoParallel-Replica-1/apply_gradient", node_gradient1.name());
- const NodeDef& node_constant_a1 = output.node(15);
+ const NodeDef& node_constant_a1 = output.node(16);
EXPECT_EQ("AutoParallel-Replica-1/constant_a", node_constant_a1.name());
- const NodeDef& node_dequeue1 = output.node(16);
+ const NodeDef& node_dequeue1 = output.node(17);
EXPECT_EQ("AutoParallel-Replica-1/dequeue", node_dequeue1.name());
- const NodeDef& node_learning_rate1 = output.node(17);
+ const NodeDef& node_learning_rate1 = output.node(18);
EXPECT_EQ("AutoParallel-Replica-1/learning_rate", node_learning_rate1.name());
- const NodeDef& node_fetch = output.node(18);
+ const NodeDef& node_fetch = output.node(19);
EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name());
EXPECT_EQ("^AutoParallel-Replica-0/apply_gradient", node_fetch.input(0));
EXPECT_EQ("^AutoParallel-Replica-1/apply_gradient", node_fetch.input(1));
- const NodeDef& node_gradient = output.node(19);
+ const NodeDef& node_gradient = output.node(20);
EXPECT_EQ("apply_gradient", node_gradient.name());
EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
}
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index ded1e474ce..28d663e2f7 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -929,7 +929,7 @@ struct TuningConfig {
// Conv2DBackpropFilter will use a specialized GEMM implementation, which is
// usually faster than the NCHW implementation. The downside is that this
// might result in more non-cancellable layout conversion nodes (implemented
- // by the Tranpose op).
+ // by the Transpose op).
bool no_gemm;
};
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 8b7c269a11..9c397954e1 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2464,7 +2464,7 @@ tf_cc_tests(
":ops_util",
":sparse_add_op",
":sparse_dense_binary_op_shared",
- ":sparse_reduce_sum_op",
+ ":sparse_reduce_op",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -3207,7 +3207,7 @@ cc_library(
":sparse_cross_op",
":sparse_dense_binary_op_shared",
":sparse_fill_empty_rows_op",
- ":sparse_reduce_sum_op",
+ ":sparse_reduce_op",
":sparse_reorder_op",
":sparse_reshape_op",
":sparse_softmax",
@@ -3263,8 +3263,8 @@ tf_kernel_library(
)
tf_kernel_library(
- name = "sparse_reduce_sum_op",
- prefix = "sparse_reduce_sum_op",
+ name = "sparse_reduce_op",
+ prefix = "sparse_reduce_op",
deps = SPARSE_DEPS,
)
diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc
index c8f12f91a6..37976f7183 100644
--- a/tensorflow/core/kernels/adjust_contrast_op.cc
+++ b/tensorflow/core/kernels/adjust_contrast_op.cc
@@ -31,6 +31,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif
// AdjustContrastOp is deprecated as of GraphDef version >= 2
@@ -410,4 +413,25 @@ REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_GPU),
AdjustContrastOpv2<GPUDevice>);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+template <>
+class AdjustContrastOpv2<SYCLDevice> : public AdjustContrastOpV2Base {
+ public:
+ explicit AdjustContrastOpv2(OpKernelConstruction* context)
+ : AdjustContrastOpV2Base(context) {}
+
+ void DoCompute(OpKernelContext* context,
+ const ComputeOptions& options) override {
+ const int64 shape[4] = {options.batch, options.height, options.width,
+ options.channels};
+ functor::AdjustContrastv2<SYCLDevice>()(
+ context->eigen_device<SYCLDevice>(),
+ options.input->shaped<float, 4>(shape), options.factor->scalar<float>(),
+ options.output->shaped<float, 4>(shape));
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_SYCL),
+ AdjustContrastOpv2<SYCLDevice>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc b/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc
index ffd47406eb..c485f14844 100644
--- a/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc
+++ b/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc
@@ -56,6 +56,11 @@ static Graph* BM_AdjustContrast(int batches, int width, int height) {
// BM_AdjustContrast_cpu_1_299_299 179084 340186 2181 751.9M items/s
// BM_AdjustContrast_gpu_32_299_299 85276 123665 4189 2.9G items/s
BM_AdjustContrastDev(cpu, 1, 299, 299);
+#if GOOGLE_CUDA
BM_AdjustContrastDev(gpu, 32, 299, 299);
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+BM_AdjustContrastDev(sycl, 32, 299, 299);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/colorspace_op.cc b/tensorflow/core/kernels/colorspace_op.cc
index d65a34fd73..ba100b32e7 100644
--- a/tensorflow/core/kernels/colorspace_op.cc
+++ b/tensorflow/core/kernels/colorspace_op.cc
@@ -35,6 +35,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif
template <typename Device, typename T>
class RGBToHSVOp : public OpKernel {
@@ -146,4 +149,16 @@ TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL(T) \
+ REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T"), \
+ RGBToHSVOp<SYCLDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T"), \
+ HSVToRGBOp<SYCLDevice, T>);
+TF_CALL_float(REGISTER_SYCL);
+TF_CALL_double(REGISTER_SYCL);
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index 203a9a9f24..64c06786bc 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -112,15 +112,14 @@ REGISTER_GPU_HOST_REF_KERNEL(string);
#undef REGISTER_GPU_HOST_KERNEL
#undef REGISTER_GPU_HOST_REF_KERNEL
-#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \
.Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T") \
- .HostMemory("pred"), \
+ .HostMemory("pred") \
+ .TypeConstraint<type>("T"),\
SwitchOp)
-REGISTER_SYCL_KERNEL(bool);
-TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH);
#define REGISTER_SYCL_REF_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
@@ -128,12 +127,41 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
.HostMemory("pred") \
.TypeConstraint<type>("T"), \
SwitchOp)
-REGISTER_SYCL_REF_SWITCH(bool);
-TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
+TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
-#undef REGISTER_SYCL_KERNEL
+#undef REGISTER_SYCL_SWITCH
#undef REGISTER_SYCL_REF_SWITCH
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Switch") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("pred") \
+ .HostMemory("output_false")\
+ .HostMemory("output_true") \
+ .TypeConstraint<type>("T"),\
+ SwitchOp)
+
+REGISTER_SYCL_HOST_KERNEL(bool);
+REGISTER_SYCL_HOST_KERNEL(string);
+REGISTER_SYCL_HOST_KERNEL(int32);
+
+#define REGISTER_SYCL_HOST_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("pred") \
+ .HostMemory("output_false") \
+ .HostMemory("output_true") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+
+REGISTER_SYCL_HOST_REF_KERNEL(int32);
+REGISTER_SYCL_HOST_REF_KERNEL(bool);
+REGISTER_SYCL_HOST_REF_KERNEL(string);
+
+#undef REGISTER_SYCL_HOST_KERNEL
+#undef REGISTER_SYCL_HOST_REF_KERNEL
#endif // TENSORFLOW_USE_SYCL
class RefSelectOp : public OpKernel {
@@ -233,13 +261,13 @@ REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
-#if TENSORFLOW_USE_SYCL
+#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Merge") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.HostMemory("value_index"), \
- MergeOp)
+ MergeOp);
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
@@ -248,9 +276,10 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.HostMemory("value_index"), \
- MergeOp)
+ MergeOp);
REGISTER_SYCL_REF_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
+
#undef REGISTER_SYCL_KERNEL
#undef REGISTER_SYCL_REF_KERNEL
#endif // TENSORFLOW_USE_SYCL
@@ -280,6 +309,30 @@ REGISTER_GPU_HOST_KERNEL(ResourceHandle);
#undef REGISTER_GPU_HOST_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Merge") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("inputs") \
+ .HostMemory("output") \
+ .HostMemory("value_index") \
+ .TypeConstraint<type>("T"), \
+ MergeOp); \
+ REGISTER_KERNEL_BUILDER(Name("RefMerge") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("inputs") \
+ .HostMemory("output") \
+ .HostMemory("value_index") \
+ .TypeConstraint<type>("T"), \
+ MergeOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(string);
+REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
+
+#undef REGISTER_SYCL_HOST_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
void EnterOp::Compute(OpKernelContext* context) {
if (IsRefType(context->input_dtype(0))) {
context->forward_ref_input_to_ref_output(0, 0);
@@ -306,7 +359,7 @@ REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
-#if TENSORFLOW_USE_SYCL
+#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
@@ -345,7 +398,7 @@ REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
#undef REGISTER_SYCL_HOST_KERNEL
#undef REGISTER_SYCL_HOST_REF_KERNEL
-#endif
+#endif // TENSORFLOW_USE_SYCL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -394,30 +447,25 @@ REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
REGISTER_GPU_KERNEL(bool);
+REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
-#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp)
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp);
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
-#define REGISTER_SYCL_REF_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp)
-REGISTER_SYCL_REF_KERNEL(bool);
-TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
-
#undef REGISTER_SYCL_KERNEL
#undef REGISTER_SYCL_REF_KERNEL
-// Special GPU kernels for int32 and string.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
#define REGISTER_SYCL_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Exit") \
.Device(DEVICE_SYCL) \
@@ -507,31 +555,19 @@ REGISTER_GPU_HOST_KERNEL(string);
#undef REGISTER_GPU_HOST_KERNEL
-#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER(Name("NextIteration") \
- .Device(DEVICE_SYCL) \
- .HostMemory("data") \
- .HostMemory("output") \
- .TypeConstraint<type>("T"), \
- NextIterationOp)
- REGISTER_SYCL_KERNEL(bool);
- TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
-#define REGISTER_SYCL_REF_KERNEL(type) \
- REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
- .Device(DEVICE_SYCL) \
- .HostMemory("data") \
- .HostMemory("output") \
- .TypeConstraint<type>("T"), \
- NextIterationOp)
- REGISTER_SYCL_REF_KERNEL(bool);
- TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ NextIterationOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\
+ NextIterationOp)
+REGISTER_SYCL_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
#undef REGISTER_SYCL_KERNEL
-#undef REGISTER_SYCL_REF_KERNEL
-// Special GPU kernels for int32 and string.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
#define REGISTER_SYCL_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("NextIteration") \
.Device(DEVICE_SYCL) \
diff --git a/tensorflow/core/kernels/cwise_op_add_2.cc b/tensorflow/core/kernels/cwise_op_add_2.cc
index 5d3385b0ed..5dea00e95c 100644
--- a/tensorflow/core/kernels/cwise_op_add_2.cc
+++ b/tensorflow/core/kernels/cwise_op_add_2.cc
@@ -22,10 +22,11 @@ namespace tensorflow {
// sharded files, only make its register calls when not __ANDROID_TYPES_SLIM__.
#if !defined(__ANDROID_TYPES_SLIM__)
-REGISTER5(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
- complex128, string);
+REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
+ uint8, complex128, string);
#if GOOGLE_CUDA
-REGISTER3(BinaryOp, GPU, "Add", functor::add, int64, complex64, complex128);
+REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
+ complex128);
#endif // GOOGLE_CUDA
#endif // !defined(__ANDROID_TYPES_SLIM__)
diff --git a/tensorflow/core/kernels/cwise_op_cosh.cc b/tensorflow/core/kernels/cwise_op_cosh.cc
new file mode 100644
index 0000000000..bca99a4f89
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_cosh.cc
@@ -0,0 +1,37 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(UnaryOp, CPU, "Cosh", functor::cosh, float, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cosh") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<TYPE>("T"), \
+ UnaryOp<SYCLDevice, functor::cosh<TYPE>>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Cosh", functor::cosh, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc
index 5aaf2b5b4b..61079ebab3 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc
@@ -19,7 +19,8 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-DEFINE_BINARY6(add, Eigen::half, float, double, int64, complex64, complex128);
+DEFINE_BINARY7(add, Eigen::half, float, double, uint8, int64, complex64,
+ complex128);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_cosh.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_cosh.cu.cc
new file mode 100644
index 0000000000..267a381d1a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_cosh.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(cosh, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sinh.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sinh.cu.cc
new file mode 100644
index 0000000000..f8329e50d6
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_sinh.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(sinh, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_invert.cc b/tensorflow/core/kernels/cwise_op_invert.cc
index c84ee6894e..df2c02e42e 100644
--- a/tensorflow/core/kernels/cwise_op_invert.cc
+++ b/tensorflow/core/kernels/cwise_op_invert.cc
@@ -20,7 +20,7 @@ REGISTER6(UnaryOp, CPU, "Invert", functor::invert, int8, int16, int32, int64,
uint8, uint16);
#ifdef TENSORFLOW_USE_SYCL
-REGISTER(UnaryOp, SYCL, "Invert", functor::invert, int8, int16, int32, int64,
+REGISTER6(UnaryOp, SYCL, "Invert", functor::invert, int8, int16, int32, int64,
uint8, uint16);
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_sinh.cc b/tensorflow/core/kernels/cwise_op_sinh.cc
new file mode 100644
index 0000000000..055f0b12e1
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_sinh.cc
@@ -0,0 +1,37 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(UnaryOp, CPU, "Sinh", functor::sinh, float, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Sinh") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<TYPE>("T"), \
+ UnaryOp<SYCLDevice, functor::sinh<TYPE>>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYC
+
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Sinh", functor::sinh, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 97bdc5e878..c11d6cfabb 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -484,6 +484,12 @@ template <typename T>
struct sign : base<T, Eigen::internal::scalar_sign_op<T> > {};
template <typename T>
+struct sinh : base<T, Eigen::internal::scalar_sinh_op<T> > {};
+
+template <typename T>
+struct cosh : base<T, Eigen::internal::scalar_cosh_op<T> > {};
+
+template <typename T>
struct tanh : base<T, Eigen::internal::scalar_tanh_op<T> > {};
template <typename T>
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index 08ae787c86..135d635514 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -165,20 +165,6 @@ class DynamicStitchOp : public OpKernel {
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
#undef REGISTER_DYNAMIC_STITCH
-#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_DYNAMIC_STITCH_SYCL(type) \
- REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T") \
- .HostMemory("indices") \
- .HostMemory("data") \
- .HostMemory("merged"), \
- DynamicStitchOp<type>)
-
-TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH_SYCL);
-#undef REGISTER_DYNAMIC_STITCH_SYCL
-#endif // TENSORFLOW_USE_SYCL
-
#if GOOGLE_CUDA
#define REGISTER_DYNAMIC_STITCH_GPU(type) \
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
@@ -194,4 +180,17 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_DYNAMIC_STITCH_SYCL(type) \
+ REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("indices") \
+ .HostMemory("data") \
+ .HostMemory("merged"), \
+ DynamicStitchOp<type>)
+
+TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_SYCL);
+#undef REGISTER_DYNAMIC_STITCH_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc
index 6431c6540e..46eaf3d9e7 100644
--- a/tensorflow/core/kernels/map_stage_op.cc
+++ b/tensorflow/core/kernels/map_stage_op.cc
@@ -547,14 +547,14 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
.HostMemory("indices")
.Device(DEVICE_GPU),
MapStageOp<true>);
-#endif
+#endif // GOOGLE_CUDA
+
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("MapStage").HostMemory("key").Device(DEVICE_SYCL),
MapStageOp<false>);
REGISTER_KERNEL_BUILDER(
Name("OrderedMapStage").HostMemory("key").Device(DEVICE_SYCL),
MapStageOp<true>);
-
#endif // TENSORFLOW_USE_SYCL
template <bool Ordered>
@@ -661,6 +661,7 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
.Device(DEVICE_GPU),
MapPeekOp<true>);
#endif
+
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(
Name("MapPeek").HostMemory("key").HostMemory("indices").Device(DEVICE_SYCL),
@@ -724,8 +725,8 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
.HostMemory("indices")
.Device(DEVICE_GPU),
MapUnstageNoKeyOp<true>);
-
#endif
+
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
.HostMemory("key")
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index dc6b88e953..ddcf241277 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -206,10 +206,15 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Mkl needs the entities in its native format.
// So create temporary tensors along with buffers to
// convert the received entities.
- Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor;
+ Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor,
+ mkl_tmp_buf_trans_input;
// This preparation sets (1) dnnResourceSrc (2) dnnResourceDiffDst
- mkl_context.MklPrepareInputs(context, &mkl_tmp_input_buf_tensor,
- &mkl_tmp_out_backprop_buf_tensor);
+ mkl_context.MklPrepareInputs(context, data_format_,
+ input_in_mkl_format,
+ out_backprop_in_mkl_format,
+ &mkl_tmp_input_buf_tensor,
+ &mkl_tmp_out_backprop_buf_tensor,
+ &mkl_tmp_buf_trans_input);
// Final conv-grad-filter should be in TF layout.
Tensor* grad_filter;
@@ -307,34 +312,58 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Compare incoming tensor layouts with MKL preferred layouts and convert
// data to the preferred layout if necessary
- void MklPrepareInputs(OpKernelContext* context,
+ void MklPrepareInputs(OpKernelContext* context, TensorFormat format,
+ bool input_in_mkl_format,
+ bool out_backprop_in_mkl_format,
Tensor* mkl_tmp_input_buf_tensor,
- Tensor* mkl_tmp_out_backprop_buf_tensor) {
+ Tensor* mkl_tmp_out_backprop_buf_tensor,
+ Tensor* mkl_tmp_buf_trans_input) {
bool mkl_convert_input, mkl_convert_out_backprop;
dnnPrimitive_t mkl_prim_convert_input, mkl_prim_convert_out_backprop;
- dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop;
+ dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop,
+ mkl_lt_trans_input;
void *mkl_buf_convert_input, *mkl_buf_convert_out_backprop;
+ void *mkl_buf_input, *mkl_buf_out_backprop;
mkl_prim_convert_input = nullptr;
mkl_prim_convert_out_backprop = nullptr;
mkl_lt_internal_input = nullptr;
mkl_lt_internal_out_backprop = nullptr;
+ mkl_lt_trans_input = nullptr;
mkl_buf_convert_input = nullptr;
mkl_buf_convert_out_backprop = nullptr;
+ mkl_buf_input = nullptr;
+ mkl_buf_out_backprop = nullptr;
// Compare with internal layouts and convert if needed
const Tensor& input = MklGetInput(context, 0);
- void* mkl_buf_input =
- const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+ if (!input_in_mkl_format && format == FORMAT_NHWC){
+ TensorShape nchw_shape = ShapeFromFormat(FORMAT_NCHW,
+ in_sizes[MklDims::N], in_sizes[MklDims::H],
+ in_sizes[MklDims::W], in_sizes[MklDims::C]);
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<float>::value, nchw_shape, mkl_tmp_buf_trans_input));
+ MklNHWCToNCHW(input, &mkl_tmp_buf_trans_input);
+ mkl_buf_input = const_cast<void*>(static_cast<const void*>(
+ mkl_tmp_buf_trans_input->flat<float>().data()));
+ size_t strides[4];
+ GetStridesFromSizes(FORMAT_NCHW, strides, in_sizes);
+ CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_trans_input, in_dims, in_sizes,
+ strides), E_SUCCESS);
+ }
+ else {
+ mkl_buf_input =
+ const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+ mkl_lt_trans_input = lt_input;
+ }
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
&mkl_lt_internal_input, prim_conv_bwdfilter, dnnResourceSrc),
E_SUCCESS);
mkl_convert_input =
- !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input);
+ !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_trans_input);
if (mkl_convert_input) {
- CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input,
- mkl_lt_internal_input),
- E_SUCCESS);
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
+ mkl_lt_trans_input, mkl_lt_internal_input), E_SUCCESS);
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
&mkl_buf_convert_input);
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
@@ -343,26 +372,30 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
dnnDelete_F32(mkl_prim_convert_input);
}
dnnLayoutDelete_F32(mkl_lt_internal_input);
+ if (!input_in_mkl_format && format == FORMAT_NHWC)
+ dnnLayoutDelete_F32(mkl_lt_trans_input);
+
conv_res[dnnResourceSrc] =
(mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
const Tensor& out_backprop = MklGetInput(context, 2);
- void* mkl_buf_out_backprop = const_cast<void*>(
- static_cast<const void*>(out_backprop.flat<T>().data()));
+ mkl_buf_out_backprop = const_cast<void*>(
+ static_cast<const void*>(out_backprop.flat<T>().data()));
+
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop,
prim_conv_bwdfilter,
dnnResourceDiffDst),
E_SUCCESS);
mkl_convert_out_backprop =
- !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop, lt_out_backprop);
+ !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop,
+ lt_out_backprop);
if (mkl_convert_out_backprop) {
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop,
- lt_out_backprop,
- mkl_lt_internal_out_backprop),
+ lt_out_backprop, mkl_lt_internal_out_backprop),
E_SUCCESS);
AllocTmpBuffer(context, mkl_tmp_out_backprop_buf_tensor,
- lt_out_backprop, &mkl_buf_convert_out_backprop);
+ mkl_lt_internal_out_backprop, &mkl_buf_convert_out_backprop);
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop,
mkl_buf_out_backprop,
mkl_buf_convert_out_backprop),
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 76b9f1798d..df49e03f31 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -267,12 +267,15 @@ class MklConv2DOp : public OpKernel {
mkl_context.MklCreateInputLayouts(context);
+ // Temp tensor used to allocate tmp buffers
Tensor mkl_tmp_input_buf_tensor, mkl_tmp_filter_buf_tensor,
- mkl_tmp_bias_buf_tensor; // Temp tensor used to allocate tmp
- // buffers
- mkl_context.MklPrepareConvolutionInputs(context, &mkl_tmp_input_buf_tensor,
+ mkl_tmp_bias_buf_tensor, mkl_tmp_buf_trans_input;
+ mkl_context.MklPrepareConvolutionInputs(context, data_format_,
+ input_in_mkl_format,
+ &mkl_tmp_input_buf_tensor,
&mkl_tmp_filter_buf_tensor,
- &mkl_tmp_bias_buf_tensor);
+ &mkl_tmp_bias_buf_tensor,
+ &mkl_tmp_buf_trans_input);
// Execute convolution
CHECK_EQ(dnnExecute_F32(mkl_context.prim_fwd, mkl_context.conv_res),
@@ -323,39 +326,59 @@ class MklConv2DOp : public OpKernel {
// Compare incoming tensor layouts with MKL preferred layouts and convert
// data to the preferred layout if necessary
void MklPrepareConvolutionInputs(OpKernelContext* context,
+ TensorFormat format,
+ bool input_in_mkl_format,
Tensor* mkl_tmp_input_buf_tensor,
Tensor* mkl_tmp_filter_buf_tensor,
- Tensor* mkl_tmp_bias_buf_tensor) {
+ Tensor* mkl_tmp_bias_buf_tensor,
+ Tensor* mkl_tmp_buf_trans_input) {
bool mkl_convert_input, mkl_convert_filter, mkl_convert_bias;
dnnPrimitive_t mkl_prim_convert_filter, mkl_prim_convert_bias,
mkl_prim_convert_input;
dnnLayout_t mkl_lt_internal_filter, mkl_lt_internal_bias,
- mkl_lt_internal_input;
+ mkl_lt_internal_input, mkl_lt_trans_input;
void *mkl_buf_convert_input, *mkl_buf_convert_filter,
- *mkl_buf_convert_bias;
+ *mkl_buf_convert_bias, *mkl_buf_input;
mkl_prim_convert_filter = nullptr;
mkl_prim_convert_bias = nullptr;
mkl_prim_convert_input = nullptr;
mkl_lt_internal_filter = nullptr;
mkl_lt_internal_bias = nullptr;
mkl_lt_internal_input = nullptr;
+ mkl_lt_trans_input = nullptr;
mkl_buf_convert_input = nullptr;
mkl_buf_convert_filter = nullptr;
mkl_buf_convert_bias = nullptr;
+ mkl_buf_input = nullptr;
// Compare with internal layouts and convert if needed
const Tensor& input = MklGetInput(context, 0);
- void* mkl_buf_input =
- const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+ if (!input_in_mkl_format && format == FORMAT_NHWC) {
+ TensorShape nchw_shape = ShapeFromFormat(FORMAT_NCHW,
+ in_sizes[MklDims::N], in_sizes[MklDims::H],
+ in_sizes[MklDims::W], in_sizes[MklDims::C]);
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<float>::value, nchw_shape, mkl_tmp_buf_trans_input));
+ MklNHWCToNCHW(input, &mkl_tmp_buf_trans_input);
+ mkl_buf_input = const_cast<void*>(static_cast<const void*>(
+ mkl_tmp_buf_trans_input->flat<float>().data()));
+ size_t strides[4];
+ GetStridesFromSizes(FORMAT_NCHW, strides, in_sizes);
+ CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_trans_input, in_dims, in_sizes,
+ strides), E_SUCCESS);
+ } else {
+ mkl_buf_input = const_cast<void*>(
+ static_cast<const void*>(input.flat<T>().data()));
+ mkl_lt_trans_input = lt_input;
+ }
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
prim_fwd, dnnResourceSrc),
E_SUCCESS);
mkl_convert_input =
- !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input);
+ !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_trans_input);
if (mkl_convert_input) {
- CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input,
- mkl_lt_internal_input),
- E_SUCCESS);
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
+ mkl_lt_trans_input, mkl_lt_internal_input), E_SUCCESS);
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
&mkl_buf_convert_input);
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
@@ -364,6 +387,8 @@ class MklConv2DOp : public OpKernel {
dnnDelete_F32(mkl_prim_convert_input);
}
dnnLayoutDelete_F32(mkl_lt_internal_input);
+ if (!input_in_mkl_format && format == FORMAT_NHWC)
+ dnnLayoutDelete_F32(mkl_lt_trans_input);
conv_res[dnnResourceSrc] =
(mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index 070aeff49f..07a7e6b5da 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -22,9 +22,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <vector>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "third_party/mkl/include/mkl_dnn.h"
-#include "third_party/mkl/include/mkl_dnn_types.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -33,6 +30,9 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/util/work_sharder.h"
@@ -66,11 +66,10 @@ class MklLRNOp : public OpKernel {
explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
int64 depth_radius64;
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
- OP_REQUIRES(
- context,
- FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
- errors::InvalidArgument("depth_radius = ", depth_radius64,
- " larger than int max"));
+ OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
depth_radius_ = static_cast<size_t>(depth_radius64);
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
@@ -93,10 +92,9 @@ class MklLRNOp : public OpKernel {
: input.dims();
OP_REQUIRES(context, mkl_context.in_dims == 4,
errors::InvalidArgument("input must be 4-dimensional"));
- OP_REQUIRES(
- context,
- FastBoundsCheck(input.NumElements(), std::numeric_limits<int>::max()),
- errors::InvalidArgument("argument to LRN too large"));
+ OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("argument to LRN too large"));
if (!input_in_mkl_format) {
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
@@ -104,15 +102,6 @@ class MklLRNOp : public OpKernel {
return;
}
- // TODO(inteltf) MKL will support depth radius not equal to 2 in the future
- if (depth_radius_ != 2) {
- Tensor converted_tensor =
- ConvertMklToTF<T>(context, input, mkl_context.input_shape);
- mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
- beta_, converted_tensor);
- return;
- }
-
if (input_in_mkl_format) {
// MKL supports normalization over channel dimension only
if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
@@ -345,11 +334,10 @@ class MklLRNGradOp : public OpKernel {
explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
int64 depth_radius64;
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
- OP_REQUIRES(
- context,
- FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
- errors::InvalidArgument("depth_radius = ", depth_radius64,
- " larger than int max"));
+ OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
depth_radius_ = static_cast<int>(depth_radius64);
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
@@ -553,6 +541,9 @@ class MklLRNGradOp : public OpKernel {
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_bdw_input, lrn_bwd,
dnnResourceDiffDst),
E_SUCCESS);
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_input, lrn_bwd,
+ dnnResourceSrc),
+ E_SUCCESS);
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
if (ingrad_in_mkl_format) {
@@ -581,44 +572,37 @@ class MklLRNGradOp : public OpKernel {
}
}
-// Although MKL documentation for LRN does not specify setting/getting
-// of dnnResourceSrc and dnnResourceDst, Caffe code sets dnnResourceSrc.
-// So we set dnnResourceSrc here. But we do not know why we are setting
-// dnnResourceDst.
-#if 0
- // NOTE: The code below is kept just so that we know how we should handle
- // dnnResourceSrc if the primitive layout for dnnResourceSrc was supported.
-
- if (!dnnLayoutCompare_F32(lt_internal_input,
- static_cast<dnnLayout_t>inimage_shape.GetCurLayout())) {
- AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
- &res_lrn_bwd[dnnResourceSrc]);
- inimage_shape.GetConvertedFlatData(lt_internal_input,
- user_fwd_input,
- res_lrn_bwd[dnnResourceSrc]);
- } else {
- res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
- }
-#endif
-
- // Since we cannot get expected layout for dnnResourceSrc, we construct
- // buffer using
- // MKL format if input is in MKL format.
- if (inimage_shape.IsMklTensor()) {
- AllocTmpBuffer(context, mkl_tmp_image_buf_tensor,
- (dnnLayout_t)inimage_shape.GetCurLayout(),
- &res_lrn_bwd[dnnResourceSrc]);
+ bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
+ if (inimage_in_mkl_format) {
+ if (!dnnLayoutCompare_F32(
+ lt_internal_input,
+ static_cast<dnnLayout_t>(inimage_shape.GetCurLayout()))) {
+ AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
+ &res_lrn_bwd[dnnResourceSrc]);
+ ingrad_shape.GetConvertedFlatData(lt_internal_input, user_fwd_input,
+ res_lrn_bwd[dnnResourceSrc]);
+ } else {
+ res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
+ }
} else {
- res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
- }
+ if (!dnnLayoutCompare_F32(
+ lt_internal_input,
+ static_cast<dnnLayout_t>(inimage_shape.GetCurLayout()))) {
+ CHECK_EQ(dnnConversionCreate_F32(
+ &convert_input,
+ static_cast<dnnLayout_t>(inimage_shape.GetCurLayout()),
+ lt_internal_input),
+ E_SUCCESS);
- // Same comment as above.
- if (outimage_shape.IsMklTensor()) {
- AllocTmpBuffer(context, mkl_tmp_outimage_buf_tensor,
- (dnnLayout_t)outimage_shape.GetCurLayout(),
- &res_lrn_bwd[dnnResourceDst]);
- } else {
- res_lrn_bwd[dnnResourceDst] = user_fwd_output;
+ AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
+ &res_lrn_bwd[dnnResourceSrc]);
+ CHECK_EQ(dnnConversionExecute_F32(convert_input, user_fwd_input,
+ res_lrn_bwd[dnnResourceSrc]),
+ E_SUCCESS);
+ dnnDelete_F32(convert_input);
+ } else {
+ res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
+ }
}
res_lrn_bwd[dnnResourceWorkspace] = workspace_buffer;
@@ -628,8 +612,6 @@ class MklLRNGradOp : public OpKernel {
// TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
// copy.
void MklDefaultToEigen(OpKernelContext* context) {
- // CHECK(false);
-
Tensor in_grads;
Tensor in_image;
Tensor out_image;
@@ -709,7 +691,7 @@ class MklLRNGradOp : public OpKernel {
Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
depth * depth, shard);
}
-
+
// release mkl resources
void Mklcleanup() {
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 10d2937584..fabecc39a8 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -184,38 +184,31 @@ class MklReluGradOp : public OpKernel {
dnnLayout_t lt_input, lt_grad;
void MklPrepareReluGradInputs(OpKernelContext* context,
- Tensor* mkl_tmp_grad_buf_tensor,
Tensor* mkl_tmp_input_buf_tensor) {
- dnnPrimitive_t cv_user_to_reluB_input, cv_user_to_reluB_grad;
- dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_grad;
-
const Tensor& g = MklGetInput(context, 0);
const Tensor& a = MklGetInput(context, 1);
-
- void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
- void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
- dnnPrimitive_t cv_input_to_grad = NULL;
- Tensor mkl_tmp_buf_tensor;
+ void* buf_input = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
void* mkl_buffer_convert = nullptr;
+ dnnPrimitive_t cv_input_to_grad = nullptr;
// if input and grad are not in the same layout, do a conversion between
// them.
if (!dnnLayoutCompare_F32(lt_input, lt_grad)) {
- AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad,
+ AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad,
&mkl_buffer_convert);
- CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad),
- E_SUCCESS);
-
- CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i,
+ CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input,
+ lt_grad), E_SUCCESS);
+ CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input,
mkl_buffer_convert),
E_SUCCESS);
relu_res[dnnResourceSrc] = mkl_buffer_convert;
dnnDelete_F32(cv_input_to_grad);
} else {
- relu_res[dnnResourceSrc] = user_i;
+ relu_res[dnnResourceSrc] = buf_input;
}
- relu_res[dnnResourceDiffDst] = user_g;
+ void* buf_grad = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
+ relu_res[dnnResourceDiffDst] = buf_grad;
}
void MklCreateInputLayouts(OpKernelContext* context) {
@@ -317,9 +310,8 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
mkl_context.lt_grad, mkl_context.lt_grad,
negative_slope),
E_SUCCESS);
- Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
- mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_grad_buf_tensor,
- &mkl_tmp_input_buf_tensor);
+ Tensor mkl_tmp_input_buf_tensor;
+ mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor);
if (input_is_mkl ||
grad_is_mkl) { /*if grad or input are MKL leave it in MKL*/
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.cc b/tensorflow/core/kernels/mkl_tfconv_op.cc
index 588d6874dd..b4aae67ca6 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.cc
+++ b/tensorflow/core/kernels/mkl_tfconv_op.cc
@@ -24,12 +24,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/mkl_util.h"
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -44,10 +45,11 @@ class MklToTfOp : public OpKernel {
explicit MklToTfOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
+ has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
}
void Compute(OpKernelContext* context) override {
- // 1. Check that input tensor is in MKL format.
+ // Check that input tensor is in MKL format.
const Tensor& input_tensor = MklGetInput(context, 0);
MklShape input_shape;
GetMklShape(context, 0, &input_shape);
@@ -68,9 +70,12 @@ class MklToTfOp : public OpKernel {
CHECK_EQ(op_data_type, output_data_type);
TensorShape output_shape;
- for (size_t i = 0; i < input_shape.GetDimension(); i++) {
+ size_t ndims = input_shape.GetDimension();
+ size_t* in_sizes = new size_t[ndims];
+ for (size_t i = 0; i < ndims; i++) {
// Outermost to innermost dimension
output_shape.AddDim(input_shape.GetSizes()[input_shape.tf_dim_idx(i)]);
+ in_sizes[i] = input_shape.GetSizes()[i];
}
// Allocate output tensor.
@@ -78,17 +83,41 @@ class MklToTfOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, output_shape, &output_tensor));
- // 3. Get input and output layout pointers.
- dnnLayout_t output_layout =
- static_cast<dnnLayout_t>(input_shape.GetTfLayout());
+ // If data format is NHWC, transform MKL tensor to NCHW format and then
+ // do NCHW -> NHWC.
+ dnnLayout_t lt_trans_input = nullptr;
+ Tensor mkl_tmp_trans_input_buf_tensor;
+ void* buf_trans_input = nullptr;
+ bool input_fmt_nhwc = input_shape.IsTensorInNHWCFormat();
+ if (input_fmt_nhwc && ndims == 4 && has_avx512f_) {
+ size_t strides_nchw[4];
+ GetStridesFromSizes(FORMAT_NCHW, strides_nchw, in_sizes);
+ CHECK_EQ(
+ dnnLayoutCreate_F32(&lt_trans_input, ndims, in_sizes, strides_nchw),
+ E_SUCCESS);
+ AllocTmpBuffer(context, &mkl_tmp_trans_input_buf_tensor, lt_trans_input,
+ &buf_trans_input);
+ } else {
+ lt_trans_input = static_cast<dnnLayout_t>(input_shape.GetTfLayout());
+ buf_trans_input =
+ static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data()));
+ }
- // 4. Execute DNNConversion.
+ // Execute DNNConversion.
void* input_buffer =
static_cast<void*>(const_cast<T*>(input_tensor.flat<T>().data()));
- void* output_buffer =
- static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data()));
- input_shape.GetConvertedFlatData(output_layout, input_buffer,
- output_buffer);
+ input_shape.GetConvertedFlatData(lt_trans_input, input_buffer,
+ buf_trans_input);
+ // NCHW -> NHWC, if data format is NHWC
+ if (input_fmt_nhwc && ndims == 4 && has_avx512f_) {
+ dnnLayoutDelete_F32(lt_trans_input);
+ TensorShape nhwc_shape = ShapeFromFormat(
+ FORMAT_NHWC, in_sizes[MklDims::N], in_sizes[MklDims::H],
+ in_sizes[MklDims::W], in_sizes[MklDims::C]);
+ MklNCHWToNHWC(mkl_tmp_trans_input_buf_tensor, &output_tensor);
+ }
+
+ delete[] in_sizes;
VLOG(1) << "MKLToTFConversion complete successfully.";
}
@@ -99,6 +128,9 @@ class MklToTfOp : public OpKernel {
/// Data type of the operation
DataType op_data_type;
+
+ /// CPUIDInfo
+ bool has_avx512f_ = false;
};
///////////////////////////////////////////////////////////
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 9ffe71e031..dc95f67ff0 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -90,20 +90,24 @@ static inline float ComputeIOU(typename TTypes<float, 2>::ConstTensor boxes,
return intersection_area / (area_i + area_j - intersection_area);
}
-void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
- const Tensor& scores, const Tensor& max_output_size,
+void DoNonMaxSuppressionOp(OpKernelContext* context,
+ const Tensor& boxes,
+ const Tensor& scores,
+ const Tensor& max_output_size,
const float iou_threshold) {
OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1,
- errors::InvalidArgument("iou_threshold must be in [0, 1]"));
-
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
+
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes);
if (!context->status().ok()) {
return;
}
- const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
- typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
+ const int output_size =
+ std::min(max_output_size.scalar<int>()(), num_boxes);
+ typename TTypes<float, 2>::ConstTensor boxes_data =
+ boxes.tensor<float, 2>();
std::vector<float> scores_data(num_boxes);
std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
@@ -123,7 +127,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
for (int j = i + 1; j < num_boxes; ++j) {
if (active[j]) {
float iou =
- ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]);
+ ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]);
if (iou > iou_threshold) {
active[j] = false;
num_active--;
@@ -141,7 +145,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
std::copy_n(selected.begin(), selected.size(), selected_indices_data.data());
}
-} // namespace
+} // namespace
template <typename Device>
class NonMaxSuppressionOp : public OpKernel {
@@ -163,8 +167,7 @@ class NonMaxSuppressionOp : public OpKernel {
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
max_output_size.shape().DebugString()));
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_);
+ DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_);
}
private:
@@ -175,7 +178,8 @@ template <typename Device>
class NonMaxSuppressionV2Op : public OpKernel {
public:
explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
- : OpKernel(context) {}
+ : OpKernel(context) {
+ }
void Compute(OpKernelContext* context) override {
// boxes: [num_boxes, 4]
@@ -190,14 +194,14 @@ class NonMaxSuppressionV2Op : public OpKernel {
max_output_size.shape().DebugString()));
// iou_threshold: scalar
const Tensor& iou_threshold = context->input(3);
- OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
- errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
- iou_threshold.shape().DebugString()));
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
+ errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
+ iou_threshold.shape().DebugString()));
const float iou_threshold_val = iou_threshold.scalar<float>()();
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_val);
+ DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_val);
}
};
diff --git a/tensorflow/core/kernels/priority_queue.cc b/tensorflow/core/kernels/priority_queue.cc
index 894ad3c9a0..4c406fc1ed 100644
--- a/tensorflow/core/kernels/priority_queue.cc
+++ b/tensorflow/core/kernels/priority_queue.cc
@@ -339,7 +339,7 @@ void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
for (; s > 0; --s) {
if (attempt->tuple.empty()) {
// Only allocate tuple when we have something to dequeue
- // so we don't use exceessive memory when there are many
+ // so we don't use excessive memory when there are many
// blocked dequeue attempts waiting.
attempt->tuple.reserve(num_components());
for (int i = 0; i < num_components(); ++i) {
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc
index d78c6d2639..c5e3164145 100644
--- a/tensorflow/core/kernels/shape_ops.cc
+++ b/tensorflow/core/kernels/shape_ops.cc
@@ -48,6 +48,7 @@ REGISTER_KERNEL_BUILDER(Name("Shape")
ShapeOp<int64>);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+TF_CALL_bool(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
REGISTER_KERNEL_BUILDER(Name("Shape")
@@ -102,7 +103,7 @@ REGISTER_KERNEL_BUILDER(Name("Shape")
.TypeConstraint<int32>("T")
.TypeConstraint<int64>("out_type"),
ShapeOp<int64>);
-#endif
+#endif // GOOGLE_CUDA
// ShapeN ---------------------------------------
REGISTER_KERNEL_BUILDER(Name("ShapeN")
@@ -152,9 +153,9 @@ REGISTER_KERNEL_BUILDER(Name("ShapeN")
.TypeConstraint<int32>("T")
.TypeConstraint<int64>("out_type"),
ShapeNOp<int64>);
-#endif
+#endif // GOOGLE_CUDA
-#if TENSORFLOW_USE_SYCL
+#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("ShapeN") \
.Device(DEVICE_SYCL) \
@@ -170,11 +171,9 @@ REGISTER_KERNEL_BUILDER(Name("ShapeN")
ShapeNOp<int64>)
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+TF_CALL_bool(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
-// A special GPU kernel for int32.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(Name("ShapeN")
.Device(DEVICE_SYCL)
.HostMemory("input")
@@ -202,13 +201,9 @@ REGISTER_KERNEL_BUILDER(Name("Rank").Device(DEVICE_CPU).HostMemory("output"),
.TypeConstraint<type>("T") \
.HostMemory("output"), \
RankOp);
-REGISTER_SYCL_KERNEL(float);
-REGISTER_SYCL_KERNEL(double);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
-// A special GPU kernel for int32 and bool.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(Name("Rank")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
@@ -250,7 +245,7 @@ REGISTER_KERNEL_BUILDER(Name("Rank")
.HostMemory("input")
.HostMemory("output"),
RankOp);
-#endif
+#endif // GOOGLE_CUDA
// Size ------------------------------------------
REGISTER_KERNEL_BUILDER(Name("Size")
@@ -299,7 +294,7 @@ REGISTER_KERNEL_BUILDER(Name("Size")
.HostMemory("input")
.HostMemory("output"),
SizeOp<int64>);
-#endif
+#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
@@ -315,13 +310,10 @@ REGISTER_KERNEL_BUILDER(Name("Size")
.TypeConstraint<int64>("out_type") \
.HostMemory("output"), \
SizeOp<int64>);
-REGISTER_SYCL_KERNEL(float);
-REGISTER_SYCL_KERNEL(double);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+TF_CALL_bool(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
-// A special GPU kernel for int32.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(Name("Size")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
@@ -336,7 +328,7 @@ REGISTER_KERNEL_BUILDER(Name("Size")
.HostMemory("input")
.HostMemory("output"),
SizeOp<int64>);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
// ExpandDims ------------------------------------
REGISTER_KERNEL_BUILDER(Name("ExpandDims")
@@ -365,7 +357,7 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims")
.HostMemory("dim")
.HostMemory("output"),
ExpandDimsOp);
-#endif // GOOGLE_CUDA
+#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
@@ -375,9 +367,8 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims")
.TypeConstraint<int32>("Tdim") \
.HostMemory("dim"), \
ExpandDimsOp);
-REGISTER_SYCL_KERNEL(float)
-REGISTER_SYCL_KERNEL(double)
-
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+TF_CALL_bool(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
REGISTER_KERNEL_BUILDER(Name("ExpandDims")
@@ -388,7 +379,7 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims")
.HostMemory("dim")
.HostMemory("output"),
ExpandDimsOp);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
// Squeeze ---------------------------------------
REGISTER_KERNEL_BUILDER(Name("Squeeze").Device(DEVICE_CPU), SqueezeOp);
@@ -411,26 +402,23 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze")
.HostMemory("input")
.HostMemory("output"),
SqueezeOp);
-#endif
+#endif // GOOGLE_CUDA
#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Squeeze").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Squeeze").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SqueezeOp);
-REGISTER_SYCL_KERNEL(float);
-REGISTER_SYCL_KERNEL(double);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+TF_CALL_bool(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
-// A special GPU kernel for int32.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(Name("Squeeze")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.HostMemory("input")
.HostMemory("output"),
SqueezeOp);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index ee6f9a28cd..d46701749b 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -118,6 +118,43 @@ static void SharedValidation(OpKernelContext* context,
}
}
+// Extracted out code in SliceOp::Compute so that MklSliceOp can reuse this
+// generic code
+template <typename T>
+static void SharedSliceCommonCases(OpKernelContext* context,
+ TensorShape* output_shape,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size,
+ Tensor** result,
+ bool* done) {
+ bool is_identity = true;
+ bool slice_dim0 = true;
+ *done = false;
+
+ SharedValidation(context, output_shape, &is_identity, &slice_dim0, begin,
+ size);
+ if (!context->status().ok()) return;
+ const Tensor& input = context->input(0);
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ *done = true;
+ return;
+ }
+
+ if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), (*begin)[0],
+ (*size)[0])) {
+ VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
+ CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true.
+ context->set_output(0, input.Slice((*begin)[0], (*begin)[0] + (*size)[0]));
+ *done = true;
+ return;
+ }
+
+ OP_REQUIRES_OK(context, context->allocate_output(0, *output_shape, result));
+}
+
+
template <typename Device, typename T>
class SliceOp : public OpKernel {
public:
@@ -125,29 +162,89 @@ class SliceOp : public OpKernel {
void Compute(OpKernelContext* context) override {
TensorShape output_shape;
- bool is_identity = true;
- bool slice_dim0 = true;
gtl::InlinedVector<int64, 4> begin;
gtl::InlinedVector<int64, 4> size;
- SharedValidation(context, &output_shape, &is_identity, &slice_dim0, &begin,
- &size);
- if (!context->status().ok()) return;
+ Tensor* result = nullptr;
+ bool done = false;
+ SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
+ &done);
+ if (!context->status().ok() || done == true) return;
+
const Tensor& input = context->input(0);
- if (is_identity) {
- VLOG(1) << "Slice identity";
- context->set_output(0, input);
- return;
+ const int input_dims = input.dims();
+
+ if (output_shape.num_elements() > 0) {
+ if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
+ DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
+ auto input = context->input(0).tensor<T, 2>();
+ auto output = result->tensor<T, 2>();
+ // TODO(agarwal): Consider multi-threading this loop for cases where
+ // size[0] is very large.
+ for (int i = 0; i < size[0]; ++i) {
+ const int64 row = begin[0] + i;
+ if (i + 1 < size[0]) {
+ port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
+ port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
+ }
+ memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
+ }
+ return;
+ }
+#define HANDLE_DIM(NDIM) \
+ if (input_dims == NDIM) { \
+ HandleCase<NDIM>(context, begin, size, result); \
+ return; \
+ }
+
+ HANDLE_DIM(1);
+ HANDLE_DIM(2);
+ HANDLE_DIM(3);
+ HANDLE_DIM(4);
+ HANDLE_DIM(5);
+ HANDLE_DIM(6);
+ HANDLE_DIM(7);
+
+#undef HANDLE_DIM
+
+ OP_REQUIRES(context, false, errors::Unimplemented(
+ "SliceOp : Unhandled input dimensions"));
}
+ }
- if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], size[0])) {
- VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
- CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true.
- context->set_output(0, input.Slice(begin[0], begin[0] + size[0]));
- return;
+ private:
+ template <int NDIM>
+ void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
+ const gtl::ArraySlice<int64>& size, Tensor* result) {
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
+ for (int i = 0; i < NDIM; ++i) {
+ indices[i] = begin[i];
+ sizes[i] = size[i];
}
+ functor::Slice<Device, T, NDIM>()(
+ context->eigen_device<Device>(), result->tensor<T, NDIM>(),
+ context->input(0).tensor<T, NDIM>(), indices, sizes);
+ }
+};
+
+#ifdef INTEL_MKL
+template <typename Device, typename T>
+class MklSliceOp : public OpKernel {
+ public:
+ explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ TensorShape output_shape;
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> size;
Tensor* result = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result));
+ bool done = false;
+ SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
+ &done);
+ if (!context->status().ok() || done == true) return;
+
+ const Tensor& input = context->input(0);
const int input_dims = input.dims();
if (output_shape.num_elements() > 0) {
@@ -189,9 +286,123 @@ class SliceOp : public OpKernel {
}
private:
+ // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following
+ // criteria matches for slice_dim: if indices for slice are 0 in all dims
+ // except slice_dim and if sizes of all the dimensions of the slice are same
+ // as the sizes of all the dimensions of the input except slice_dim, then
+ // returns True. Otherwise, returns False.
+ bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape,
+ const gtl::ArraySlice<int64>& begin,
+ const gtl::ArraySlice<int64>& size,
+ int slice_dim) {
+ for (int dim = 0; dim < 4; dim++) {
+ if (dim != slice_dim &&
+ (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Is 'input' tensor being sliced over a single dimension out of 4?
+ //
+ // This check is applicable in the context of Slice of a 4-D tensor in
+ // NHWC or NCHW format over channel dimension.
+ //
+ // If indices for slice are 0 in all dims except one dimension and if sizes of
+ // all dimensions of slice are same as sizes of all dimensions of inputs
+ // except that dimension, then we are slicing over a single dimension.
+ //
+ // Returns True if Slicing over a single dimension, and sets slice_dim
+ // to the number of the dimension that satisfies criteria.
+ bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape,
+ const gtl::ArraySlice<int64>& begin,
+ const gtl::ArraySlice<int64>& size,
+ int* slice_dim) {
+ for (int dim = 0; dim < 4; dim++) {
+ if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) {
+ *slice_dim = dim;
+ return true;
+ }
+ }
+ return false;
+ }
+
template <int NDIM>
- void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
+ void HandleCase(OpKernelContext* context,
+ const gtl::ArraySlice<int64>& begin,
const gtl::ArraySlice<int64>& size, Tensor* result) {
+ int slice_dim = -1;
+ TensorShape in_shape = context->input(0).shape();
+ // Special case for handling 4-D tensor slice when shape of the slice
+ // differs from the input tensor in only 1 out of 4 dimensions.
+ // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
+ // format over channel dimension.
+ if (NDIM == 4 &&
+ DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
+ size_t in_strides[4] = { (size_t) in_shape.dim_size(1) *
+ in_shape.dim_size(2) *
+ in_shape.dim_size(3),
+ (size_t) in_shape.dim_size(2) *
+ in_shape.dim_size(3),
+ (size_t) in_shape.dim_size(3),
+ (size_t) 1
+ };
+
+ size_t out_strides[4] = { (size_t) size[1] * size[2] * size[3],
+ (size_t) size[2] * size[3],
+ (size_t) size[3],
+ (size_t) 1 };
+
+ T *in_buf = const_cast<T*>(const_cast<const T*>(
+ context->input(0).flat<T>().data()));
+ T *op_buf = result->flat<T>().data();
+
+ if (slice_dim == 1) {
+ /* data format = NCHW */
+
+ #pragma omp parallel for
+ for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
+ T *ip = in_buf + (d0 * in_strides[0]);
+ T *op = op_buf + ((d0 - begin[0]) * out_strides[0]);
+ #pragma omp parallel for
+ for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
+ T *ip1 = ip + (d1 * in_strides[1]);
+ T *op1 = op + ((d1 - begin[1]) * out_strides[1]);
+ // For NCHW, H and W will be contiguous. So we can copy
+ // both with one memcpy.
+ memcpy(static_cast<void*>(op1), static_cast<void*>(ip1),
+ sizeof(T) * in_strides[1]);
+ }
+ }
+ return;
+ } else if (slice_dim == 3) {
+ /* data_format = NHWC */
+
+ #pragma omp parallel for
+ for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
+ T *ip = in_buf + (d0 * in_strides[0]);
+ T *op = op_buf + ((d0 - begin[0]) * out_strides[0]);
+ #pragma omp parallel for
+ for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
+ T *ip1 = ip + (d1 * in_strides[1]);
+ T *op1 = op + ((d1 - begin[1]) * out_strides[1]);
+ #pragma omp parallel for
+ for (size_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
+ T *ip2 = ip1 + (d2 * in_strides[2]);
+ T *ip3 = ip2 + begin[3];
+ T *op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
+ T *op3 = op2;
+ memcpy(static_cast<void*>(op3), static_cast<void*>(ip3),
+ sizeof(T) * size[3]);
+ }
+ }
+ }
+ return;
+ }
+ // slice_dim is not 1 or 3, then we fallback to Eigen implementation.
+ }
+
Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
for (int i = 0; i < NDIM; ++i) {
@@ -204,6 +415,7 @@ class SliceOp : public OpKernel {
context->input(0).tensor<T, NDIM>(), indices, sizes);
}
};
+#endif
// Forward declarations of the functor specializations for declared in the
// sharded source files.
@@ -233,6 +445,7 @@ DECLARE_FOR_N(bfloat16);
#undef DECLARE_CPU_SPEC
} // namespace functor
+#ifndef INTEL_MKL
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
.Device(DEVICE_CPU) \
@@ -244,8 +457,21 @@ DECLARE_FOR_N(bfloat16);
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
REGISTER_SLICE(bfloat16);
+#undef REGISTER_SLICE
+#else
+#define REGISTER_SLICE(type) \
+ REGISTER_KERNEL_BUILDER(Name("Slice") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size"), \
+ MklSliceOp<CPUDevice, type>)
+TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
+TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
+REGISTER_SLICE(bfloat16);
#undef REGISTER_SLICE
+#endif // INTEL_MKL
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
diff --git a/tensorflow/core/kernels/sparse_reduce_op.cc b/tensorflow/core/kernels/sparse_reduce_op.cc
new file mode 100644
index 0000000000..9e60791f97
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_reduce_op.cc
@@ -0,0 +1,341 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/sparse_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+// TODO(b/31496047): Fix non-standard include order.
+#include <numeric> // clang-format off
+
+using tensorflow::sparse::SparseTensor;
+using tensorflow::gtl::ArraySlice;
+
+namespace tensorflow {
+
+struct ReduceDetails {
+ // The dimensions to call Reorder() with.
+ std::vector<int64> reorder_dims;
+
+ // The dimensions to call group() with after Reorder().
+ std::vector<int64> group_by_dims;
+
+ // The shape after reduction.
+ TensorShape reduced_shape;
+};
+
+// Compute common reduce parameters that'll be used for SparseTensor
+// reductions. Usage:
+// ReduceDetails reduction = SparseTensorReduceHelper(sp, axes, keep_dims);
+// sp.Reorder(reduction.reorder_dims);
+// for (const auto& g : sp.group(reduction.group_by_dims)) {
+// ...
+// }
+// // Set output shape to reduction.reduced_shape.
+ReduceDetails SparseTensorReduceHelper(const SparseTensor &sp,
+ gtl::ArraySlice<int32> axes_slice,
+ bool keep_dims) {
+ ReduceDetails reduction;
+
+ std::vector<int32> reduction_axes(axes_slice.begin(), axes_slice.end());
+ int ndims = sp.dims();
+ for (int64 i = 0; i < reduction_axes.size(); ++i) {
+ reduction_axes[i] = (reduction_axes[i] + ndims) % ndims;
+ }
+ std::sort(reduction_axes.begin(), reduction_axes.end());
+
+ // (0) Calculate the grouping dimensions:
+ // group_by_dims == {0, .., NDIMS-1} \ reduction_axes.
+ std::vector<int64> perm(ndims);
+ std::iota(perm.begin(), perm.end(), 0);
+
+ // Requires perm and reduction_axes_ be sorted; group_by_dims will be
+ // sorted as well.
+ std::set_difference(
+ perm.begin(), perm.end(), reduction_axes.begin(), reduction_axes.end(),
+ std::inserter(reduction.group_by_dims, reduction.group_by_dims.begin()));
+
+ // Now append the rest of the axes (the complement of group_by_dims_);
+ // result is used by Reorder().
+ reduction.reorder_dims = reduction.group_by_dims;
+ std::set_difference(perm.begin(), perm.end(), reduction.group_by_dims.begin(),
+ reduction.group_by_dims.end(),
+ std::back_inserter(reduction.reorder_dims));
+
+ // (1) Calculate the shape after reduction.
+ auto sp_shape = sp.shape();
+ std::vector<int64> out_dim_sizes;
+ if (keep_dims) {
+ out_dim_sizes.reserve(ndims);
+ auto beg = reduction.group_by_dims.begin();
+ auto end = reduction.group_by_dims.end();
+ for (int d = 0; d < ndims; ++d) {
+ if (std::find(beg, end, d) == end) {
+ out_dim_sizes.push_back(1); // A reduced axis.
+ } else {
+ out_dim_sizes.push_back(sp_shape[d]);
+ }
+ }
+ } else {
+ out_dim_sizes = sp.PickDims(reduction.group_by_dims);
+ }
+
+ reduction.reduced_shape = TensorShape(out_dim_sizes);
+ return reduction;
+}
+
+Status ValidateInputs(const Tensor *shape_t, const Tensor *reduction_axes_t) {
+ // indices and values are validated in SparseTensor ctor.
+ if (!TensorShapeUtils::IsVector(shape_t->shape())) {
+ return errors::InvalidArgument(
+ "Expected input_shape to be a vector; got shape: ",
+ shape_t->shape().DebugString());
+ }
+ if (!TensorShapeUtils::IsScalar(reduction_axes_t->shape()) &&
+ !TensorShapeUtils::IsVector(reduction_axes_t->shape())) {
+ return errors::InvalidArgument(
+ "Expected reduction_axes to be a scalar or a vector; got shape: ",
+ reduction_axes_t->shape().DebugString());
+ }
+
+ const auto reduction_axes_flat = reduction_axes_t->flat<int32>();
+ for (int64 i = 0; i < reduction_axes_flat.size(); i++) {
+ int32 axis = reduction_axes_flat(i);
+ if (axis < -shape_t->NumElements() || axis >= shape_t->NumElements()) {
+ return errors::InvalidArgument("Invalid reduction dimension ", axis,
+ ", for input with ",
+ shape_t->NumElements(), " dimensions.");
+ }
+ }
+
+ return Status::OK();
+}
+
+struct SumOp {
+ template <typename T>
+ static void Run(OpKernelContext *ctx, typename TTypes<T>::Scalar &s, const typename TTypes<T>::UnalignedVec &v) {
+ s.device(ctx->eigen_cpu_device()) = v.sum();
+ }
+ static StringPiece Name() {
+ return "sum";
+ }
+};
+
+struct MaxOp {
+ template <typename T>
+ static void Run(OpKernelContext *ctx, typename TTypes<T>::Scalar &s, const typename TTypes<T>::UnalignedVec &v) {
+ s.device(ctx->eigen_cpu_device()) = v.maximum();
+ }
+ static StringPiece Name() {
+ return "max";
+ }
+};
+
+template <typename T, typename Op>
+class SparseReduceOp : public OpKernel {
+ public:
+ explicit SparseReduceOp(OpKernelConstruction *ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+ }
+
+ void Compute(OpKernelContext *ctx) override {
+ const Tensor *indices_t, *values_t, *shape_t, *reduction_axes_t;
+ OP_REQUIRES_OK(ctx, ctx->input("input_indices", &indices_t));
+ OP_REQUIRES_OK(ctx, ctx->input("input_values", &values_t));
+ OP_REQUIRES_OK(ctx, ctx->input("input_shape", &shape_t));
+ OP_REQUIRES_OK(ctx, ctx->input("reduction_axes", &reduction_axes_t));
+
+ OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t));
+
+ // TODO(zongheng): we will call Reorder() below, which will modify
+ // in-place the underlying indices and values buffers. To avoid
+ // surprises of this kernel being stateful, we work around the above by
+ // making deep copies here. Remove this if/when we change Reorder()'s
+ // semantics.
+ const auto shape_vec = shape_t->vec<int64>();
+ SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
+ TensorShape(shape_vec));
+ ReduceDetails reduction = SparseTensorReduceHelper(
+ sp, reduction_axes_t->flat<int32>(), keep_dims_);
+
+ Tensor *out_values;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, reduction.reduced_shape, &out_values));
+ auto out_flat = out_values->flat<T>();
+ out_flat.setZero();
+
+ Tensor tmp_reduced_val;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({}), &tmp_reduced_val));
+ auto reduced_val = tmp_reduced_val.scalar<T>();
+
+ // Compute strides, and use it to convert coords to flat index. The
+ // coordinates returned by .group() have the same ndims as group_by_dims.
+ gtl::InlinedVector<int64, 8> output_strides(reduction.group_by_dims.size());
+ if (!output_strides.empty()) { // Do this iff we don't reduce all.
+ output_strides.back() = 1;
+ for (int d = output_strides.size() - 2; d >= 0; --d) {
+ output_strides[d] =
+ output_strides[d + 1] * shape_vec(reduction.group_by_dims[d + 1]);
+ }
+ }
+
+ auto CoordinatesToFlatIndex = [](ArraySlice<int64> coords,
+ ArraySlice<int64> strides) {
+ if (strides.empty()) { // Reduce all.
+ return 0LL;
+ }
+ CHECK_EQ(coords.size(), strides.size());
+ int64 idx = 0;
+ for (int i = 0; i < coords.size(); ++i) {
+ idx += coords[i] * strides[i];
+ }
+ return idx;
+ };
+
+ // Each group maps one-on-one onto a value in the reduced tensor.
+ // g.group() provides the coordinates of a particular reduced value.
+ sp.Reorder<T>(reduction.reorder_dims);
+ for (const auto &g : sp.group(reduction.group_by_dims)) {
+ Op::template Run<T>(ctx, reduced_val, g.template values<T>());
+ const int64 idx = CoordinatesToFlatIndex(g.group(), output_strides);
+ out_flat(idx) = reduced_val();
+ VLOG(2) << "coords: " << str_util::Join(g.group(), ",")
+ << "; idx: " << idx << "; group " << Op::Name() << ": "
+ << reduced_val();
+ }
+ }
+
+ private:
+ // True if the number of dimensions should be maintained.
+ bool keep_dims_;
+};
+
+#define REGISTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseReduceSum").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ SparseReduceOp<T, SumOp>)
+TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseReduceMax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ SparseReduceOp<T, MaxOp>)
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+template <typename T, typename Op>
+class SparseReduceSparseOp : public OpKernel {
+ public:
+ explicit SparseReduceSparseOp(OpKernelConstruction *ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+ }
+
+ void Compute(OpKernelContext *ctx) override {
+ const Tensor *indices_t, *values_t, *shape_t, *reduction_axes_t;
+ OP_REQUIRES_OK(ctx, ctx->input("input_indices", &indices_t));
+ OP_REQUIRES_OK(ctx, ctx->input("input_values", &values_t));
+ OP_REQUIRES_OK(ctx, ctx->input("input_shape", &shape_t));
+ OP_REQUIRES_OK(ctx, ctx->input("reduction_axes", &reduction_axes_t));
+
+ OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t));
+
+ SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
+ TensorShape(shape_t->vec<int64>()));
+ ReduceDetails reduction = SparseTensorReduceHelper(
+ sp, reduction_axes_t->flat<int32>(), keep_dims_);
+
+ sp.Reorder<T>(reduction.reorder_dims);
+ // Count nnzs in the output SparseTensor.
+ int64 nnz = 0;
+ auto iter = sp.group(reduction.group_by_dims);
+ for (auto it = iter.begin(); it != iter.end(); ++it) {
+ nnz++;
+ }
+
+ Tensor *out_indices_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(
+ 0, TensorShape({nnz, reduction.reduced_shape.dims()}),
+ &out_indices_t));
+ typename TTypes<int64>::Matrix out_indices_mat =
+ out_indices_t->matrix<int64>();
+ // For keep_dims. We don't explicitly set dim fields for reduced dims below.
+ out_indices_mat.setZero();
+
+ Tensor *out_values_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(1, TensorShape({nnz}), &out_values_t));
+ auto out_flat = out_values_t->flat<T>();
+
+ Tensor tmp_reduced_val;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({}), &tmp_reduced_val));
+ auto reduced_val = tmp_reduced_val.scalar<T>();
+ int64 i = 0;
+ for (const auto &g : sp.group(reduction.group_by_dims)) {
+ Op::template Run<T>(ctx, reduced_val, g.template values<T>());
+ std::vector<int64> group = g.group();
+ for (int64 j = 0; j < group.size(); j++) {
+ if (keep_dims_) {
+ out_indices_mat(i, reduction.group_by_dims[j]) = group[j];
+ } else {
+ out_indices_mat(i, j) = group[j];
+ }
+ }
+ out_flat(i) = reduced_val();
+ i++;
+ VLOG(2) << "coords: " << str_util::Join(g.group(), ",")
+ << "; group " << Op::Name() << ": "
+ << reduced_val();
+ }
+
+ Tensor *out_shape_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(
+ 2, TensorShape({reduction.reduced_shape.dims()}),
+ &out_shape_t));
+ auto out_shape_flat = out_shape_t->flat<int64>();
+ auto out_dim_sizes = reduction.reduced_shape.dim_sizes();
+ std::copy(out_dim_sizes.begin(), out_dim_sizes.end(), &out_shape_flat(0));
+ }
+
+ private:
+ // True if the number of dimensions should be maintained.
+ bool keep_dims_;
+};
+
+#define REGISTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseReduceSumSparse").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ SparseReduceSparseOp<T, SumOp>)
+TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseReduceMaxSparse").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ SparseReduceSparseOp<T, MaxOp>)
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index b4698a8053..2db3e5ef77 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -40,6 +40,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
class Stack : public ResourceBase {
public:
@@ -182,6 +185,10 @@ class StackOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_CPU), StackOp);
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_GPU).HostMemory("handle"),
StackOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_SYCL).HostMemory("handle"),
+ StackOp);
+#endif // TENSORFLOW_USE_SYCL
template <typename Device>
class StackPushOp : public AsyncOpKernel {
@@ -213,7 +220,11 @@ class StackPushOp : public AsyncOpKernel {
static constexpr int kCopyThreshold = 2048;
static constexpr double kOccupancy = 0.7;
if (swap_memory_ && !alloc_attrs.on_host() &&
- std::is_same<Device, GPUDevice>::value &&
+ ( std::is_same<Device, GPUDevice>::value
+#ifdef TENSORFLOW_USE_SYCL
+ || std::is_same<Device, SYCLDevice>::value
+#endif // TENSORFLOW_USE_SYCL
+ ) &&
tensor.TotalBytes() > kCopyThreshold && stack->IsUsefulToSwap(tensor)) {
DeviceContext* device_ctxt = ctx->op_device_context();
auto device = static_cast<tensorflow::Device*>(ctx->device());
@@ -289,6 +300,31 @@ REGISTER_GPU_HOST_KERNEL(bool);
#undef REGISTER_GPU_HOST_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("StackPush") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("T"), \
+ StackPushOp<SYCLDevice>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("StackPush") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .HostMemory("elem") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ StackPushOp<SYCLDevice>)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(bool);
+#undef REGISTER_SYCL_KERNEL
+#undef REGISTER_SYCL_HOST_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
class StackPopOp : public AsyncOpKernel {
public:
explicit StackPopOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}
@@ -359,6 +395,31 @@ REGISTER_GPU_HOST_KERNEL(bool);
#undef REGISTER_GPU_HOST_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("StackPop") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("elem_type"), \
+ StackPopOp)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("StackPop") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .HostMemory("elem") \
+ .TypeConstraint<type>("elem_type"), \
+ StackPopOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(bool);
+
+#undef REGISTER_SYCL_KERNEL
+#undef REGISTER_SYCL_HOST_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
class StackCloseOp : public OpKernel {
public:
explicit StackCloseOp(OpKernelConstruction* context) : OpKernel(context) {}
@@ -376,5 +437,8 @@ class StackCloseOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("StackClose").Device(DEVICE_CPU), StackCloseOp);
REGISTER_KERNEL_BUILDER(
Name("StackClose").Device(DEVICE_GPU).HostMemory("handle"), StackCloseOp);
-
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("StackClose").Device(DEVICE_SYCL).HostMemory("handle"), StackCloseOp);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc
index 630fcb76f3..5c89eaef5f 100644
--- a/tensorflow/core/kernels/topk_op.cc
+++ b/tensorflow/core/kernels/topk_op.cc
@@ -93,7 +93,7 @@ class TopK : public OpKernel {
rows_by_one.set(0, num_rows);
#else
Eigen::array<int, 1> reduce_on_cols = {1};
- Eigen::array<int, 1> rows_by_one = {static_cast<int>(num_rows), 1};
+ Eigen::array<int, 2> rows_by_one = {static_cast<int>(num_rows), 1};
#endif
values.device(d) =
diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h
index 124cf14dd2..f1ab770eeb 100644
--- a/tensorflow/core/kernels/transpose_functor.h
+++ b/tensorflow/core/kernels/transpose_functor.h
@@ -132,6 +132,13 @@ template <typename Device, typename T, int NDIMS>
void TransposeUsingEigen(const Device& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out);
+
+#ifdef TENSORFLOW_USE_SYCL
+// For SYCL lets always go through Eigen
+template <typename Device, typename T>
+void TransposeSYCL(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out);
+#endif // TENSORFLOW_USE_SYCL
} // namespace internal
template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 75ed76a697..d3305fb83a 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -233,10 +233,7 @@ Status TransposeSyclOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
.TypeConstraint<int32>("Tperm") \
.HostMemory("perm"), \
TransposeSyclOp);
-REGISTER(float);
-REGISTER(bool);
-REGISTER(int32);
+TF_CALL_POD_TYPES(REGISTER);
#undef REGISTER
#endif
-
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
index dbd7de7ce0..1980f758fc 100644
--- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
@@ -22,7 +22,7 @@ namespace tensorflow {
/*
* TypedConditionalAccumulatorBase is a templated companion of
- * ConditionalAccumulatorBase which allows for subclassses to use different
+ * ConditionalAccumulatorBase which allows for subclasses to use different
* types for the input gradients. (See ConditionalAccumulator and
* SparseConditionalAccumulator.)
*
diff --git a/tensorflow/core/lib/gtl/optional.h b/tensorflow/core/lib/gtl/optional.h
index 8ba4b09143..2ff8b9c7d1 100644
--- a/tensorflow/core/lib/gtl/optional.h
+++ b/tensorflow/core/lib/gtl/optional.h
@@ -656,7 +656,7 @@ class optional : private internal_optional::optional_data<T>,
constexpr const T& reference() const { return *this->pointer(); }
T& reference() { return *(this->pointer()); }
- // T constaint checks. You can't have an optional of nullopt_t, in_place_t or
+ // T constraint checks. You can't have an optional of nullopt_t, in_place_t or
// a reference.
static_assert(
!std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value,
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index a530d286f7..9a58a31757 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -155,6 +155,26 @@ Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Log1p", Log1pGrad);
+Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"cosh"}, "Cosh", {"x"}, {}, {"dy"}},
+ {{"dx"}, "Mul", {"dy", "cosh"}}, // dy * cosh(x)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Sinh", SinhGrad);
+
+Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"sinh"}, "Sinh", {"x"}, {}, {"dy"}},
+ {{"dx"}, "Mul", {"dy", "sinh"}}, // dy * sinh(x)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Cosh", CoshGrad);
+
Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 38813b3f2b..aa9706a328 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -495,6 +495,26 @@ TEST_F(MathGradTest, Log1p) {
test::ExpectClose(ans, dx);
}
+TEST_F(MathGradTest, Sinh) {
+ auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
+ TensorShape({2, 3}));
+ auto g = [](float x) { return std::cosh(x); };
+ auto dx = test::AsTensor<float>(
+ {g(-3.f), g(-2.f), g(-1.f), g(1.f), g(2.f), g(3.f)}, TensorShape({2, 3}));
+ auto ans = SymGrad("Sinh", x);
+ test::ExpectClose(ans, dx);
+}
+
+TEST_F(MathGradTest, Cosh) {
+ auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
+ TensorShape({2, 3}));
+ auto g = [](float x) { return std::sinh(x); };
+ auto dx = test::AsTensor<float>(
+ {g(-3.f), g(-2.f), g(-1.f), g(1.f), g(2.f), g(3.f)}, TensorShape({2, 3}));
+ auto ans = SymGrad("Cosh", x);
+ test::ExpectClose(ans, dx);
+}
+
TEST_F(MathGradTest, Tanh) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
TensorShape({2, 3}));
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 3fe543a239..30d6987707 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -293,6 +293,14 @@ Computes natural logarithm of (1 + x) element-wise.
I.e., \\(y = \log_e (1 + x)\\).
)doc");
+REGISTER_OP("Sinh").UNARY_COMPLEX().Doc(R"doc(
+Computes hyperbolic sine of x element-wise.
+)doc");
+
+REGISTER_OP("Cosh").UNARY_COMPLEX().Doc(R"doc(
+Computes hyperbolic cosine of x element-wise.
+)doc");
+
REGISTER_OP("Tanh").UNARY_COMPLEX().Doc(R"doc(
Computes hyperbolic tangent of `x` element-wise.
)doc");
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 70302c3886..3a25fd15da 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -831,11 +831,13 @@ a different filter to each input channel (expanding from 1 channel to
`channel_multiplier` channels for each), then concatenates the results
together. Thus, the output has `in_channels * channel_multiplier` channels.
+```
for k in 0..in_channels-1
for q in 0..channel_multiplier-1
output[b, i, j, k * channel_multiplier + q] =
sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
filter[di, dj, k, q]
+```
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index b122b5a992..92ca6ea367 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -5185,6 +5185,30 @@ op {
summary: "Computes cos of x element-wise."
}
op {
+ name: "Cosh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ summary: "Computes hyperbolic cosine of x element-wise."
+}
+op {
name: "CountUpTo"
input_arg {
name: "ref"
@@ -22793,6 +22817,30 @@ op {
summary: "Computes sin of x element-wise."
}
op {
+ name: "Sinh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ summary: "Computes hyperbolic sine of x element-wise."
+}
+op {
name: "Size"
input_arg {
name: "input"
@@ -27259,6 +27307,33 @@ op {
is_stateful: true
}
op {
+ name: "LMDBReader"
+ output_arg {
+ name: "reader_handle"
+ description: "The handle to reference the Reader."
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ description: "If non-empty, this reader is placed in the given container.\nOtherwise, a default container is used."
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ description: "If non-empty, this reader is named in the given bucket\nwith this shared_name. Otherwise, the node name is used instead."
+ }
+ summary: "A Reader that outputs the records from a LMDB database."
+ is_stateful: true
+}
+op {
name: "TakeDataset"
input_arg {
name: "input_dataset"
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index 9722f0ee9a..6aca2c3b01 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -710,6 +710,75 @@ a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`.
b: `ndims`-D Tensor. With shape `a_shape`.
)doc");
+REGISTER_OP("SparseReduceMax")
+ .Input("input_indices: int64")
+ .Input("input_values: T")
+ .Input("input_shape: int64")
+ .Input("reduction_axes: int32")
+ .Attr("keep_dims: bool = False")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Computes the max of elements across dimensions of a SparseTensor.
+
+This Op takes a SparseTensor and is the sparse counterpart to
+`tf.reduce_max()`. In particular, this Op also returns a dense `Tensor`
+instead of a sparse one.
+
+Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained
+with length 1.
+
+If `reduction_axes` has no entries, all dimensions are reduced, and a tensor
+with a single element is returned. Additionally, the axes can be negative,
+which are interpreted according to the indexing rules in Python.
+
+input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
+ SparseTensor, possibly not in canonical ordering.
+input_values: 1-D. `N` non-empty values corresponding to `input_indices`.
+input_shape: 1-D. Shape of the input SparseTensor.
+reduction_axes: 1-D. Length-`K` vector containing the reduction axes.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: `R-K`-D. The reduced Tensor.
+)doc");
+
+REGISTER_OP("SparseReduceMaxSparse")
+ .Input("input_indices: int64")
+ .Input("input_values: T")
+ .Input("input_shape: int64")
+ .Input("reduction_axes: int32")
+ .Attr("keep_dims: bool = False")
+ .Output("output_indices: int64")
+ .Output("output_values: T")
+ .Output("output_shape: int64")
+ .Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Computes the max of elements across dimensions of a SparseTensor.
+
+This Op takes a SparseTensor and is the sparse counterpart to
+`tf.reduce_max()`. In contrast to SparseReduceMax, this Op returns a
+SparseTensor.
+
+Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained
+with length 1.
+
+If `reduction_axes` has no entries, all dimensions are reduced, and a tensor
+with a single element is returned. Additionally, the axes can be negative,
+which are interpreted according to the indexing rules in Python.
+
+input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
+ SparseTensor, possibly not in canonical ordering.
+input_values: 1-D. `N` non-empty values corresponding to `input_indices`.
+input_shape: 1-D. Shape of the input SparseTensor.
+reduction_axes: 1-D. Length-`K` vector containing the reduction axes.
+keep_dims: If true, retain reduced dimensions with length 1.
+)doc");
+
REGISTER_OP("SparseReduceSum")
.Input("input_indices: int64")
.Input("input_values: T")
diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/cloud/retrying_utils.cc
index 096c77c6e3..99691ecfb9 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils.cc
@@ -89,7 +89,7 @@ Status RetryingUtils::DeleteWithRetries(
bool is_retried = false;
return RetryingUtils::CallWithRetries(
[delete_func, &is_retried]() {
- const auto& status = delete_func();
+ const Status status = delete_func();
if (is_retried && status.code() == error::NOT_FOUND) {
return Status::OK();
}
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index cf05aece39..e476a84a13 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -171,7 +171,7 @@ message ExecutorOpts {
};
message RunGraphRequest {
- // session_handle is the the master-generated unique id for this session.
+ // session_handle is the master-generated unique id for this session.
// If session_handle is non-empty, it must be the same as used when
// registering the graph. If it is empty, a single global namespace is used to
// search for the graph_handle.
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index d30d7819fc..0e5611e359 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -24,7 +24,7 @@ limitations under the License.
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX "-rc2"
+#define TF_VERSION_SUFFIX ""
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 6a37256ea9..67468bdc3f 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
#include "third_party/mkl/include/mkl_service.h"
-
+#include "third_party/mkl/include/mkl_trans.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -616,6 +616,42 @@ inline void ForwarMklTensorInToOut(OpKernelContext* context,
}
}
+ // TODO(intel_tf): Remove this routine when faster MKL layout conversion is
+ // out.
+inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) {
+ const float* buf_in = input.flat<float>().data();
+ float* buf_out = (*output)->flat<float>().data();
+
+ int64 N = input.dim_size(0);
+ int64 H = input.dim_size(1);
+ int64 W = input.dim_size(2);
+ int64 C = input.dim_size(3);
+ int64 stride_n = H*W*C;
+# pragma omp parallel for num_threads(16)
+ for (int64 n = 0; n < N; ++n) {
+ mkl_somatcopy('R', 'T', H*W, C, 1, buf_in + n*stride_n, C,
+ buf_out + n*stride_n, H*W);
+ }
+}
+
+ // TODO(intel_tf): Remove this routine when faster MKL layout conversion is
+ // out.
+inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
+ const float* buf_in = input.flat<float>().data();
+ float* buf_out = (*output)->flat<float>().data();
+
+ int64 N = (*output)->dim_size(0);
+ int64 H = (*output)->dim_size(1);
+ int64 W = (*output)->dim_size(2);
+ int64 C = (*output)->dim_size(3);
+ int64 stride_n = H*W*C;
+# pragma omp parallel for num_threads(16)
+ for (int64 n = 0; n < N; ++n) {
+ mkl_somatcopy('R', 'T', C, H*W, 1, buf_in + n*stride_n, H*W,
+ buf_out + n*stride_n, C);
+ }
+}
+
namespace mkl_op_registry {
static const char* kMklOpLabel = "MklOp";
static const char* kMklOpLabelPattern = "label='MklOp'";