aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt12
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt18
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt69
-rw-r--r--tensorflow/core/framework/bfloat16_test.cc12
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc30
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc106
-rw-r--r--tensorflow/core/framework/numeric_types.h49
-rw-r--r--tensorflow/core/framework/op_def_builder_test.cc15
-rw-r--r--tensorflow/core/framework/types.cc24
-rw-r--r--tensorflow/core/grappler/costs/BUILD1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc157
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h37
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc91
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc2
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD6
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc107
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc106
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc367
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h21
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc159
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc321
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc203
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/static_schedule.cc4
-rw-r--r--tensorflow/core/kernels/BUILD21
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc93
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc97
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.h16
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc4
-rw-r--r--tensorflow/core/kernels/conv_ops.cc113
-rw-r--r--tensorflow/core/kernels/conv_ops.h10
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc3
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h12
-rw-r--r--tensorflow/core/kernels/conv_ops_test.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_asinh.cc2
-rw-r--r--tensorflow/core/kernels/dataset.cc140
-rw-r--r--tensorflow/core/kernels/dataset.h155
-rw-r--r--tensorflow/core/kernels/dataset_utils.cc2
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op.cc5
-rw-r--r--tensorflow/core/kernels/filter_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/flat_map_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/group_by_window_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/interleave_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/map_and_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/map_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc13
-rw-r--r--tensorflow/core/kernels/multinomial_op.cc51
-rw-r--r--tensorflow/core/kernels/multinomial_op.h2
-rw-r--r--tensorflow/core/kernels/multinomial_op_gpu.cu.cc30
-rw-r--r--tensorflow/core/kernels/nn_ops_test.cc2
-rw-r--r--tensorflow/core/kernels/padded_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/parallel_map_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/quantized_conv_ops.cc13
-rw-r--r--tensorflow/core/kernels/random_dataset_op.cc154
-rw-r--r--tensorflow/core/kernels/reduction_ops_min.cc1
-rw-r--r--tensorflow/core/kernels/reduction_ops_test.cc5
-rw-r--r--tensorflow/core/kernels/scan_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc63
-rw-r--r--tensorflow/core/kernels/serialize_sparse_op.cc177
-rw-r--r--tensorflow/core/kernels/softmax_op_functor.h33
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc1
-rw-r--r--tensorflow/core/kernels/strided_slice_op_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/tensor_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/tensor_slice_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/zip_dataset_op.cc2
-rw-r--r--tensorflow/core/lib/core/arena.cc18
-rw-r--r--tensorflow/core/lib/math/math_util.h17
-rw-r--r--tensorflow/core/lib/math/math_util_test.cc29
-rw-r--r--tensorflow/core/lib/monitoring/collected_metrics.h1
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.h6
-rw-r--r--tensorflow/core/lib/monitoring/gauge.h33
-rw-r--r--tensorflow/core/lib/monitoring/gauge_test.cc22
-rw-r--r--tensorflow/core/lib/monitoring/metric_def.h13
-rw-r--r--tensorflow/core/ops/array_ops.cc32
-rw-r--r--tensorflow/core/ops/dataset_ops.cc18
-rw-r--r--tensorflow/core/ops/math_ops.cc125
-rw-r--r--tensorflow/core/ops/nn_ops.cc146
-rw-r--r--tensorflow/core/ops/random_ops.cc11
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc5
-rw-r--r--tensorflow/core/ops/sparse_ops.cc42
-rw-r--r--tensorflow/core/ops/state_ops.cc56
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request_test.cc2
-rw-r--r--tensorflow/core/platform/cloud/file_block_cache.cc4
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc10
-rw-r--r--tensorflow/core/profiler/g3doc/options.md11
-rw-r--r--tensorflow/core/profiler/internal/tfprof_node.cc28
-rw-r--r--tensorflow/core/profiler/internal/tfprof_node.h10
-rw-r--r--tensorflow/core/profiler/internal/tfprof_show_test.cc37
-rw-r--r--tensorflow/core/profiler/internal/tfprof_stats_test.cc105
-rw-r--r--tensorflow/core/profiler/tfprof_log.proto5
102 files changed, 3008 insertions, 1105 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index bd7617fa96..5bcb87d2d1 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1016,7 +1016,7 @@ filegroup(
cc_library(
name = "android_tensorflow_lib_lite",
srcs = if_android(["//tensorflow/core:android_srcs"]),
- copts = tf_copts() + if_not_android_mips_and_mips64(["-Os"]),
+ copts = tf_copts(android_optimization_level_override = None),
linkopts = ["-lz"],
tags = [
"manual",
@@ -1106,8 +1106,7 @@ cc_library(
cc_library(
name = "android_tensorflow_lib_selective_registration",
srcs = if_android(["//tensorflow/core:android_srcs"]),
- copts = tf_copts() + [
- "-Os",
+ copts = tf_copts(android_optimization_level_override = None) + [
"-DSUPPORT_SELECTIVE_REGISTRATION",
],
tags = [
@@ -1129,8 +1128,7 @@ cc_library(
cc_library(
name = "android_tensorflow_lib_selective_registration_nortti",
srcs = if_android(["//tensorflow/core:android_srcs"]),
- copts = tf_copts() + tf_opts_nortti_if_android() + [
- "-Os",
+ copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [
"-DSUPPORT_SELECTIVE_REGISTRATION",
],
tags = [
@@ -1210,7 +1208,7 @@ cc_library(
"framework/tensor_testutil.h",
"util/reporter.h",
],
- copts = tf_copts() + ["-Os"],
+ copts = tf_copts(android_optimization_level_override = None),
tags = [
"manual",
"notap",
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt
index 6522ce976f..070d6adb97 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt
@@ -26,7 +26,7 @@ END
description: <<END
1-D tensor of length 4. The stride of the sliding window for each
dimension of `input`. The dimension order is determined by the value of
- `data_format`, see below for details.
+`data_format`, see below for details.
END
}
attr {
@@ -45,6 +45,16 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes a 2-D convolution given 4-D `input` and `filter` tensors."
description: <<END
Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt
index 4ea3374dbb..ff2d9d71db 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt
@@ -53,5 +53,15 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes the gradients of convolution with respect to the filter."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt
index 4420073e38..2de38b4263 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt
@@ -52,5 +52,15 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes the gradients of convolution with respect to the input."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt
index 8f3cd4493c..d26564097e 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt
@@ -36,6 +36,16 @@ Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 5. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes a 3-D convolution given 5-D `input` and `filter` tensors."
description: <<END
In signal processing, cross-correlation is a measure of similarity of
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt
index 6f9b917237..937c9c8ead 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt
@@ -45,5 +45,15 @@ Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 5. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes the gradients of 3-D convolution with respect to the filter."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt
index 19aba156d5..414e418dc5 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt
@@ -45,5 +45,15 @@ Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 5. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes the gradients of 3-D convolution with respect to the input."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt
index cc10ebe923..3c313f7be6 100644
--- a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt
@@ -23,6 +23,16 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors."
description: <<END
Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
diff --git a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt
index 9126be2afa..e66aa3b707 100644
--- a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt
@@ -56,5 +56,15 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes the gradients of depthwise convolution with respect to the filter."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt
index f1d16858db..f501ad21b3 100644
--- a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt
@@ -56,5 +56,15 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes the gradients of depthwise convolution with respect to the input."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt
index 00e96c8a15..dfaa531cbc 100644
--- a/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt
@@ -14,4 +14,47 @@ The `dtype` of the serialized `SparseTensor` objects.
END
}
summary: "Deserialize `SparseTensor` objects."
+ description: <<END
+The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where
+the last dimension stores serialized `SparseTensor` objects and the other N
+dimensions (N >= 0) correspond to a batch. The ranks of the original
+`SparseTensor` objects must all match. When the final `SparseTensor` is
+created, its rank is the rank of the incoming `SparseTensor` objects plus N;
+the sparse tensors have been concatenated along new dimensions, one for each
+batch.
+
+The output `SparseTensor` object's shape values for the original dimensions
+are the max across the input `SparseTensor` objects' shape values for the
+corresponding dimensions. The new dimensions match the size of the batch.
+
+The input `SparseTensor` objects' indices are assumed ordered in
+standard lexicographic order. If this is not the case, after this
+step run `SparseReorder` to restore index ordering.
+
+For example, if the serialized input is a `[2 x 3]` matrix representing two
+original `SparseTensor` objects:
+
+ index = [ 0]
+ [10]
+ [20]
+ values = [1, 2, 3]
+ shape = [50]
+
+and
+
+ index = [ 2]
+ [10]
+ values = [4, 5]
+ shape = [30]
+
+then the final deserialized `SparseTensor` will be:
+
+ index = [0 0]
+ [0 10]
+ [0 20]
+ [1 2]
+ [1 10]
+ values = [1, 2, 3, 4, 5]
+ shape = [2 50]
+END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt
index b19bbeab12..d18bafdce9 100644
--- a/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt
@@ -55,6 +55,16 @@ END
The type of padding algorithm to use.
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes a 2D convolution given quantized 4D input and filter tensors."
description: <<END
The inputs are quantized tensors where the lowest value represents the real
diff --git a/tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt
new file mode 100644
index 0000000000..0466b40f85
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt
@@ -0,0 +1,18 @@
+op {
+ graph_op_name: "RandomDataset"
+ in_arg {
+ name: "seed"
+ description: <<END
+A scalar seed for the random number generator. If either seed or
+seed2 is set to be non-zero, the random number generator is seeded
+by the given seed. Otherwise, a random seed is used.
+END
+ }
+ in_arg {
+ name: "seed2"
+ description: <<END
+A second scalar seed to avoid seed collision.
+END
+ }
+ summary: "Creates a Dataset that returns pseudorandom numbers."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
new file mode 100644
index 0000000000..b07ee9fda9
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "ResourceScatterNdUpdate"
+ in_arg {
+ name: "ref"
+ description: <<END
+A resource handle. Must be from a VarHandleOp.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A Tensor. Must be one of the following types: int32, int64.
+A tensor of indices into ref.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A Tensor. Must have the same type as ref. A tensor of updated
+values to add to ref.
+END
+ }
+ attr {
+ name: "use_locking"
+ description: <<END
+An optional bool. Defaults to True. If True, the assignment will
+be protected by a lock; otherwise the behavior is undefined,
+but may exhibit less contention.
+END
+ }
+ summary: "Applies sparse `updates` to individual values or slices within a given"
+ description: <<END
+variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to update 4 scattered elements to a rank-1 tensor to
+8 elements. In Python, that update would look like this:
+
+```python
+ ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ update = tf.scatter_nd_update(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(update)
+```
+
+The resulting update to ref would look like this:
+
+ [1, 11, 3, 10, 9, 6, 7, 12]
+
+See @{tf.scatter_nd} for more details about how to make updates to
+slices.
+END
+}
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
index 6e45338751..17e6209f8e 100644
--- a/tensorflow/core/framework/bfloat16_test.cc
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/bfloat16.h"
+#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -104,6 +105,17 @@ TEST(Bfloat16Test, Conversion) {
}
}
+TEST(Bfloat16Test, Epsilon) {
+ EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
+ EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
+ bfloat16(1.0f)));
+}
+
+TEST(Bfloat16Test, Negate) {
+ EXPECT_EQ(-3.0f, static_cast<float>(-bfloat16(3.0f)));
+ EXPECT_EQ(4.5f, static_cast<float>(-bfloat16(-4.5f)));
+}
+
static void BM_FloatToBFloat16(int iters) {
testing::StopTiming();
static const int N = 32 << 20;
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index ea66863bed..036e3473b1 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -397,6 +397,15 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(
CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
+ std::vector<int32> dilations;
+ TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
+
+ if (dilations.size() != 4) {
+ return errors::InvalidArgument(
+ "Conv2D requires the dilation attribute to contain 4 values, but got: ",
+ dilations.size());
+ }
+
std::vector<int32> strides;
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
@@ -410,6 +419,8 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
+ const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
+ const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
DimensionHandle batch_size_dim;
DimensionHandle input_depth_dim;
@@ -447,12 +458,12 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
DimensionHandle output_rows, output_cols;
- TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(c, input_spatial_dims[0],
- filter_rows_dim, stride_rows,
- padding, &output_rows));
- TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(c, input_spatial_dims[1],
- filter_cols_dim, stride_cols,
- padding, &output_cols));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
+ c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
+ padding, &output_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
+ c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
+ padding, &output_cols));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(
@@ -1307,6 +1318,9 @@ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
Status ScatterNdUpdateShape(InferenceContext* c) {
ShapeHandle input_shape = c->input(0);
+ if (c->input_handle_shapes_and_types(0) != nullptr) {
+ input_shape = (*c->input_handle_shapes_and_types(0))[0].shape;
+ }
ShapeHandle indices_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
ShapeHandle updates_shape;
@@ -1361,7 +1375,9 @@ Status ScatterNdUpdateShape(InferenceContext* c) {
}
}
- c->set_output(0, input_shape);
+ if (c->input_handle_shapes_and_types(0) == nullptr) {
+ c->set_output(0, input_shape);
+ }
return Status::OK();
}
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index ec9746b2af..5f3e5ad457 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -423,6 +423,15 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
.Finalize(&op.node_def));
};
+ // Invalid rank for input
+ INFER_ERROR("must be rank 4", op, "[4,4];[2,1,1,1]");
+ // Invalid rank for filter
+ INFER_ERROR("must be rank 4", op, "[1,4,4,1];[2,1,1]");
+
+ // Invalid value for strides
+ set_op({{1, 1, 0, 1}}, "VALID", "NHWC", "HWIO");
+ INFER_ERROR("must be > 0", op, "[1,2,2,1];[1,1,1,1]");
+
// 1x1 filter
set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO");
INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
@@ -443,11 +452,6 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
set_op({{1, 1, 2, 1}}, "VALID", "NHWC", "HWIO");
INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]");
- // Invalid rank for input
- INFER_ERROR("must be rank 4", op, "[4,4];[2,1,1,1]");
- // Invalid rank for filter
- INFER_ERROR("must be rank 4", op, "[1,4,4,1];[2,1,1]");
-
// Unknown dims in the critical fields lead to partial inference.
INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]");
INFER_OK(op, "[1,?,4,1];[2,1,1,1]", "[d0_0,?,2,d1_3]");
@@ -538,6 +542,98 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,2,2,d1_3]");
}
+TEST(CommonShapeFnsTest, Conv2DDilatedShapeTest) {
+ ShapeInferenceTestOp op("Conv2D");
+ auto set_op = [&op](const std::vector<int32>& dilations,
+ const std::vector<int32>& strides, const string& padding,
+ const string& data_format) {
+ TF_CHECK_OK(NodeDefBuilder("test", "Conv2D")
+ .Input("input", 0, DT_FLOAT)
+ .Input("filter", 0, DT_FLOAT)
+ .Attr("dilations", dilations)
+ .Attr("strides", strides)
+ .Attr("padding", padding)
+ .Attr("data_format", data_format)
+ .Finalize(&op.node_def));
+ };
+
+ // Invalid rank for dilation
+ set_op({{1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
+ INFER_ERROR("contain 4 values", op, "[1,2,2,1];[1,1,1,1]");
+
+ // Invalid value for dilation
+ set_op({{1, 0, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
+ INFER_ERROR("must be >= 1", op, "[1,2,2,1];[1,1,1,1]");
+
+ // Tests for NHWC
+ // 1x1 filter, 2x1 dilations, 1x1 strides
+ set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
+
+ // 1x1 filter, 2x1 dilations, 2x1 strides
+ set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,4,d1_3]");
+
+ // 1x1 filter, 2x1 dilations, 2x2 strides
+ set_op({{1, 2, 1, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
+
+ // 3x3 filter, 2x1 dilations, 1x1 strides
+ set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]");
+
+ // 3x3 filter, 2x1 dilations, 2x1 strides
+ set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]");
+
+ // 3x3 filter, 1x2 dilations, 2x2 strides
+ set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,2,1,d1_3]");
+
+ // Tests for NCHW
+ // 1x1 filter, 2x1 dilations, 1x1 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]");
+
+ // 1x1 filter, 2x1 dilations, 2x1 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,4]");
+
+ // 1x1 filter, 2x1 dilations, 2x2 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 2, 2}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,2]");
+
+ // 3x3 filter, 2x1 dilations, 1x1 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]");
+
+ // 3x3 filter, 2x1 dilations, 2x1 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]");
+
+ // 3x3 filter, 1x2 dilations, 2x2 strides
+ set_op({{1, 1, 1, 2}}, {{1, 1, 2, 2}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,2,1]");
+
+ // Some tests for "SAME" padding
+
+ // 4x4 input, 1x1 filter, 2x1 dilations, 1x1 stride
+ set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
+
+ // 3x3 input, 2x2 filter, 2x2 dilations, 1x1 stride
+ set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
+ INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
+
+ // 4x4 input, 2x2 filter, 1x2 dilations, 2x2 stride
+ set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "SAME", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]");
+
+ // 4x4 input, 2x2 filter, 2x2 dilations, 1x1 stride
+ set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
+}
+
TEST(CommonShapeFnsTest, Conv3DShapeTest) {
ShapeInferenceTestOp op("Conv3D");
auto set_op = [&op](const std::vector<int32>& strides,
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index 2b080e13fd..bdd5af064b 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -58,7 +58,7 @@ struct bfloat16 {
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
: bfloat16(static_cast<float>(val)) {}
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
+ EIGEN_DEVICE_FUNC explicit operator float() const {
float result;
uint16_t* q = reinterpret_cast<uint16_t*>(&result);
@@ -89,6 +89,10 @@ struct bfloat16 {
return static_cast<int>(float(*this));
}
+ EIGEN_DEVICE_FUNC explicit operator long() const {
+ return static_cast<long>(float(*this));
+ }
+
EIGEN_DEVICE_FUNC explicit operator char() const {
return static_cast<char>(float(*this));
}
@@ -121,15 +125,48 @@ struct bfloat16 {
return static_cast<double>(float(*this));
}
+ static bfloat16 epsilon() {
+ bfloat16 x;
+ x.value = 0x3c00; // 0x1.0p-7
+ return x;
+ }
+
uint16_t value;
};
-inline bool operator==(const bfloat16 a, const bfloat16 b) {
- return a.value == b.value;
+inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) + static_cast<float>(b));
}
-
-inline bool operator!=(const bfloat16 a, const bfloat16 b) {
- return a.value != b.value;
+inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) - static_cast<float>(b));
+}
+inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) * static_cast<float>(b));
+}
+inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) / static_cast<float>(b));
+}
+inline bfloat16 operator-(bfloat16 a) {
+ a.value ^= 0x8000;
+ return a;
+}
+inline bool operator<(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) < static_cast<float>(b);
+}
+inline bool operator<=(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) <= static_cast<float>(b);
+}
+inline bool operator==(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) == static_cast<float>(b);
+}
+inline bool operator!=(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) != static_cast<float>(b);
+}
+inline bool operator>(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) > static_cast<float>(b);
+}
+inline bool operator>=(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) >= static_cast<float>(b);
}
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc
index c1511ebe34..9b24e3aa00 100644
--- a/tensorflow/core/framework/op_def_builder_test.cc
+++ b/tensorflow/core/framework/op_def_builder_test.cc
@@ -124,22 +124,23 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
- "DT_QINT32, DT_UINT32, DT_UINT64] } } }");
+ "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16] } } }");
ExpectSuccess(
b().Attr("a:{numbertype, variant}"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
- "DT_QINT32, DT_UINT32, DT_UINT64, DT_VARIANT] } } }");
+ "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16, DT_VARIANT] } } }");
ExpectSuccess(b().Attr("a:realnumbertype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
- "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64] } } }");
+ "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, "
+ "DT_BFLOAT16] } } }");
ExpectSuccess(b().Attr("a:{realnumbertype, variant , string, }"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
"DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, "
- "DT_VARIANT, DT_STRING] } } }");
+ "DT_BFLOAT16, DT_VARIANT, DT_STRING] } } }");
ExpectSuccess(b().Attr("a:quantizedtype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
@@ -216,12 +217,14 @@ TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
b().Attr("a:list(realnumbertype)"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
- "DT_UINT16, DT_INT8, DT_HALF, DT_UINT32, DT_UINT64] } } }");
+ "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64"
+ "] } } }");
ExpectSuccess(
b().Attr("a:list({realnumbertype, variant})"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
- "DT_UINT16, DT_INT8, DT_HALF, DT_UINT32, DT_UINT64, DT_VARIANT] } } }");
+ "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64, "
+ "DT_VARIANT] } } }");
ExpectSuccess(
b().Attr("a:list(quantizedtype)"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
index faae19585d..48849f9dda 100644
--- a/tensorflow/core/framework/types.cc
+++ b/tensorflow/core/framework/types.cc
@@ -206,18 +206,18 @@ string DataTypeSliceString(const DataTypeSlice types) {
}
DataTypeVector AllTypes() {
- return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
- DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128,
- DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16,
- DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE, DT_VARIANT,
- DT_UINT32, DT_UINT64};
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
+ DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128,
+ DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16,
+ DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE, DT_VARIANT,
+ DT_UINT32, DT_UINT64, DT_BFLOAT16};
}
#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION)
DataTypeVector RealNumberTypes() {
- return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16,
- DT_INT8, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64};
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16,
+ DT_INT8, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64, DT_BFLOAT16};
}
DataTypeVector QuantizedTypes() {
@@ -227,14 +227,14 @@ DataTypeVector QuantizedTypes() {
DataTypeVector RealAndQuantizedTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8,
DT_UINT16, DT_UINT16, DT_INT8, DT_QINT8, DT_QUINT8,
- DT_QINT16, DT_QUINT16, DT_QINT32, DT_HALF};
+ DT_QINT16, DT_QUINT16, DT_QINT32, DT_HALF, DT_BFLOAT16};
}
DataTypeVector NumberTypes() {
- return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32,
- DT_UINT8, DT_UINT16, DT_INT16, DT_INT8,
- DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8,
- DT_QINT32, DT_HALF, DT_UINT32, DT_UINT64};
+ return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8,
+ DT_UINT16, DT_INT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128,
+ DT_QINT8, DT_QUINT8, DT_QINT32, DT_HALF, DT_UINT32,
+ DT_UINT64, DT_BFLOAT16};
}
#elif defined(__ANDROID_TYPES_FULL__)
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index f02cb51038..f1edbbb602 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -50,6 +50,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
],
)
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index dd389de636..ec44d11bdd 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
@@ -264,6 +265,79 @@ bool IsEnterWithQueue(const Node& node) {
return false;
}
+bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
+ if (proto.unknown_rank()) {
+ return true;
+ }
+ for (const auto& dim : proto.dim()) {
+ if (dim.size() < 0) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void VerboseLogUnknownDimensionSources(
+ const Graph& graph,
+ const std::map<string, std::vector<OpInfo::TensorProperties>>&
+ input_properties_map,
+ const std::map<string, std::vector<OpInfo::TensorProperties>>&
+ output_properties_map) {
+ if (!VLOG_IS_ON(2)) {
+ return;
+ }
+
+ VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
+
+ // Find all nodes in the graph for which we
+ // do not have any unknown dimensions in their inputs, but
+ // we have some unknown dimensions in their outputs.
+ for (const Node* const node : graph.nodes()) {
+ if (node->num_outputs() == 0) {
+ continue;
+ }
+
+ const auto& input_properties = input_properties_map.at(node->name());
+ const auto& output_properties = output_properties_map.at(node->name());
+
+ bool has_unknown_inputs = false;
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ if (HasAnyUnknownDimensions(input_properties[i].shape())) {
+ has_unknown_inputs = true;
+ break;
+ }
+ }
+
+ if (has_unknown_inputs) {
+ continue;
+ }
+
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ if (HasAnyUnknownDimensions(output_properties[i].shape())) {
+ string inputs = "input_shapes=[";
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ inputs +=
+ PartialTensorShape::DebugString(input_properties[i].shape());
+ }
+ inputs += "]";
+
+ string outputs = "output_shapes=[";
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ outputs +=
+ PartialTensorShape::DebugString(output_properties[i].shape());
+ }
+ outputs += "]";
+
+ VLOG(2) << "Node: " << node->name() << ", Op: " << node->def().op()
+ << ", " << inputs << ", " << outputs;
+
+ // don't log again for this node
+ break;
+ }
+ }
+ }
+}
+
} // namespace
// Queue of nodes to process. Nodes can be enqueued in any order, but will be
@@ -312,9 +386,15 @@ class SymbolicShapeRefiner {
Status UpdateNode(const Node* node, bool relax, bool* refined) {
return shape_refiner_->UpdateNode(node, relax, refined);
}
- Status SetShape(const Node* node, int output_port,
- shape_inference::ShapeHandle shape) {
- return shape_refiner_->SetShape(node, output_port, shape);
+ Status SetUnknownShape(const Node* node, int output_port) {
+ shape_inference::ShapeHandle shape =
+ GetUnknownOutputShape(node, output_port);
+ InferenceContext* ctx = GetContext(node);
+ if (ctx == nullptr) {
+ return errors::InvalidArgument("Missing context");
+ }
+ ctx->set_output(output_port, shape);
+ return Status::OK();
}
struct ShapeId {
@@ -646,6 +726,23 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
return Status::OK();
}
+Status GraphProperties::OverwriteFedPorts(
+ SymbolicShapeRefiner* shape_refiner,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* node, TopoQueue* new_shapes) const {
+ auto it = fed_ports.find(node->name());
+ Status status;
+ if (it != fed_ports.end()) {
+ // It is possible to feed node output ports with tensors of any shape: as a
+ // result, the shape of a fed port is completely unknown.
+ for (const int output_port : it->second) {
+ status.Update(shape_refiner->SetUnknownShape(node, output_port));
+ }
+ new_shapes->push(node);
+ }
+ return status;
+}
+
// Manually propagate the input shape for Enter nodes and update any Merge node
// outputs.
Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
@@ -673,9 +770,10 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
return Status::OK();
}
-Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
- bool relax, const Node* n,
- TopoQueue* new_shapes) {
+Status GraphProperties::UpdateShapes(
+ SymbolicShapeRefiner* shape_refiner, bool relax,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* n, TopoQueue* new_shapes) const {
if (n->IsEnter()) {
// The Enter shape function always forwards an UnknownShape, so do the right
// thing here.
@@ -695,7 +793,9 @@ Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
}
}
}
- return Status::OK();
+ // Nodes can be fed with any shape. The TensorFlow shape inference code can't
+ // handle this properly, so overwrite its behavior here.
+ return OverwriteFedPorts(shape_refiner, fed_ports, n, new_shapes);
}
// Propagates the shapes in the transitive fan-out of <new_shapes>.
@@ -703,6 +803,7 @@ Status GraphProperties::PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
int num_loops) const {
// Limit the number of iterations to prevent infinite loops in the presence of
// incorrect shape functions. The algoritm should converge in at most
@@ -728,8 +829,8 @@ Status GraphProperties::PropagateShapes(
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge()) {
const Node* fanout = e->dst();
- TF_RETURN_IF_ERROR(
- UpdateShapes(shape_refiner, relax, fanout, new_shapes));
+ TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, fed_ports,
+ fanout, new_shapes));
}
}
}
@@ -803,7 +904,7 @@ Status GraphProperties::UpdateResource(
return Status::OK();
}
-Status GraphProperties::InferStatically() {
+Status GraphProperties::InferStatically(bool assume_valid_feeds) {
Graph graph(OpRegistry::Global());
FunctionLibraryDefinition function_library(graph.op_registry(),
item_.graph.library());
@@ -820,11 +921,21 @@ Status GraphProperties::InferStatically() {
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s);
+ std::unordered_map<string, std::unordered_set<int>> fed_ports;
+ if (!assume_valid_feeds) {
+ for (const auto& feed : item_.feed) {
+ int port_index = 0;
+ string node_name = ParseNodeName(feed.first, &port_index);
+ fed_ports[node_name].insert(port_index);
+ }
+ }
+
// List the resources and the nodes using them. Also collect the Enter and
// Merge nodes.
std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
std::unordered_set<const Node*> enter_nodes;
std::unordered_set<const Node*> merge_nodes;
+ std::unordered_set<const Node*> fed_nodes;
int num_loops = 0;
for (const Node* const node : graph.nodes()) {
for (int i = 0; i < node->num_inputs(); ++i) {
@@ -841,6 +952,9 @@ Status GraphProperties::InferStatically() {
} else if (node->IsNextIteration()) {
++num_loops;
}
+ if (fed_ports.find(node->name()) != fed_ports.end()) {
+ fed_nodes.insert(node);
+ }
}
SymbolicShapeRefiner refiner(&shape_refiner);
@@ -855,15 +969,22 @@ Status GraphProperties::InferStatically() {
// Force the propagation of shapes of Enter nodes manually (the Enter shape
// function always forwards an UnknownShape).
for (const Node* node : enter_nodes) {
- TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes));
+ TF_RETURN_IF_ERROR(
+ UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
}
// Seed the propagation of shapes through merge nodes.
for (const Node* node : merge_nodes) {
- TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes));
+ TF_RETURN_IF_ERROR(
+ UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
+ }
+ // Also seed the propagation of shapes in the fanout of fed nodes.
+ for (const Node* node : fed_nodes) {
+ TF_RETURN_IF_ERROR(
+ OverwriteFedPorts(&refiner, fed_ports, node, &new_shapes));
}
// Propagate shapes normally.
- TF_RETURN_IF_ERROR(
- PropagateShapes(&refiner, relax, &new_shapes, resources, num_loops));
+ TF_RETURN_IF_ERROR(PropagateShapes(&refiner, relax, &new_shapes, resources,
+ fed_ports, num_loops));
}
// Track shapes globally across the graph.
@@ -874,6 +995,10 @@ Status GraphProperties::InferStatically() {
if (!node_ctx) {
continue;
}
+ // Skip any information that comes from fed nodes.
+ if (fed_ports.find(node->name()) != fed_ports.end()) {
+ continue;
+ }
for (const auto& merged_shapes : node_ctx->MergedShapes()) {
if (!shape_manager.Merge(merged_shapes.first, merged_shapes.second)
.ok()) {
@@ -948,6 +1073,10 @@ Status GraphProperties::InferStatically() {
}
}
+ // Help trace the unknown dimensions to their origins.
+ VerboseLogUnknownDimensionSources(graph, input_properties_,
+ output_properties_);
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 95bc5044d0..6fc53a7f2e 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -34,12 +34,19 @@ class TopoQueue;
// nodes, and potentially a set of nodes to feed.
class GraphProperties {
public:
- // Factory method for creating a GrapplerShapes from a MetaGraphDef.
- // Returns nullptr if the given meta_graph cannot be converted.
explicit GraphProperties(const GrapplerItem& item) : item_(item) {}
- Status InferStatically();
+ // Infer the shapes through abstract interpretation. Feed information can be
+ // incorrect so it should be discarded to ensure correctness of the analysis.
+ // However, it can help infer shapes in the fanout of fed nodes (even though
+ // the correctness of these shapes can't be guaranteed), so in some cases
+ // (such as simulation or scheduling) it makes sense of keep these shapes.
+ Status InferStatically(bool assume_valid_feeds);
+ // Infer the shape by running the graph on the specified cluster and recording
+ // the shapes of the processed tensors.
Status InferDynamically(Cluster* cluster);
+ // Extract the properties from a cost graph. For testing only since there is
+ // no way to ensure that the cost graph match the item.
Status InferFromCostGraph(const CostGraphDef& cost_graph);
// Stores `item_.graph` with the inferred output shapes to `output_graph_def`.
@@ -65,12 +72,6 @@ class GraphProperties {
OpInfo::TensorProperties*);
private:
- // Inputs
- GrapplerItem item_;
- std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
- std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
- const std::vector<OpInfo::TensorProperties> missing_properties_;
-
// Merges shapes <shapes_and_types>, determined from an EnqueueV2 node, into
// <*queue_shapes_and_types>.
static Status MergeEnqueueShapesAndTypes(
@@ -99,17 +100,31 @@ class GraphProperties {
static Status UpdateEnter(SymbolicShapeRefiner* shape_refiner,
const Node* node, bool relax,
TopoQueue* new_shapes);
+ // Process a node that is used to feed the model.
+ Status OverwriteFedPorts(
+ SymbolicShapeRefiner* shape_refiner,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* node, TopoQueue* new_shapes) const;
// Update the shapes for node 'n'. If output shapes for n have changed,
// enqueue its fanout in 'new_shapes'.
- static Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, bool relax,
- const Node* n, TopoQueue* new_shapes);
+ Status UpdateShapes(
+ SymbolicShapeRefiner* shape_refiner, bool relax,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* n, TopoQueue* new_shapes) const;
// Propagate the shapes for the nodes enqueued in new_shapes and their
// transitive fanout until a fixed point is reached.
Status PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
int num_loops) const;
+
+ // Data members
+ GrapplerItem item_;
+ std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
+ std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
+ const std::vector<OpInfo::TensorProperties> missing_properties_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index c11af5777a..cc40ff2cfc 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -73,7 +73,7 @@ TEST_F(GraphPropertiesTest, StaticProperties) {
CHECK(fake_input.NextItem(&item));
GraphProperties properties(item);
- Status s = properties.InferStatically();
+ Status s = properties.InferStatically(true);
TF_CHECK_OK(s);
for (const auto& node : item.graph.node()) {
@@ -179,7 +179,7 @@ TEST_F(GraphPropertiesTest, Variables) {
{
GraphProperties static_properties(item);
- TF_CHECK_OK(static_properties.InferStatically());
+ TF_CHECK_OK(static_properties.InferStatically(false));
const auto props = static_properties.GetOutputProperties("Var");
EXPECT_EQ(1, props.size());
@@ -219,7 +219,7 @@ TEST_F(GraphPropertiesTest, VarHandles) {
.Finalize(item.graph.add_node()));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props = properties.GetOutputProperties("VarRead");
EXPECT_EQ(1, props.size());
@@ -286,7 +286,7 @@ TEST_F(GraphPropertiesTest, Queues) {
TF_CHECK_OK(root.ToGraphDef(&item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props1 = properties.GetOutputProperties("Dequeue1");
ASSERT_EQ(1, props1.size());
@@ -335,7 +335,7 @@ TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
"merge_without_loops.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
@@ -377,7 +377,7 @@ TEST_F(GraphPropertiesTest, WhileLoop) {
"while_loop.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -435,7 +435,7 @@ TEST_F(GraphPropertiesTest, NestedLoop) {
"nested_loop.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -498,7 +498,7 @@ TEST_F(GraphPropertiesTest, LoopsAndQueues) {
"loops_and_queues.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -556,7 +556,7 @@ TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
"loops_and_resource_vars.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -608,7 +608,7 @@ TEST_F(GraphPropertiesTest, QueuesAndLoops) {
"queues_and_loops.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -657,7 +657,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
item.fetch.push_back("init_restore");
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto restore_props = properties.GetOutputProperties("restore");
const OpInfo::TensorProperties& restore_prop = restore_props[0];
@@ -704,7 +704,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
item.fetch.push_back("init2");
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props = properties.GetOutputProperties("restore");
const OpInfo::TensorProperties& prop = props[0];
@@ -732,7 +732,7 @@ TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
"simple_function.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props = properties.GetOutputProperties("MyAdd_55e046a8_1");
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
@@ -766,7 +766,7 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
EXPECT_EQ(2, shape_a.dim_size());
@@ -822,7 +822,7 @@ TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
GraphProperties properties(item);
// This function should return OK, since it doesn't validate the colocation
// constraints internally.
- TF_EXPECT_OK(properties.InferStatically());
+ TF_EXPECT_OK(properties.InferStatically(false));
}
TEST_F(GraphPropertiesTest, ShapeTracking) {
@@ -842,7 +842,7 @@ TEST_F(GraphPropertiesTest, ShapeTracking) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
@@ -851,6 +851,65 @@ TEST_F(GraphPropertiesTest, ShapeTracking) {
EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
}
+TEST_F(GraphPropertiesTest, FedNodes) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+ cluster_->GetDeviceNames());
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ {
+ // Conservative shape analysis: the shape of fed ports should be unknown
+ GraphProperties properties(item);
+ Status s = properties.InferStatically(false);
+ TF_CHECK_OK(s);
+ for (const auto& node : item.graph.node()) {
+ if (node.op() == "Const") {
+ continue;
+ }
+ const auto in_props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, in_props.size());
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ const auto out_props = properties.GetOutputProperties(node.name());
+ EXPECT_EQ(1, out_props.size());
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+
+ if (node.name() == "x") {
+ // x is fed: its input should have a known shape, while its output
+ // doesn't
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(1, in_prop.shape().dim_size());
+ EXPECT_EQ(2, in_prop.shape().dim(0).size());
+ EXPECT_TRUE(out_prop.shape().unknown_rank());
+ } else if (node.op() == "Square" || node.op() == "AddN") {
+ // These nodes are in the fanout of x: their shapes should be unknown.
+ EXPECT_TRUE(in_prop.shape().unknown_rank());
+ EXPECT_TRUE(out_prop.shape().unknown_rank());
+ }
+ }
+ }
+ {
+ // Optimistic shape analysis: the shape of fed ports should be derived from
+ // the shape of the fanin.
+ GraphProperties properties(item);
+ Status s = properties.InferStatically(true);
+ TF_CHECK_OK(s);
+ for (const auto& node : item.graph.node()) {
+ if (node.op() == "Square" || node.op() == "AddN") {
+ const auto in_props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, in_props.size());
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ const auto out_props = properties.GetOutputProperties(node.name());
+ EXPECT_EQ(1, out_props.size());
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString());
+ }
+ }
+ }
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index e5e1ee3292..6640de668d 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -122,7 +122,7 @@ Status VirtualScheduler::Init() {
// Construct graph properties.
Status status;
if (use_static_shapes_) {
- status = graph_properties_.InferStatically();
+ status = graph_properties_.InferStatically(true);
} else {
status = graph_properties_.InferDynamically(cluster_);
}
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 36c7f92c49..da99777bbc 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -173,7 +173,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
<< ", skipping this input.";
return nullptr;
}
- LOG(INFO) << "Will use feed node " << feed_name;
+ VLOG(1) << "Will use feed node " << feed_name;
new_item->feed.emplace_back(feed_name, Tensor());
}
@@ -188,7 +188,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
<< ", skipping this input";
return nullptr;
}
- LOG(INFO) << "Will use fetch node " << name;
+ VLOG(1) << "Will use fetch node " << name;
new_item->fetch.push_back(name);
}
}
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 5d9eb8e0b1..7b4ed10e7e 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -96,6 +96,7 @@ cc_library(
":graph_optimizer",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
@@ -332,6 +333,11 @@ tf_cc_test(
deps = [
":layout_optimizer",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 3cfc4f61e4..efe8ac05a3 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -253,6 +253,30 @@ bool IsNumberType(DataType dtype) {
const char kOutputShapesAttr[] = "_output_shapes";
+PartialTensorShape GetInputShape(const string& input, const NodeMap& node_map) {
+ int output_pos;
+ string node_name = ParseNodeName(input, &output_pos);
+ const NodeDef* input_node = node_map.GetNode(node_name);
+ return input_node->attr().at(kOutputShapesAttr).list().shape(output_pos);
+}
+
+bool ShapesEqual(const string& input_x, const string& input_y,
+ const NodeMap& node_map) {
+ PartialTensorShape x_shape = GetInputShape(input_x, node_map);
+ PartialTensorShape y_shape = GetInputShape(input_y, node_map);
+ if (x_shape.unknown_rank() || y_shape.unknown_rank() ||
+ x_shape.dims() != y_shape.dims()) {
+ return false;
+ }
+ for (int i = 0; i < x_shape.dims(); ++i) {
+ if (x_shape.dim_size(i) == -1 || y_shape.dim_size(i) == -1 ||
+ x_shape.dim_size(i) != y_shape.dim_size(i)) {
+ return false;
+ }
+ }
+ return true;
+}
+
// Returns whether `reshape` is an identity op. The tensor that `reshape`
// reshapes is the `output_pos`-th output of node `input`.
bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
@@ -868,8 +892,11 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// multiplication over addition to hoist common factors out of aggregate nodes
// where all the inputs are Mul nodes. This pattern occurs frequently in
// regularization terms for the gradients during training.
- // TODO(rmlarsen): Check shapes and enable for AddN.
- if (IsAdd(*node) && NumNonControlInputs(*node) > 1 &&
+ // For example, we can rewrite an expression of the form:
+ // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
+ // to the following:
+ // Mul(x, AddN(y1, y2, y3, ... yn))
+ if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
!OptimizedNodeExists(StrCat(node->name(), "_hoist_add"))) {
// Determine the set of common factors if the input nodes are all Mul nodes.
std::set<string> common_factors;
@@ -899,24 +926,15 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
if (common_factors.size() == 1) {
const string& common_factor = *common_factors.begin();
- // In this case we have an expression of the form
- // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
- // that can be rewritten as
- // Mul(x, AddN(y1, y2, y3, ... yn))
-
- // 1. Use a copy of the first Mul node for the outer multiplication.
- NodeDef* new_mul_node = AddNode(StrCat(node->name(), "_hoist_mul"),
- node_map_->GetNode(node->input(0)));
- NodeDef* new_add_node = AddNode(StrCat(node->name(), "_hoist_add"), node);
- new_mul_node->set_device(node->device());
- new_mul_node->set_input(0, common_factor);
- node_map_->AddOutput(common_factor, new_mul_node->name());
- new_mul_node->set_input(1, new_add_node->name());
- node_map_->AddOutput(new_add_node->name(), new_mul_node->name());
-
- // 2. Hoist non-shared factors up into the new AddN node.
- nodes_to_simplify->PushBack(new_add_node);
- for (int i = 0; i < node->input_size(); ++i) {
+
+ // Gather up the non-shared factors (the y's in the example).
+ // Unless the aggregation is Add, we have to make sure that all the y's
+ // have the same shape since the other aggregation ops do not support
+ // broadcasting.
+ std::vector<string> unique_factors;
+ unique_factors.reserve(node->input_size());
+ bool shapes_match = true;
+ for (int i = 0; i < node->input_size() && shapes_match; ++i) {
const string& input = node->input(i);
if (IsControlInput(input)) {
break;
@@ -924,15 +942,41 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* mul_node = node_map_->GetNode(input);
const int unique_factor_index =
mul_node->input(0) == common_factor ? 1 : 0;
- const string unique_factor = mul_node->input(unique_factor_index);
- new_add_node->set_input(i, unique_factor);
+ unique_factors.push_back(mul_node->input(unique_factor_index));
+ if (i > 0 && !IsAdd(*node)) {
+ shapes_match = ShapesEqual(unique_factors.front(),
+ unique_factors.back(), *node_map_);
+ }
}
- // 4. Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
- {new_add_node});
+ if (shapes_match) {
+ // 1. Use a copy of the first Mul node for the outer multiplication.
+ NodeDef* new_mul_node = AddNode(StrCat(node->name(), "_hoist_mul"),
+ node_map_->GetNode(node->input(0)));
+ NodeDef* new_add_node =
+ AddNode(StrCat(node->name(), "_hoist_add"), node);
+ new_mul_node->set_device(node->device());
+ new_mul_node->set_input(0, common_factor);
+ node_map_->AddOutput(common_factor, new_mul_node->name());
+ new_mul_node->set_input(1, new_add_node->name());
+ node_map_->AddOutput(new_add_node->name(), new_mul_node->name());
+
+ // 2. Hoist non-shared factors up into the new AddN node.
+ nodes_to_simplify->PushBack(new_add_node);
+ for (int i = 0; i < node->input_size(); ++i) {
+ const string& input = node->input(i);
+ if (IsControlInput(input)) {
+ break;
+ }
+ new_add_node->set_input(i, unique_factors[i]);
+ }
- return new_mul_node->name();
+ // 3. Add frame dependencies that the original node might have had.
+ AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
+ {new_add_node});
+
+ return new_mul_node->name();
+ }
}
}
@@ -1064,13 +1108,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
int num_frames;
TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
&frame_map_, &num_frames));
- if (opt_level_ == RewriterConfig::AGGRESSIVE) {
- graph_properties_.reset(new GraphProperties(item));
- // Shapes are only needed in aggressive mode.
- TF_RETURN_IF_ERROR(graph_properties_->InferStatically());
- TF_RETURN_IF_ERROR(
- graph_properties_->AnnotateOutputShapes(optimized_graph_));
- }
+ graph_properties_.reset(new GraphProperties(item));
+ // Shapes are only needed in aggressive mode.
+ TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+ TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_));
// Perform the optimizations.
DedupComputations();
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index e8a18ff9d9..80f42694d9 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -32,6 +32,21 @@ string OptimizedName(const string& name) {
return AddPrefixToNodeName(name, kArithmeticOptimizer);
}
+void VerifyGraphsMatch(const GraphDef& original_graph,
+ const GraphDef& optimized_graph, int line) {
+ EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
+ for (int i = 0; i < original_graph.node_size(); ++i) {
+ const NodeDef& original = original_graph.node(i);
+ const NodeDef& optimized = optimized_graph.node(i);
+ EXPECT_EQ(original.name(), optimized.name()) << line;
+ EXPECT_EQ(original.op(), optimized.op()) << line;
+ EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
+ for (int j = 0; j < original.input_size(); ++j) {
+ EXPECT_EQ(original.input(j), optimized.input(j)) << line;
+ }
+ }
+}
+
class ArithmeticOptimizerTest : public ::testing::Test {};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -44,18 +59,7 @@ TEST_F(ArithmeticOptimizerTest, NoOp) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
-
- EXPECT_EQ(item.graph.node_size(), output.node_size());
- for (int i = 0; i < item.graph.node_size(); ++i) {
- const NodeDef& original = item.graph.node(i);
- const NodeDef& optimized = output.node(i);
- EXPECT_EQ(original.name(), optimized.name());
- EXPECT_EQ(original.op(), optimized.op());
- EXPECT_EQ(original.input_size(), optimized.input_size());
- for (int j = 0; j < original.input_size(); ++j) {
- EXPECT_EQ(original.input(j), optimized.input(j));
- }
- }
+ VerifyGraphsMatch(item.graph, output, __LINE__);
}
TEST_F(ArithmeticOptimizerTest, OpDedupping) {
@@ -398,39 +402,51 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
}
TEST_F(ArithmeticOptimizerTest, HoistFactor) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
- Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
- Output y2 = ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
- Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
- Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
- Output add = ops::Add(s.WithOpName("add"), mul1, mul2);
- Output id = ops::Identity(s.WithOpName("id"), add);
-
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
- ArithmeticOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
-
- EXPECT_EQ(9, output.node_size());
- const NodeDef& new_add = output.node(8);
- EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name());
- EXPECT_EQ("y1", new_add.input(0));
- EXPECT_EQ("y2", new_add.input(1));
- const NodeDef& new_mul = output.node(7);
- EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name());
- EXPECT_EQ("x", new_mul.input(0));
- EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1));
- const NodeDef& new_id = output.node(6);
- EXPECT_EQ("id", new_id.name());
- EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0));
+ for (bool matching_shapes : {true, false}) {
+ for (bool use_addn : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
+ Output y2 = matching_shapes
+ ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
+ : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
+ Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
+ Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
+ Output id =
+ use_addn ? ops::Identity(s.WithOpName("id"),
+ ops::AddN(s.WithOpName("add"), {mul1, mul2}))
+ : ops::Identity(s.WithOpName("id"),
+ ops::Add(s.WithOpName("add"), mul1, mul2));
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ if (use_addn && !matching_shapes) {
+ VerifyGraphsMatch(item.graph, output, __LINE__);
+ } else {
+ EXPECT_EQ(9, output.node_size());
+ const NodeDef& new_add = output.node(8);
+ EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name());
+ EXPECT_EQ("y1", new_add.input(0));
+ EXPECT_EQ("y2", new_add.input(1));
+ const NodeDef& new_mul = output.node(7);
+ EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name());
+ EXPECT_EQ("x", new_mul.input(0));
+ EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1));
+ const NodeDef& new_id = output.node(6);
+ EXPECT_EQ("id", new_id.name());
+ EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0));
+ }
+ }
+ }
}
TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index c77b2badf4..e0f39c2931 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -30,13 +30,16 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/bcast.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
namespace tensorflow {
namespace grappler {
@@ -95,7 +98,38 @@ class DeviceSimple : public DeviceBase {
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
};
+template <typename T>
+bool AllValuesAre(const TensorProto& tensor, const T& value) {
+ // TensorProto represents the content of the tensor in either <type>_val or
+ // tensor_content.
+ typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
+ checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
+ if (!tensor_values->empty()) {
+ for (const T& tensor_value : *tensor_values) {
+ if (tensor_value != value) {
+ return false;
+ }
+ }
+ return true;
+ }
+ const auto tensor_content_size = tensor.tensor_content().size();
+ if (tensor_content_size > 0) {
+ CHECK_EQ(0, tensor_content_size % sizeof(T));
+ std::vector<T> raw_values(tensor_content_size / sizeof(T));
+ port::CopyToArray(tensor.tensor_content(),
+ reinterpret_cast<char*>(raw_values.data()));
+ for (int i = 0; i < tensor_content_size / sizeof(T); ++i) {
+ if (raw_values[i] != value) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
} // namespace
+
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
DeviceBase* cpu_device)
: opt_level_(opt_level), cpu_device_(cpu_device) {
@@ -190,14 +224,21 @@ Status ConvertShapeToConstant(const string& op, const DataType& type,
return Status::OK();
}
-Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
- const GraphProperties& properties) {
+bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
+ if (!IsConstant(node)) {
+ return false;
+ }
+ // If the node is fed it's not constant anymore.
+ return feed_nodes_.find(node.name()) == feed_nodes_.end();
+}
+
+Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
// We may add some nodes to the graph to encode control dependencies: there is
// no need to process these, so only iterate over the nodes of the input
// graph.
- const int node_count = graph_.node_size();
+ const int node_count = graph_->node_size();
for (int i = 0; i < node_count; ++i) {
- NodeDef& node = *graph_.mutable_node(i);
+ NodeDef& node = *graph_->mutable_node(i);
const string op = node.op();
if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
continue;
@@ -241,7 +282,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
// cases where the shape/rank/size would have been run in
// the original graph. Additional inputs are extra control
string ctrl_dep =
- AddControlDependency(node.input(0), &graph_, node_map_.get());
+ AddControlDependency(node.input(0), graph_, node_map_.get());
node.set_input(0, ctrl_dep);
node_map_->AddOutput(NodeName(ctrl_dep), node.name());
} else {
@@ -256,7 +297,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
AddPrefixToNodeName(strings::StrCat(node.name(), "-", j),
kConstantFoldingConst);
if (node_map_->GetNode(const_name) == nullptr) {
- NodeDef* added_node = graph_.add_node();
+ NodeDef* added_node = graph_->add_node();
added_node->set_name(const_name);
added_node->set_op("Const");
added_node->set_device(node.device());
@@ -267,7 +308,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
// We add a control dependency to the original ShapeN node,
// so that the node will only be run if all inputs of the
// original ShapeN node are run.
- string ctrl_dep = AddControlDependency(node.name(), &graph_,
+ string ctrl_dep = AddControlDependency(node.name(), graph_,
node_map_.get());
*added_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
@@ -285,6 +326,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
return Status::OK();
}
+namespace {
bool ShapesEqual(const TensorShapeProto& shape1,
const TensorShapeProto& shape2) {
if (shape1.unknown_rank() || shape2.unknown_rank()) {
@@ -297,11 +339,13 @@ bool ShapesEqual(const TensorShapeProto& shape1,
if (shape1.dim(i).size() != shape2.dim(i).size()) {
return false;
}
+ if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
+ return false;
+ }
}
return true;
}
-namespace {
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
BCast::Vec* shape, int64* min_id) {
if (shape_node.op() == "Shape") {
@@ -344,9 +388,9 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
if (shape_node1 == nullptr ||
- (shape_node1->op() != "Shape" && shape_node1->op() != "Const") ||
+ (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
shape_node2 == nullptr ||
- (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) {
+ (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
return Status::OK();
}
int64 min_id = 0;
@@ -392,13 +436,13 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
strings::StrCat(node.name(), "-", j), kConstantFoldingConst);
out[j] = node_map_->GetNode(const_name);
if (out[j] == nullptr) {
- out[j] = graph_.add_node();
+ out[j] = graph_->add_node();
Tensor value(type, TensorShape({0}));
*out[j] = CreateNodeDef(const_name, TensorValue(&value));
out[j]->set_device(node.device());
node_map_->AddNode(const_name, out[j]);
string ctrl_dep =
- AddControlDependency(node.name(), &graph_, node_map_.get());
+ AddControlDependency(node.name(), graph_, node_map_.get());
*out[j]->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
}
@@ -426,7 +470,7 @@ Status ConstantFolding::MaterializeReductionIndices(
return Status::OK();
}
const NodeDef* indices = node_map_->GetNode(node->input(1));
- if (!indices || IsConstant(*indices)) {
+ if (!indices || IsReallyConstant(*indices)) {
// The reduction indices are already constant, there's nothing to do.
return Status::OK();
}
@@ -479,7 +523,7 @@ Status ConstantFolding::MaterializeReductionIndices(
if (node_map_->GetNode(const_name)) {
return Status::OK();
}
- NodeDef* reduction_indices = graph_.add_node();
+ NodeDef* reduction_indices = graph_->add_node();
Tensor value(dtype, TensorShape({rank}));
for (int i = 0; i < rank; ++i) {
if (dtype == DT_INT32) {
@@ -491,7 +535,7 @@ Status ConstantFolding::MaterializeReductionIndices(
*reduction_indices = CreateNodeDef(const_name, TensorValue(&value));
reduction_indices->set_device(node->device());
string ctrl_dep =
- AddControlDependency(node->input(1), &graph_, node_map_.get());
+ AddControlDependency(node->input(1), graph_, node_map_.get());
*reduction_indices->add_input() = ctrl_dep;
node_map_->AddNode(const_name, reduction_indices);
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
@@ -504,10 +548,10 @@ Status ConstantFolding::MaterializeReductionIndices(
}
Status ConstantFolding::MaterializeConstants(
- const GrapplerItem& item, const GraphProperties& properties) {
- const int node_count = graph_.node_size();
+ const GraphProperties& properties) {
+ const int node_count = graph_->node_size();
for (int i = 0; i < node_count; ++i) {
- NodeDef& node = *graph_.mutable_node(i);
+ NodeDef& node = *graph_->mutable_node(i);
const string& op = node.op();
if (op == "BroadcastGradientArgs") {
TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
@@ -523,24 +567,23 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (node.input().empty()) {
return false;
}
-
// Skips nodes that must be preserved except whitelisted nodes.
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
return false;
}
-
- // Skips ops that don't benefit from folding.
- const string& op = node.op();
- // Skip constants, they're already folded
- if (op == "Const") {
+ // Skip control flow nodes, they can't be folded
+ if (ModifiesFrameInfo(node)) {
return false;
}
- // Skip constrol flow nodes, they can't be folded
- if (op == "Enter" || op == "RefEnter" || op == "Exit" || op == "RefExit" ||
- op == "NextIteration" || op == "RefNextIteration") {
+ // Skip constants, they're already folded
+ if (IsConstant(node)) {
return false;
}
+
+ // Skips ops that don't benefit from folding.
+ const string& op = node.op();
+
if (op.find("Placeholder") == 0) {
return false;
}
@@ -594,7 +637,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (!input_node) {
return false;
}
- bool is_const = IsConstant(*input_node);
+ bool is_const = IsReallyConstant(*input_node);
if (!is_const && !is_merge) {
return false;
}
@@ -612,6 +655,36 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return true;
}
+namespace {
+
+#define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \
+ case DTYPE: \
+ t->add_##NAME##_val(static_cast<TYPE>(value)); \
+ break;
+
+Status CreateConstantTensorAttrValue(DataType type, double value,
+ const TensorShapeProto& shape,
+ AttrValue* attr_tensor) {
+ TensorProto* t = attr_tensor->mutable_tensor();
+ *t->mutable_tensor_shape() = shape;
+ switch (type) {
+ SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
+ SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
+ SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
+ SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
+ default:
+ return errors::InvalidArgument("Unsupported type: ", type);
+ }
+ return Status::OK();
+}
+
+#undef SET_TENSOR_CAL_CASE
+} // namespace
+
// static
NodeDef ConstantFolding::CreateNodeDef(const string& name,
const TensorValue& tensor) {
@@ -652,6 +725,14 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name,
POPULATE_TENSOR_PROTO(tensor, t, int64, int64)
} else if (tensor->dtype() == DT_INT32) {
POPULATE_TENSOR_PROTO(tensor, t, int32, int)
+ } else if (tensor->dtype() == DT_INT16) {
+ POPULATE_TENSOR_PROTO(tensor, t, int16, int)
+ } else if (tensor->dtype() == DT_INT8) {
+ POPULATE_TENSOR_PROTO(tensor, t, int8, int)
+ } else if (tensor->dtype() == DT_UINT8) {
+ POPULATE_TENSOR_PROTO(tensor, t, uint8, int)
+ } else if (tensor->dtype() == DT_BOOL) {
+ POPULATE_TENSOR_PROTO(tensor, t, bool, bool)
}
}
if (optimized) {
@@ -720,7 +801,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
break;
}
const NodeDef* input_node = node_map_->GetNode(input);
- if (!IsConstant(*input_node)) {
+ if (!IsReallyConstant(*input_node)) {
return Status(error::INVALID_ARGUMENT,
strings::StrCat("Can't fold ", node.name(), ", its ", input,
" isn't constant"));
@@ -774,7 +855,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
continue;
}
NodeDef* input_node = node_map_->GetNode(input);
- if (!IsConstant(*input_node)) {
+ if (!IsReallyConstant(*input_node)) {
continue;
}
bool valid_input = true;
@@ -955,8 +1036,8 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
Status ConstantFolding::FoldGraph(GraphDef* output) {
std::unordered_set<string> processed_nodes;
std::deque<NodeDef*> queue;
- for (int i = 0; i < graph_.node_size(); i++) {
- auto node = graph_.mutable_node(i);
+ for (int i = 0; i < graph_->node_size(); i++) {
+ auto node = graph_->mutable_node(i);
if (IsFoldable(*node)) {
queue.push_back(node);
}
@@ -995,7 +1076,7 @@ Status ConstantFolding::FoldGraph(GraphDef* output) {
output->mutable_node()->DeleteSubrange(last + 1,
output->node_size() - last - 1);
- for (const auto& node : graph_.node()) {
+ for (const auto& node : graph_->node()) {
// If no fetch nodes is provided, we conservatively
// keep all nodes in the original graph in case users need to fetch
// their values.
@@ -1016,7 +1097,7 @@ bool ConstantFolding::IsSimplifiableReduction(const NodeDef& node) const {
if (IsReduction(node)) {
CHECK_LE(2, node.input_size());
const NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
- if (IsConstant(*reductions_indices)) {
+ if (IsReallyConstant(*reductions_indices)) {
TensorVector output;
Status s = EvaluateNode(*reductions_indices, TensorVector(), &output);
if (!s.ok()) {
@@ -1040,7 +1121,7 @@ bool ConstantFolding::IsSimplifiableReshape(
}
CHECK_LE(2, node.input_size());
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
- if (!IsConstant(*new_shape)) {
+ if (!IsReallyConstant(*new_shape)) {
return false;
}
TensorVector outputs;
@@ -1090,8 +1171,107 @@ bool ConstantFolding::IsSimplifiableReshape(
return shape.IsCompatibleWith(new_dims);
}
+#define IS_VALUE_CASE(DTYPE, VALUE) \
+ case DTYPE: \
+ return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
+ node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
+
+#define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
+#define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
+
+bool ConstantFolding::IsOnes(const NodeDef& node) const {
+ if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
+ return false;
+ }
+ if (node.op() == "OnesLike") {
+ return true;
+ }
+ if (node.op() != "Const") {
+ return false;
+ }
+ const auto dtype = node.attr().at("dtype").type();
+ switch (dtype) {
+ // IS_ONES_CASE(DT_HALF);
+ IS_ONES_CASE(DT_FLOAT);
+ IS_ONES_CASE(DT_DOUBLE);
+ IS_ONES_CASE(DT_UINT8);
+ IS_ONES_CASE(DT_INT8);
+ IS_ONES_CASE(DT_UINT16);
+ IS_ONES_CASE(DT_INT16);
+ IS_ONES_CASE(DT_INT32);
+ IS_ONES_CASE(DT_INT64);
+ IS_ONES_CASE(DT_COMPLEX64);
+ IS_ONES_CASE(DT_COMPLEX128);
+ default:
+ LOG(ERROR) << "Unexpected type " << DataTypeString(dtype);
+ return false;
+ }
+ return false;
+}
+
+bool ConstantFolding::IsZeros(const NodeDef& node) const {
+ if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
+ return false;
+ }
+ if (node.op() == "ZerosLike") {
+ return true;
+ }
+ if (!IsConstant(node)) {
+ return false;
+ }
+ const auto dtype = node.attr().at("dtype").type();
+ switch (dtype) {
+ // IS_ZEROS_CASE(DT_HALF);
+ IS_ZEROS_CASE(DT_FLOAT);
+ IS_ZEROS_CASE(DT_DOUBLE);
+ IS_ZEROS_CASE(DT_UINT8);
+ IS_ZEROS_CASE(DT_INT8);
+ IS_ZEROS_CASE(DT_UINT16);
+ IS_ZEROS_CASE(DT_INT16);
+ IS_ZEROS_CASE(DT_INT32);
+ IS_ZEROS_CASE(DT_INT64);
+ IS_ZEROS_CASE(DT_COMPLEX64);
+ IS_ZEROS_CASE(DT_COMPLEX128);
+ default:
+ LOG(ERROR) << "Unexpected type " << DataTypeString(dtype);
+ return false;
+ }
+ return false;
+}
+
+void ConstantFolding::ReplaceAddOrMulWithIdentity(int input_to_forward,
+ NodeDef* node) {
+ node->set_op("Identity");
+ // Propagate the designated input through the identity.
+ node->mutable_input()->SwapElements(0, input_to_forward);
+ // Add all other inputs as control dependencies.
+ for (int i = 1; i < node->input_size(); ++i) {
+ node->set_input(i, AsControlDependency(node->input(i)));
+ }
+ graph_modified_ = true;
+}
+
+Status ConstantFolding::ReplaceAddOrMulWithConstant(
+ double value, const TensorShapeProto& shape, NodeDef* node) {
+ AttrValue tensor_attr;
+ TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(node->attr().at("T").type(),
+ value, shape, &tensor_attr));
+ node->mutable_attr()->insert({"value", tensor_attr});
+ node->set_op("Const");
+ // Convert all inputs to control dependencies.
+ for (int i = 0; i < node->input_size(); ++i) {
+ if (IsControlInput(node->input(i))) {
+ break;
+ }
+ node->set_input(i, AsControlDependency(node->input(i)));
+ }
+ graph_modified_ = true;
+ return Status::OK();
+}
+
Status ConstantFolding::SimplifyGraph(GraphDef* output,
- const GraphProperties& properties) {
+ const GraphProperties& properties,
+ bool use_shape_info) {
for (auto& node : *output->mutable_node()) {
if (IsSimplifiableReduction(node)) {
// Replace the reduction node with an identity node, that can be further
@@ -1116,10 +1296,10 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
*node.add_input() = input;
}
}
- // It's possible to feed a placeholder with a tensor that doesn't have the
- // proper shape, and reshape this tensor later on. Therefore only remove
- // reshapes in graphs that don't have placeholders.
- if (IsSimplifiableReshape(node, properties)) {
+ const bool safe_to_use_shapes =
+ use_shape_info &&
+ (feed_nodes_.empty() || opt_level_ == RewriterConfig::AGGRESSIVE);
+ if (safe_to_use_shapes && IsSimplifiableReshape(node, properties)) {
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
DataType output_type = node.attr().at("T").type();
node.set_op("Identity");
@@ -1134,6 +1314,63 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
*node.add_input() = input;
}
}
+
+ // Simplify multiplication by ones or zeros, and addition of zeros.
+ bool is_mul = IsMul(node);
+ bool is_add = IsAdd(node);
+ if (opt_level_ == RewriterConfig::AGGRESSIVE && use_shape_info &&
+ (is_mul || is_add) && properties.HasInputProperties(node.name()) &&
+ properties.HasOutputProperties(node.name())) {
+ const NodeDef* x = node_map_->GetNode(node.input(0));
+ const NodeDef* y = node_map_->GetNode(node.input(1));
+ if (x == nullptr || y == nullptr) {
+ return errors::InvalidArgument("Invalid inputs to node: ",
+ node.DebugString());
+ }
+ const TensorShapeProto& output_shape =
+ properties.GetOutputProperties(node.name())[0].shape();
+ const TensorShapeProto& x_shape =
+ properties.GetInputProperties(node.name())[0].shape();
+
+ // Simplify multiplication by or addition of zeros.
+ const bool x_is_zero = IsZeros(*x);
+ const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
+ if (x_is_zero && x_matches_output_shape) {
+ // 0 * y = 0 or 0 + y = y.
+ ReplaceAddOrMulWithIdentity(is_mul ? 0 : 1, &node);
+ continue;
+ }
+ const TensorShapeProto& y_shape =
+ properties.GetInputProperties(node.name())[1].shape();
+ const bool y_is_zero = IsZeros(*y);
+ const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
+ if (y_is_zero && y_matches_output_shape) {
+ // x * 0 = 0 or x + 0 = x.
+ ReplaceAddOrMulWithIdentity(is_mul ? 1 : 0, &node);
+ continue;
+ }
+
+ if (is_mul) {
+ // Simplify multiplication by zeros where the output shape does not
+ // match the shape of the zero input.
+ if (x_is_zero || y_is_zero) {
+ TF_RETURN_IF_ERROR(
+ ReplaceAddOrMulWithConstant(0, output_shape, &node));
+ continue;
+ }
+
+ // Simplify multiplication by ones.
+ if (IsOnes(*x) && y_matches_output_shape) {
+ // 1 * y = y.
+ ReplaceAddOrMulWithIdentity(1, &node);
+ continue;
+ } else if (IsOnes(*y) && x_matches_output_shape) {
+ // x * 1 = x.
+ ReplaceAddOrMulWithIdentity(0, &node);
+ continue;
+ }
+ }
+ }
}
return Status::OK();
}
@@ -1141,7 +1378,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
const GrapplerItem& item,
GraphDef* output) {
- node_map_.reset(new NodeMap(&graph_));
+ node_map_.reset(new NodeMap(graph_));
nodes_whitelist_.clear();
// Fold fetch nodes iff it has a single fanout. Note that if a fetch node
// has a single fanout, it would be rewritten as a constant with the same
@@ -1158,36 +1395,34 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
}
GraphProperties properties(item);
- const bool has_feed = !item.feed.empty();
- bool needs_shapes = !has_feed || opt_level_ == RewriterConfig::AGGRESSIVE;
- Status s = errors::Unknown(
- "The graph properties are needed but were not initialized");
- if (needs_shapes) {
- s = properties.InferStatically();
- }
-
- if (!has_feed && s.ok()) {
- // Only use static shape information when there is no feed in the
- // graph. That's because it's possible to feed a placeholder with a tensor
- // of any shape, which could make the static information inconsistent with
- // the shapes actually fed.
- TF_RETURN_IF_ERROR(MaterializeShapes(item, properties));
- }
- if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) {
- TF_RETURN_IF_ERROR(MaterializeConstants(item, properties));
+ // It's possible to feed a placeholder with a tensor of any shape: make sure
+ // that the shape inference deals with this conservatively unless we're in
+ // aggressive mode.
+ const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
+ Status s = properties.InferStatically(assume_valid_feeds);
+ const bool can_use_shape_info = s.ok();
+
+ if (can_use_shape_info) {
+ TF_RETURN_IF_ERROR(MaterializeShapes(properties));
+
+ if (opt_level_ == RewriterConfig::AGGRESSIVE) {
+ TF_RETURN_IF_ERROR(MaterializeConstants(properties));
+ }
}
TF_RETURN_IF_ERROR(FoldGraph(output));
- if (!has_feed && s.ok()) {
- TF_RETURN_IF_ERROR(SimplifyGraph(output, properties));
- }
+ TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+
return Status::OK();
}
Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
nodes_to_preserve_ = item.NodesToPreserve();
+ for (const auto& feed : item.feed) {
+ feed_nodes_.insert(NodeName(feed.first));
+ }
if (cpu_device_ == nullptr) {
owned_device_.reset(new DeviceSimple());
@@ -1200,13 +1435,13 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
*output = item.graph;
int64 node_count;
do {
- graph_.Swap(output);
- item_to_optimize.graph = graph_;
+ graph_modified_ = false;
+ item_to_optimize.graph.Swap(output);
+ graph_ = &item_to_optimize.graph;
*output = GraphDef();
- node_count = graph_.node_size();
+ node_count = graph_->node_size();
TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output));
- } while (output->node_size() != node_count);
-
+ } while (graph_modified_ || output->node_size() != node_count);
*output->mutable_library() = item.graph.library();
*output->mutable_versions() = item.graph.versions();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index f04f413c10..3bb9926338 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -51,16 +51,16 @@ class ConstantFolding : public GraphOptimizer {
const GraphDef& optimize_output, double result) override;
private:
- Status MaterializeShapes(const GrapplerItem& item,
- const GraphProperties& properties);
+ bool IsReallyConstant(const NodeDef& node) const;
+
+ Status MaterializeShapes(const GraphProperties& properties);
Status MaterializeBroadcastGradientArgs(const NodeDef& node,
const GraphProperties& properties);
Status MaterializeReductionIndices(NodeDef* node,
const GraphProperties& properties);
- Status MaterializeConstants(const GrapplerItem& item,
- const GraphProperties& properties);
+ Status MaterializeConstants(const GraphProperties& properties);
bool IsFoldable(const NodeDef& node) const;
Status EvaluateNode(const NodeDef& node,
@@ -72,12 +72,19 @@ class ConstantFolding : public GraphOptimizer {
Status FoldNode(NodeDef* node, GraphDef* output_graph);
+ bool IsOnes(const NodeDef& node) const;
+ bool IsZeros(const NodeDef& node) const;
+ void ReplaceAddOrMulWithIdentity(int input_to_forward, NodeDef* node);
+ Status ReplaceAddOrMulWithConstant(double value,
+ const TensorShapeProto& shape,
+ NodeDef* node);
Status FoldGraph(GraphDef* output);
bool IsSimplifiableReduction(const NodeDef& node) const;
bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
- Status SimplifyGraph(GraphDef* output, const GraphProperties& properties);
+ Status SimplifyGraph(GraphDef* output, const GraphProperties& properties,
+ bool use_shape_info);
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
GraphDef* output);
@@ -88,11 +95,13 @@ class ConstantFolding : public GraphOptimizer {
std::unique_ptr<DeviceBase> owned_device_;
std::unique_ptr<ResourceMgr> resource_mgr_;
- GraphDef graph_;
+ GraphDef* graph_;
std::unique_ptr<NodeMap> node_map_;
std::unordered_set<string> nodes_to_preserve_;
std::unordered_set<string> nodes_whitelist_;
+ std::unordered_set<string> feed_nodes_;
bool has_fetch_;
+ bool graph_modified_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b2d9b02c68..32a691d3ee 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -77,11 +77,166 @@ TEST_F(ConstantFoldingTest, SimpleFolding) {
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
+TEST_F(ConstantFoldingTest, NeutralElement) {
+ for (bool use_const : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({1, 2})));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({1, 2})));
+ Output zeros =
+ !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
+ : ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f}, {1, 2});
+ Output zeros_broadcast =
+ ops::Const(s.WithOpName("zeros_broadcast"), {0.0f}, {1, 1});
+ Output ones = !use_const
+ ? ops::OnesLike(s.WithOpName("ones"), x)
+ : ops::Const(s.WithOpName("ones"), {1.0f, 1.0f}, {1, 2});
+ Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
+ Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
+ Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
+ Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
+ Output mul5 = ops::Mul(s.WithOpName("mul1"), x, zeros_broadcast);
+ Output mul6 = ops::Mul(s.WithOpName("mul2"), zeros_broadcast, y);
+ Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
+ Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
+ Output addn = ops::AddN(s, {mul1, mul2, mul3, mul4, add1, add2});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(14, output.node_size());
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ const string& name = node.name();
+ if (name == "mul1") {
+ if (use_const) {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^x", node.input(0));
+ } else {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ("^x", node.input(1));
+ }
+ } else if (name == "mul2") {
+ if (use_const) {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^y", node.input(0));
+ } else {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ("^y", node.input(1));
+ }
+ } else if (name == "mul3") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ } else if (name == "mul4") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ } else if (name == "mul5") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^x", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ TensorProto t = node.attr().at("value").tensor();
+ EXPECT_EQ(1, t.float_val_size());
+ EXPECT_EQ(0, t.float_val(0));
+ EXPECT_EQ(2, t.tensor_shape().dim_size());
+ EXPECT_EQ(1, t.tensor_shape().dim(0).size());
+ EXPECT_EQ(2, t.tensor_shape().dim(1).size());
+ } else if (name == "mul6") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^y", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ TensorProto t = node.attr().at("value").tensor();
+ EXPECT_EQ(1, t.float_val_size());
+ EXPECT_EQ(0, t.float_val(0));
+ EXPECT_EQ(2, t.tensor_shape().dim_size());
+ EXPECT_EQ(1, t.tensor_shape().dim(0).size());
+ EXPECT_EQ(2, t.tensor_shape().dim(1).size());
+ } else if (name == "add1") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^zeros", node.input(1));
+ } else if (name == "add2") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("^zeros", node.input(1));
+ }
+ }
+ }
+}
+
+TEST_F(ConstantFoldingTest, CreateConstNodes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+#define MAKE_TEST_GRAPH(TYPE) \
+ Output TYPE##_const = \
+ ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
+ Output TYPE##_mul = \
+ ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \
+ Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
+
+ MAKE_TEST_GRAPH(float);
+ MAKE_TEST_GRAPH(double);
+ MAKE_TEST_GRAPH(int64);
+ MAKE_TEST_GRAPH(int32);
+ MAKE_TEST_GRAPH(int16);
+ MAKE_TEST_GRAPH(int8);
+ MAKE_TEST_GRAPH(uint8);
+#undef MAKE_TEST_GRAPH
+
+ Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
+ Output bool_and =
+ ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
+ Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ ConstantFolding fold(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(24, output.node_size());
+ for (const NodeDef& node : output.node()) {
+#define CHECK_RESULT(TYPE, FIELD) \
+ if (node.name() == #TYPE "_mul") { \
+ EXPECT_EQ(5, \
+ node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
+ EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \
+ EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \
+ }
+
+ CHECK_RESULT(float, float);
+ CHECK_RESULT(double, double);
+ CHECK_RESULT(int64, int64);
+ CHECK_RESULT(int32, int);
+ CHECK_RESULT(int16, int);
+ CHECK_RESULT(int8, int);
+ CHECK_RESULT(uint8, int);
+#undef CHECK_RESULT
+
+ if (node.name() == "bool_and") {
+ EXPECT_EQ(5,
+ node.attr().at("value").tensor().tensor_shape().dim(0).size());
+ EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
+ EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
+ }
+ }
+}
+
TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output a = ops::Const(s.WithOpName("a"), 10, {3});
+ Output a = ops::Const(s.WithOpName("a"), 10, {5});
auto b = ops::Unique(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), {b.y});
Output d = ops::Identity(s.WithOpName("d"), {b.idx});
@@ -963,3 +1118,5 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
} // namespace
} // namespace grappler
} // namespace tensorflow
+
+// LocalWords: NewRootScope
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index d5563e9d4c..e9436638f0 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <deque>
#include <unordered_set>
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -69,6 +70,8 @@ std::set<string> GetOpsFormatSupported() {
return ops_format_supported;
}
+// TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
+// because of the worse performance in some cases.
std::set<string> GetOpsFormatAgnostic() {
std::set<string> ops_format_agnostic = {"Add",
"AddN",
@@ -88,7 +91,7 @@ std::set<string> GetOpsFormatAgnostic() {
"Split",
"SquaredDifference",
"Squeeze",
- "Sub"};
+ /*"Sum",*/ "Sub"};
return ops_format_agnostic;
}
@@ -186,33 +189,6 @@ class GraphProcessor {
return node;
}
- NodeDef* AddNodeReductionConst(const string& name, const string& device) {
- NodeDef* node = graph_->add_node();
- node_map_->AddNode(name, node);
- node->set_name(name);
- node->set_op("Const");
- AttrValue attr_data_type;
- attr_data_type.set_type(DT_INT32);
- node->mutable_attr()->insert({"dtype", attr_data_type});
-
- AttrValue attr_tensor;
- Tensor tensor(DT_INT32, TensorShape({3}));
- std::vector<int> axis = {0, 2, 3};
- for (int i = 0; static_cast<size_t>(i) < axis.size(); i++) {
- tensor.flat<int>()(i) = axis[i];
- }
- tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
- node->mutable_attr()->insert({"value", attr_tensor});
- string device_name;
- if (device.empty()) {
- device_name = virtual_placer_.get_canonical_device_name(*node);
- } else {
- device_name = device;
- }
- node->set_device(device_name);
- return node;
- }
-
const VirtualPlacer& virtual_placer_;
const std::unordered_set<string>& nodes_to_preserve_;
GraphDef* graph_;
@@ -370,10 +346,20 @@ class NodeProcessor : public GraphProcessor {
LOG(ERROR) << "Failed to parse TensorProto.";
}
if (tensor.dims() == 1) {
- int c = tensor.flat<int>()(3);
- tensor.flat<int>()(3) = tensor.flat<int>()(2);
- tensor.flat<int>()(2) = tensor.flat<int>()(1);
- tensor.flat<int>()(1) = c;
+ if (tensor.flat<int>().size() == 4) {
+ int c = tensor.flat<int>()(3);
+ tensor.flat<int>()(3) = tensor.flat<int>()(2);
+ tensor.flat<int>()(2) = tensor.flat<int>()(1);
+ tensor.flat<int>()(1) = c;
+ } else if (tensor.flat<int>().size() == 3) {
+ tensor.flat<int>()(0) = 0;
+ tensor.flat<int>()(1) = 2;
+ tensor.flat<int>()(2) = 3;
+ } else {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("Unsupported tensor size: ",
+ tensor.flat<int>().size()));
+ }
} else if (tensor.dims() == 2) {
for (int i = 0; i < 2; i++) {
int c = tensor.matrix<int>()(3, i);
@@ -394,7 +380,9 @@ class NodeProcessor : public GraphProcessor {
Status UpdateAttrValueOfInput(int input_index) {
auto input_node = node_map_->GetNode(node_->input(input_index));
// We created a copy of the node, so that we don't modify the original node,
- // which might be used elsewhere.
+ // which might be used elsewhere. Note that this copy also copies the
+ // control dependency input in the case this node is inside a loop,
+ // to ensure added_node is in the same frame with node_.
NodeDef* added_node = graph_->add_node();
*added_node = *input_node;
string base_name = strings::StrCat(node_->name(), "-", input_node->name());
@@ -411,6 +399,14 @@ class NodeProcessor : public GraphProcessor {
return input_pos;
}
+ virtual std::set<int> GetOutputPos() const {
+ // For most nodes, no need to process control nodes or nodes that use an
+ // output other than the first output: only the first output is of
+ // 4D NCHW/NHWC format and thus relevant here.
+ std::set<int> output_pos = {0};
+ return output_pos;
+ }
+
NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
const string& const_name, DataType data_type,
const TensorShapeProto& input_shape,
@@ -476,37 +472,28 @@ class NodeProcessor : public GraphProcessor {
auto outputs = node_map_->GetOutputs(node_->name());
string const_name = GetOrAddNodePermNCHWToNHWC();
for (const auto& output : outputs) {
- string base_name = strings::StrCat(node_->name(), "-", output->name());
- string node_name =
- AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-");
- // TODO(yaozhang): handle the rare case where node A is connected to more
- // than one input of node B.
- auto it = std::find_if(output->mutable_input()->begin(),
- output->mutable_input()->end(),
- [this](const string& input) {
- string node_name = NodeName(input);
- return node_name.compare(node_->name()) == 0;
- });
- if (it == output->mutable_input()->end()) {
- return Status(error::INVALID_ARGUMENT,
- strings::StrCat("Expect ", node_->name(),
- " to be an input of ", output->name()));
- }
- int output_pos = NodePosition(*it);
- // No need to process control nodes or nodes that use an output
- // other than the first output: only the first output is of 4D NCHW/NHWC
- // format and thus relevant here.
- if (output_pos != 0) {
- continue;
+ for (int i = 0; i < output->input_size(); i++) {
+ auto& input = *output->mutable_input(i);
+ int input_port;
+ string input_name = ParseNodeName(input, &input_port);
+ auto output_pos = GetOutputPos();
+ if (input_name == node_->name() &&
+ output_pos.find(input_port) != output_pos.end()) {
+ string base_name =
+ strings::StrCat(node_->name(), "-", output->name(), "-", i);
+ string node_name =
+ AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-");
+ TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
+ TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
+ AddNodeTranspose(
+ node_name, input, const_name, node_->attr().at("T").type(),
+ node_->attr().at("_output_shapes").list().shape(0), false);
+ input = node_name;
+ node_map_->AddOutput(node_->name(), node_name);
+ node_map_->AddOutput(node_name, output->name());
+ }
}
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
- AddNodeTranspose(
- node_name, node_->name(), const_name, node_->attr().at("T").type(),
- node_->attr().at("_output_shapes").list().shape(0), false);
- *it = node_name;
- node_map_->UpdateOutput(node_->name(), output->name(), node_name);
- node_map_->AddOutput(node_name, output->name());
+ node_map_->RemoveOutput(node_->name(), output->name());
}
return Status::OK();
}
@@ -775,24 +762,52 @@ class AgnosticNodeProcessor : public NodeProcessor {
bool IsNodeAfterNCHWToNHWC() const {
std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
- auto node = node_map_->GetNode(node_->name());
- while (node->input_size() > 0) {
- int data_input_pos = 0;
- if (IsConcatV1(*node) || IsSplit(*node)) {
- data_input_pos = 1;
- }
- node = node_map_->GetNode(node->input(data_input_pos));
- if (IsNodeNCHWToNHWC(node->name())) {
+ std::deque<NodeDef*> queue;
+ auto first_node_pos = DataInputPos(*node_);
+ for (const auto& pos : first_node_pos) {
+ auto input_node = node_map_->GetNode(node_->input(pos));
+ queue.push_back(input_node);
+ }
+ // The code will exit this while loop in one iteration in most cases, as the
+ // graph is already topologically sorted.
+ while (!queue.empty()) {
+ NodeDef* current_node = queue.front();
+ queue.pop_front();
+ if (IsNodeNCHWToNHWC(current_node->name())) {
return true;
}
- bool connected =
- ops_format_agnostic.find(node->op()) != ops_format_agnostic.end();
- if (!connected) {
- return false;
+ // We only continue searching if the path is connected through
+ // format-agnostic nodes.
+ if (ops_format_agnostic.find(current_node->op()) !=
+ ops_format_agnostic.end()) {
+ auto current_node_pos = DataInputPos(*current_node);
+ for (const auto& pos : current_node_pos) {
+ auto input_node = node_map_->GetNode(current_node->input(pos));
+ queue.push_back(input_node);
+ }
}
}
return false;
}
+
+ private:
+ std::vector<int> DataInputPos(const NodeDef& node) const {
+ std::vector<int> pos;
+ if (IsSplit(node)) {
+ return {1};
+ }
+ if (IsConcatV1(node)) {
+ return {1};
+ }
+ if (IsAdd(node) || IsMul(node) || IsRealDiv(node) ||
+ IsSquaredDifference(node) || IsSub(node)) {
+ return {0, 1};
+ }
+ if (node.input_size() > 0 && !IsControlInput(node.input(0))) {
+ return {0};
+ }
+ return {};
+ }
};
class AddNProcessor : public AgnosticNodeProcessor {
@@ -815,42 +830,49 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
public:
explicit BinaryOpProcessor(const OptimizeContext& opt_cxt)
: AgnosticNodeProcessor(opt_cxt) {
- is_4d_with_vector_ = Is4DOperateWithVector();
+ is_4d_with_vector_ = IsNDOperateWithMD(4, 1);
}
protected:
bool ShouldProcess() const override {
+ // TODO(yaozhang): Support IsNDOperateWithMD(1, 4): first input is a vector
+ // and the second input is a 4D tensor; and update CustomizedProcessing()
+ // accordingly.
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
IsNodeAfterNCHWToNHWC() &&
- (Is4DOperateWithND(4) || Is4DOperateWithScalar() ||
- Is4DOperateWithVector()) &&
+ (IsNDOperateWithMD(4, 0) || IsNDOperateWithMD(4, 1) ||
+ IsNDOperateWithMD(4, 4) || IsNDOperateWithMD(0, 4)) &&
IsOnGPU();
}
std::vector<int> GetInputPos() const override {
- std::vector<int> input_pos = {0};
- if (Is4DOperateWithND(4)) {
+ std::vector<int> input_pos;
+ auto input0 = node_map_->GetNode(node_->input(0));
+ auto input1 = node_map_->GetNode(node_->input(1));
+ if (IsDimsFour(*input0)) {
+ input_pos.push_back(0);
+ }
+ if (IsDimsFour(*input1)) {
input_pos.push_back(1);
}
return input_pos;
}
- bool Is4DOperateWithND(int n) const {
+ bool IsDimsFour(const NodeDef& node) const {
+ return NodeProcessor::IsDimsFour(node) || IsNodeNCHWToNHWC(node.name());
+ }
+
+ bool IsNDOperateWithMD(int n, int m) const {
auto input0 = node_map_->GetNode(node_->input(0));
auto input1 = node_map_->GetNode(node_->input(1));
if (input0 && input1) {
- return (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) &&
- ((n == 4)
- ? (IsDimsFour(*input1) || IsNodeNCHWToNHWC(input1->name()))
- : IsDimsN(*input1, n));
+ bool input0_is_n = (n == 4) ? IsDimsFour(*input0) : IsDimsN(*input0, n);
+ bool input1_is_m = (m == 4) ? IsDimsFour(*input1) : IsDimsN(*input1, m);
+ return input0_is_n && input1_is_m;
}
return false;
}
- bool Is4DOperateWithScalar() const { return Is4DOperateWithND(0); }
-
- bool Is4DOperateWithVector() const { return Is4DOperateWithND(1); }
-
NodeDef* AddNodeShapeConst(const string& name, int num_channels) {
NodeDef* node = graph_->add_node();
node_map_->AddNode(name, node);
@@ -948,7 +970,7 @@ class ConcatProcessor : public AgnosticNodeProcessor {
}
Status CustomizedProcessing() override {
- string concat_const_name = GetOrAddNodeConcatConst();
+ string concat_const_name = AddNodeConcatConst()->name();
node_map_->AddOutput(concat_const_name, node_->name());
*node_->mutable_input(axis_node_pos_) = concat_const_name;
return Status::OK();
@@ -956,8 +978,14 @@ class ConcatProcessor : public AgnosticNodeProcessor {
bool IsAlongDimC() const {
auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_));
+ if (!IsConstant(*axis_node)) {
+ return false;
+ }
if (axis_node->attr().find("value") != axis_node->attr().end()) {
- return axis_node->attr().at("value").tensor().int_val(0) == 3;
+ auto tensor = axis_node->attr().at({"value"}).tensor();
+ if (tensor.tensor_shape().dim_size() == 0 && tensor.int_val_size() == 1) {
+ return tensor.int_val(0) == 3;
+ }
}
return false;
}
@@ -965,28 +993,18 @@ class ConcatProcessor : public AgnosticNodeProcessor {
int axis_node_pos_;
private:
- NodeDef* AddNodeConcatConst(const string& suffix, const string& depended_node,
- const string& device) {
- auto const_node = AddNodeConstScalar(
- strings::StrCat(kConcatConst, "-", suffix), device, DT_INT32, 1);
- // This is to ensure the concat node and the const node are
- // in the same frame.
- *const_node->add_input() = AsControlDependency(depended_node);
- return const_node;
- }
-
- string GetOrAddNodeConcatConst() {
- string const_name;
- if (is_in_frame_) {
- int value_node_pos = (axis_node_pos_ == 0) ? 1 : 0;
- auto const_node = AddNodeConcatConst(
- node_->name(), NodeName(node_->input(value_node_pos)),
- node_->device());
- const_name = const_node->name();
- } else {
- const_name = kConcatConst;
- }
- return const_name;
+ NodeDef* AddNodeConcatConst() {
+ auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_));
+ // We created a copy of the node, so that we don't modify the original node,
+ // which might be used elsewhere. Note that this copy also copies the
+ // control dependency input in the case this node is inside a loop,
+ // to ensure added_node is in the same frame with node_.
+ auto added_node = graph_->add_node();
+ *added_node = *axis_node;
+ added_node->set_name(strings::StrCat(kConcatConst, "-", node_->name()));
+ added_node->mutable_attr()->at({"value"}).mutable_tensor()->set_int_val(0,
+ 1);
+ return added_node;
}
};
@@ -1036,6 +1054,16 @@ class SplitProcessor : public AgnosticNodeProcessor {
return input_pos;
}
+ std::set<int> GetOutputPos() const override {
+ std::set<int> output_pos{0};
+ if (HasAttribute(*node_, "num_split").ok()) {
+ for (int i = 1; i < node_->attr().at("num_split").i(); i++) {
+ output_pos.insert(i);
+ }
+ }
+ return output_pos;
+ }
+
Status CustomizedProcessing() override {
string split_const_name = AddNodeSplitConst()->name();
node_map_->AddOutput(split_const_name, node_->name());
@@ -1073,7 +1101,7 @@ class SplitProcessor : public AgnosticNodeProcessor {
// We created a copy of the node, so that we don't modify the original node,
// which might be used elsewhere. Note that this copy also copies the
// control dependency input in the case this node is inside a loop,
- // to ensure added_node is in the same frame with the Split node.
+ // to ensure added_node is in the same frame with node_.
NodeDef* added_node = graph_->add_node();
*added_node = *dim_node;
added_node->set_name(strings::StrCat(kSplitConst, "-", node_->name()));
@@ -1329,20 +1357,21 @@ class SumProcessor : public AgnosticNodeProcessor {
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
- Status CustomizedProcessing() override {
- node_map_->AddOutput(kReductionConst, node_->name());
- *node_->mutable_input(1) = GetOrAddNodeReductionConst();
- return Status::OK();
- }
+ Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); }
private:
bool IsAlongDimNHW() const {
- NodeDef* node = node_map_->GetNode(node_->input(1));
+ NodeDef* reduction_indices = node_map_->GetNode(node_->input(1));
+ if (!IsConstant(*reduction_indices)) {
+ return false;
+ }
Tensor tensor;
- if (node->attr().find({"value"}) == node->attr().end()) {
+ if (reduction_indices->attr().find({"value"}) ==
+ reduction_indices->attr().end()) {
return false;
}
- auto success = tensor.FromProto(node->attr().at({"value"}).tensor());
+ auto success =
+ tensor.FromProto(reduction_indices->attr().at({"value"}).tensor());
if (!success) {
LOG(ERROR) << "Failed to parse TensorProto.";
return false;
@@ -1356,29 +1385,6 @@ class SumProcessor : public AgnosticNodeProcessor {
}
return false;
}
-
- NodeDef* AddNodeReductionConst(const string& suffix,
- const string& depended_node,
- const string& device) {
- auto const_node = GraphProcessor::AddNodeReductionConst(
- strings::StrCat(kReductionConst, "-", suffix), device);
- // This is to ensure the Sum node and the const node are in the
- // same frame.
- *const_node->add_input() = AsControlDependency(depended_node);
- return const_node;
- }
-
- string GetOrAddNodeReductionConst() {
- string const_name;
- if (is_in_frame_) {
- auto const_node = AddNodeReductionConst(
- node_->name(), NodeName(node_->input(0)), node_->device());
- const_name = const_node->name();
- } else {
- const_name = kReductionConst;
- }
- return const_name;
- }
};
class DataLayoutOptimizer : GraphProcessor {
@@ -1409,18 +1415,10 @@ class DataLayoutOptimizer : GraphProcessor {
return AddNodePermConst(kPermNCHWToNHWC, "", {0, 2, 3, 1});
}
- NodeDef* AddNodeConcatConst() {
- return AddNodeConstScalar(kConcatConst, "", DT_INT32, 1);
- }
-
NodeDef* AddNodeGatherAxisConst() {
return AddNodeConstScalar(kGatherAxisConst, "", DT_INT32, 0);
}
- NodeDef* AddNodeReductionConst() {
- return GraphProcessor::AddNodeReductionConst(kReductionConst, "");
- }
-
// Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
Status Expand() {
int node_size_original = graph_->node_size();
@@ -1474,9 +1472,7 @@ class DataLayoutOptimizer : GraphProcessor {
if (graph_->node_size() > node_size_original) {
NodeDef* n = AddNodePermNHWCToNCHW();
n = AddNodePermNCHWToNHWC();
- n = AddNodeConcatConst();
n = AddNodeGatherAxisConst();
- n = AddNodeReductionConst();
std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
for (int i = 0; i < graph_->node_size(); i++) {
if (ops_format_agnostic.find(graph_->node(i).op()) !=
@@ -1620,27 +1616,20 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
virtual_placer_.reset(new VirtualPlacer(cluster));
nodes_to_preserve_ = item.NodesToPreserve();
GraphProperties graph_properties(item);
- auto status = graph_properties.InferStatically();
+ auto status = graph_properties.InferStatically(false);
if (!status.ok()) {
*output = item.graph;
return status;
}
TuningConfig config;
- config.no_gemm = false;
+ config.no_gemm = true;
+ // TODO(yaozhang): Enable tuning with various TuningConfig choices wtih
+ // the measurement-based estimator.
status = Tune(item, graph_properties, config, output);
- // This is based on an empirical observation that if the introduced Transpose
- // nodes is more than 30, not using GEMM implementation would result in better
- // performance.
- if (status.ok() && GetNumTranspose(*output) > 30) {
- config.no_gemm = true;
- status = Tune(item, graph_properties, config, output);
- }
-
if (!status.ok()) {
*output = item.graph;
}
-
return status;
}
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index 8c89f6744b..363b4c3fd8 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -298,6 +298,39 @@ TEST_F(LayoutOptimizerTest, Connectivity) {
EXPECT_EQ(node_i2_output->input(0), "i1");
}
+TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto i1 = ops::Identity(s.WithOpName("i1"), conv);
+ auto i2 = ops::Identity(s.WithOpName("i2"), i1);
+ auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {});
+ auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2);
+ auto i3 = ops::Identity(s.WithOpName("i3"), sub);
+ auto i4 = ops::Identity(s.WithOpName("i4"), i3);
+ auto i5 = ops::Identity(s.WithOpName("i5"), i4);
+ auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {});
+ auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5);
+ auto i6 = ops::Identity(s.WithOpName("i6"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ // Make the graph not in topological order to test the handling of multi-hop
+ // connectivity (here we say two nodes are connected if all nodes in the
+ // middle are layout agnostic). If the graph is already in topological order,
+ // the problem is easier, where layout optimizer only needs to check
+ // single-hop connectivity.
+ NodeMap node_map_original(&item.graph);
+ auto node_i1 = node_map_original.GetNode("i1");
+ auto node_mul = node_map_original.GetNode("mul");
+ node_mul->Swap(node_i1);
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map_output(&output);
+ auto mul_node = node_map_output.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "scalar_mul");
+ EXPECT_EQ(mul_node->input(1), "i5");
+}
+
TEST_F(LayoutOptimizerTest, PreserveFetch) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D(&s, 3, 2, "VALID");
@@ -495,7 +528,175 @@ TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
auto split_node = node_map.GetNode("split");
EXPECT_EQ(split_node->input(0), "i1");
EXPECT_EQ(split_node->input(1),
- "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split");
+ "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split-1");
+}
+
+TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto axis = ops::Const(s.WithOpName("axis"), 3);
+ auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
+ auto concat =
+ ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
+ auto o = ops::Identity(s.WithOpName("o"), concat);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto concat_node = node_map.GetNode("concat");
+ EXPECT_EQ(concat_node->input(0), "split:1");
+ EXPECT_EQ(concat_node->input(1), "split:1");
+ EXPECT_EQ(concat_node->input(2), "split:1");
+ EXPECT_EQ(concat_node->input(3), "LayoutOptimizerConcatConst-concat");
+ auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat");
+ EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
+}
+
+TEST_F(LayoutOptimizerTest, Concat) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto axis = ops::Const(s.WithOpName("axis"), 3);
+ auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
+ auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
+ auto o = ops::Identity(s.WithOpName("o"), concat);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto concat_node = node_map.GetNode("concat");
+ EXPECT_EQ(concat_node->input(0), "split");
+ EXPECT_EQ(concat_node->input(1), "split:1");
+ EXPECT_EQ(concat_node->input(2), "LayoutOptimizerConcatConst-concat");
+ auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat");
+ EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
+}
+
+TEST_F(LayoutOptimizerTest, Sum) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto reduction_indices =
+ ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
+ auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
+ auto o = ops::Identity(s.WithOpName("o"), sum);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
+ // because of the worse performance in some cases.
+ /*
+ NodeMap node_map(&output);
+ auto sum_node = node_map.GetNode("sum");
+ EXPECT_EQ(sum_node->input(0), "Conv2D");
+ EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
+ auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
+ Tensor tensor;
+ EXPECT_TRUE(
+ tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
+ Tensor tensor_expected(DT_INT32, {3});
+ test::FillValues<int>(&tensor_expected, {0, 2, 3});
+ test::ExpectTensorEqual<int>(tensor_expected, tensor);
+ */
+}
+
+TEST_F(LayoutOptimizerTest, MulScalarAnd4D) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
+ auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "scalar");
+ EXPECT_EQ(mul_node->input(1), "Conv2D");
+}
+
+TEST_F(LayoutOptimizerTest, Mul4DAndScalar) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
+ auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "Conv2D");
+ EXPECT_EQ(mul_node->input(1), "scalar");
+}
+
+TEST_F(LayoutOptimizerTest, Mul4DAnd4D) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto i = ops::Identity(s.WithOpName("i"), conv);
+ auto mul = ops::Mul(s.WithOpName("mul"), conv, i);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "Conv2D");
+ EXPECT_EQ(mul_node->input(1), "i");
+}
+
+TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
+ auto mul = ops::Mul(s.WithOpName("mul"), conv, vector);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "Conv2D");
+ EXPECT_EQ(mul_node->input(1), "LayoutOptimizerReshapeNHWCToNCHW-mul-vector");
+ auto mul_const = node_map.GetNode("LayoutOptimizerReshapeConst-mul-vector");
+ Tensor tensor;
+ EXPECT_TRUE(
+ tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
+ Tensor tensor_expected(DT_INT32, {4});
+ test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
+ test::ExpectTensorEqual<int>(tensor_expected, tensor);
+}
+
+TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
+ auto mul = ops::Mul(s.WithOpName("mul"), vector, conv);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ // TODO(yaozhang): Support vector as the first input and 4d tensor as the
+ // second input for BinaryOpProcessor.
+ EXPECT_EQ(mul_node->input(0), "vector");
+ EXPECT_EQ(mul_node->input(1),
+ "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-mul-1");
}
} // namespace
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 7c44ce15c6..a2a2680c4f 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -716,7 +716,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
{
// Estimate the size of the data to swap for each node.
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(true));
for (auto& swap : nodes_to_swap) {
const NodeDef* node = swap.first;
std::vector<OpInfo::TensorProperties> props =
diff --git a/tensorflow/core/grappler/optimizers/static_schedule.cc b/tensorflow/core/grappler/optimizers/static_schedule.cc
index 6ce6deef2c..450e853407 100644
--- a/tensorflow/core/grappler/optimizers/static_schedule.cc
+++ b/tensorflow/core/grappler/optimizers/static_schedule.cc
@@ -86,7 +86,7 @@ Status EstimateEarliestExecutionTimes(
name_map.clear();
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(true));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster);
@@ -154,7 +154,7 @@ Status EstimateRequiredTimes(
}
}
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(true));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 21411097e8..dcffb28513 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3923,7 +3923,11 @@ tf_kernel_library(
"scatter_nd_op.h",
"scatter_nd_op_gpu.cu.cc",
],
- deps = STATE_DEPS + [":dense_update_functor"],
+ deps = STATE_DEPS + [
+ ":dense_update_functor",
+ ":training_op_helpers",
+ ":variable_ops",
+ ],
)
tf_kernel_library(
@@ -5833,11 +5837,11 @@ cc_library(
srcs = ["dataset.cc"],
hdrs = ["dataset.h"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/util/tensor_bundle",
],
)
@@ -6125,6 +6129,18 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "random_dataset_op",
+ srcs = ["random_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "range_dataset_op",
srcs = ["range_dataset_op.cc"],
deps = [
@@ -6291,6 +6307,7 @@ tf_kernel_library(
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
":prefetch_dataset_op",
+ ":random_dataset_op",
":range_dataset_op",
":reader_dataset_ops",
":repeat_dataset_op",
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 3d2bb57aff..1791c51096 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -194,7 +194,23 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ // TODO(yangzihao): Add a CPU implementation for dilated convolution.
+ OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
+ errors::InvalidArgument(
+ "Current Eigen and libxsmm implementations do not "
+ "yet support dilation rates larger than 1."));
}
void Compute(OpKernelContext* context) override {
@@ -262,6 +278,7 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
@@ -290,7 +307,23 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ // TODO(yangzihao): Add a CPU implementation for dilated convolution.
+ OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
+ errors::InvalidArgument(
+ "Current libxsmm and customized CPU implementations do "
+ "not yet support dilation rates larger than 1."));
}
void Compute(OpKernelContext* context) override {
@@ -459,6 +492,7 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
@@ -510,10 +544,30 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
int stride_n = GetTensorDim(strides_, data_format_, 'N');
int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ int stride_h = GetTensorDim(strides_, data_format_, 'H');
+ int stride_w = GetTensorDim(strides_, data_format_, 'W');
OP_REQUIRES(
context, (stride_n == 1 && stride_c == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
@@ -546,13 +600,16 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
// do not support striding on the batch or depth dimension).
const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
+ const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
- stride_rows, stride_cols, padding_, filter_backprop,
- data_format_);
+ dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
+ filter_backprop, data_format_);
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
bool use_cudnn_;
@@ -566,38 +623,46 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
template <typename T>
void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& out_backprop, const Tensor& input, int row_stride,
- int col_stride, const Padding& padding, Tensor* filter_backprop,
- TensorFormat data_format) {
+ const Tensor& out_backprop, const Tensor& input, int row_dilation,
+ int col_dilation, int row_stride, int col_stride, const Padding& padding,
+ Tensor* filter_backprop, TensorFormat data_format) {
using perftools::gputools::dnn::AlgorithmConfig;
using perftools::gputools::dnn::AlgorithmDesc;
using perftools::gputools::dnn::ProfileResult;
+ std::vector<int32> dilations(4, 1);
+ dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation;
+ dilations[GetTensorDimIndex(data_format, 'W')] = col_dilation;
+
std::vector<int32> strides(4, 1);
strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
TensorShape filter_shape = filter_backprop->shape();
ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
+ OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2(
"Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2,
input.shape(), filter_shape, out_backprop.shape(),
- strides, padding, data_format, &dims));
+ dilations, strides, padding, data_format, &dims));
+ // TODO(yangzihao): The padding computations should be done in
+ // GetWindowedOutputSize() functions.
const int padding_rows =
(padding == VALID)
? 0
: std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
dims.spatial_dims[0].stride +
- dims.spatial_dims[0].filter_size -
- dims.spatial_dims[0].input_size);
+ (dims.spatial_dims[0].filter_size - 1) *
+ dims.spatial_dims[0].dilation +
+ 1 - dims.spatial_dims[0].input_size);
const int padding_cols =
(padding == VALID)
? 0
: std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
dims.spatial_dims[1].stride +
- dims.spatial_dims[1].filter_size -
- dims.spatial_dims[1].input_size);
+ (dims.spatial_dims[1].filter_size - 1) *
+ dims.spatial_dims[1].dilation +
+ 1 - dims.spatial_dims[1].input_size);
// TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
@@ -730,7 +795,9 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
+ conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
+ .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
+ .set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
@@ -821,6 +888,8 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
dims.out_depth, // out_depths
{{dims.spatial_dims[0].filter_size, // filter_rows
dims.spatial_dims[1].filter_size}}, // filter_cols
+ {{dims.spatial_dims[0].dilation, // dilation_rows
+ dims.spatial_dims[1].dilation}}, // dilation_cols
{{dims.spatial_dims[0].stride, // stride_rows
dims.spatial_dims[1].stride}}, // stride_cols
{{padding_rows, // padding_rows
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index d28f6b4d10..736241a029 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -198,7 +198,23 @@ class Conv2DFastBackpropInputOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, (dilations_[0] && dilations_[3]),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ // TODO(yangzihao): Add a CPU implementation for dilated convolution.
+ OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
+ errors::InvalidArgument(
+ "Current Eigen and libxsmm implementations do not "
+ "yet support dilation rates larger than 1."));
}
void Compute(OpKernelContext* context) override {
@@ -268,6 +284,7 @@ class Conv2DFastBackpropInputOp : public OpKernel {
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
@@ -296,7 +313,23 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ // TODO(yangzihao): Add a CPU implementation for dilated convolution.
+ OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
+ errors::InvalidArgument(
+ "Current libxsmm and customized CPU implementations do "
+ "not yet support dilation rates larger than 1."));
}
void Compute(OpKernelContext* context) override {
@@ -532,6 +565,7 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
@@ -586,10 +620,30 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
"specify 4 dimensions"));
int stride_n = GetTensorDim(strides_, data_format_, 'N');
int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ int stride_h = GetTensorDim(strides_, data_format_, 'H');
+ int stride_w = GetTensorDim(strides_, data_format_, 'W');
OP_REQUIRES(
context, (stride_n == 1 && stride_c == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
@@ -622,12 +676,16 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
// do not support striding on the batch or depth dimension).
const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
+ const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
- stride_rows, stride_cols, padding_, in_backprop, data_format_);
+ dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
+ in_backprop, data_format_);
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
bool use_cudnn_;
@@ -641,39 +699,48 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
template <typename T>
void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& out_backprop, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* in_backprop,
- TensorFormat data_format) {
+ const Tensor& out_backprop, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride, const Padding& padding,
+ Tensor* in_backprop, TensorFormat data_format) {
using perftools::gputools::dnn::AlgorithmConfig;
using perftools::gputools::dnn::AlgorithmDesc;
using perftools::gputools::dnn::ProfileResult;
std::vector<int32> strides(4, 1);
- strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
- strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
+ std::vector<int32> dilations(4, 1);
+ auto input_h = GetTensorDimIndex(data_format, 'H');
+ auto input_w = GetTensorDimIndex(data_format, 'W');
+ strides[input_h] = row_stride;
+ strides[input_w] = col_stride;
+ dilations[input_h] = row_dilation;
+ dilations[input_w] = col_dilation;
TensorShape input_shape = in_backprop->shape();
const TensorShape& filter_shape = filter.shape();
ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
+ OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2(
"Conv2DSlowBackpropInput", /*num_spatial_dims=*/2,
input_shape, filter_shape, out_backprop.shape(),
- strides, padding, data_format, &dims));
+ dilations, strides, padding, data_format, &dims));
+ // TODO(yangzihao): The padding computations should be done in
+ // GetWindowedOutputSize() functions.
const int padding_rows =
(padding == VALID)
? 0
: std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
dims.spatial_dims[0].stride +
- dims.spatial_dims[0].filter_size -
- dims.spatial_dims[0].input_size);
+ (dims.spatial_dims[0].filter_size - 1) *
+ dims.spatial_dims[0].dilation +
+ 1 - dims.spatial_dims[0].input_size);
const int padding_cols =
(padding == VALID)
? 0
: std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
dims.spatial_dims[1].stride +
- dims.spatial_dims[1].filter_size -
- dims.spatial_dims[1].input_size);
+ (dims.spatial_dims[1].filter_size - 1) *
+ dims.spatial_dims[1].dilation +
+ 1 - dims.spatial_dims[1].input_size);
// TODO(keveman): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
@@ -789,7 +856,9 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
+ conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
+ .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
+ .set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
@@ -875,6 +944,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
dims.out_depth, // out_depths
{{dims.spatial_dims[0].filter_size, // filter_rows
dims.spatial_dims[1].filter_size}}, // filter_cols
+ {{dims.spatial_dims[0].dilation, // dilation_rows
+ dims.spatial_dims[1].dilation}}, // dilation_cols
{{dims.spatial_dims[0].stride, // stride_rows
dims.spatial_dims[1].stride}}, // stride_cols
{{padding_rows, // padding_rows
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index e068fb8684..535586d53a 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -175,15 +175,17 @@ template <typename Device, typename T>
struct LaunchConv2DBackpropInputOp {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& filter,
- int row_stride, int col_stride, const Padding& padding,
- Tensor* in_backprop, TensorFormat data_format);
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding, Tensor* in_backprop,
+ TensorFormat data_format);
};
template <typename Device, typename T>
struct LaunchConv2DBackpropFilterOp {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& input,
- int row_stride, int col_stride, const Padding& padding,
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding,
Tensor* filter_backprop, TensorFormat data_format);
};
@@ -191,8 +193,9 @@ struct LaunchConv2DBackpropFilterOp {
template <typename T>
struct LaunchConv2DBackpropInputOp<Eigen::GpuDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* output,
+ const Tensor& input, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride,
+ const Padding& padding, Tensor* output,
TensorFormat data_format);
};
@@ -200,7 +203,8 @@ template <typename T>
struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& input,
- int row_stride, int col_stride, const Padding& padding,
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding,
Tensor* filter_backprop, TensorFormat data_format);
};
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index c2d24d1f12..4d0f1ab317 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -645,6 +645,9 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
{{input_size[0], input_size[1], input_size[2]}},
out_depth,
{{filter_size[0], filter_size[1], filter_size[2]}},
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
+ /*dilations=*/{{1, 1, 1}},
{{strides[0], strides[1], strides[2]}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
@@ -1011,6 +1014,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
{{input_size[0], input_size[1], input_size[2]}},
out_depth,
{{filter_size[0], filter_size[1], filter_size[2]}},
+ {{1, 1, 1}},
{{strides[0], strides[1], strides[2]}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index bb67113fb0..ba40c428e4 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -112,7 +112,8 @@ struct LaunchGeneric {
template <typename T>
struct LaunchConv2DOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
+ const Tensor& input, const Tensor& filter,
+ int /*row_dilation*/, int /*col_dilation*/, int row_stride,
int col_stride, const Padding& padding, Tensor* output,
TensorFormat data_format) {
if (data_format != FORMAT_NHWC) {
@@ -133,8 +134,10 @@ class LaunchDeepConvOp {
const Tensor& filter, int batch, int input_rows,
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
- int out_cols, int out_depth, int stride_rows, int stride_cols,
- Tensor* output, TensorFormat data_format) {
+ int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/,
+ int /*dilation_cols*/, int /*stride_rows*/,
+ int /*stride_cols*/, Tensor* /*output*/,
+ TensorFormat /*data_format*/) {
return false;
}
};
@@ -147,9 +150,11 @@ class LaunchDeepConvOp<CPUDevice, float> {
const Tensor& filter, int batch, int input_rows,
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
- int out_cols, int out_depth, int stride_rows, int stride_cols,
+ int out_cols, int out_depth, int dilation_rows,
+ int dilation_cols, int stride_rows, int stride_cols,
Tensor* output, TensorFormat data_format) {
- if (data_format != FORMAT_NHWC ||
+ if (data_format != FORMAT_NHWC || dilation_rows != 1 ||
+ dilation_cols != 1 ||
!CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
in_depth, out_depth, out_rows, out_cols)) {
return false;
@@ -187,7 +192,8 @@ class LaunchXsmmConvOp {
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
int out_cols, int out_depth, int stride_rows, int stride_cols,
- Tensor* output, TensorFormat data_format) {
+ int dilation_rows, int dilation_cols, Tensor* output,
+ TensorFormat data_format) {
return false;
}
};
@@ -199,7 +205,8 @@ class LaunchXsmmConvOp<CPUDevice, float> {
const Tensor& filter, int batch, int input_rows,
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
- int out_cols, int out_depth, int stride_rows, int stride_cols,
+ int out_cols, int out_depth, int dilation_rows,
+ int dilation_cols, int stride_rows, int stride_cols,
Tensor* output, TensorFormat data_format) {
auto num_threads =
ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
@@ -228,11 +235,8 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
- if (!CanUseXsmmConv2D(desc, data_format)) {
- return false;
- }
-
- if (!CanUseXsmmConv2D(desc, data_format)) {
+ if (dilation_rows != 1 || dilation_cols != 1 ||
+ !CanUseXsmmConv2D(desc, data_format)) {
return false;
}
@@ -251,6 +255,7 @@ template <typename Device, typename T>
class Conv2DOp : public BinaryOp<T> {
public:
explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
@@ -259,15 +264,35 @@ class Conv2DOp : public BinaryOp<T> {
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
OP_REQUIRES(context, strides_.size() == 4,
errors::InvalidArgument("Sliding window strides field must "
"specify 4 dimensions"));
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
+ const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
+ const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
OP_REQUIRES(
context, stride_n == 1 && stride_c == 1,
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+
+ const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
@@ -334,18 +359,22 @@ class Conv2DOp : public BinaryOp<T> {
errors::InvalidArgument("batch is too large"));
const int batch = static_cast<int>(batch_raw);
- // For now we take the stride from the second and third dimensions only (we
- // do not support striding on the batch or depth dimension).
+ // For now we take the stride and dilation from the second and third
+ // dimensions only (we do not support striding or dilation on the batch or
+ // depth dimension).
const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
+ const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
+
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context,
- GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
- padding_, &out_rows, &pad_rows));
- OP_REQUIRES_OK(context,
- GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
- padding_, &out_cols, &pad_cols));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
+ input_rows, filter_rows, dilation_rows,
+ stride_rows, padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
+ input_cols, filter_cols, dilation_cols,
+ stride_cols, padding_, &out_cols, &pad_cols));
TensorShape out_shape =
ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
@@ -361,6 +390,8 @@ class Conv2DOp : public BinaryOp<T> {
<< ", filter_rows = " << filter_rows
<< ", stride_rows = " << stride_rows
<< ", stride_cols = " << stride_cols
+ << ", dilation_rows = " << dilation_rows
+ << ", dilation_cols = " << dilation_cols
<< ", out_depth = " << out_depth;
// If there is nothing to compute, return.
@@ -372,7 +403,8 @@ class Conv2DOp : public BinaryOp<T> {
if (LaunchXsmmConvOp<Device, T>::Run(
context, input, filter, batch, input_rows, input_cols, in_depth,
filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, stride_rows, stride_cols, output, data_format_)) {
+ out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
+ output, data_format_)) {
return;
}
#endif
@@ -380,15 +412,18 @@ class Conv2DOp : public BinaryOp<T> {
if (LaunchDeepConvOp<Device, T>::Run(
context, input, filter, batch, input_rows, input_cols, in_depth,
filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, stride_rows, stride_cols, output, data_format_)) {
+ out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
+ output, data_format_)) {
return;
}
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- stride_rows, stride_cols, padding_, output, data_format_);
+ dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
+ output, data_format_);
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
bool use_cudnn_;
Padding padding_;
@@ -443,9 +478,9 @@ typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters,
template <typename T>
void LaunchConv2DOp<GPUDevice, T>::operator()(
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input_param, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* output,
- TensorFormat data_format) {
+ const Tensor& input_param, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride, const Padding& padding,
+ Tensor* output, TensorFormat data_format) {
using perftools::gputools::dnn::AlgorithmConfig;
using perftools::gputools::dnn::AlgorithmDesc;
using perftools::gputools::dnn::ProfileResult;
@@ -461,8 +496,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
Tensor input = input_param;
- if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
- col_stride == 1 && data_format == FORMAT_NHWC) {
+ if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_dilation == 1 &&
+ col_dilation == 1 && row_stride == 1 && col_stride == 1 &&
+ data_format == FORMAT_NHWC) {
// 1x1 filter, so call cublas directly.
const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
const uint64 k = filter.dim_size(2);
@@ -487,7 +523,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
}
return;
} else if (filter.dim_size(0) == input.dim_size(1) &&
- filter.dim_size(1) == input.dim_size(2) && padding == VALID &&
+ filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
+ col_dilation == 1 && padding == VALID &&
data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, so call cublas
// directly.
@@ -530,17 +567,19 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
const int64 patch_cols = filter.dim_size(1);
if (padding == SAME) {
// Total padding on rows and cols is
- // Pr = (R' - 1) * S + Kr - R
- // Pc = (C' - 1) * S + Kc - C
+ // Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R
+ // Pc = (C' - 1) * S + (Kc - 1) * Dc + 1 - C
// where (R', C') are output dimensions, (R, C) are input dimensions, S
- // is stride, (Kr, Kc) are filter dimensions.
+ // is stride, (Dr, Dc) are dilations, (Kr, Kc) are filter dimensions.
// We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
// and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
// we pad more on the right and bottom than on the top and left.
padding_rows =
- std::max<int>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
+ std::max<int>(0, (out_rows - 1) * row_stride +
+ (patch_rows - 1) * row_dilation + 1 - in_rows);
padding_cols =
- std::max<int>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
+ std::max<int>(0, (out_cols - 1) * col_stride +
+ (patch_cols - 1) * col_dilation + 1 - in_cols);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
if (rows_odd || cols_odd) {
@@ -605,7 +644,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
.set_input_feature_map_count(filter.dim_size(2))
.set_output_feature_map_count(filter.dim_size(3));
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(row_stride)
+ conv_desc.set_vertical_dilation_rate(row_dilation)
+ .set_horizontal_dilation_rate(col_dilation)
+ .set_vertical_filter_stride(row_stride)
.set_horizontal_filter_stride(col_stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
@@ -652,6 +693,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
out_depths, // out_depths
{{patch_rows, // filter_rows
patch_cols}}, // filter_cols
+ {{row_dilation, // dilation_rows
+ col_dilation}}, // dilation_cols
{{row_stride, // stride_rows
col_stride}}, // stride_cols
{{padding_rows, // padding_rows
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index e29271dff2..09a3b78776 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -34,8 +34,9 @@ class OpKernelContext;
template <typename Device, typename T>
struct LaunchConv2DOp {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* output,
+ const Tensor& input, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride,
+ const Padding& padding, Tensor* output,
TensorFormat data_format);
};
@@ -43,8 +44,9 @@ struct LaunchConv2DOp {
template <typename T>
struct LaunchConv2DOp<Eigen::GpuDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* output,
+ const Tensor& input, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride,
+ const Padding& padding, Tensor* output,
TensorFormat data_format);
};
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 37cb67bc51..39202d7334 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -377,6 +377,9 @@ struct LaunchConvOp<GPUDevice, T> {
{{in_planes, in_rows, in_cols}},
out_depth,
{{filter_planes, filter_rows, filter_cols}},
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
+ /*dilations=*/{{1, 1, 1}},
{{strides[0], strides[1], strides[2]}},
{{pad_planes, pad_rows, pad_cols}},
dtype,
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index c852dc9991..6f82698596 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -91,13 +91,14 @@ class ConvParameters {
using SpatialArray = gtl::InlinedVector<int64, 3>;
ConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
int64 out_depths, const SpatialArray& filter,
- const SpatialArray& stride, const SpatialArray& padding,
- DataType dtype, int device_id)
+ const SpatialArray& dilation, const SpatialArray& stride,
+ const SpatialArray& padding, DataType dtype, int device_id)
: batch_(batch),
in_depths_(in_depths),
out_depths_(out_depths),
in_(in),
filter_(filter),
+ dilation_(dilation),
stride_(stride),
padding_(padding),
dtype_(dtype),
@@ -107,6 +108,7 @@ class ConvParameters {
for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val);
hash_code_ = Hash64Combine(hash_code_, out_depths);
for (int64 val : filter) hash_code_ = Hash64Combine(hash_code_, val);
+ for (int64 val : dilation) hash_code_ = Hash64Combine(hash_code_, val);
for (int64 val : stride) hash_code_ = Hash64Combine(hash_code_, val);
for (int64 val : padding) hash_code_ = Hash64Combine(hash_code_, val);
hash_code_ = Hash64Combine(hash_code_, dtype);
@@ -128,6 +130,7 @@ class ConvParameters {
"(", str_util::Join(in_, ", "), "), ",
out_depths_, ", ",
"(", str_util::Join(filter_, ", "), "), ",
+ "(", str_util::Join(dilation_, ", "), "), ",
"(", str_util::Join(stride_, ", "), "), ",
"(", str_util::Join(padding_, ", "), "), ",
dtype_, ", ",
@@ -154,11 +157,11 @@ class ConvParameters {
protected:
using ParameterDataType =
std::tuple<int64, int64, SpatialArray, int64, SpatialArray, SpatialArray,
- SpatialArray, DataType, int>;
+ SpatialArray, SpatialArray, DataType, int>;
ParameterDataType get_data_as_tuple() const {
return std::make_tuple(batch_, in_depths_, in_, out_depths_, filter_,
- stride_, padding_, dtype_, device_id_);
+ dilation_, stride_, padding_, dtype_, device_id_);
}
uint64 hash_code_;
@@ -169,6 +172,7 @@ class ConvParameters {
int64 out_depths_;
SpatialArray in_;
SpatialArray filter_;
+ SpatialArray dilation_;
SpatialArray stride_;
SpatialArray padding_;
DataType dtype_;
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
index ea54d6cf6c..666bca265c 100644
--- a/tensorflow/core/kernels/conv_ops_test.cc
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -43,6 +43,8 @@ TEST(ConvParameters, WinogradNonfusedAlgoSize) {
128, // out_depths
{{3, // filter_rows
3}}, // filter_cols
+ {{1, // dilation_rows
+ 1}}, // dilation_cols
{{1, // stride_rows
1}}, // stride_cols
{{0, // padding_rows
@@ -60,6 +62,8 @@ TEST(ConvParameters, WinogradNonfusedAlgoSize) {
768, // out_depths
{{3, // filter_rows
3}}, // filter_cols
+ {{1, // dilation_rows
+ 1}}, // dilation_cols
{{1, // stride_rows
1}}, // stride_cols
{{0, // padding_rows
diff --git a/tensorflow/core/kernels/cwise_op_asinh.cc b/tensorflow/core/kernels/cwise_op_asinh.cc
index e6e1b83b30..0aec6aac34 100644
--- a/tensorflow/core/kernels/cwise_op_asinh.cc
+++ b/tensorflow/core/kernels/cwise_op_asinh.cc
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc
index fcfa2956f7..0972129787 100644
--- a/tensorflow/core/kernels/dataset.cc
+++ b/tensorflow/core/kernels/dataset.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/kernels/dataset.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/node_builder.h"
+
namespace tensorflow {
namespace {
@@ -70,6 +73,143 @@ class DatasetVariantWrapper {
} // namespace
+Status GraphDefBuilderWrapper::AddDataset(
+ const GraphDatasetBase* dataset,
+ const std::vector<std::pair<size_t, Node*>>& inputs,
+ const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
+ const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
+ Node** output) {
+ const string& op_type_name = dataset->op_name();
+ std::unique_ptr<const GraphDefBuilder::Options> opts(
+ new GraphDefBuilder::Options(b_->opts()));
+ // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
+ // attributes defined. It will be nice to have a consistent pattern.
+ bool has_output_types_attr = HasAttr(op_type_name, "output_types");
+ bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
+ if (has_output_shapes_attr) {
+ opts.reset(new GraphDefBuilder::Options(
+ opts->WithAttr("output_shapes", dataset->output_shapes())));
+ }
+ if (has_output_types_attr) {
+ opts.reset(new GraphDefBuilder::Options(
+ opts->WithAttr("output_types", dataset->output_dtypes())));
+ }
+ for (auto attr : attrs) {
+ opts.reset(
+ new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second)));
+ }
+ if (opts->HaveError()) {
+ return errors::Internal("AddDataset: Failed to build Options with error ",
+ opts->StatusToString());
+ }
+ NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
+ opts->op_registry());
+ {
+ size_t total_size = inputs.size() + list_inputs.size();
+ auto inputs_iter = inputs.begin();
+ auto list_inputs_iter = list_inputs.begin();
+ for (int i = 0; i < total_size; i++) {
+ if (inputs_iter != inputs.end() && inputs_iter->first == i) {
+ node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
+ inputs_iter++;
+ } else if (list_inputs_iter != list_inputs.end() &&
+ list_inputs_iter->first == i) {
+ std::vector<NodeBuilder::NodeOut> nodeout_inputs;
+ nodeout_inputs.reserve(list_inputs_iter->second.size());
+ for (Node* n : list_inputs_iter->second) {
+ nodeout_inputs.emplace_back(n);
+ }
+ node_builder.Input(nodeout_inputs);
+ list_inputs_iter++;
+ } else {
+ return errors::InvalidArgument("No input found for index ", i);
+ }
+ }
+ }
+ *output = opts->FinalizeBuilder(&node_builder);
+ if (*output == nullptr) {
+ return errors::Internal("AddDataset: Failed to build ", op_type_name,
+ " op with error ", opts->StatusToString());
+ }
+ return Status::OK();
+}
+
+Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx,
+ const string& function_name) {
+ if (b_->HasFunction(function_name)) {
+ LOG(INFO) << "Function with name " << function_name << "already exists in"
+ << " the graph. It will not be added again.";
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name));
+ const FunctionLibraryDefinition* flib_def =
+ ctx->function_library()->GetFunctionLibraryDefinition();
+ const FunctionDef* f_def = flib_def->Find(function_name);
+ if (f_def == nullptr) {
+ return errors::InvalidArgument("Unable to find FunctionDef for ",
+ function_name, " in the registry.");
+ }
+ FunctionDefLibrary def;
+ *def.add_function() = *f_def;
+ const string gradient_func = flib_def->FindGradient(function_name);
+ if (!gradient_func.empty()) {
+ GradientDef* g_def = def.add_gradient();
+ g_def->set_function_name(function_name);
+ g_def->set_gradient_func(gradient_func);
+ }
+ TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
+
+ // Recursively add functions in inputs of function_name.
+ for (const NodeDef& node_def : f_def->node_def()) {
+ const OpRegistrationData* op_reg_data = nullptr;
+ TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data));
+ if (op_reg_data->is_function_op) {
+ TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
+ }
+ // Recursively add functions in attrs of this NodeDef.
+ for (const auto& pair : node_def.attr()) {
+ TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx));
+ }
+ }
+
+ // Recursively add functions in attrs of function_name.
+ for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
+ TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx));
+ }
+ return Status::OK();
+}
+
+void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
+ Node** output) {
+ *output = ops::SourceOp(
+ "Const",
+ b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
+}
+
+bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name,
+ const string& attr_name) const {
+ const OpDef* op_def = nullptr;
+ Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
+ if (!s.ok() || op_def == nullptr) {
+ return false;
+ }
+ return HasAttr(op_def, attr_name);
+}
+
+Status GraphDatasetBase::Serialize(OpKernelContext* ctx,
+ string* serialized_graph_def,
+ string* output_node) const {
+ GraphDefBuilder b;
+ DatasetGraphDefBuilder db(&b);
+ Node* node = nullptr;
+ TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
+ *output_node = node->name();
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ graph_def.SerializeToString(serialized_graph_def);
+ return Status::OK();
+}
+
Status GetDatasetFromVariantTensor(const Tensor& tensor,
DatasetBase** out_dataset) {
if (!(tensor.dtype() == DT_VARIANT ||
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index afbebb0692..504a88a309 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -19,12 +19,13 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
-#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tracing.h"
@@ -59,6 +60,12 @@ class IteratorStateWriter {
virtual ~IteratorStateWriter() {}
};
+// Forward declarations to avoid introducing a dependency on headers in
+// "tensorflow/core/graph/...".
+class GraphDefBuilder;
+class GraphDatasetBase;
+class Node;
+
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
class GraphDefBuilderWrapper {
public:
@@ -110,10 +117,8 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- template <class DatasetType>
- Status AddDataset(const DatasetType* dataset,
- const std::vector<NodeBuilder::NodeOut>& inputs,
- Node** output) {
+ Status AddDataset(const GraphDatasetBase* dataset,
+ const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
}
@@ -125,77 +130,23 @@ class GraphDefBuilderWrapper {
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
// non-null if the method returns with an OK status.
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
- template <class DatasetType>
- Status AddDataset(const DatasetType* dataset,
- const std::vector<NodeBuilder::NodeOut>& inputs,
+ Status AddDataset(const GraphDatasetBase* dataset,
+ const std::vector<Node*>& inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
Node** output) {
- std::vector<std::pair<size_t, NodeBuilder::NodeOut>> enumerated_inputs(
- inputs.size());
+ std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
enumerated_inputs[i] = std::make_pair(i, inputs[i]);
}
return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
}
- template <class DatasetType>
Status AddDataset(
- const DatasetType* dataset,
- const std::vector<std::pair<size_t, NodeBuilder::NodeOut>>& inputs,
- const std::vector<
- std::pair<size_t, gtl::ArraySlice<NodeBuilder::NodeOut>>>&
- list_inputs,
+ const GraphDatasetBase* dataset,
+ const std::vector<std::pair<size_t, Node*>>& inputs,
+ const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
- Node** output) {
- const string& op_type_name = dataset->op_name();
- std::unique_ptr<const GraphDefBuilder::Options> opts(
- new GraphDefBuilder::Options(b_->opts()));
- // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
- // attributes defined. It will be nice to have a consistent pattern.
- bool has_output_types_attr = HasAttr(op_type_name, "output_types");
- bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
- if (has_output_shapes_attr) {
- opts.reset(new GraphDefBuilder::Options(
- opts->WithAttr("output_shapes", dataset->output_shapes())));
- }
- if (has_output_types_attr) {
- opts.reset(new GraphDefBuilder::Options(
- opts->WithAttr("output_types", dataset->output_dtypes())));
- }
- for (auto attr : attrs) {
- opts.reset(new GraphDefBuilder::Options(
- opts->WithAttr(attr.first, attr.second)));
- }
- if (opts->HaveError()) {
- return errors::Internal("AddDataset: Failed to build Options with error ",
- opts->StatusToString());
- }
- NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
- opts->op_registry());
- {
- size_t total_size = inputs.size() + list_inputs.size();
- auto inputs_iter = inputs.begin();
- auto list_inputs_iter = list_inputs.begin();
- for (int i = 0; i < total_size; i++) {
- if (inputs_iter != inputs.end() && inputs_iter->first == i) {
- node_builder.Input(inputs_iter->second);
- inputs_iter++;
- } else if (list_inputs_iter != list_inputs.end() &&
- list_inputs_iter->first == i) {
- node_builder.Input(list_inputs_iter->second);
- list_inputs_iter++;
- } else {
- return errors::InvalidArgument("No input found for index ", i);
- }
- }
- }
- *output = opts->FinalizeBuilder(&node_builder);
- if (*output == nullptr) {
- return errors::Internal("AddDataset: Failed to build ", op_type_name,
- " op with error ", opts->StatusToString());
- }
- return Status::OK();
- }
+ Node** output);
// Adds a user-defined function with name `function_name` to the graph and
// recursively adds all functions it references. If a function with a matching
@@ -203,50 +154,7 @@ class GraphDefBuilderWrapper {
// name `function_name` is not found in the FunctionLibraryDefinition, returns
// an InvalidArgumentError. If the function with name `function_name` or any
// of its dependent functions are stateful, returns an InvalidArgument error.
- Status AddFunction(OpKernelContext* ctx, const string& function_name) {
- if (b_->HasFunction(function_name)) {
- LOG(INFO) << "Function with name " << function_name << "already exists in"
- << " the graph. It will not be added again.";
- return Status::OK();
- }
- TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name));
- const FunctionLibraryDefinition* flib_def =
- ctx->function_library()->GetFunctionLibraryDefinition();
- const FunctionDef* f_def = flib_def->Find(function_name);
- if (f_def == nullptr) {
- return errors::InvalidArgument("Unable to find FunctionDef for ",
- function_name, " in the registry.");
- }
- FunctionDefLibrary def;
- *def.add_function() = *f_def;
- const string gradient_func = flib_def->FindGradient(function_name);
- if (!gradient_func.empty()) {
- GradientDef* g_def = def.add_gradient();
- g_def->set_function_name(function_name);
- g_def->set_gradient_func(gradient_func);
- }
- TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
-
- // Recursively add functions in inputs of function_name.
- for (const NodeDef& node_def : f_def->node_def()) {
- const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data));
- if (op_reg_data->is_function_op) {
- TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
- }
- // Recursively add functions in attrs of this NodeDef.
- for (const auto& pair : node_def.attr()) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx));
- }
- }
-
- // Recursively add functions in attrs of function_name.
- for (auto iter = f_def->attr().begin(); iter != f_def->attr().end();
- iter++) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx));
- }
- return Status::OK();
- }
+ Status AddFunction(OpKernelContext* ctx, const string& function_name);
template <typename T>
void BuildAttrValue(const T& value, AttrValue* attr) {
@@ -254,11 +162,7 @@ class GraphDefBuilderWrapper {
}
private:
- void AddTensorInternal(const Tensor& val, Node** output) {
- *output = ops::SourceOp(
- "Const",
- b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
- }
+ void AddTensorInternal(const Tensor& val, Node** output);
Status EnsureFunctionIsStateless(OpKernelContext* ctx,
const string& function_name) const {
@@ -294,14 +198,7 @@ class GraphDefBuilderWrapper {
HasAttr(op_def, "output_shapes");
}
- bool HasAttr(const string& op_type_name, const string& attr_name) const {
- const OpDef* op_def = nullptr;
- Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
- if (!s.ok() || op_def == nullptr) {
- return false;
- }
- return HasAttr(op_def, attr_name);
- }
+ bool HasAttr(const string& op_type_name, const string& attr_name) const;
bool HasAttr(const OpDef* op_def, const string& attr_name) const {
for (auto attr : op_def->attr()) {
@@ -548,17 +445,7 @@ class GraphDatasetBase : public DatasetBase {
private:
Status Serialize(OpKernelContext* ctx, string* serialized_graph_def,
- string* output_node) const {
- GraphDefBuilder b;
- DatasetGraphDefBuilder db(&b);
- Node* node = nullptr;
- TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
- *output_node = node->name();
- GraphDef graph_def;
- TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
- graph_def.SerializeToString(serialized_graph_def);
- return Status::OK();
- }
+ string* output_node) const;
const string op_name_;
};
diff --git a/tensorflow/core/kernels/dataset_utils.cc b/tensorflow/core/kernels/dataset_utils.cc
index cd58c80912..bd20e20cad 100644
--- a/tensorflow/core/kernels/dataset_utils.cc
+++ b/tensorflow/core/kernels/dataset_utils.cc
@@ -32,7 +32,7 @@ Status MakeIteratorFromInputElement(
// is always 0, so a negative random step ID should suffice.
opts.step_id = CapturedFunction::generate_step_id();
ScopedStepContainer step_container(
- opts.step_id, [captured_func, ctx](const string& name) {
+ opts.step_id, [captured_func](const string& name) {
captured_func->resource_manager()->Cleanup(name).IgnoreError();
});
opts.step_container = &step_container;
diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc
index 2759ecb2f1..a5fd07fbe1 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op.cc
@@ -373,8 +373,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
// If in_depth==1, this operation is just a standard convolution, so
// invoke that op.
if (std::is_same<T, float>::value && in_depth == 1) {
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- stride_, stride_, padding_, output, data_format_);
+ /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
+ padding_, output, data_format_);
return;
}
diff --git a/tensorflow/core/kernels/filter_dataset_op.cc b/tensorflow/core/kernels/filter_dataset_op.cc
index e4d80e4ce3..67417d467d 100644
--- a/tensorflow/core/kernels/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/filter_dataset_op.cc
@@ -95,7 +95,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
- std::vector<NodeBuilder::NodeOut> other_arguments;
+ std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
@@ -149,7 +149,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/flat_map_dataset_op.cc b/tensorflow/core/kernels/flat_map_dataset_op.cc
index ac1689e5bf..8fe8489371 100644
--- a/tensorflow/core/kernels/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/flat_map_dataset_op.cc
@@ -102,7 +102,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
- std::vector<NodeBuilder::NodeOut> other_arguments;
+ std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
diff --git a/tensorflow/core/kernels/group_by_window_dataset_op.cc b/tensorflow/core/kernels/group_by_window_dataset_op.cc
index 8644bcf9b5..604555a560 100644
--- a/tensorflow/core/kernels/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/group_by_window_dataset_op.cc
@@ -169,7 +169,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
opts.step_id = CapturedFunction::generate_step_id();
opts.runner = ctx->runner();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_key_func_->resource_manager()
->Cleanup(name)
@@ -198,7 +198,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
opts2.step_id = CapturedFunction::generate_step_id();
opts2.runner = ctx->runner();
ScopedStepContainer step_container2(
- opts2.step_id, [this, ctx](const string& name) {
+ opts2.step_id, [this](const string& name) {
dataset()
->captured_window_size_func_->resource_manager()
->Cleanup(name)
@@ -257,7 +257,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
opts.step_id = CapturedFunction::generate_step_id();
opts.runner = ctx->runner();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_reduce_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/interleave_dataset_op.cc b/tensorflow/core/kernels/interleave_dataset_op.cc
index cbee68b2db..833e8cb9c5 100644
--- a/tensorflow/core/kernels/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/interleave_dataset_op.cc
@@ -126,7 +126,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
- std::vector<NodeBuilder::NodeOut> other_arguments;
+ std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
diff --git a/tensorflow/core/kernels/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/map_and_batch_dataset_op.cc
index ad1e356dbd..9bd66e681f 100644
--- a/tensorflow/core/kernels/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/map_and_batch_dataset_op.cc
@@ -239,8 +239,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// to unblock a consumer.
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
- ScopedStepContainer* step_container = new ScopedStepContainer(
- opts.step_id, [this, ctx](const string& name) {
+ ScopedStepContainer* step_container =
+ new ScopedStepContainer(opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/map_dataset_op.cc b/tensorflow/core/kernels/map_dataset_op.cc
index 4ba09bc335..29899a987e 100644
--- a/tensorflow/core/kernels/map_dataset_op.cc
+++ b/tensorflow/core/kernels/map_dataset_op.cc
@@ -100,7 +100,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector other_arguments_types(
captured_func_->captured_inputs().size());
- std::vector<NodeBuilder::NodeOut> other_arguments(
+ std::vector<Node*> other_arguments(
captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
@@ -146,7 +146,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 138acdf298..9fee94f946 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -28,6 +28,7 @@ limitations under the License.
#if defined(INTEL_MKL)
#include <vector>
#include "mkl_cblas.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -72,10 +73,10 @@ class BatchMatMulMkl : public OpKernel {
TensorShape out_shape;
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
- errors::InvalidArgument("lhs.dim(", i, ") and rhs.dim(", i,
- ") must be the same: ",
- lhs.shape().DebugString(), " vs ",
- rhs.shape().DebugString()));
+ errors::InvalidArgument(
+ "lhs.dim(", i, ") and rhs.dim(", i,
+ ") must be the same: ", lhs.shape().DebugString(), " vs ",
+ rhs.shape().DebugString()));
out_shape.AddDim(lhs.dim_size(i));
}
auto batch_size = (ndims == 2) ? 1 : out_shape.num_elements();
@@ -109,7 +110,7 @@ class BatchMatMulMkl : public OpKernel {
const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1);
const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2);
const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2);
-
+
std::vector<MKL_INT> m_array(batch_size, M);
std::vector<MKL_INT> n_array(batch_size, N);
std::vector<MKL_INT> k_array(batch_size, K);
@@ -128,7 +129,7 @@ class BatchMatMulMkl : public OpKernel {
b_array.push_back(&rhs_reshaped(i, 0, 0));
c_array.push_back(&out_reshaped(i, 0, 0));
}
-
+
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, &m_array[0], &n_array[0],
&k_array[0], &a_array[0], &lda_array[0], &b_array[0],
&ldb_array[0], &c_array[0], &ldc_array[0], 1,
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index 8c0109f5c8..d086abb247 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -40,7 +40,7 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
-template <typename Device, typename T>
+template <typename Device, typename T, typename OutputType>
struct MultinomialFunctor {
void operator()(OpKernelContext* ctx, const Device& d,
typename TTypes<T>::ConstMatrix logits,
@@ -49,11 +49,11 @@ struct MultinomialFunctor {
typename TTypes<float>::Flat scratch, int batch_size,
int num_classes, int num_samples,
const random::PhiloxRandom& gen,
- typename TTypes<int64>::Matrix output);
+ typename TTypes<OutputType>::Matrix output);
};
-template <typename T>
-struct MultinomialFunctor<CPUDevice, T> {
+template <typename T, typename OutputType>
+struct MultinomialFunctor<CPUDevice, T, OutputType> {
void operator()(OpKernelContext* ctx, const CPUDevice& d,
typename TTypes<T>::ConstMatrix logits,
typename TTypes<float>::Flat /* noises */,
@@ -61,7 +61,7 @@ struct MultinomialFunctor<CPUDevice, T> {
typename TTypes<float>::Flat /* scratch */, int batch_size,
int num_classes, int num_samples,
const random::PhiloxRandom& gen,
- typename TTypes<int64>::Matrix output) {
+ typename TTypes<OutputType>::Matrix output) {
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
// The implementation only parallelizes by batch.
@@ -128,7 +128,7 @@ struct MultinomialFunctor<CPUDevice, T> {
} // namespace functor
// Samples from a multinomial distribution.
-template <typename Device, typename T>
+template <typename Device, typename T, typename OutputType>
class MultinomialOp : public OpKernel {
public:
explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -195,11 +195,11 @@ class MultinomialOp : public OpKernel {
if (std::is_same<Device, CPUDevice>::value) num_samples_ceil_4 *= 2;
auto rng =
generator_.ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256);
- functor::MultinomialFunctor<Device, T>()(
+ functor::MultinomialFunctor<Device, T, OutputType>()(
ctx, ctx->eigen_device<Device>(), logits_t.matrix<T>(),
noises.flat<float>(), scores.flat<float>(), scratch.flat<float>(),
batch_size, num_classes, num_samples, rng,
- samples_t->matrix<int64>());
+ samples_t->matrix<OutputType>());
}
}
@@ -209,10 +209,17 @@ class MultinomialOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(MultinomialOp);
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("Multinomial").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
- MultinomialOp<CPUDevice, TYPE>);
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ MultinomialOp<CPUDevice, TYPE, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ MultinomialOp<CPUDevice, TYPE, int64>);
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
@@ -220,12 +227,20 @@ TF_CALL_double(REGISTER);
#undef REGISTER
#if GOOGLE_CUDA
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("Multinomial") \
- .Device(DEVICE_GPU) \
- .HostMemory("num_samples") \
- .TypeConstraint<TYPE>("T"), \
- MultinomialOp<GPUDevice, TYPE>)
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("num_samples") \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ MultinomialOp<GPUDevice, TYPE, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("num_samples") \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ MultinomialOp<GPUDevice, TYPE, int64>)
+
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
diff --git a/tensorflow/core/kernels/multinomial_op.h b/tensorflow/core/kernels/multinomial_op.h
index af5e81f219..6e41060aa4 100644
--- a/tensorflow/core/kernels/multinomial_op.h
+++ b/tensorflow/core/kernels/multinomial_op.h
@@ -21,7 +21,7 @@ namespace tensorflow {
namespace functor {
// Generic helper functor for the Multinomial Op.
-template <typename Device, typename T>
+template <typename Device, typename T, typename OutputType>
struct MultinomialFunctor;
} // namespace functor
diff --git a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc
index 19b4f3ca55..5cc5877cce 100644
--- a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc
@@ -37,20 +37,22 @@ using GPUDevice = Eigen::GpuDevice;
// Kernel for Multinomial op. Data is interpreted to have the following shapes:
// scores: [B, S, C]; maxima: [B, S]; output: [B, S].
+template <typename OutputType>
__global__ void MultinomialKernel(int32 nthreads, const int32 num_classes,
const int32 num_samples, const float* scores,
- const float* maxima, int64* output) {
+ const float* maxima, OutputType* output) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int maxima_idx = index / num_classes;
if (ldg(maxima + maxima_idx) == ldg(scores + index)) {
- CudaAtomicMax(reinterpret_cast<uint64*>(output + maxima_idx),
- static_cast<uint64>(index % num_classes));
+ using UnsignedOutputType = typename std::make_unsigned<OutputType>::type;
+ CudaAtomicMax(reinterpret_cast<UnsignedOutputType*>(output + maxima_idx),
+ static_cast<UnsignedOutputType>(index % num_classes));
}
}
}
-template <typename T>
-struct MultinomialFunctor<GPUDevice, T> {
+template <typename T, typename OutputType>
+struct MultinomialFunctor<GPUDevice, T, OutputType> {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
typename TTypes<T>::ConstMatrix logits,
typename TTypes<float>::Flat noises,
@@ -58,7 +60,7 @@ struct MultinomialFunctor<GPUDevice, T> {
typename TTypes<float>::Flat maxima, int batch_size,
int num_classes, int num_samples,
const random::PhiloxRandom& gen,
- typename TTypes<int64>::Matrix output) {
+ typename TTypes<OutputType>::Matrix output) {
// Uniform, [0, 1).
typedef random::UniformDistribution<random::PhiloxRandom, float> Dist;
functor::FillPhiloxRandom<GPUDevice, Dist>()(ctx, d, gen, noises.data(),
@@ -111,11 +113,17 @@ struct MultinomialFunctor<GPUDevice, T> {
};
// Explicit instantiation of the GPU functors.
-template struct MultinomialFunctor<GPUDevice, Eigen::half>;
-template struct MultinomialFunctor<GPUDevice, float>;
-template struct MultinomialFunctor<GPUDevice, double>;
-template struct MultinomialFunctor<GPUDevice, int32>;
-template struct MultinomialFunctor<GPUDevice, int64>;
+template struct MultinomialFunctor<GPUDevice, Eigen::half, int32>;
+template struct MultinomialFunctor<GPUDevice, float, int32>;
+template struct MultinomialFunctor<GPUDevice, double, int32>;
+template struct MultinomialFunctor<GPUDevice, int32, int32>;
+template struct MultinomialFunctor<GPUDevice, int64, int32>;
+
+template struct MultinomialFunctor<GPUDevice, Eigen::half, int64>;
+template struct MultinomialFunctor<GPUDevice, float, int64>;
+template struct MultinomialFunctor<GPUDevice, double, int64>;
+template struct MultinomialFunctor<GPUDevice, int32, int64>;
+template struct MultinomialFunctor<GPUDevice, int64, int64>;
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/nn_ops_test.cc b/tensorflow/core/kernels/nn_ops_test.cc
index 0db7c63b8b..a841291ddd 100644
--- a/tensorflow/core/kernels/nn_ops_test.cc
+++ b/tensorflow/core/kernels/nn_ops_test.cc
@@ -653,6 +653,8 @@ BM_ConvFloatDepthwiseFwd(32, 7, 7, 1024, 1, 1024, 3, 3, 1, SAME, conv6);
// Benchmarks with different stride and padding options.
BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 2, SAME, conv7);
BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 2, VALID, conv8);
+BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 3, 3, 1, SAME, conv9);
+BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 5, 5, 1, SAME, conv10);
#define BM_ConvFloatDepthwiseBk(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, LABEL) \
static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL(int iters) { \
diff --git a/tensorflow/core/kernels/padded_batch_dataset_op.cc b/tensorflow/core/kernels/padded_batch_dataset_op.cc
index 7c28d955e1..cef5bde156 100644
--- a/tensorflow/core/kernels/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/padded_batch_dataset_op.cc
@@ -242,7 +242,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
Node* batch_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
- std::vector<NodeBuilder::NodeOut> padded_shapes;
+ std::vector<Node*> padded_shapes;
padded_shapes.reserve(padded_shapes_.size());
for (int i = 0; i < padded_shapes_.size(); i++) {
Node* node;
@@ -254,7 +254,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
padded_shapes.emplace_back(node);
}
- std::vector<NodeBuilder::NodeOut> padding_values;
+ std::vector<Node*> padding_values;
padding_values.reserve(padding_values_.size());
for (const Tensor& t : padding_values_) {
Node* node;
diff --git a/tensorflow/core/kernels/parallel_map_dataset_op.cc b/tensorflow/core/kernels/parallel_map_dataset_op.cc
index 2be87f4bde..b9175fe904 100644
--- a/tensorflow/core/kernels/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/parallel_map_dataset_op.cc
@@ -195,8 +195,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
- ScopedStepContainer* step_container = new ScopedStepContainer(
- opts.step_id, [this, ctx](const string& name) {
+ ScopedStepContainer* step_container =
+ new ScopedStepContainer(opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc
index 3b0764bb9b..f83998e0c1 100644
--- a/tensorflow/core/kernels/quantized_conv_ops.cc
+++ b/tensorflow/core/kernels/quantized_conv_ops.cc
@@ -457,6 +457,19 @@ class QuantizedConv2DOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ std::vector<int32> dilations;
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations));
+ OP_REQUIRES(context, dilations.size() == 4,
+ errors::InvalidArgument("Dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, dilations[1] == 1 && dilations[2] == 1,
+ errors::InvalidArgument(
+ "Current implementation only supports dilated rate as 1 "
+ "in the row and column dimensions."));
+ OP_REQUIRES(context, (dilations[0] == 1 && dilations[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
diff --git a/tensorflow/core/kernels/random_dataset_op.cc b/tensorflow/core/kernels/random_dataset_op.cc
new file mode 100644
index 0000000000..03d481a593
--- /dev/null
+++ b/tensorflow/core/kernels/random_dataset_op.cc
@@ -0,0 +1,154 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/dataset.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class RandomDatasetOp : public DatasetOpKernel {
+ public:
+ explicit RandomDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ int64 seed;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
+
+ int64 seed2;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
+
+ // By TensorFlow convention, passing 0 for both seeds indicates
+ // that the shuffling should be seeded non-deterministically.
+ if (seed == 0 && seed2 == 0) {
+ seed = random::New64();
+ seed2 = random::New64();
+ }
+
+ *output = new Dataset(ctx, seed, seed2);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
+ : GraphDatasetBase(ctx), seed_(seed), seed2_(seed2) {}
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Random")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_INT64});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() override {
+ return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_,
+ ")::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* seed = nullptr;
+ Node* seed2 = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
+ TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {seed, seed2}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ parent_generator_(dataset()->seed_, dataset()->seed2_),
+ generator_(&parent_generator_) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ Tensor value_tensor(cpu_allocator(), DT_INT64, {});
+ value_tensor.scalar<int64>()() = Random();
+ out_tensors->emplace_back(std::move(value_tensor));
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
+ num_random_samples_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
+ &num_random_samples_));
+ parent_generator_ =
+ random::PhiloxRandom(dataset()->seed_, dataset()->seed2_);
+ generator_ = random::SingleSampleAdapter<random::PhiloxRandom>(
+ &parent_generator_);
+ generator_.Skip(num_random_samples_);
+ return Status::OK();
+ }
+
+ private:
+ random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ num_random_samples_++;
+ auto out = generator_();
+ return out;
+ }
+ mutex mu_;
+ random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
+ random::SingleSampleAdapter<random::PhiloxRandom> generator_
+ GUARDED_BY(mu_);
+ int64 num_random_samples_ GUARDED_BY(mu_) = 0;
+ };
+
+ const int64 seed_;
+ const int64 seed2_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU),
+ RandomDatasetOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc
index 807ac0a456..5c537c5b9c 100644
--- a/tensorflow/core/kernels/reduction_ops_min.cc
+++ b/tensorflow/core/kernels/reduction_ops_min.cc
@@ -50,6 +50,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
+REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
diff --git a/tensorflow/core/kernels/reduction_ops_test.cc b/tensorflow/core/kernels/reduction_ops_test.cc
index 9bbe993a2f..fe8ea59f1b 100644
--- a/tensorflow/core/kernels/reduction_ops_test.cc
+++ b/tensorflow/core/kernels/reduction_ops_test.cc
@@ -174,6 +174,11 @@ static void BM_Min2DToScalarGPU(int iters, int num_x, int num_y) {
}
BENCHMARK(BM_Min2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
+static void BM_Min2DToScalarGPUHalf(int iters, int num_x, int num_y) {
+ ReduceToScalar<Eigen::half>(iters, "gpu", "Min", num_x, num_y);
+}
+BENCHMARK(BM_Min2DToScalarGPUHalf)->RangePair(2048, 8192, 2048, 8192);
+
static void BM_Bool2DToScalarGPU(int iters, int num_x, int num_y) {
ReduceToScalar<bool>(iters, "gpu", "All", num_x, num_y);
}
diff --git a/tensorflow/core/kernels/scan_dataset_op.cc b/tensorflow/core/kernels/scan_dataset_op.cc
index 76c219f1ae..bc52322022 100644
--- a/tensorflow/core/kernels/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/scan_dataset_op.cc
@@ -132,7 +132,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 484932ab01..98c0181afb 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#endif // GOOGLE_CUDA
#include "tensorflow/core/kernels/scatter_nd_op.h"
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -28,6 +29,8 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/fill_functor.h"
+#include "tensorflow/core/kernels/training_op_helpers.h"
+#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -83,7 +86,10 @@ class ScatterNdUpdateOp : public OpKernel {
const DataType dt = DataTypeToEnum<T>::v();
const DataType dt_ref = DataTypeToEnum<T>::ref();
const DataType index_t = DataTypeToEnum<Index>::v();
- if (IsRefType(c->input_type(0))) {
+ dtype_ = c->input_type(0);
+ if (c->input_type(0) == DT_RESOURCE) {
+ // TODO(apassos): what to validate here?
+ } else if (IsRefType(c->input_type(0))) {
OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
} else {
@@ -93,7 +99,16 @@ class ScatterNdUpdateOp : public OpKernel {
}
void Compute(OpKernelContext* c) override {
- if (use_exclusive_lock_) {
+ if (dtype_ == DT_RESOURCE) {
+ if (use_exclusive_lock_) {
+ Var* v;
+ OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ mutex_lock m(*v->mu());
+ DoCompute(c);
+ } else {
+ DoCompute(c);
+ }
+ } else if (use_exclusive_lock_) {
// If we're here, it means the input type is a ref.
DCHECK(IsRefType(c->input_dtype(0)));
// Hold mutex while we apply updates
@@ -105,6 +120,7 @@ class ScatterNdUpdateOp : public OpKernel {
}
private:
+ DataType dtype_;
bool use_exclusive_lock_;
void DoCompute(OpKernelContext* c) {
@@ -113,7 +129,20 @@ class ScatterNdUpdateOp : public OpKernel {
Tensor params;
TensorShape params_shape;
- if (IsRefType(c->input_dtype(0))) {
+ if (dtype_ == DT_RESOURCE) {
+ Var* v;
+ OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ Tensor* t = v->tensor();
+ if (!use_exclusive_lock_) {
+ // We're not holding the lock in the outer scope so need it here.
+ mutex_lock m(*v->mu());
+ OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
+ } else {
+ OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
+ }
+ params = *t;
+ params_shape = params.shape();
+ } else if (IsRefType(c->input_dtype(0))) {
params = c->mutable_input(0, use_exclusive_lock_);
params_shape = params.shape();
c->forward_ref_input_to_ref_output(0, 0);
@@ -159,6 +188,16 @@ class ScatterNdUpdateOp : public OpKernel {
.TypeConstraint<index_type>("Tindices"), \
ScatterNdUpdateOp<dev##Device, type, index_type, op>)
+#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \
+ dev, name, op) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(name) \
+ .Device(DEVICE_##dev) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices") \
+ .HostMemory("ref"), \
+ ScatterNdUpdateOp<dev##Device, type, index_type, op>)
+
#define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name)
@@ -167,6 +206,11 @@ class ScatterNdUpdateOp : public OpKernel {
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
+#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
+ REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \
+ op); \
+ REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
+
#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
scatter_nd_op::UpdateOp::ADD); \
@@ -178,9 +222,11 @@ class ScatterNdUpdateOp : public OpKernel {
#define REGISTER_SCATTER_ND(type, dev) \
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
-#define REGISTER_SCATTER_ND_UPDATE(type, dev) \
- REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
- scatter_nd_op::UpdateOp::ASSIGN);
+#define REGISTER_SCATTER_ND_UPDATE(type, dev) \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
+ scatter_nd_op::UpdateOp::ASSIGN); \
+ REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
+ type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
@@ -281,8 +327,7 @@ Status ValidateUpdateShape(const TensorShape& params_shape,
}
template <typename Index>
-Status PrepareAndValidateInputs(OpKernelContext* c,
- const TensorShape& params_shape,
+Status PrepareAndValidateInputs(const TensorShape& params_shape,
const Tensor& indices, const Tensor& updates,
int64* slice_dim, Index* num_updates,
Index* slice_size) {
@@ -396,7 +441,7 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
Index num_updates;
Index slice_size;
TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>(
- c, shape, indices, updates, &slice_dim, &num_updates, &slice_size));
+ shape, indices, updates, &slice_dim, &num_updates, &slice_size));
IndexFlattener<Device, Index> index_flattener;
auto indices_flat = index_flattener(c, indices);
diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc
index cfb86904d5..f4159da229 100644
--- a/tensorflow/core/kernels/serialize_sparse_op.cc
+++ b/tensorflow/core/kernels/serialize_sparse_op.cc
@@ -409,186 +409,11 @@ class DeserializeSparseOp : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
-template <typename T>
-class DeserializeManySparseOp : public OpKernel {
- public:
- explicit DeserializeManySparseOp(OpKernelConstruction* context)
- : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- const Tensor& serialized_sparse = context->input(0);
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(serialized_sparse.shape()),
- errors::InvalidArgument(
- "Serialized sparse should be a matrix but received shape ",
- serialized_sparse.shape().DebugString()));
- OP_REQUIRES(
- context, serialized_sparse.shape().dim_size(1) == 3,
- errors::InvalidArgument(
- "Serialized sparse should have 3 columns but received shape ",
- serialized_sparse.shape().DebugString()));
-
- int num_sparse_tensors = serialized_sparse.shape().dim_size(0);
-
- OP_REQUIRES(
- context, num_sparse_tensors > 0,
- errors::InvalidArgument("Must have at least 1 serialized SparseTensor, "
- "but input matrix has 0 rows"));
-
- std::vector<Tensor> indices_to_concat;
- std::vector<Tensor> values_to_concat;
- std::vector<TensorShape> shapes_to_concat;
-
- const auto& serialized_sparse_t = serialized_sparse.matrix<string>();
-
- for (int i = 0; i < num_sparse_tensors; ++i) {
- Tensor output_indices(DT_INT64);
- Tensor output_values(DataTypeToEnum<T>::value);
- Tensor output_shape(DT_INT64);
- TensorProto proto_indices;
- TensorProto proto_values;
- TensorProto proto_shape;
-
- OP_REQUIRES(
- context,
- ParseProtoUnlimited(&proto_indices, serialized_sparse_t(i, 0)),
- errors::InvalidArgument("Could not parse serialized_sparse[", i,
- ", 0]"));
- OP_REQUIRES(context,
- ParseProtoUnlimited(&proto_values, serialized_sparse_t(i, 1)),
- errors::InvalidArgument("Could not parse serialized_sparse[",
- i, ", 1]"));
- OP_REQUIRES(context,
- ParseProtoUnlimited(&proto_shape, serialized_sparse_t(i, 2)),
- errors::InvalidArgument("Could not parse serialized_sparse[",
- i, ", 2]"));
-
- OP_REQUIRES(context, output_indices.FromProto(proto_indices),
- errors::InvalidArgument(
- "Could not construct Tensor serialized_sparse[", i,
- ", 0] (indices)"));
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
- errors::InvalidArgument(
- "Expected serialized_sparse[", i,
- ", 0] to represent an index matrix but received shape ",
- output_indices.shape().DebugString()));
- OP_REQUIRES(context, output_values.FromProto(proto_values),
- errors::InvalidArgument(
- "Could not construct Tensor serialized_sparse[", i,
- ", 1] (values)"));
- OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
- errors::InvalidArgument(
- "Expected serialized_sparse[", i,
- ", 1] to represent a values vector but received shape ",
- output_values.shape().DebugString()));
- OP_REQUIRES(context, output_shape.FromProto(proto_shape),
- errors::InvalidArgument(
- "Could not construct Tensor serialized_sparse[", i,
- ", 2] (shape)"));
- OP_REQUIRES(
- context, TensorShapeUtils::IsVector(output_shape.shape()),
- errors::InvalidArgument("Expected serialized_sparse[", i,
- ", 1] to be a shape vector but its shape is ",
- output_shape.shape().DebugString()));
-
- OP_REQUIRES(
- context, DataTypeToEnum<T>::value == output_values.dtype(),
- errors::InvalidArgument(
- "Requested SparseTensor of type ",
- DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i,
- "].values.dtype() == ", DataTypeString(output_values.dtype())));
-
- int64 num_entries = output_indices.dim_size(0);
- OP_REQUIRES(context, num_entries == output_values.dim_size(0),
- errors::InvalidArgument(
- "Expected row counts of SparseTensor[", i,
- "].indices and SparseTensor[", i,
- "].values to match but they do not: ", num_entries,
- " vs. ", output_values.dim_size(0)));
- int rank = output_indices.dim_size(1);
- OP_REQUIRES(
- context, rank == output_shape.dim_size(0),
- errors::InvalidArgument("Expected column counts of SparseTensor[", i,
- "].indices to match size of SparseTensor[", i,
- "].shape "
- "but they do not: ",
- rank, " vs. ", output_shape.dim_size(0)));
-
- // Now we expand each SparseTensors' indices and shape by
- // prefixing a dimension
- Tensor expanded_indices(
- DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)}));
- Tensor expanded_shape(DT_INT64,
- TensorShape({1 + output_shape.dim_size(0)}));
- const auto& output_indices_t = output_indices.matrix<int64>();
- const auto& output_shape_t = output_shape.vec<int64>();
- auto expanded_indices_t = expanded_indices.matrix<int64>();
- auto expanded_shape_t = expanded_shape.vec<int64>();
- expanded_indices_t.chip<1>(0).setZero();
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
- expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
- expanded_shape_t(0) = 1;
- std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
-
- TensorShape expanded_tensor_shape(expanded_shape.vec<int64>());
-
- indices_to_concat.push_back(expanded_indices);
- values_to_concat.push_back(output_values);
- shapes_to_concat.push_back(expanded_tensor_shape);
- }
-
- int rank = -1;
- for (int i = 0; i < num_sparse_tensors; ++i) {
- if (rank < 0) rank = shapes_to_concat[i].dims();
- OP_REQUIRES(context, rank == shapes_to_concat[i].dims(),
- errors::InvalidArgument(
- "Inconsistent rank across SparseTensors: rank prior to "
- "SparseTensor[",
- i, "] was: ", rank, " but rank of SparseTensor[", i,
- "] is: ", shapes_to_concat[i].dims()));
- }
-
- // SparseTensor::Concat requires consistent shape for all but the
- // primary order dimension (dimension 0 in this case). So we get
- // the maximum value across all the input SparseTensors for each
- // dimension and use that.
- TensorShape preconcat_shape(shapes_to_concat[0]);
- for (int i = 0; i < num_sparse_tensors; ++i) {
- for (int d = 0; d < rank; ++d) {
- preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d),
- shapes_to_concat[i].dim_size(d)));
- }
- }
-
- // Dimension 0 is the primary dimension.
- gtl::InlinedVector<int64, 8> std_order(rank);
- std::iota(std_order.begin(), std_order.end(), 0);
-
- std::vector<SparseTensor> tensors_to_concat;
- tensors_to_concat.reserve(num_sparse_tensors);
- for (int i = 0; i < num_sparse_tensors; ++i) {
- tensors_to_concat.emplace_back(indices_to_concat[i], values_to_concat[i],
- preconcat_shape, std_order);
- }
-
- SparseTensor output = SparseTensor::Concat<T>(tensors_to_concat);
-
- Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));
-
- std::copy_n(output.shape().data(), output.dims(),
- final_output_shape.vec<int64>().data());
-
- context->set_output(0, output.indices());
- context->set_output(1, output.values());
- context->set_output(2, final_output_shape);
- }
-};
-
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
- DeserializeManySparseOp<type>)
+ DeserializeSparseOp<type>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
diff --git a/tensorflow/core/kernels/softmax_op_functor.h b/tensorflow/core/kernels/softmax_op_functor.h
index 1f38bdce8c..d3a267ed87 100644
--- a/tensorflow/core/kernels/softmax_op_functor.h
+++ b/tensorflow/core/kernels/softmax_op_functor.h
@@ -64,23 +64,21 @@ struct SoftmaxEigenImpl {
one_by_class.set(1, num_classes);
#endif
// shifted_logits = logits - max(logits along classes);
- auto shifted_logits = (logits -
- logits.maximum(along_class)
- .eval()
- .reshape(batch_by_one)
- .broadcast(one_by_class));
+ auto shifted_logits = (logits - logits.maximum(along_class)
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
if (log) {
// Calculate the log of the softmax
// softmax = logits - max(logits along classes);
softmax.device(d) = shifted_logits;
// softmax = softmax - log(sum(exp(softmax along classes)));
- softmax.device(d) = (softmax -
- softmax.exp()
- .sum(along_class)
- .eval()
- .reshape(batch_by_one)
- .log()
- .broadcast(one_by_class));
+ softmax.device(d) = (softmax - softmax.exp()
+ .sum(along_class)
+ .log()
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
} else {
// NOTE(touts): If you modify this implementation please run
// the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc.
@@ -88,12 +86,11 @@ struct SoftmaxEigenImpl {
// softmax = exp(logits - max(logits along classes));
softmax.device(d) = shifted_logits.exp();
// softmax = softmax * (1 / sum(softmax along classes));
- softmax.device(d) = (softmax *
- softmax.sum(along_class)
- .inverse()
- .eval()
- .reshape(batch_by_one)
- .broadcast(one_by_class));
+ softmax.device(d) = (softmax * softmax.sum(along_class)
+ .inverse()
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
}
}
};
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 8fc40db3cc..73b6d4cf6a 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -427,6 +427,7 @@ REGISTER_STRIDED_SLICE(bfloat16);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc b/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc
index a8487f49f4..8ca27e3b92 100644
--- a/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc
@@ -53,6 +53,7 @@ typedef Eigen::GpuDevice GPUDevice;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
TF_CALL_complex128(DEFINE_GPU_KERNELS);
+TF_CALL_int64(DEFINE_GPU_KERNELS);
DEFINE_GPU_KERNELS(int32);
#undef DEFINE_GPU_KERNELS
diff --git a/tensorflow/core/kernels/tensor_dataset_op.cc b/tensorflow/core/kernels/tensor_dataset_op.cc
index fe53434d17..5cf9931188 100644
--- a/tensorflow/core/kernels/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/tensor_dataset_op.cc
@@ -70,7 +70,7 @@ class TensorDatasetOp : public DatasetOpKernel {
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
- std::vector<NodeBuilder::NodeOut> components;
+ std::vector<Node*> components;
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
diff --git a/tensorflow/core/kernels/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/tensor_slice_dataset_op.cc
index e85f59b584..19d4816ff3 100644
--- a/tensorflow/core/kernels/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/tensor_slice_dataset_op.cc
@@ -86,7 +86,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
- std::vector<NodeBuilder::NodeOut> components;
+ std::vector<Node*> components;
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc
index 9381915ae9..31e5737f62 100644
--- a/tensorflow/core/kernels/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/zip_dataset_op.cc
@@ -80,7 +80,7 @@ class ZipDatasetOp : public DatasetOpKernel {
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
- std::vector<NodeBuilder::NodeOut> input_graph_nodes;
+ std::vector<Node*> input_graph_nodes;
input_graph_nodes.reserve(inputs_.size());
for (const auto& input : inputs_) {
Node* input_node;
diff --git a/tensorflow/core/lib/core/arena.cc b/tensorflow/core/lib/core/arena.cc
index 2a04f7bd39..55e481d0e6 100644
--- a/tensorflow/core/lib/core/arena.cc
+++ b/tensorflow/core/lib/core/arena.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mem.h"
@@ -113,24 +114,11 @@ void Arena::MakeNewBlock(const uint32 alignment) {
CHECK(SatisfyAlignment(alignment));
}
-// The following simple numeric routines also exist in util/math/mathutil.h
-// but we don't want to depend on that library.
-
-// Euclid's algorithm for Greatest Common Denominator.
-static uint32 GCD(uint32 x, uint32 y) {
- while (y != 0) {
- uint32 r = x % y;
- x = y;
- y = r;
- }
- return x;
-}
-
static uint32 LeastCommonMultiple(uint32 a, uint32 b) {
if (a > b) {
- return (a / GCD(a, b)) * b;
+ return (a / MathUtil::GCD<uint32>(a, b)) * b;
} else if (a < b) {
- return (b / GCD(b, a)) * a;
+ return (b / MathUtil::GCD<uint32>(b, a)) * a;
} else {
return a;
}
diff --git a/tensorflow/core/lib/math/math_util.h b/tensorflow/core/lib/math/math_util.h
index 6f279865e7..9e71598622 100644
--- a/tensorflow/core/lib/math/math_util.h
+++ b/tensorflow/core/lib/math/math_util.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_LIB_MATH_MATH_UTIL_H_
#define TENSORFLOW_LIB_MATH_MATH_UTIL_H_
+#include <type_traits>
+
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -59,6 +61,9 @@ class MathUtil {
template <typename IntegralType, bool ceil>
static IntegralType CeilOrFloorOfRatio(IntegralType numerator,
IntegralType denominator);
+
+ template <typename IntegralType>
+ static IntegralType GCD(IntegralType x, IntegralType y);
};
// ---- CeilOrFloorOfRatio ----
@@ -107,6 +112,18 @@ IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator,
}
}
+template <typename IntegralType>
+IntegralType MathUtil::GCD(IntegralType a, IntegralType b) {
+ static_assert(std::is_unsigned<IntegralType>::value,
+ "signed GCD not supported!");
+ while (b != 0) {
+ IntegralType r = a % b;
+ a = b;
+ b = r;
+ }
+ return a;
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_LIB_MATH_MATH_UTIL_H_
diff --git a/tensorflow/core/lib/math/math_util_test.cc b/tensorflow/core/lib/math/math_util_test.cc
index eaf8c31a43..a96e5467c3 100644
--- a/tensorflow/core/lib/math/math_util_test.cc
+++ b/tensorflow/core/lib/math/math_util_test.cc
@@ -195,4 +195,33 @@ TEST(MathUtil, CeilOfRatio) {
#endif
}
+struct GCDTestCase {
+ unsigned int x;
+ unsigned int y;
+ unsigned int gcd;
+};
+
+TEST(MathUtil, GCD) {
+ std::vector<GCDTestCase> testcases({
+ {10, 20, 10}, //
+ {27, 8, 1}, //
+ {4, 3, 1}, //
+ {6, 8, 2}, //
+ {5, 0, 5}, //
+ {5, 5, 5}, //
+ {0, 0, 0} //
+ });
+
+ for (const auto& tc : testcases) {
+ EXPECT_EQ(tc.gcd, MathUtil::GCD<uint32>(tc.x, tc.y));
+ EXPECT_EQ(tc.gcd, MathUtil::GCD<uint32>(tc.y, tc.x));
+ EXPECT_EQ(tc.gcd, MathUtil::GCD<uint64>(tc.x, tc.y));
+ EXPECT_EQ(tc.gcd, MathUtil::GCD<uint64>(tc.y, tc.x));
+ }
+
+ const uint64 biggish_prime = 1666666667;
+ EXPECT_EQ(biggish_prime,
+ MathUtil::GCD<uint64>(biggish_prime * 3, biggish_prime * 4));
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/lib/monitoring/collected_metrics.h b/tensorflow/core/lib/monitoring/collected_metrics.h
index fbef25619f..acdb0d86ed 100644
--- a/tensorflow/core/lib/monitoring/collected_metrics.h
+++ b/tensorflow/core/lib/monitoring/collected_metrics.h
@@ -88,6 +88,7 @@ struct Point {
ValueType value_type;
int64 int64_value;
string string_value;
+ bool bool_value;
HistogramProto histogram_value;
// start_timestamp and end_timestamp indicate the time period over which this
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index 113d37e07d..2c8e250c56 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -225,6 +225,12 @@ inline void CollectValue(const string& value, Point* const point) {
}
template <>
+inline void CollectValue(const bool& value, Point* const point) {
+ point->value_type = ValueType::kBool;
+ point->bool_value = value;
+}
+
+template <>
inline void CollectValue(const HistogramProto& value, Point* const point) {
point->value_type = ValueType::kHistogram;
// This is inefficient. If and when we hit snags, we can change the API to do
diff --git a/tensorflow/core/lib/monitoring/gauge.h b/tensorflow/core/lib/monitoring/gauge.h
index 75471cfb22..ec978a9193 100644
--- a/tensorflow/core/lib/monitoring/gauge.h
+++ b/tensorflow/core/lib/monitoring/gauge.h
@@ -86,8 +86,29 @@ class GaugeCell<int64> {
TF_DISALLOW_COPY_AND_ASSIGN(GaugeCell);
};
+// Explicit specialization of GaugeCell<bool>. Compared to the primary
+// template, it uses atomic values as opposed to mutex. This class is
+// thread-safe.
+template <>
+class GaugeCell<bool> {
+ public:
+ explicit GaugeCell(bool value) : value_(value) {}
+ ~GaugeCell() {}
+
+ // Atomically sets the value.
+ void Set(bool value);
+
+ // Retrieves the current value.
+ bool value() const;
+
+ private:
+ std::atomic<bool> value_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GaugeCell);
+};
+
// A stateful class for updating a gauge-like metric. Allowed ValueType are
-// int64 and string.
+// int64, string and bool.
//
// This class encapsulates a set of values (or a single value for a label-less
// metric). Each value is identified by a tuple of labels. The class allows the
@@ -117,6 +138,9 @@ class Gauge {
//
// auto* integer_gauge = Gauge<int64, 0>::New("/tensorflow/integer_gauge",
// "Integer gauge")
+ //
+ // auto* bool_gauge = Gauge<bool, 0>::New("/tensorflow/bool_gauge",
+ // "Bool gauge")
template <typename... MetricDefArgs>
static Gauge* New(MetricDefArgs&&... metric_def_args);
@@ -172,12 +196,17 @@ inline void GaugeCell<int64>::Set(int64 value) { value_ = value; }
inline int64 GaugeCell<int64>::value() const { return value_; }
+inline void GaugeCell<bool>::Set(bool value) { value_ = value; }
+
+inline bool GaugeCell<bool>::value() const { return value_; }
+
template <typename ValueType, int NumLabels>
template <typename... MetricDefArgs>
Gauge<ValueType, NumLabels>* Gauge<ValueType, NumLabels>::New(
MetricDefArgs&&... metric_def_args) {
static_assert(std::is_same<ValueType, int64>::value ||
- std::is_same<ValueType, string>::value,
+ std::is_same<ValueType, string>::value ||
+ std::is_same<ValueType, bool>::value,
"Gauge only allows int64 and string types.");
return new Gauge<ValueType, NumLabels>(
MetricDef<MetricKind::kGauge, ValueType, NumLabels>(
diff --git a/tensorflow/core/lib/monitoring/gauge_test.cc b/tensorflow/core/lib/monitoring/gauge_test.cc
index f98cfe2a3b..c8f673db38 100644
--- a/tensorflow/core/lib/monitoring/gauge_test.cc
+++ b/tensorflow/core/lib/monitoring/gauge_test.cc
@@ -87,6 +87,28 @@ TEST(GaugeOfStringValue, GetCell) {
EXPECT_EQ("bar", same_cell->value());
}
+auto* bool_gauge =
+ Gauge<bool, 0>::New("/tensorflow/test/bool_gauge", "Gauge of bool value.");
+
+TEST(GaugeOfBoolValue, InitializedWithFalseValue) {
+ EXPECT_EQ(false, bool_gauge->GetCell()->value());
+}
+
+TEST(GaugeOfBoolValue, GetCell) {
+ auto* cell = bool_gauge->GetCell();
+ EXPECT_EQ(false, cell->value());
+
+ cell->Set(true);
+ EXPECT_EQ(true, cell->value());
+
+ auto* same_cell = bool_gauge->GetCell();
+ EXPECT_EQ(true, cell->value());
+
+ same_cell->Set(false);
+ EXPECT_EQ(false, cell->value());
+ EXPECT_EQ(false, same_cell->value());
+}
+
} // namespace
} // namespace monitoring
} // namespace tensorflow
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index a7f14f9c94..f046842618 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -28,16 +28,16 @@ namespace monitoring {
// The different metric kinds available.
//
// Gauge indicates that the metric's values are instantaneous measurements of a
-// (typically) continuously varying quantity or a string value. Examples: a
-// process's current heap size, a queue's current length, the name of the binary
-// used by a process.
+// (typically) continuously varying value. Examples: a process's current heap
+// size, a queue's current length, the name of the binary used by a process,
+// whether a task is complete.
//
// Cumulative indicates that the metric's values represent non-negative changes
// over specified time periods. Example: the number of rpc calls to a service.
enum class MetricKind : int { kGauge = 0, kCumulative };
// The type of the metric values.
-enum class ValueType : int { kInt64 = 0, kHistogram, kString };
+enum class ValueType : int { kInt64 = 0, kHistogram, kString, kBool };
// Everything in the internal namespace is implementation details. Do not depend
// on this.
@@ -61,6 +61,11 @@ inline ValueType GetValueType<string>() {
return ValueType::kString;
}
+template <>
+inline ValueType GetValueType<bool>() {
+ return ValueType::kBool;
+}
+
} // namespace internal
// Abstract base class for a metric definition.
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 9fa6423d59..6f4ea09206 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -724,8 +724,8 @@ REGISTER_OP("OnesLike")
.Input("x: T")
.Output("y: T")
.Attr(
- "T: {float, double, int8, uint8, int16, uint16, int32, int64, "
- "complex64, complex128, bool}")
+ "T: {bfloat16, float, double, int8, uint8, int16, uint16, int32, "
+ "int64, complex64, complex128, bool}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns a tensor of ones with the same shape and type as x.
@@ -738,7 +738,7 @@ y: a tensor of the same shape and type as x but filled with ones.
REGISTER_OP("Diag")
.Input("diagonal: T")
.Output("output: T")
- .Attr("T: {float, double, int32, int64, complex64, complex128}")
+ .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle in = c->input(0);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
@@ -776,7 +776,7 @@ diagonal: Rank k tensor where k is at most 1.
REGISTER_OP("DiagPart")
.Input("input: T")
.Output("diagonal: T")
- .Attr("T: {float, double, int32, int64, complex64, complex128}")
+ .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle in = c->input(0);
if (!c->RankKnown(in)) {
@@ -1059,9 +1059,8 @@ REGISTER_OP("Reverse")
.Input("dims: bool")
.Output("output: T")
.Attr(
- "T: {uint8, int8, uint16, int16, int32, int64, bool, half, float, "
- "double, complex64, "
- "complex128, string}")
+ "T: {uint8, int8, uint16, int16, int32, int64, bool, half, "
+ "float, double, complex64, complex128, string}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle dims;
@@ -1137,9 +1136,8 @@ REGISTER_OP("ReverseV2")
.Output("output: T")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr(
- "T: {uint8, int8, uint16, int16, int32, int64, bool, half, float, "
- "double, complex64, "
- "complex128, string}")
+ "T: {uint8, int8, uint16, int16, int32, int64, bool, half, bfloat16, "
+ "float, double, complex64, complex128, string}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle axis;
@@ -1834,7 +1832,7 @@ this operation.
REGISTER_OP("CheckNumerics")
.Input("tensor: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("message: string")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
@@ -4565,12 +4563,12 @@ REGISTER_OP("Bitcast")
.Output("output: type")
// All supported dtypes are listed here to include qint16 and quint16.
.Attr(
- "T: {float, double, int64, int32, uint8, uint16, int8, int16,"
+ "T: {bfloat16, float, double, int64, int32, uint8, uint16, int8, int16,"
" complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
" half}")
.Attr(
- "type: {float, double, int64, int32, uint8, uint16, int8, int16,"
- " complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
+ "type: {bfloat16, float, double, int64, int32, uint8, uint16, int8, "
+ "int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
" half}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
@@ -4782,7 +4780,7 @@ REGISTER_OP("QuantizeAndDequantize")
.Attr("input_min: float = 0")
.Attr("input_max: float = 0")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Deprecated(22, "Replaced by QuantizeAndDequantizeV2")
.Doc(R"doc(
@@ -4798,7 +4796,7 @@ REGISTER_OP("QuantizeAndDequantizeV2")
.Attr("num_bits: int = 8")
.Attr("range_given: bool = false")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@@ -4877,7 +4875,7 @@ REGISTER_OP("QuantizeAndDequantizeV3")
.Attr("signed_input: bool = true")
.Attr("range_given: bool = true")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 6bf226e7a5..be41531347 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -469,6 +469,24 @@ stop: corresponds to stop in python's xrange().
step: corresponds to step in python's xrange().
)doc");
+REGISTER_OP("RandomDataset")
+ .Input("seed: int64")
+ .Input("seed2: int64")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a Dataset that returns pseudorandom numbers.
+
+seed: A scalar seed for the random number generator. If either seed or
+ seed2 is set to be non-zero, the random number generator is seeded
+ by the given seed. Otherwise, a random seed is used.
+seed2: A second scalar seed to avoid seed collision.
+)doc");
+
REGISTER_OP("ShuffleDataset")
.Input("input_dataset: variant")
.Input("buffer_size: int64")
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index df75caca37..45ebfa203b 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -85,7 +85,7 @@ REGISTER_OP("BatchMatMul")
.Input("x: T")
.Input("y: T")
.Output("output: T")
- .Attr("T: {half, float, double, int32, complex64, complex128}")
+ .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.SetShapeFn([](InferenceContext* c) {
@@ -184,7 +184,7 @@ _HostCast requires its input and produces its output in host memory.
REGISTER_OP("Abs")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, float, double, int32, int64}")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes the absolute value of a tensor.
@@ -210,29 +210,31 @@ value is computed as \\( \sqrt{a^2 + b^2}\\).
)doc");
// Declares cwise unary operations signature: 't -> 't
-#define UNARY() \
- Input("x: T") \
- .Output("y: T") \
- .Attr("T: {half, float, double, int32, int64, complex64, complex128}") \
+#define UNARY() \
+ Input("x: T") \
+ .Output("y: T") \
+ .Attr( \
+ "T: {half, bfloat16, float, double, int32, int64, complex64, " \
+ "complex128}") \
.SetShapeFn(shape_inference::UnchangedShape)
-#define UNARY_REAL() \
- Input("x: T") \
- .Output("y: T") \
- .Attr("T: {half, float, double}") \
+#define UNARY_REAL() \
+ Input("x: T") \
+ .Output("y: T") \
+ .Attr("T: {half, bfloat16, float, double}") \
.SetShapeFn(shape_inference::UnchangedShape)
-#define UNARY_COMPLEX() \
- Input("x: T") \
- .Output("y: T") \
- .Attr("T: {half, float, double, complex64, complex128}") \
+#define UNARY_COMPLEX() \
+ Input("x: T") \
+ .Output("y: T") \
+ .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
.SetShapeFn(shape_inference::UnchangedShape)
-#define UNARY_GRADIENT_COMPLEX() \
- Input("y: T") \
- .Input("dy: T") \
- .Output("z: T") \
- .Attr("T: {half, float, double, complex64, complex128}") \
+#define UNARY_GRADIENT_COMPLEX() \
+ Input("y: T") \
+ .Input("dy: T") \
+ .Output("z: T") \
+ .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
.SetShapeFn(shape_inference::UnchangedShape)
REGISTER_OP("Neg")
@@ -481,7 +483,7 @@ Computes atan of x element-wise.
REGISTER_OP("IsNan")
.Input("x: T")
.Output("y: bool")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns which elements of x are NaN.
@@ -494,7 +496,7 @@ Equivalent to np.isnan
REGISTER_OP("IsInf")
.Input("x: T")
.Output("y: bool")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns which elements of x are Inf.
@@ -507,7 +509,7 @@ Equivalent to np.isinf
REGISTER_OP("IsFinite")
.Input("x: T")
.Output("y: bool")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns which elements of x are finite.
@@ -520,7 +522,9 @@ Equivalent to np.isfinite
REGISTER_OP("Sign")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, float, double, int32, int64, complex64, complex128}")
+ .Attr(
+ "T: {half, bfloat16, float, double, int32, int64, complex64, "
+ "complex128}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns an element-wise indication of the sign of a number.
@@ -533,7 +537,7 @@ For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
REGISTER_OP("Floor")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns element-wise largest integer not greater than x.
@@ -542,7 +546,7 @@ Returns element-wise largest integer not greater than x.
REGISTER_OP("Ceil")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns element-wise smallest integer in not less than x.
@@ -551,7 +555,7 @@ Returns element-wise smallest integer in not less than x.
REGISTER_OP("Rint")
.Input("x: T")
.Output("y: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns element-wise integer closest to x.
@@ -569,22 +573,23 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.]
// Declares cwise binary operations signature: 't, 't -> 't.
-#define BINARY_MORE() \
- Input("x: T").Input("y: T").Output("z: T").Attr( \
- "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, " \
- "complex64, complex128}")
+#define BINARY_MORE() \
+ Input("x: T").Input("y: T").Output("z: T").Attr( \
+ "T: {half, bfloat16, float, double, uint8, int8, uint16, int16, int32, " \
+ "int64, complex64, complex128}")
-#define BINARY_FEWER() \
- Input("x: T").Input("y: T").Output("z: T").Attr( \
- "T: {half, float, double, int32, int64, complex64, complex128}")
+#define BINARY_FEWER() \
+ Input("x: T").Input("y: T").Output("z: T").Attr( \
+ "T: {half, bfloat16, float, double, int32, int64, complex64, " \
+ "complex128}")
REGISTER_OP("Add")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr(
- "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
- "complex128, string}")
+ "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
+ "complex64, complex128, string}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns x + y element-wise.
@@ -600,8 +605,8 @@ REGISTER_OP("AddV2")
.Input("y: T")
.Output("z: T")
.Attr(
- "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
- "complex128}")
+ "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
+ "complex64, complex128}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.SetIsAggregate()
.SetIsCommutative()
@@ -757,7 +762,7 @@ REGISTER_OP("Maximum")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {half, float, double, int32, int64}")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
@@ -788,7 +793,7 @@ REGISTER_OP("Minimum")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {half, float, double, int32, int64}")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
@@ -802,7 +807,7 @@ REGISTER_OP("Mod")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {int32, int64, float, double}")
+ .Attr("T: {int32, int64, bfloat16, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns element-wise remainder of division. This emulates C semantics in that
@@ -817,7 +822,7 @@ REGISTER_OP("FloorMod")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {int32, int64, float, double}")
+ .Attr("T: {int32, int64, bfloat16, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns element-wise remainder of division. When `x < 0` xor `y < 0` is
@@ -832,7 +837,7 @@ REGISTER_OP("TruncateMod")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {int32, int64, float, double}")
+ .Attr("T: {int32, int64, bfloat16, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns element-wise remainder of division. This emulates C semantics in that
@@ -847,7 +852,9 @@ REGISTER_OP("Pow")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {half, float, double, int32, int64, complex64, complex128}")
+ .Attr(
+ "T: {half, bfloat16, float, double, int32, int64, complex64, "
+ "complex128}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Computes the power of one value to another.
@@ -946,7 +953,7 @@ REGISTER_OP("Atan2")
.Input("y: T")
.Input("x: T")
.Output("z: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Computes arctangent of `y/x` element-wise, respecting signs of the arguments.
@@ -1064,15 +1071,15 @@ Returns the truth value of (x >= y) element-wise.
// --------------------------------------------------------------------------
-#define EQUALITY_COMPARISON() \
- Input("x: T") \
- .Input("y: T") \
- .Output("z: bool") \
- .SetIsCommutative() \
- .Attr( \
- "T: {half, float, double, uint8, int8, int16, int32, int64, " \
- "complex64, " \
- "quint8, qint8, qint32, string, bool, complex128}") \
+#define EQUALITY_COMPARISON() \
+ Input("x: T") \
+ .Input("y: T") \
+ .Output("z: bool") \
+ .SetIsCommutative() \
+ .Attr( \
+ "T: {half, bfloat16, float, double, uint8, int8, int16, int32, " \
+ "int64, complex64, quint8, qint8, qint32, string, bool, " \
+ "complex128}") \
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
REGISTER_OP("Equal")
@@ -1291,7 +1298,7 @@ REGISTER_OP("MatMul")
.Output("product: T")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
- .Attr("T: {half, float, double, int32, complex64, complex128}")
+ .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
.SetShapeFn(shape_inference::MatMulShape)
.Doc(R"doc(
Multiply the matrix "a" by the matrix "b".
@@ -1811,10 +1818,11 @@ output: Has same shape as data, except for dimension 0 which
REGISTER_OP("UnsortedSegmentSum")
.Input("data: T")
.Input("segment_ids: Tindices")
- .Input("num_segments: int32")
+ .Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32,int64}")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn)
.Doc(R"doc(
Computes the sum along segments of a tensor.
@@ -1849,10 +1857,11 @@ output: Has same shape as data, except for the first `segment_ids.rank`
REGISTER_OP("UnsortedSegmentMax")
.Input("data: T")
.Input("segment_ids: Tindices")
- .Input("num_segments: int32")
+ .Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: realnumbertype")
.Attr("Tindices: {int32,int64}")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn)
.Doc(R"doc(
Computes the Max along segments of a tensor.
@@ -2103,7 +2112,7 @@ REGISTER_OP("Range")
.Input("limit: Tidx")
.Input("delta: Tidx")
.Output("output: Tidx")
- .Attr("Tidx: {float, double, int32, int64} = DT_INT32")
+ .Attr("Tidx: {bfloat16, float, double, int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
@@ -2158,7 +2167,7 @@ REGISTER_OP("LinSpace")
.Input("stop: T")
.Input("num: Tidx")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 654e890b57..102de94787 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -73,7 +73,7 @@ REGISTER_OP("AvgPool")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::AvgPoolShape)
.Doc(R"doc(
Performs average pooling on the input.
@@ -101,7 +101,7 @@ REGISTER_OP("AvgPoolGrad")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -300,7 +300,7 @@ REGISTER_OP("FusedBatchNormV2")
.Output("batch_variance: U")
.Output("reserve_space_1: U")
.Output("reserve_space_2: U")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("data_format: string = 'NHWC'")
@@ -393,7 +393,7 @@ REGISTER_OP("FusedBatchNormGradV2")
.Output("offset_backprop: U")
.Output("reserve_space_3: U")
.Output("reserve_space_4: U")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("data_format: string = 'NHWC'")
@@ -508,11 +508,12 @@ REGISTER_OP("Conv2D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn(shape_inference::Conv2DShape)
.Doc(R"doc(
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
@@ -546,7 +547,7 @@ filter: A 4-D tensor of shape
output: A 4-D tensor. The dimension order is determined by the value of
`data_format`, see below for details.
strides: 1-D tensor of length 4. The stride of the sliding window for each
- dimension of `input`. The dimension order is determined by the value of
+ dimension of `input`. The dimension order is determined by the value of
`data_format`, see below for details.
padding: The type of padding algorithm to use.
data_format: Specify the data format of the input and output data. With the
@@ -554,6 +555,11 @@ data_format: Specify the data format of the input and output data. With the
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
REGISTER_OP("Conv2DBackpropInput")
@@ -561,11 +567,12 @@ REGISTER_OP("Conv2DBackpropInput")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -589,10 +596,15 @@ padding: The type of padding algorithm to use.
output: 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient
w.r.t. the input of the convolution.
data_format: Specify the data format of the input and output data. With the
- default format "NHWC", the data is stored in the order of:
- [batch, in_height, in_width, in_channels].
- Alternatively, the format could be "NCHW", the data storage order of:
- [batch, in_channels, in_height, in_width].
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
)doc");
// TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
@@ -603,11 +615,12 @@ REGISTER_OP("Conv2DBackpropFilter")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
@@ -632,10 +645,15 @@ output: 4-D with shape
`[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
the `filter` input of the convolution.
data_format: Specify the data format of the input and output data. With the
- default format "NHWC", the data is stored in the order of:
- [batch, in_height, in_width, in_channels].
- Alternatively, the format could be "NCHW", the data storage order of:
- [batch, in_channels, in_height, in_width].
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
)doc");
namespace {
@@ -819,10 +837,11 @@ REGISTER_OP("DepthwiseConv2dNative")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape)
.Doc(R"doc(
Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors.
@@ -845,7 +864,6 @@ for k in 0..in_channels-1
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
-
strides: 1-D of length 4. The stride of the sliding window for each dimension
of `input`.
padding: The type of padding algorithm to use.
@@ -854,6 +872,11 @@ data_format: Specify the data format of the input and output data. With the
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
)doc");
REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
@@ -861,10 +884,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -892,6 +916,11 @@ data_format: Specify the data format of the input and output data. With the
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
output: 4-D with shape according to `data_format`. For example, if
`data_format` is 'NHWC', output shape is `[batch, in_height,
in_width, in_channels]`. Gradient w.r.t. the input of the
@@ -903,10 +932,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
@@ -935,6 +965,11 @@ data_format: Specify the data format of the input and output data. With the
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
output: 4-D with shape
`[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
the `filter` input of the convolution.
@@ -945,10 +980,11 @@ REGISTER_OP("Conv3D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
.SetShapeFn(shape_inference::Conv3DShape)
.Doc(R"doc(
Computes a 3-D convolution given 5-D `input` and `filter` tensors.
@@ -970,6 +1006,11 @@ data_format: The data format of the input and output data. With the
[batch, in_depth, in_height, in_width, in_channels].
Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
+dilations: 1-D tensor of length 5. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
REGISTER_OP("Conv3DBackpropInput")
@@ -1032,10 +1073,11 @@ REGISTER_OP("Conv3DBackpropInputV2")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -1061,6 +1103,11 @@ data_format: The data format of the input and output data. With the
[batch, in_depth, in_height, in_width, in_channels].
Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
+dilations: 1-D tensor of length 5. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
@@ -1069,10 +1116,11 @@ REGISTER_OP("Conv3DBackpropFilterV2")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
@@ -1098,6 +1146,11 @@ data_format: The data format of the input and output data. With the
[batch, in_depth, in_height, in_width, in_channels].
Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
+dilations: 1-D tensor of length 5. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
@@ -1110,7 +1163,7 @@ REGISTER_OP("AvgPool3D")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn(shape_inference::Pool3DShape)
.Doc(R"doc(
Performs 3D average pooling on the input.
@@ -1137,7 +1190,7 @@ REGISTER_OP("AvgPool3DGrad")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -1172,7 +1225,7 @@ REGISTER_OP("MaxPool3D")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: {float}")
+ .Attr("T: {bfloat16, float}")
.SetShapeFn(shape_inference::Pool3DShape)
.Doc(R"doc(
Performs 3D max pooling on the input.
@@ -1200,8 +1253,8 @@ REGISTER_OP("MaxPool3DGrad")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: {float} = DT_FLOAT")
- .Attr("TInput: {float} = DT_FLOAT")
+ .Attr("T: {bfloat16, float} = DT_FLOAT")
+ .Attr("TInput: {bfloat16, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 5);
})
@@ -1266,7 +1319,7 @@ data_format: The data format of the input and output data. With the
REGISTER_OP("L2Loss")
.Input("t: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
L2 Loss.
@@ -1288,7 +1341,7 @@ REGISTER_OP("LRN")
.Attr("bias: float = 1.0")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.5")
- .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("T: {half, bfloat16, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 4);
})
@@ -1323,7 +1376,7 @@ REGISTER_OP("LRNGrad")
.Attr("bias: float = 1.0")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.5")
- .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("T: {half, bfloat16, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
@@ -1349,8 +1402,8 @@ output: The gradients for LRN.
REGISTER_OP("MaxPool")
.Attr(
- "T: {float, double, int32, int64, uint8, int16, int8, uint16, "
- "half, qint8} = DT_FLOAT")
+ "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
+ "uint16, qint8} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
@@ -1376,8 +1429,8 @@ output: The max pooled output tensor.
REGISTER_OP("MaxPoolV2")
.Attr(
- "T: {float, double, int32, int64, uint8, int16, int8, uint16, "
- "half, qint8} = DT_FLOAT")
+ "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
+ "uint16, qint8} = DT_FLOAT")
.Attr(GetPaddingAttrString())
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
.Input("input: T")
@@ -1860,7 +1913,7 @@ backprops: The gradients:
REGISTER_OP("Elu")
.Input("features: T")
.Output("activations: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise.
@@ -1873,7 +1926,7 @@ REGISTER_OP("EluGrad")
.Input("gradients: T")
.Input("outputs: T")
.Output("backprops: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes gradients for the exponential linear (Elu) operation.
@@ -1887,7 +1940,7 @@ backprops: The gradients: `gradients * (outputs + 1)` if outputs < 0,
REGISTER_OP("Selu")
.Input("features: T")
.Output("activations: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
@@ -1900,7 +1953,7 @@ REGISTER_OP("SeluGrad")
.Input("gradients: T")
.Input("outputs: T")
.Output("backprops: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes gradients for the scaled exponential linear (Selu) operation.
@@ -1962,7 +2015,7 @@ backprops: The gradients: `gradients / (1 + abs(features)) ** 2`.
REGISTER_OP("Softmax")
.Input("logits: T")
.Output("softmax: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
})
@@ -1982,7 +2035,7 @@ softmax: Same shape as `logits`.
REGISTER_OP("LogSoftmax")
.Input("logits: T")
.Output("logsoftmax: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
})
@@ -2004,7 +2057,7 @@ REGISTER_OP("SoftmaxCrossEntropyWithLogits")
.Input("labels: T")
.Output("loss: T")
.Output("backprop: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
@@ -2033,7 +2086,7 @@ REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
.Input("labels: Tlabels")
.Output("loss: T")
.Output("backprop: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("Tlabels: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle features;
@@ -2613,6 +2666,7 @@ REGISTER_OP("QuantizedConv2D")
.Attr("out_type: quantizedtype = DT_QINT32")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
ShapeHandle unused;
@@ -2641,7 +2695,11 @@ min_filter: The float value that the lowest quantized filter value represents.
max_filter: The float value that the highest quantized filter value represents.
min_output: The float value that the lowest quantized output value represents.
max_output: The float value that the highest quantized output value represents.
-
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
REGISTER_OP("QuantizedMaxPool")
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 2429171fa9..31d9c82e53 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -29,7 +29,7 @@ REGISTER_OP("RandomUniform")
.Output("output: dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
- .Attr("dtype: {half,float,double}")
+ .Attr("dtype: {half,bfloat16,float,double}")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
@@ -87,7 +87,7 @@ REGISTER_OP("RandomStandardNormal")
.Output("output: dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
- .Attr("dtype: {half,float,double}")
+ .Attr("dtype: {half,bfloat16,float,double}")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
@@ -115,7 +115,7 @@ REGISTER_OP("ParameterizedTruncatedNormal")
.Output("output: dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
- .Attr("dtype: {half,float,double}")
+ .Attr("dtype: {half,bfloat16,float,double}")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
@@ -145,7 +145,7 @@ REGISTER_OP("TruncatedNormal")
.Output("output: dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
- .Attr("dtype: {half,float,double}")
+ .Attr("dtype: {half,bfloat16,float,double}")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
@@ -201,10 +201,11 @@ REGISTER_OP("Multinomial")
.SetIsStateful()
.Input("logits: T")
.Input("num_samples: int32")
- .Output("output: int64")
+ .Output("output: output_dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("T: realnumbertype")
+ .Attr("output_dtype: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle logits_shape;
ShapeHandle unused;
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index cdfbec85cf..bf9e673e8e 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -204,7 +204,10 @@ Status VariableShapeShapeFn(InferenceContext* c) {
if (handle_data == nullptr || handle_data->empty()) {
return errors::InvalidArgument("Handle doesn't have shape information.");
}
- c->set_output(0, (*handle_data)[0].shape);
+ ShapeHandle var_shape = (*handle_data)[0].shape;
+ int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)
+ : InferenceContext::kUnknownDim;
+ c->set_output(0, c->Vector(rank));
return Status::OK();
}
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index 8414519f0b..772e2531dc 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -256,6 +256,48 @@ REGISTER_OP("DeserializeSparse")
.Doc(R"doc(
Deserialize `SparseTensor` objects.
+The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where
+the last dimension stores serialized `SparseTensor` objects and the other N
+dimensions (N >= 0) correspond to a batch. The ranks of the original
+`SparseTensor` objects must all match. When the final `SparseTensor` is
+created, its rank is the rank of the incoming `SparseTensor` objects plus N;
+the sparse tensors have been concatenated along new dimensions, one for each
+batch.
+
+The output `SparseTensor` object's shape values for the original dimensions
+are the max across the input `SparseTensor` objects' shape values for the
+corresponding dimensions. The new dimensions match the size of the batch.
+
+The input `SparseTensor` objects' indices are assumed ordered in
+standard lexicographic order. If this is not the case, after this
+step run `SparseReorder` to restore index ordering.
+
+For example, if the serialized input is a `[2 x 3]` matrix representing two
+original `SparseTensor` objects:
+
+ index = [ 0]
+ [10]
+ [20]
+ values = [1, 2, 3]
+ shape = [50]
+
+and
+
+ index = [ 2]
+ [10]
+ values = [4, 5]
+ shape = [30]
+
+then the final deserialized `SparseTensor` will be:
+
+ index = [0 0]
+ [0 10]
+ [0 20]
+ [1 2]
+ [1 10]
+ values = [1, 2, 3, 4, 5]
+ shape = [2 50]
+
serialized_sparse: The serialized `SparseTensor` objects. The last dimension
must have 3 columns.
dtype: The `dtype` of the serialized `SparseTensor` objects.
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index da5f091e9f..5b1f5d2477 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -513,6 +513,62 @@ output_ref: Same as ref. Returned as a convenience for operations that want to
use the updated values after the update is done.
)doc");
+REGISTER_OP("ResourceScatterNdUpdate")
+ .Input("ref: resource")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = true")
+ .SetShapeFn(shape_inference::ScatterNdUpdateShape)
+ .Doc(R"doc(
+Applies sparse `updates` to individual values or slices within a given
+variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to update 4 scattered elements to a rank-1 tensor to
+8 elements. In Python, that update would look like this:
+
+```python
+ ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ update = tf.scatter_nd_update(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(update)
+```
+
+The resulting update to ref would look like this:
+
+ [1, 11, 3, 10, 9, 6, 7, 12]
+
+See @{tf.scatter_nd} for more details about how to make updates to
+slices.
+
+ref: A resource handle. Must be from a VarHandleOp.
+indices: A Tensor. Must be one of the following types: int32, int64.
+ A tensor of indices into ref.
+updates: A Tensor. Must have the same type as ref. A tensor of updated
+ values to add to ref.
+use_locking: An optional bool. Defaults to True. If True, the assignment will
+ be protected by a lock; otherwise the behavior is undefined,
+ but may exhibit less contention.
+)doc");
+
REGISTER_OP("ScatterNdAdd")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc
index 6c0f081852..d476a1a4db 100644
--- a/tensorflow/core/platform/cloud/curl_http_request_test.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc
@@ -263,7 +263,6 @@ TEST(CurlHttpRequestTest, GetRequest) {
std::vector<char> scratch;
scratch.insert(scratch.begin(), kTestContent.begin(), kTestContent.end());
- StringPiece result;
scratch.reserve(100);
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
@@ -594,7 +593,6 @@ TEST(CurlHttpRequestTest, ErrorReturnsNoResponse) {
std::vector<char> scratch;
scratch.insert(scratch.begin(), kTestContent.begin(), kTestContent.end());
- StringPiece result;
scratch.reserve(100);
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
diff --git a/tensorflow/core/platform/cloud/file_block_cache.cc b/tensorflow/core/platform/cloud/file_block_cache.cc
index a472ae52fc..e1afc7b308 100644
--- a/tensorflow/core/platform/cloud/file_block_cache.cc
+++ b/tensorflow/core/platform/cloud/file_block_cache.cc
@@ -181,7 +181,9 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
// The requested offset is at or beyond the end of the file. This can
// happen if `offset` is not block-aligned, and the read returns the last
// block in the file, which does not extend all the way out to `offset`.
- return errors::OutOfRange("EOF at offset ", offset);
+ return errors::OutOfRange("EOF at offset ", offset, " in file ", filename,
+ " at position ", pos, "with data size ",
+ data.size());
}
auto begin = data.begin();
if (offset > pos) {
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 54d38fe962..45e9b05092 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -697,6 +697,9 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
bucket, "/", object);
+ VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ "
+ << offset << " of size: " << out->size();
+
if (out->size() < block_size()) {
// Check stat cache to see if we encountered an interrupted read.
FileStatistics stat;
@@ -706,6 +709,8 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
"File contents are inconsistent for file: %s @ %lu.",
filename.c_str(), offset));
}
+ VLOG(2) << "Successful integrity check for: gs://" << bucket << "/"
+ << object << " @ " << offset;
}
}
@@ -868,6 +873,11 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated));
TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->mtime_nsec)));
+ VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- "
+ << " length: " << stat->length
+ << "; mtime_nsec: " << stat->mtime_nsec
+ << "; updated: " << updated;
+
stat->is_directory = false;
return Status::OK();
};
diff --git a/tensorflow/core/profiler/g3doc/options.md b/tensorflow/core/profiler/g3doc/options.md
index 4c73e372e3..dd12f76d6f 100644
--- a/tensorflow/core/profiler/g3doc/options.md
+++ b/tensorflow/core/profiler/g3doc/options.md
@@ -60,11 +60,14 @@ Currently, profiler only tracks the allocation of memory. As a result, the
accumulated memory request is uaually larger than the peak memory of the overall
model.
-bytes: The memory allocations requested by the operation.
-peak_bytes: The peak requested memory (not de-allocated) by the operation.
-residual_bytes: The memory requested by the operation and not de-allocated
+It's recommended to generate timeline to see the allocator memory usage over
+time.
+
+`bytes`: The memory allocations requested by the operation.
+`peak_bytes`: The peak requested memory (not de-allocated) by the operation.
+`residual_bytes`: The memory requested by the operation and not de-allocated
when Compute finishes.
-output_bytes: The memory output by the operation. It's not necessarily requested
+`output_bytes`: The memory output by the operation. It's not necessarily requested
by the current operation. For example, it can be a tensor
forwarded from input to output, with in-place mutation.
diff --git a/tensorflow/core/profiler/internal/tfprof_node.cc b/tensorflow/core/profiler/internal/tfprof_node.cc
index 671b65d708..5cd1050bcc 100644
--- a/tensorflow/core/profiler/internal/tfprof_node.cc
+++ b/tensorflow/core/profiler/internal/tfprof_node.cc
@@ -139,6 +139,25 @@ void ExecStep::AddMemoryStats(const string& dev,
exec_.accelerator_persistent_bytes() +
step_stat.memory_stats().device_persistent_memory_size());
}
+
+ // TODO(xpan): Make this more accurate:
+ // High level: Memory tracking is suspicous and requires large scale
+ // clean up.
+ // Investigte the memory usage difference between CPU/GPU with OpViewTest.
+ //
+ // 1. OpKernelConstruction::allocate_xxx is not traced. Below, we only
+ // discuss OpKernelContext-related allocations.
+ // 2. allocate_output calls allocate_tensor, which is properly tracked in
+ // 'NodeExecStats.memory'.
+ // 3. allocate_temp is only tracked through record_xxx_temp. It appears
+ // in 'NodeExecStats.memory_stats'.
+ // 4. allocate_persistent calls allocate_tensor, which is properly tracked
+ // in 'NodeExecStats.memory'. However, there is no way to count it as
+ // persistent now.
+ // 5. record_xxx_persistent is called when allocate_persistent
+ // is not used and hence tracks some complementary bytes. It appears in
+ // 'NodeExecStats.memory_stats'. It's suspicious. But we should
+ // use it now since it covers constant op.
int64 residual_bytes = 0;
int64 requested_bytes = 0;
int64 peak_bytes = 0;
@@ -147,6 +166,15 @@ void ExecStep::AddMemoryStats(const string& dev,
requested_bytes += mem.total_bytes();
peak_bytes += mem.peak_bytes();
}
+ residual_bytes +=
+ exec_.host_persistent_bytes() + exec_.accelerator_persistent_bytes();
+ requested_bytes += exec_.host_persistent_bytes() +
+ exec_.accelerator_persistent_bytes() +
+ exec_.host_temp_bytes() + exec_.accelerator_temp_bytes();
+ peak_bytes += exec_.host_persistent_bytes() +
+ exec_.accelerator_persistent_bytes() + exec_.host_temp_bytes() +
+ exec_.accelerator_temp_bytes();
+
exec_.set_requested_bytes(requested_bytes);
exec_.set_residual_bytes(residual_bytes);
exec_.set_peak_bytes(peak_bytes);
diff --git a/tensorflow/core/profiler/internal/tfprof_node.h b/tensorflow/core/profiler/internal/tfprof_node.h
index e2d0563a07..77c14cb792 100644
--- a/tensorflow/core/profiler/internal/tfprof_node.h
+++ b/tensorflow/core/profiler/internal/tfprof_node.h
@@ -593,17 +593,11 @@ class TFGraphNode {
int64 accelerator_persistent_bytes() const {
int64 persistent_bytes = 0;
for (const auto& exec : execs_) {
- persistent_bytes += exec.second.accelerator_persistent_bytes();
+ persistent_bytes = std::max(persistent_bytes,
+ exec.second.accelerator_persistent_bytes());
}
return persistent_bytes;
}
- int64 host_persistent_bytes(int64 step) const {
- auto exec = execs_.find(step);
- if (exec == execs_.end()) {
- return 0;
- }
- return exec->second.host_persistent_bytes();
- }
const std::map<int32, std::pair<int64, uint64>>& output_memory(
int64 step) const {
auto exec = execs_.find(step);
diff --git a/tensorflow/core/profiler/internal/tfprof_show_test.cc b/tensorflow/core/profiler/internal/tfprof_show_test.cc
index 1f19f8c322..98773ae19e 100644
--- a/tensorflow/core/profiler/internal/tfprof_show_test.cc
+++ b/tensorflow/core/profiler/internal/tfprof_show_test.cc
@@ -105,12 +105,13 @@ TEST_F(TFProfShowTest, DumpScopeMode) {
"node name | # parameters | # float_ops | requested bytes | peak bytes | "
"residual bytes | output bytes | total execution time | accelerator "
"execution time | cpu execution time\n_TFProfRoot (--/451 params, --/0 "
- "flops, --/0B, --/0B, --/0B, --/2.56KB, --/13us, --/0us, --/13us)\n DW "
- "(3x3x3x6, 162/162 params, 0/0 flops, 0B/0B, 0B/0B, 0B/0B, "
- "1.28KB/1.28KB, 2us/2us, 0us/0us, 2us/2us)\n DW2 (2x2x6x12, 288/288 "
- "params, 0/0 flops, 0B/0B, 0B/0B, 0B/0B, 1.28KB/1.28KB, 11us/11us, "
- "0us/0us, 11us/11us)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, 0B/0B, "
- "0B/0B, 0B/0B, 0us/0us, 0us/0us, 0us/0us)\n",
+ "flops, --/2.56KB, --/2.56KB, --/2.56KB, --/2.56KB, --/13us, --/0us, "
+ "--/13us)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 1.28KB/1.28KB, "
+ "1.28KB/1.28KB, 1.28KB/1.28KB, 1.28KB/1.28KB, 2us/2us, 0us/0us, "
+ "2us/2us)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.28KB/1.28KB, "
+ "1.28KB/1.28KB, 1.28KB/1.28KB, 1.28KB/1.28KB, 11us/11us, 0us/0us, "
+ "11us/11us)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, 0B/0B, 0B/0B, "
+ "0B/0B, 0us/0us, 0us/0us, 0us/0us)\n",
dump_str);
EXPECT_EQ(dump_str, TestToFromProto("scope", opts));
@@ -178,22 +179,22 @@ TEST_F(TFProfShowTest, DumpOpMode) {
EXPECT_EQ(
"nodename|requestedbytes|totalexecutiontime|acceleratorexecutiontime|"
"cpuexecutiontime|#parameters|#float_ops|opoccurrence(run|defined)|"
- "inputshapes\nVariableV20B(0.00%,0.00%),13us(100.00%,0.26%),0us(100.00%,"
- "0.00%),13us(100.00%,0.29%),451params(100.00%,100.00%),0float_ops(100.00%"
- ",0.00%),2|3\n\ninput_type:\t(run*2|defined*3)\texec_time:13us\n\nAdd0B("
- "0.00%,0.00%),0us(99.74%,0.00%),0us(100.00%,0.00%),0us(99.71%,0.00%),"
- "0params(0.00%,0.00%),0float_ops(100.00%,0.00%),0|3\n\ninput_type:0:1,"
- "\t1:1\t(run*0|defined*1)\texec_time:0us\ninput_type:0:2x2x6x12,\t1:1\t("
- "run*0|defined*1)\texec_time:0us\ninput_type:0:3x3x3x6,\t1:1\t(run*0|"
- "defined*1)\texec_time:0us\n\nAssign0B(0.00%,0.00%),0us(99.74%,0.00%),"
- "0us(100.00%,0.00%),0us(99.71%,0.00%),0params(0.00%,0.00%),0float_ops("
- "100.00%,0.00%),0|3\n\ninput_type:0:1,\t1:1\t(run*0|defined*1)\texec_"
+ "inputshapes\nVariableV22.56KB(100.00%,8.40%),13us(100.00%,0.26%),0us("
+ "100.00%,0.00%),13us(100.00%,0.29%),451params(100.00%,100.00%),0float_"
+ "ops(100.00%,0.00%),2|3\n\ninput_type:\t(run*2|defined*3)\texec_time:"
+ "13us\n\nAdd0B(0.00%,0.00%),0us(99.74%,0.00%),0us(100.00%,0.00%),0us(99."
+ "71%,0.00%),0params(0.00%,0.00%),0float_ops(100.00%,0.00%),0|3\n\ninput_"
+ "type:0:1,\t1:1\t(run*0|defined*1)\texec_time:0us\ninput_type:0:2x2x6x12,"
+ "\t1:1\t(run*0|defined*1)\texec_time:0us\ninput_type:0:3x3x3x6,\t1:1\t("
+ "run*0|defined*1)\texec_time:0us\n\nAssign0B(0.00%,0.00%),0us(99.74%,0."
+ "00%),0us(100.00%,0.00%),0us(99.71%,0.00%),0params(0.00%,0.00%),0float_"
+ "ops(100.00%,0.00%),0|3\n\ninput_type:0:1,\t1:1\t(run*0|defined*1)\texec_"
"time:0us\ninput_type:0:2x2x6x12,\t1:2x2x6x12\t(run*0|defined*1)\texec_"
"time:0us\ninput_type:0:3x3x3x6,\t1:3x3x3x6\t(run*0|defined*1)\texec_"
"time:0us\n\nConst0B(0.00%,0.00%),2us(99.74%,0.04%),0us(100.00%,0.00%),"
"2us(99.71%,0.04%),0params(0.00%,0.00%),0float_ops(100.00%,0.00%),1|"
- "10\n\ninput_type:\t(run*1|defined*10)\texec_time:2us\n\nConv2D14.59KB("
- "100.00%,100.00%),4.89ms(99.70%,98.87%),404us(100.00%,100.00%),4.49ms(99."
+ "10\n\ninput_type:\t(run*1|defined*10)\texec_time:2us\n\nConv2D27.90KB("
+ "91.60%,91.60%),4.89ms(99.70%,98.87%),404us(100.00%,100.00%),4.49ms(99."
"67%,98.77%),0params(0.00%,0.00%),10.44kfloat_ops(100.00%,100.00%),2|"
"2\n\ninput_type:0:2x3x3x6,\t1:2x2x6x12\t(run*1|defined*1)\texec_time:"
"597us\ninput_type:0:2x6x6x3,\t1:3x3x3x6\t(run*1|defined*1)\texec_time:4."
diff --git a/tensorflow/core/profiler/internal/tfprof_stats_test.cc b/tensorflow/core/profiler/internal/tfprof_stats_test.cc
index 2f2101d76b..b86a83cb1b 100644
--- a/tensorflow/core/profiler/internal/tfprof_stats_test.cc
+++ b/tensorflow/core/profiler/internal/tfprof_stats_test.cc
@@ -89,21 +89,27 @@ TEST_F(TFProfStatsTest, CustomOpType) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
- "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_parameters: "
- "451\nchildren {\n name: \"DW\"\n exec_micros: 2\n parameters: 162\n "
- "total_exec_micros: 2\n total_parameters: 162\n devices: "
+ "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_requested_bytes: "
+ "2560\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n exec_micros: "
+ "2\n requested_bytes: 1280\n parameters: 162\n total_exec_micros: 2\n "
+ " total_requested_bytes: 1280\n total_parameters: 162\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 2\n "
"total_cpu_exec_micros: 2\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: "
- "1280\n}\nchildren {\n name: \"DW2\"\n exec_micros: 11\n parameters: "
- "288\n total_exec_micros: 11\n total_parameters: 288\n devices: "
+ "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n "
+ " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: "
+ "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"DW2\"\n "
+ "exec_micros: 11\n requested_bytes: 1280\n parameters: 288\n "
+ "total_exec_micros: 11\n total_requested_bytes: 1280\n "
+ "total_parameters: 288\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 11\n "
"total_cpu_exec_micros: 11\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: "
- "1280\n}\nchildren {\n name: \"ScalarW\"\n parameters: 1\n "
- "total_parameters: 1\n total_definition_count: "
+ "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n "
+ " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: "
+ "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"ScalarW\"\n "
+ "parameters: 1\n total_parameters: 1\n total_definition_count: "
"1\n}\ntotal_cpu_exec_micros: 13\ntotal_run_count: "
- "2\ntotal_definition_count: 3\ntotal_output_bytes: 2560\n",
+ "2\ntotal_definition_count: 3\ntotal_peak_bytes: "
+ "2560\ntotal_residual_bytes: 2560\ntotal_output_bytes: 2560\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -119,21 +125,27 @@ TEST_F(TFProfStatsTest, CheckPointOpType) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
- "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_parameters: "
- "451\nchildren {\n name: \"DW\"\n exec_micros: 2\n parameters: 162\n "
- "total_exec_micros: 2\n total_parameters: 162\n devices: "
+ "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_requested_bytes: "
+ "2560\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n exec_micros: "
+ "2\n requested_bytes: 1280\n parameters: 162\n total_exec_micros: 2\n "
+ " total_requested_bytes: 1280\n total_parameters: 162\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 2\n "
"total_cpu_exec_micros: 2\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: "
- "1280\n}\nchildren {\n name: \"DW2\"\n exec_micros: 11\n parameters: "
- "288\n total_exec_micros: 11\n total_parameters: 288\n devices: "
+ "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n "
+ " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: "
+ "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"DW2\"\n "
+ "exec_micros: 11\n requested_bytes: 1280\n parameters: 288\n "
+ "total_exec_micros: 11\n total_requested_bytes: 1280\n "
+ "total_parameters: 288\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 11\n "
"total_cpu_exec_micros: 11\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: "
- "1280\n}\nchildren {\n name: \"ScalarW\"\n parameters: 1\n "
- "total_parameters: 1\n total_definition_count: "
+ "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n "
+ " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: "
+ "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"ScalarW\"\n "
+ "parameters: 1\n total_parameters: 1\n total_definition_count: "
"1\n}\ntotal_cpu_exec_micros: 13\ntotal_run_count: "
- "2\ntotal_definition_count: 3\ntotal_output_bytes: 2560\n",
+ "2\ntotal_definition_count: 3\ntotal_peak_bytes: "
+ "2560\ntotal_residual_bytes: 2560\ntotal_output_bytes: 2560\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -150,7 +162,7 @@ TEST_F(TFProfStatsTest, TestGraph) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\ntotal_exec_micros: 4945\ntotal_requested_bytes: "
- "14592\ntotal_parameters: 451\nchildren {\n name: "
+ "30464\ntotal_parameters: 451\nchildren {\n name: "
"\"DW/Initializer/random_normal/mul\"\n children {\n name: "
"\"DW/Initializer/random_normal/RandomStandardNormal\"\n children {\n "
" name: \"DW/Initializer/random_normal/shape\"\n "
@@ -166,7 +178,7 @@ TEST_F(TFProfStatsTest, TestGraph) {
"4\n}\ntotal_float_ops: 10440\ntotal_accelerator_exec_micros: "
"404\ntotal_cpu_exec_micros: 4541\ntotal_run_count: "
"6\ntotal_definition_count: 32\ntotal_peak_bytes: "
- "9984\ntotal_residual_bytes: 1280\ntotal_output_bytes: 4864\n",
+ "25856\ntotal_residual_bytes: 3840\ntotal_output_bytes: 4864\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -181,9 +193,9 @@ TEST_F(TFProfStatsTest, TestFloatOps) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\ntotal_exec_micros: 4945\ntotal_requested_bytes: "
- "14592\ntotal_parameters: 451\nchildren {\n name: \"Conv2D\"\n "
- "exec_micros: 4292\n requested_bytes: 9472\n total_exec_micros: 4292\n "
- " total_requested_bytes: 9472\n devices: "
+ "30464\ntotal_parameters: 451\nchildren {\n name: \"Conv2D\"\n "
+ "exec_micros: 4292\n requested_bytes: 18176\n total_exec_micros: "
+ "4292\n total_requested_bytes: 18176\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n float_ops: 5832\n "
"total_float_ops: 5832\n input_shapes {\n key: 0\n value {\n "
"dim {\n size: 2\n }\n dim {\n size: 6\n "
@@ -194,11 +206,11 @@ TEST_F(TFProfStatsTest, TestFloatOps) {
"6\n }\n }\n }\n accelerator_exec_micros: 226\n "
"cpu_exec_micros: 4066\n total_accelerator_exec_micros: 226\n "
"total_cpu_exec_micros: 4066\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n peak_bytes: 5888\n residual_bytes: 768\n "
- "output_bytes: 768\n total_peak_bytes: 5888\n total_residual_bytes: "
+ "total_definition_count: 1\n peak_bytes: 14592\n residual_bytes: 768\n "
+ " output_bytes: 768\n total_peak_bytes: 14592\n total_residual_bytes: "
"768\n total_output_bytes: 768\n}\nchildren {\n name: \"Conv2D_1\"\n "
- "exec_micros: 597\n requested_bytes: 5120\n total_exec_micros: 597\n "
- "total_requested_bytes: 5120\n devices: "
+ "exec_micros: 597\n requested_bytes: 9728\n total_exec_micros: 597\n "
+ "total_requested_bytes: 9728\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n float_ops: 4608\n "
"total_float_ops: 4608\n input_shapes {\n key: 0\n value {\n "
"dim {\n size: 2\n }\n dim {\n size: 3\n "
@@ -209,12 +221,12 @@ TEST_F(TFProfStatsTest, TestFloatOps) {
"12\n }\n }\n }\n accelerator_exec_micros: 178\n "
"cpu_exec_micros: 419\n total_accelerator_exec_micros: 178\n "
"total_cpu_exec_micros: 419\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n peak_bytes: 4096\n residual_bytes: 512\n "
- "output_bytes: 512\n total_peak_bytes: 4096\n total_residual_bytes: "
+ "total_definition_count: 1\n peak_bytes: 8704\n residual_bytes: 512\n "
+ "output_bytes: 512\n total_peak_bytes: 8704\n total_residual_bytes: "
"512\n total_output_bytes: 512\n}\ntotal_float_ops: "
"10440\ntotal_accelerator_exec_micros: 404\ntotal_cpu_exec_micros: "
"4541\ntotal_run_count: 6\ntotal_definition_count: 35\ntotal_peak_bytes: "
- "9984\ntotal_residual_bytes: 1280\ntotal_output_bytes: 4864\n",
+ "25856\ntotal_residual_bytes: 3840\ntotal_output_bytes: 4864\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -231,9 +243,9 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\ntotal_exec_micros: 597\ntotal_requested_bytes: "
- "5120\nchildren {\n name: \"Conv2D_1\"\n exec_micros: 597\n "
- "requested_bytes: 5120\n total_exec_micros: 597\n "
- "total_requested_bytes: 5120\n devices: "
+ "9728\nchildren {\n name: \"Conv2D_1\"\n exec_micros: 597\n "
+ "requested_bytes: 9728\n total_exec_micros: 597\n "
+ "total_requested_bytes: 9728\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n float_ops: 4608\n "
"total_float_ops: 4608\n input_shapes {\n key: 0\n value {\n "
"dim {\n size: 2\n }\n dim {\n size: 3\n "
@@ -244,12 +256,12 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
"12\n }\n }\n }\n accelerator_exec_micros: 178\n "
"cpu_exec_micros: 419\n total_accelerator_exec_micros: 178\n "
"total_cpu_exec_micros: 419\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n peak_bytes: 4096\n residual_bytes: 512\n "
- "output_bytes: 512\n total_peak_bytes: 4096\n total_residual_bytes: "
+ "total_definition_count: 1\n peak_bytes: 8704\n residual_bytes: 512\n "
+ "output_bytes: 512\n total_peak_bytes: 8704\n total_residual_bytes: "
"512\n total_output_bytes: 512\n}\ntotal_float_ops: "
"4608\ntotal_accelerator_exec_micros: 178\ntotal_cpu_exec_micros: "
"419\ntotal_run_count: 1\ntotal_definition_count: 2\ntotal_peak_bytes: "
- "4096\ntotal_residual_bytes: 512\ntotal_output_bytes: 512\n",
+ "8704\ntotal_residual_bytes: 512\ntotal_output_bytes: 512\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -265,8 +277,9 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\ntotal_exec_micros: 4945\ntotal_requested_bytes: "
- "14592\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n "
- "exec_micros: 2\n parameters: 162\n total_exec_micros: 2\n "
+ "30464\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n "
+ "exec_micros: 2\n requested_bytes: 1280\n parameters: 162\n "
+ "total_exec_micros: 2\n total_requested_bytes: 1280\n "
"total_parameters: 162\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n tensor_value {\n dtype: "
"DT_FLOAT\n value_double: -0.000534315\n value_double: "
@@ -351,11 +364,13 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) {
"value_double: 0.000374641\n value_double: -0.00149603\n "
"value_double: -0.000317367\n value_double: -0.000417829\n }\n "
"cpu_exec_micros: 2\n total_cpu_exec_micros: 2\n run_count: 1\n "
- "total_run_count: 1\n total_definition_count: 10\n output_bytes: "
- "1280\n total_output_bytes: 1280\n}\ntotal_float_ops: "
- "10440\ntotal_accelerator_exec_micros: 404\ntotal_cpu_exec_micros: "
- "4541\ntotal_run_count: 6\ntotal_definition_count: 35\ntotal_peak_bytes: "
- "9984\ntotal_residual_bytes: 1280\ntotal_output_bytes: 4864\n",
+ "total_run_count: 1\n total_definition_count: 10\n peak_bytes: 1280\n "
+ "residual_bytes: 1280\n output_bytes: 1280\n total_peak_bytes: 1280\n "
+ "total_residual_bytes: 1280\n total_output_bytes: "
+ "1280\n}\ntotal_float_ops: 10440\ntotal_accelerator_exec_micros: "
+ "404\ntotal_cpu_exec_micros: 4541\ntotal_run_count: "
+ "6\ntotal_definition_count: 35\ntotal_peak_bytes: "
+ "25856\ntotal_residual_bytes: 3840\ntotal_output_bytes: 4864\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
}
diff --git a/tensorflow/core/profiler/tfprof_log.proto b/tensorflow/core/profiler/tfprof_log.proto
index f92301133a..b49bdf64ac 100644
--- a/tensorflow/core/profiler/tfprof_log.proto
+++ b/tensorflow/core/profiler/tfprof_log.proto
@@ -124,9 +124,10 @@ message ExecProfile {
int64 residual_bytes = 9;
// Total bytes output by the op (not necessarily requested by the op).
int64 output_bytes = 10;
- // Total temporary bytes allocated and released by the op.
+ // NOTE: Please don't depend on the following 4 fields yet. Due to
+ // TensorFlow internal tracing issues, the numbers can be quite wrong.
+ // TODO(xpan): Fix the TensorFlow internal tracing.
int64 host_temp_bytes = 11;
- // Total persistent bytes (e.g. variable) allocated by the op.
int64 host_persistent_bytes = 12;
int64 accelerator_temp_bytes = 13;
int64 accelerator_persistent_bytes = 14;