aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-02-17 11:42:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-17 12:56:41 -0800
commitfe056f0b5e52db86766761f5e6446a89c1aa3938 (patch)
tree68bce0e257d181a3fa37f83c97fdff0fdad877fc /tensorflow/core
parent19d632338f983e02dd0268b931e9cced03b74805 (diff)
Merge changes from github.
Change: 114882676
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc2
-rw-r--r--tensorflow/core/common_runtime/function_test.cc8
-rw-r--r--tensorflow/core/framework/allocator.h2
-rw-r--r--tensorflow/core/framework/function.h6
-rw-r--r--tensorflow/core/framework/op_kernel.h4
-rw-r--r--tensorflow/core/graph/costmodel.h2
-rw-r--r--tensorflow/core/graph/graph.h2
-rw-r--r--tensorflow/core/graph/graph_def_builder.h4
-rw-r--r--tensorflow/core/graph/graph_partition.h2
-rw-r--r--tensorflow/core/kernels/avgpooling_op.h2
-rw-r--r--tensorflow/core/kernels/constant_op.cc9
-rw-r--r--tensorflow/core/kernels/cwise_op_abs.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_add.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_equal_to.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_greater.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_greater_equal.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_less.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_less_equal.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_maximum.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_minimum.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_sign.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_sub.cc2
-rw-r--r--tensorflow/core/kernels/identity_op.cc2
-rw-r--r--tensorflow/core/kernels/pad_op.cc2
-rw-r--r--tensorflow/core/kernels/range_sampler.cc2
-rw-r--r--tensorflow/core/kernels/range_sampler.h2
-rw-r--r--tensorflow/core/kernels/reshape_op.cc2
-rw-r--r--tensorflow/core/kernels/scatter_op.cc184
-rw-r--r--tensorflow/core/kernels/scatter_op.h48
-rw-r--r--tensorflow/core/kernels/scatter_op_gpu.cu.cc108
-rw-r--r--tensorflow/core/kernels/scatter_op_test.cc2
-rw-r--r--tensorflow/core/kernels/shape_ops.cc12
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.cc8
-rw-r--r--tensorflow/core/lib/core/command_line_flags.cc2
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector.h2
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_handle.h2
-rw-r--r--tensorflow/core/lib/random/distribution_sampler.cc2
-rw-r--r--tensorflow/core/lib/random/random_distributions.h4
-rw-r--r--tensorflow/core/lib/random/random_distributions_test.cc4
-rw-r--r--tensorflow/core/lib/strings/numbers.cc2
-rw-r--r--tensorflow/core/lib/strings/numbers.h2
-rw-r--r--tensorflow/core/lib/strings/strcat.h2
-rw-r--r--tensorflow/core/ops/image_ops.cc30
-rw-r--r--tensorflow/core/ops/ops.pbtxt16
-rw-r--r--tensorflow/core/ops/state_ops.cc5
-rw-r--r--tensorflow/core/platform/default/build_config.bzl6
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD2
-rw-r--r--tensorflow/core/platform/default/mutex.h13
-rw-r--r--tensorflow/core/platform/default/thread_annotations.h9
-rw-r--r--tensorflow/core/public/version.h2
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h54
-rw-r--r--tensorflow/core/util/events_writer.cc2
53 files changed, 458 insertions, 142 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 6506b06096..57eb5999b2 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -535,7 +535,7 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
pending_feeds.erase(id);
}
- // Initialize the stack with the fecth nodes.
+ // Initialize the stack with the fetch nodes.
std::vector<const Node*> stack;
for (const string& fetch : fetches) {
TensorId id(ParseTensorName(fetch));
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 3b61dc8697..53645a2061 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -304,7 +304,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
ExpandInlineFunctions(lib_, g);
EXPECT_EQ(e2, DebugString(g));
- // Get rid of redunant Identity nodes.
+ // Get rid of redundant Identity nodes.
RemoveIdentityNodes(g);
const char* e3 = R"P(
(n2:float) -> (n42:float) {
@@ -683,7 +683,7 @@ TEST(OptimizationTest, RemoveDeadNodes) {
{{"a"}, "Square", {"x"}, {{"T", T}}},
// 1
FDH::Const("o", 1),
- // A bunch of extra arithmatic that y doesn't depend on
+ // A bunch of extra arithmetic that y doesn't depend on
{{"x1"}, "Add", {"o", "o"}, {{"T", T}}},
{{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}},
{{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}},
@@ -722,7 +722,7 @@ TEST(OptimizationTest, RemoveIdentityNodes_Ref) {
// Nodes
{// variable
{{"v"}, "Variable", {}, {{"dtype", T}, {"shape", TensorShape({})}}},
- // read the variable. Shouln't be removed.
+ // read the variable. Shouldn't be removed.
{{"v_read"}, "Identity", {"v"}, {{"T", T}}},
// returns v + v
{{"ret"}, "Add", {"v_read", "v_read"}, {{"T", T}}}});
@@ -761,7 +761,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
{{"a"}, "Square", {"x"}, {{"T", T}}},
// 1
FDH::Const("o", 1),
- // A bunch of extra arithmatic that y doesn't depend on
+ // A bunch of extra arithmetic that y doesn't depend on
{{"x1"}, "Identity", {"a"}, {{"T", T}}},
{{"x2"}, "Identity", {"x1"}, {{"T", T}}},
{{"x3"}, "Identity", {"x2"}, {{"T", T}}},
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index b8e4a1f783..f904d5be8d 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -105,7 +105,7 @@ class Allocator {
// Returns true if this allocator tracks the sizes of allocations.
// RequestedSize and AllocatedSize must be overridden if
- // TracksAlloctionSizes is overridden to return true.
+ // TracksAllocationSizes is overridden to return true.
virtual bool TracksAllocationSizes() { return false; }
// Returns true if this allocator requires tensors with 0 elements
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 503e578fac..8231b91a3f 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -160,10 +160,10 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
// "attr_values", which is a map from a placeholder name to an attr
// value.
//
-// InstatiateFunction calls "get_function" to find signatures of other
+// InstantiateFunction calls "get_function" to find signatures of other
// functions and primitive ops.
-// Placeholders in "fdef" is substitued based on "attr_values" here.
+// Placeholders in "fdef" is substituted based on "attr_values" here.
typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap;
typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
InstantiateAttrValueSlice;
@@ -329,7 +329,7 @@ class FunctionLibraryRuntime {
// std::function<Status(const AttrSlice&, FunctionDef*)>.
//
// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
-// definition of a brain function which computate the gradient for the
+// definition of a brain function which compute the gradient for the
// <op_name> when the <op_name> is instantiated with the given attrs.
//
// E.g.,
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index a8c2d58b92..dc98044804 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -473,7 +473,7 @@ class OpKernelContext {
// will never use eigen_gpu_device. It seems better to have
// ensure_eigen_gpu_device fall through and regenerate the
// nullptr every time an OpKernelContext is instantiated, than
- // to do an unneccessary allocation of a dummy eigen GPU
+ // to do an unnecessary allocation of a dummy eigen GPU
// device for CPU device Ops.
eigen_gpu_device = device->MakeGpuDevice();
}
@@ -1037,7 +1037,7 @@ typedef ::tensorflow::KernelDefBuilder Name;
static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \
::tensorflow::register_kernel::kernel_builder.Build(), \
- +[](::tensorflow::OpKernelConstruction* context) \
+ [](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
void* GlobalKernelRegistry();
diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h
index f86fc6a141..037e6d4684 100644
--- a/tensorflow/core/graph/costmodel.h
+++ b/tensorflow/core/graph/costmodel.h
@@ -95,7 +95,7 @@ class CostModel {
// Check that an estimate is available for every OP node in graph.
void CheckInitialized(const Graph& graph) const;
- // Helper routines to encapsulate static estimatation heuristics
+ // Helper routines to encapsulate static estimation heuristics
// Compute an estimate of the time to copy "b" bytes over the network,
// given a fixed cost of "network_latency_millis" milliseconds and
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index b4b91e115e..4ad2a306b2 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -15,7 +15,7 @@ limitations under the License.
// A Graph describes a set of computations that are to be
// performed, as well as the dependencies between those
-// compuations. The basic model is a DAG (directed acyclic graph) with
+// computations. The basic model is a DAG (directed acyclic graph) with
// * internal nodes representing computational operations to be performed;
// * edges represent dependencies, indicating the target may only be
// executed once the source has completed; and
diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h
index 2a212bbc49..ec28343668 100644
--- a/tensorflow/core/graph/graph_def_builder.h
+++ b/tensorflow/core/graph/graph_def_builder.h
@@ -37,7 +37,7 @@ namespace tensorflow {
// node_builder.Input(input);
// return opts.FinalizeBuilder(&node_builder);
// }
-// } // namspace ops
+// } // namespace ops
//
// // Or, alternatively:
// namespace ops {
@@ -45,7 +45,7 @@ namespace tensorflow {
// static const string kOpName = "Identity";
// return UnaryOp(kOpName, input, opts);
// }
-// } // namspace ops
+// } // namespace ops
//
// You call it like:
// GraphDefBuilder b;
diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h
index 4ae0133977..5c69af0144 100644
--- a/tensorflow/core/graph/graph_partition.h
+++ b/tensorflow/core/graph/graph_partition.h
@@ -40,7 +40,7 @@ struct PartitionOptions {
// A function that returns the incarnation of a device given the
// device's fullname. If not found, GetIncarnationFunc should return
- // kIlledgalIncarnation.
+ // kIllegalIncarnation.
static const uint64 kIllegalIncarnation = 0;
typedef std::function<uint64(const string&)> GetIncarnationFunc;
GetIncarnationFunc get_incarnation = nullptr;
diff --git a/tensorflow/core/kernels/avgpooling_op.h b/tensorflow/core/kernels/avgpooling_op.h
index a275a32c3b..128ab2b3b9 100644
--- a/tensorflow/core/kernels/avgpooling_op.h
+++ b/tensorflow/core/kernels/avgpooling_op.h
@@ -41,7 +41,7 @@ struct SpatialAvgPooling {
typedef Eigen::GpuDevice GPUDevice;
-// Lauch a custom GPU kernels from Yanqing for the avgpooling backward operation
+// Launch a custom GPU kernels from Yanqing for the avgpooling backward operation
// that works NHWC data formats.
// Arguments:
// top_diff: backprop to the output of the pooling layer
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index c4f8298e27..5f5d7cf8db 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -85,6 +85,7 @@ void HostConstantOp::Compute(OpKernelContext* ctx) {
ctx->set_output(0, tensor_);
}
+#if GOOGLE_CUDA
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
@@ -93,6 +94,7 @@ REGISTER_KERNEL_BUILDER(Name("Const")
.HostMemory("output")
.TypeConstraint<int32>("dtype"),
HostConstantOp);
+#endif
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
@@ -178,10 +180,6 @@ REGISTER_KERNEL(GPU, int16);
REGISTER_KERNEL(GPU, int64);
// Currently we do not support filling strings and complex64 on GPU
-#endif // GOOGLE_CUDA
-
-#undef REGISTER_KERNEL
-
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
@@ -192,6 +190,9 @@ REGISTER_KERNEL_BUILDER(Name("Fill")
.HostMemory("value")
.HostMemory("output"),
FillOp<CPUDevice, int32>);
+#endif
+
+#undef REGISTER_KERNEL
template <typename Device, typename T>
class ZerosLikeOp : public OpKernel {
diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc
index ca61a94391..1b976c7210 100644
--- a/tensorflow/core/kernels/cwise_op_abs.cc
+++ b/tensorflow/core/kernels/cwise_op_abs.cc
@@ -23,7 +23,6 @@ REGISTER_KERNEL_BUILDER(Name("ComplexAbs").Device(DEVICE_CPU),
#endif
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Abs", functor::abs, float, double, int64);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -34,5 +33,6 @@ REGISTER_KERNEL_BUILDER(Name("Abs")
.HostMemory("y")
.TypeConstraint<int32>("T"),
UnaryOp<CPUDevice, functor::abs<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_add.cc b/tensorflow/core/kernels/cwise_op_add.cc
index 49d9e0c507..8a9c1979e5 100644
--- a/tensorflow/core/kernels/cwise_op_add.cc
+++ b/tensorflow/core/kernels/cwise_op_add.cc
@@ -20,7 +20,6 @@ REGISTER8(BinaryOp, CPU, "Add", functor::add, float, double, int32, int64, int8,
int16, complex64, string);
#if GOOGLE_CUDA
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, double, int64);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -32,5 +31,6 @@ REGISTER_KERNEL_BUILDER(Name("Add")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::add<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index 979db4c50c..e97b6b4360 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "Div", functor::div, float, double, uint8, int16,
#if GOOGLE_CUDA
REGISTER5(BinaryOp, GPU, "Div", functor::div, float, double, uint8, int16,
int64);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("Div")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::div<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_equal_to.cc b/tensorflow/core/kernels/cwise_op_equal_to.cc
index 28801a49d6..1b744445ff 100644
--- a/tensorflow/core/kernels/cwise_op_equal_to.cc
+++ b/tensorflow/core/kernels/cwise_op_equal_to.cc
@@ -21,7 +21,6 @@ REGISTER9(BinaryOp, CPU, "Equal", functor::equal_to, float, double, uint8, int8,
#if GOOGLE_CUDA
REGISTER6(BinaryOp, GPU, "Equal", functor::equal_to, float, double, uint8, int8,
int16, int64);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("Equal")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::equal_to<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc
index 7f746745bf..f860ea19ac 100644
--- a/tensorflow/core/kernels/cwise_op_greater.cc
+++ b/tensorflow/core/kernels/cwise_op_greater.cc
@@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "Greater", functor::greater, float, double, int32,
#if GOOGLE_CUDA
REGISTER6(BinaryOp, GPU, "Greater", functor::greater, float, double, int64,
uint8, int8, int16);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("Greater")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::greater<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc
index adf06121a2..465c5a22ae 100644
--- a/tensorflow/core/kernels/cwise_op_greater_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc
@@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float, double,
#if GOOGLE_CUDA
REGISTER6(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, float, double,
int64, uint8, int8, int16);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("GreaterEqual")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::greater_equal<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc
index 1710304703..e7acfa091f 100644
--- a/tensorflow/core/kernels/cwise_op_less.cc
+++ b/tensorflow/core/kernels/cwise_op_less.cc
@@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "Less", functor::less, float, double, int32, int64,
#if GOOGLE_CUDA
REGISTER6(BinaryOp, GPU, "Less", functor::less, float, double, int64, uint8,
int8, int16);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("Less")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::less<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc
index 65f79b5799..3175dcae5e 100644
--- a/tensorflow/core/kernels/cwise_op_less_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_less_equal.cc
@@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "LessEqual", functor::less_equal, float, double, int32,
#if GOOGLE_CUDA
REGISTER6(BinaryOp, GPU, "LessEqual", functor::less_equal, float, double, int64,
uint8, int8, int16);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("LessEqual")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::less_equal<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_maximum.cc b/tensorflow/core/kernels/cwise_op_maximum.cc
index 732e6b48f2..a2fc480674 100644
--- a/tensorflow/core/kernels/cwise_op_maximum.cc
+++ b/tensorflow/core/kernels/cwise_op_maximum.cc
@@ -20,7 +20,6 @@ REGISTER4(BinaryOp, CPU, "Maximum", functor::maximum, float, double, int32,
int64);
#if GOOGLE_CUDA
REGISTER3(BinaryOp, GPU, "Maximum", functor::maximum, float, double, int64);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -32,5 +31,6 @@ REGISTER_KERNEL_BUILDER(Name("Maximum")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::maximum<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_minimum.cc b/tensorflow/core/kernels/cwise_op_minimum.cc
index 017daca4cc..0c6797ec86 100644
--- a/tensorflow/core/kernels/cwise_op_minimum.cc
+++ b/tensorflow/core/kernels/cwise_op_minimum.cc
@@ -20,7 +20,6 @@ REGISTER4(BinaryOp, CPU, "Minimum", functor::minimum, float, double, int32,
int64);
#if GOOGLE_CUDA
REGISTER3(BinaryOp, GPU, "Minimum", functor::minimum, float, double, int64);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -32,5 +31,6 @@ REGISTER_KERNEL_BUILDER(Name("Minimum")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::minimum<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc
index 94f7ddd3b2..5b5ad90207 100644
--- a/tensorflow/core/kernels/cwise_op_sign.cc
+++ b/tensorflow/core/kernels/cwise_op_sign.cc
@@ -19,7 +19,6 @@ namespace tensorflow {
REGISTER4(UnaryOp, CPU, "Sign", functor::sign, float, double, int32, int64);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Sign", functor::sign, float, double, int64);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -30,5 +29,6 @@ REGISTER_KERNEL_BUILDER(Name("Sign")
.HostMemory("y")
.TypeConstraint<int32>("T"),
UnaryOp<CPUDevice, functor::sign<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc
index 8858db9e5f..b3727ec361 100644
--- a/tensorflow/core/kernels/cwise_op_sub.cc
+++ b/tensorflow/core/kernels/cwise_op_sub.cc
@@ -20,7 +20,6 @@ REGISTER5(BinaryOp, CPU, "Sub", functor::sub, float, double, int32, int64,
complex64);
#if GOOGLE_CUDA
REGISTER3(BinaryOp, GPU, "Sub", functor::sub, float, double, int64);
-#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -32,5 +31,6 @@ REGISTER_KERNEL_BUILDER(Name("Sub")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::sub<int32>>);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc
index 22a26e8310..38a46583bd 100644
--- a/tensorflow/core/kernels/identity_op.cc
+++ b/tensorflow/core/kernels/identity_op.cc
@@ -47,6 +47,7 @@ REGISTER_GPU_KERNEL(bfloat16);
#undef REGISTER_GPU_KERNEL
+#if GOOGLE_CUDA
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
@@ -56,5 +57,6 @@ REGISTER_KERNEL_BUILDER(Name("Identity")
.HostMemory("output")
.TypeConstraint<int32>("T"),
IdentityOp);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc
index 286b74ca64..f3ad98ab0a 100644
--- a/tensorflow/core/kernels/pad_op.cc
+++ b/tensorflow/core/kernels/pad_op.cc
@@ -170,7 +170,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
PadOp<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
-#endif // GOOGLE_CUDA
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -182,5 +181,6 @@ REGISTER_KERNEL_BUILDER(Name("Pad")
.HostMemory("paddings")
.HostMemory("output"),
PadOp<CPUDevice, int32>);
+#endif
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/range_sampler.cc b/tensorflow/core/kernels/range_sampler.cc
index 58bd103c80..40be4bdb20 100644
--- a/tensorflow/core/kernels/range_sampler.cc
+++ b/tensorflow/core/kernels/range_sampler.cc
@@ -60,7 +60,7 @@ namespace {
// We use batch_size and num_tries, where num_tries is the observed number of
// tries it took to get batch_size unique values.
//
-// Assuming (falsely) that the nubmer of tries to get a batch of batch_size
+// Assuming (falsely) that the number of tries to get a batch of batch_size
// distinct values is _always_ num_tries, the probability that the value
// is in a batch is (1 - (1-p)^num_tries)
static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
diff --git a/tensorflow/core/kernels/range_sampler.h b/tensorflow/core/kernels/range_sampler.h
index b2a1d44da1..2513f50b37 100644
--- a/tensorflow/core/kernels/range_sampler.h
+++ b/tensorflow/core/kernels/range_sampler.h
@@ -65,7 +65,7 @@ class RangeSampler {
// Expected counts for the elements of the returned "batch" are reported
// in the aligned array "batch_expected_count".
//
- // The user can optionally provide "extras", containg values in the range.
+ // The user can optionally provide "extras", containing values in the range.
// The expected counts for the extras are reported in the aligned array
// "extras_expected_count".
//
diff --git a/tensorflow/core/kernels/reshape_op.cc b/tensorflow/core/kernels/reshape_op.cc
index 2413623684..1ee959c026 100644
--- a/tensorflow/core/kernels/reshape_op.cc
+++ b/tensorflow/core/kernels/reshape_op.cc
@@ -30,6 +30,7 @@ REGISTER_KERNEL_BUILDER(Name("Reshape").Device(DEVICE_CPU).HostMemory("shape"),
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
+#if GOOGLE_CUDA
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
@@ -40,5 +41,6 @@ REGISTER_KERNEL_BUILDER(Name("Reshape")
.HostMemory("output")
.TypeConstraint<int32>("T"),
ReshapeOp);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index afc03efbca..30fd105b5f 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -15,6 +15,8 @@ limitations under the License.
// See docs in ../ops/state_ops.cc.
+#include "tensorflow/core/kernels/scatter_op.h"
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -23,33 +25,38 @@ limitations under the License.
namespace tensorflow {
-enum class UpdateOp { ASSIGN, ADD, SUB };
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace {
-template <UpdateOp Op>
+template <scatter_op::UpdateOp Op>
struct Assign {};
template <>
-struct Assign<UpdateOp::ASSIGN> {
+struct Assign<scatter_op::UpdateOp::ASSIGN> {
template <typename Params, typename Update>
static void Run(Params p, Update u) {
p = u;
}
};
template <>
-struct Assign<UpdateOp::ADD> {
+struct Assign<scatter_op::UpdateOp::ADD> {
template <typename Params, typename Update>
static void Run(Params p, Update u) {
p += u;
}
};
template <>
-struct Assign<UpdateOp::SUB> {
+struct Assign<scatter_op::UpdateOp::SUB> {
template <typename Params, typename Update>
static void Run(Params p, Update u) {
p -= u;
}
};
-template <class T, typename Index, UpdateOp op>
+} // namespace
+
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
class ScatterUpdateOp : public OpKernel {
public:
// QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
@@ -108,85 +115,136 @@ class ScatterUpdateOp : public OpKernel {
"updates.shape ", Tupdates.shape().DebugString(),
", indices.shape ", Tindices.shape().DebugString(),
", params.shape ", Tparams.shape().DebugString()));
- const Index N = Tindices.NumElements();
// We always return the input ref.
c->forward_ref_input_to_ref_output(0, 0);
+ const Index N = Tindices.NumElements();
if (N > 0) {
- const Index first_dim_size = Tparams.dim_size(0);
- // Validate all the indices are in range
- auto Tindices_vec = Tindices.flat<Index>();
- for (Index i = 0; i < N; i++) {
- const Index index = Tindices_vec(i);
- OP_REQUIRES(c, index >= 0 && index < first_dim_size,
- errors::InvalidArgument(
- strings::StrCat("Index ", index, " at offset ", i,
- " in indices is out of range")));
- }
+ auto Tindices_flat = Tindices.flat<Index>();
auto Tparams_flat = Tparams.flat_outer_dims<T>();
auto Tupdates_flat =
Tupdates.shaped<T, 2>({N, Tupdates.NumElements() / N});
- for (Index i = 0; i < N; i++) {
- // Copy last Ndim-1 dimensions of Tupdates[i] to
- // Tparams[Tindices[i]]
- Assign<op>::Run(Tparams_flat.template chip<0>(Tindices_vec(i)),
- Tupdates_flat.template chip<0>(i));
- }
+
+ functor::ScatterFunctor<Device, T, Index, op> functor;
+ functor(c, c->template eigen_device<Device>(),
+ Tparams_flat, Tupdates_flat, Tindices_flat);
}
}
};
-#define REGISTER_SCATTER_UPDATE(type, index_type) \
- REGISTER_KERNEL_BUILDER( \
- Name("ScatterUpdate") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- ScatterUpdateOp<type, index_type, UpdateOp::ASSIGN>);
+namespace functor {
+// Implementation of update functor for CPU.
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterFunctor<CPUDevice, T, Index, op> {
+ void operator()(OpKernelContext* c, const CPUDevice& d,
+ typename TTypes<T>::Matrix params,
+ typename TTypes<T>::ConstMatrix updates,
+ typename TTypes<Index>::ConstFlat indices) {
+ Index N = indices.size();
+ // Validate all the indices are in range
+ Index first_dim_size = params.dimension(0);
+ for (Index i = 0; i < N; i++) {
+ const Index index = indices(i);
+ OP_REQUIRES(c, index >= 0 && index < first_dim_size,
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in indices is out of range")));
+ }
+ for (Index i = 0; i < N; i++) {
+ // Copy last Ndim-1 dimensions of Tupdates[i] to
+ // Tparams[Tindices[i]]
+ Assign<op>::Run(params.template chip<0>(indices(i)),
+ updates.template chip<0>(i));
+ }
+ }
+};
+} // namespace functor
-#define REGISTER_SCATTER_UPDATE_INT32(type) REGISTER_SCATTER_UPDATE(type, int32)
-#define REGISTER_SCATTER_UPDATE_INT64(type) REGISTER_SCATTER_UPDATE(type, int64)
+#define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(name) \
+ .Device(DEVICE_##dev) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ScatterUpdateOp<dev##Device, type, index_type, op>)
-TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT32);
-TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT64);
+#define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
+ REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
+ REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-#undef REGISTER_SCATTER_UPDATE_INT64
-#undef REGISTER_SCATTER_UPDATE_INT32
-#undef REGISTER_SCATTER_UPDATE
+#define REGISTER_SCATTER_ADD_SUB(type, dev) \
+ REGISTER_SCATTER_KERNEL( \
+ type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_KERNEL( \
+ type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
-#define REGISTER_SCATTER_ADD(type, index_type) \
- REGISTER_KERNEL_BUILDER(Name("ScatterAdd") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- ScatterUpdateOp<type, index_type, UpdateOp::ADD>);
+#define REGISTER_SCATTER_UPDATE(type, dev) \
+ REGISTER_SCATTER_KERNEL( \
+ type, dev, "ScatterUpdate", scatter_op::UpdateOp::ASSIGN);
-#define REGISTER_SCATTER_ADD_INT32(type) REGISTER_SCATTER_ADD(type, int32)
-#define REGISTER_SCATTER_ADD_INT64(type) REGISTER_SCATTER_ADD(type, int64)
+// Registers CPU kernels.
+#define REGISTER_SCATTER_ADD_SUB_CPU(type) \
+ REGISTER_SCATTER_ADD_SUB(type, CPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_INT32);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_INT64);
+#define REGISTER_SCATTER_UPDATE_CPU(type) \
+ REGISTER_SCATTER_UPDATE(type, CPU);
-#undef REGISTER_SCATTER_ADD_INT32
-#undef REGISTER_SCATTER_ADD_INT64
-#undef REGISTER_SCATTER_ADD
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_CPU);
+TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
+
+// Registers GPU kernels.
+#if GOOGLE_CUDA
+#define REGISTER_SCATTER_ADD_SUB_GPU(type) \
+ REGISTER_SCATTER_ADD_SUB(type, GPU);
-#define REGISTER_SCATTER_SUB(type, index_type) \
- REGISTER_KERNEL_BUILDER(Name("ScatterSub") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- ScatterUpdateOp<type, index_type, UpdateOp::SUB>);
+#define REGISTER_SCATTER_UPDATE_GPU(type) \
+ REGISTER_SCATTER_UPDATE(type, GPU);
-#define REGISTER_SCATTER_SUB_INT32(type) REGISTER_SCATTER_SUB(type, int32)
-#define REGISTER_SCATTER_SUB_INT64(type) REGISTER_SCATTER_SUB(type, int64)
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_GPU);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_SUB_INT32);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_SUB_INT64);
+#endif // GOOGLE_CUDA
-#undef REGISTER_SCATTER_SUB_INT64
-#undef REGISTER_SCATTER_SUB_INT32
-#undef REGISTER_SCATTER_SUB
+#undef REGISTER_SCATTER_ADD
+#undef REGISTER_SCATTER_ADD_SUB
+#undef REGISTER_SCATTER_ADD_SUB_CPU
+#undef REGISTER_SCATTER_ADD_SUB_GPU
+#undef REGISTER_SCATTER_UPDATE
+#undef REGISTER_SCATTER_UPDATE_CPU
+#undef REGISTER_SCATTER_UPDATE_GPU
+#undef REGISTER_SCATTER_KERNEL
+#undef REGISTER_SCATTER_KERNEL_INDEX
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+#define DECLARE_GPU_SPECS_OP(T, Index, op) \
+ template <> \
+ void ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
+ OpKernelContext* c, const GPUDevice& d, \
+ typename TTypes<T>::Matrix params, \
+ typename TTypes<T>::ConstMatrix updates, \
+ typename TTypes<Index>::ConstFlat indices); \
+ extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
+
+#define DECLARE_GPU_SPECS_INDEX(T, Index) \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);
+
+#define DECLARE_GPU_SPECS(T) \
+ DECLARE_GPU_SPECS_INDEX(T, int32); \
+ DECLARE_GPU_SPECS_INDEX(T, int64);
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+
+#undef DECLARE_GPU_SPECS
+#undef DECLARE_GPU_SPECS_INDEX
+#undef DECLARE_GPU_SPECS_OP
+
+} // namespace functor
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_op.h b/tensorflow/core/kernels/scatter_op.h
new file mode 100644
index 0000000000..b7c7df97a7
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_op.h
@@ -0,0 +1,48 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_SCATTER_OP_H_
+#define TENSORFLOW_KERNELS_SCATTER_OP_H_
+
+// Functor definitions for Scatter ops, must be compilable by nvcc.
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+namespace scatter_op {
+
+enum class UpdateOp { ASSIGN, ADD, SUB };
+
+} // namespace scatter_op
+
+namespace functor {
+
+// Functor used by ScatterOp to do the computations.
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterFunctor {
+ void operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<T>::Matrix params,
+ typename TTypes<T>::ConstMatrix updates,
+ typename TTypes<Index>::ConstFlat indices);
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_SCATTER_OP_H_
diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
new file mode 100644
index 0000000000..6ef23419ab
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
@@ -0,0 +1,108 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/scatter_op.h"
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename T, typename Index, scatter_op::UpdateOp op>
+__global__ void ScatterOpCustomKernel(
+ T* params, const T* updates, const Index* indices,
+ Index first_dim_size, Index updates_size, Index indices_size) {
+ Index update_block = updates_size / indices_size;
+ CUDA_1D_KERNEL_LOOP(i, updates_size) {
+ int indices_i = i / update_block;
+ int updates_i = i;
+ int param_first_index = indices[indices_i];
+ if (!(param_first_index >= 0 && param_first_index < first_dim_size)) {
+ // Ignore indices that are out of range.
+ continue;
+ }
+ int params_i = param_first_index * update_block + (i % update_block);
+ switch (op) {
+ case scatter_op::UpdateOp::ASSIGN: {
+ params[params_i] = ldg(updates + updates_i);
+ break;
+ }
+ case scatter_op::UpdateOp::ADD: {
+ CudaAtomicAdd(params + params_i, ldg(updates + updates_i));
+ break;
+ }
+ case scatter_op::UpdateOp::SUB: {
+ CudaAtomicSub(params + params_i, ldg(updates + updates_i));
+ break;
+ }
+ }
+ }
+}
+
+namespace functor {
+// Specialization for a GPU device.
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterFunctor<GPUDevice, T, Index, op> {
+ void operator()(OpKernelContext* c, const GPUDevice& d,
+ typename TTypes<T>::Matrix params,
+ typename TTypes<T>::ConstMatrix updates,
+ typename TTypes<Index>::ConstFlat indices) {
+ // TODO: Implement indices range check. The hardest part is with returning
+ // a value after the range check, as we do not want to do device to host
+ // memcpy during a stream.
+ const Index first_dim_size = params.dimension(0);
+ const Index indices_size = indices.size();
+ const Index updates_size = updates.size();
+ CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d);
+ ScatterOpCustomKernel<T,Index,op>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ params.data(), updates.data(), indices.data(),
+ first_dim_size, updates_size, indices_size);
+ }
+};
+
+} // namespace functor
+
+#define DEFINE_GPU_SPECS_OP(T, Index, op) \
+ template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
+
+#define DEFINE_GPU_SPECS_INDEX(T, Index) \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);
+
+#define DEFINE_GPU_SPECS(T) \
+ DEFINE_GPU_SPECS_INDEX(T, int32); \
+ DEFINE_GPU_SPECS_INDEX(T, int64);
+
+DEFINE_GPU_SPECS(float);
+DEFINE_GPU_SPECS(double);
+// TODO: The following fails to compile.
+// TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
+
+#undef DEFINE_GPU_SPECS
+#undef DEFINE_GPU_SPECS_INDEX
+#undef DEFINE_GPU_SPECS_OP
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc
index cb6abbe75f..1d6285882c 100644
--- a/tensorflow/core/kernels/scatter_op_test.cc
+++ b/tensorflow/core/kernels/scatter_op_test.cc
@@ -243,6 +243,7 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op) {
testing::StopTiming();
const int kRows = 10000000 / embedding_size;
std::vector<float> values;
+ values.reserve(kRows);
for (int i = 0; i < kRows * embedding_size; i++) {
values.push_back(i);
}
@@ -270,6 +271,7 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op) {
while (iters-- > 0) {
Status s = bm.RunOpKernel();
}
+ testing::StopTiming();
}
static void BM_ScatterUpdateInt32(int iters, int embedding_size) {
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc
index c2ba5ac275..b2c8be925a 100644
--- a/tensorflow/core/kernels/shape_ops.cc
+++ b/tensorflow/core/kernels/shape_ops.cc
@@ -43,6 +43,7 @@ class ShapeOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Shape").Device(DEVICE_CPU).HostMemory("output"),
ShapeOp);
+#if GOOGLE_CUDA
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Shape") \
.Device(DEVICE_GPU) \
@@ -61,6 +62,7 @@ REGISTER_KERNEL_BUILDER(Name("Shape")
.HostMemory("output")
.TypeConstraint<int32>("T"),
ShapeOp);
+#endif
class ShapeNOp : public OpKernel {
public:
@@ -82,6 +84,7 @@ class ShapeNOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ShapeN").Device(DEVICE_CPU).HostMemory("output"),
ShapeNOp);
+#if GOOGLE_CUDA
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("ShapeN") \
.Device(DEVICE_GPU) \
@@ -100,6 +103,7 @@ REGISTER_KERNEL_BUILDER(Name("ShapeN")
.HostMemory("output")
.TypeConstraint<int32>("T"),
ShapeNOp);
+#endif
class RankOp : public OpKernel {
public:
@@ -118,6 +122,7 @@ class RankOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Rank").Device(DEVICE_CPU).HostMemory("output"),
RankOp);
+#if GOOGLE_CUDA
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Rank") \
.Device(DEVICE_GPU) \
@@ -143,6 +148,7 @@ REGISTER_KERNEL_BUILDER(Name("Rank")
.HostMemory("input")
.HostMemory("output"),
RankOp);
+#endif
class SizeOp : public OpKernel {
public:
@@ -162,6 +168,7 @@ class SizeOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Size").Device(DEVICE_CPU).HostMemory("output"),
SizeOp);
+#if GOOGLE_CUDA
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Size") \
.Device(DEVICE_GPU) \
@@ -180,6 +187,7 @@ REGISTER_KERNEL_BUILDER(Name("Size")
.HostMemory("input")
.HostMemory("output"),
SizeOp);
+#endif
class ExpandDimsOp : public OpKernel {
public:
@@ -225,6 +233,7 @@ class ExpandDimsOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ExpandDims").Device(DEVICE_CPU).HostMemory("dim"),
ExpandDimsOp);
+#if GOOGLE_CUDA
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("ExpandDims") \
.Device(DEVICE_GPU) \
@@ -241,6 +250,7 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims")
.HostMemory("dim")
.HostMemory("output"),
ExpandDimsOp);
+#endif
class SqueezeOp : public OpKernel {
public:
@@ -313,6 +323,7 @@ class SqueezeOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Squeeze").Device(DEVICE_CPU), SqueezeOp);
+#if GOOGLE_CUDA
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Squeeze").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
@@ -329,5 +340,6 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze")
.HostMemory("input")
.HostMemory("output"),
SqueezeOp);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc
index 1aa5ae0ab5..69b53500a1 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op.cc
@@ -67,7 +67,7 @@ static const int N = 128;
// Note that all the data/indices of all the blocks are stored in the same
// vectors respectively. To identify block boundaries, we store the block
// offsets using index3_offset/index_offset. If there are n blocks in the slice,
-// index3_offset and index_offset have n entires. The indices for the ith block
+// index3_offset and index_offset have n entries. The indices for the ith block
// are the values in the following range:
// [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
// index_offset.
@@ -475,7 +475,7 @@ class SparseMatMulOp : public OpKernel {
if (!a_is_sparse_ && !b_is_sparse_) {
// Fallback to Eigen contract.
// Note that we currently don't optimize the case where only right is
- // sparse. That can generally be handled by tranposing the order of the
+ // sparse. That can generally be handled by transposing the order of the
// matmul.
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0].first = transpose_a_ ? 0 : 1;
@@ -540,7 +540,7 @@ class SparseMatMulOp : public OpKernel {
// Encodes "mat" using a sparse representation and stores that in
// "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
// "slice_num_cols", each grid element is converted into a SparseSlice and
- // stored in mat_slices. "slice_block_size" is used to perform futher column
+ // stored in mat_slices. "slice_block_size" is used to perform further column
// blocking of each slice.
static inline BlockingCounter* CreateSparseSlices(
const ConstMatrixMap& mat, bool transpose, int slice_num_rows,
@@ -776,7 +776,7 @@ inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left,
*KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
*NR = right.dimension(1);
if (*KR * *NR > mem) {
- // 4096 may be enough to ammortize the cost of writes.
+ // 4096 may be enough to amortize the cost of writes.
*KR = std::min<int>(*KR, 4096);
}
// Use sizes that are multiples of K and 256.
diff --git a/tensorflow/core/lib/core/command_line_flags.cc b/tensorflow/core/lib/core/command_line_flags.cc
index 26e495a520..1cfc01b6e2 100644
--- a/tensorflow/core/lib/core/command_line_flags.cc
+++ b/tensorflow/core/lib/core/command_line_flags.cc
@@ -43,7 +43,7 @@ bool StringToValue<string>(const string& content, string* value) {
// Return OK if the argument is used. It store the extracted value into the
// matching flag.
// Return NOT_FOUND if the argument is not recognized.
-// Retrun INVALID_ARGUMENT if the command is recognized, but fails to extract
+// Return INVALID_ARGUMENT if the command is recognized, but fails to extract
// its value.
template <typename T>
Status ParseArgument(const string& argument) {
diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h
index 6b2cba809c..1f35c0cab8 100644
--- a/tensorflow/core/lib/gtl/inlined_vector.h
+++ b/tensorflow/core/lib/gtl/inlined_vector.h
@@ -374,7 +374,7 @@ class InlinedVector {
// Moves srcs[0,n-1] contents to dst[0,n-1].
static void Move(const T* src, size_t n, T* dst) {
- for (int i = 0; i < n; i++) {
+ for (size_t i = 0; i < n; i++) {
new (dst + i) T(std::move(*(src + i)));
}
}
diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.h b/tensorflow/core/lib/jpeg/jpeg_handle.h
index 04b092f053..4beca30c3f 100644
--- a/tensorflow/core/lib/jpeg/jpeg_handle.h
+++ b/tensorflow/core/lib/jpeg/jpeg_handle.h
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
// This file declares the functions and structures for memory I/O with libjpeg
-// These functions are not meant to be used directly, see jpeg_mem.h isntead.
+// These functions are not meant to be used directly, see jpeg_mem.h instead.
#ifndef TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
#define TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
diff --git a/tensorflow/core/lib/random/distribution_sampler.cc b/tensorflow/core/lib/random/distribution_sampler.cc
index 1fc1397499..3daaa1447c 100644
--- a/tensorflow/core/lib/random/distribution_sampler.cc
+++ b/tensorflow/core/lib/random/distribution_sampler.cc
@@ -42,7 +42,7 @@ DistributionSampler::DistributionSampler(
std::vector<int> low;
low.reserve(n);
- // compute propotional weights
+ // compute proportional weights
for (int i = 0; i < n; i++) {
double p = (weights[i] * n) / sum;
pr[i] = p;
diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h
index 9f510342d7..a2ee5c96aa 100644
--- a/tensorflow/core/lib/random/random_distributions.h
+++ b/tensorflow/core/lib/random/random_distributions.h
@@ -192,7 +192,7 @@ class SingleSampleAdapter {
// each invocation. It needs to define kResultElementCount for the
// sample count for each invocation, and ResultType for actual
// returned sample type.
-// RealType: the data type of the real numberes that will be returned by the
+// RealType: the data type of the real numbers that will be returned by the
// distribution. This could be either float or double for now.
// This class is meant to be implemented through specialization. The default
// is not defined by design.
@@ -259,7 +259,7 @@ class NormalDistribution<Generator, double> {
// each invocation. It needs to define kResultElementCount for the
// sample count for each invocation, and ResultType for actual
// returned sample type.
-// RealType: the data type of the real numberes that will be returned by the
+// RealType: the data type of the real numbers that will be returned by the
// distribution. This could be either float or double for now.
// This class is meant to be implemented through specialization. The default
// is not defined by design.
diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc
index 4d3e4a5cdc..13398b838f 100644
--- a/tensorflow/core/lib/random/random_distributions_test.cc
+++ b/tensorflow/core/lib/random/random_distributions_test.cc
@@ -36,7 +36,7 @@ namespace {
static constexpr float kZLimit = 6.0;
// A utility function to fill the given array with samples from the given
-// distribution, using the single adatper of the underlying generator
+// distribution, using the single adapter of the underlying generator
template <class Distribution>
void FillRandomsWithSingles(PhiloxRandom gen,
typename Distribution::ResultElementType* p,
@@ -87,7 +87,7 @@ bool CheckSamplesMoments(const std::vector<T>& samples,
break;
}
// moments[i] store the i-th order measured moments.
- // bypass std::vector::opeartor[] because they are too slow in the debug
+ // bypass std::vector::operator[] because they are too slow in the debug
// mode, given the large number of samples.
moments_data[i] += moment;
++moments_sample_count_data[i];
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
index 859a654e36..778545b44b 100644
--- a/tensorflow/core/lib/strings/numbers.cc
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -33,7 +33,7 @@ char* FastInt32ToBufferLeft(int32 i, char* buffer) {
if (i < 0) {
*buffer++ = '-';
// We need to do the negation in modular (i.e., "unsigned")
- // arithmetic; MSVC++ apprently warns for plain "-u", so
+ // arithmetic; MSVC++ apparently warns for plain "-u", so
// we write the equivalent expression "0 - u" instead.
u = 0 - u;
}
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
index f1924ccf93..4dd0bcdec4 100644
--- a/tensorflow/core/lib/strings/numbers.h
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -77,7 +77,7 @@ char* FloatToBuffer(float i, char* buffer);
string FpToString(Fprint fp);
// Attempt to parse a fingerprint in the form encoded by FpToString. If
-// successsful, stores the fingerprint in *fp and returns true. Otherwise,
+// successful, stores the fingerprint in *fp and returns true. Otherwise,
// returns false.
bool StringToFp(const string& s, Fprint* fp);
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
index 02632a6b2b..33b6028153 100644
--- a/tensorflow/core/lib/strings/strcat.h
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -30,7 +30,7 @@ limitations under the License.
// The AlphaNum type was designed to be used as the parameter type for StrCat().
// Any routine accepting either a string or a number may accept it.
// The basic idea is that by accepting a "const AlphaNum &" as an argument
-// to your function, your callers will automagically convert bools, integers,
+// to your function, your callers will automatically convert bools, integers,
// and floating point values to strings for you.
//
// NOTE: Use of AlphaNum outside of the //strings package is unsupported except
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index d9c3d8d2c9..99c90a811b 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -446,22 +446,20 @@ Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
height of the underlying image.
-Example:
-
-```
-# Generate a single distorted bounding box.
-begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
- tf.shape(image),
- bounding_boxes=bounding_boxes)
-
-# Draw the bounding box in an image summary.
-image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
- bbox_for_draw)
-tf.image_summary('images_with_box', image_with_box)
-
-# Employ the bounding box to distort the image.
-distorted_image = tf.slice(image, begin, size)
-```
+For example,
+
+ # Generate a single distorted bounding box.
+ begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
+ tf.shape(image),
+ bounding_boxes=bounding_boxes)
+
+ # Draw the bounding box in an image summary.
+ image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
+ bbox_for_draw)
+ tf.image_summary('images_with_box', image_with_box)
+
+ # Employ the bounding box to distort the image.
+ distorted_image = tf.slice(image, begin, size)
Note that if no bounding box information is available, setting
`use_image_if_no_bounding_boxes = true` will assume there is a single implicit
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 644bb546b3..1a1e019230 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -7122,22 +7122,27 @@ op {
name: "SampleDistortedBoundingBox"
input_arg {
name: "image_size"
+ description: "1-D, containing `[height, width, channels]`."
type_attr: "T"
}
input_arg {
name: "bounding_boxes"
+ description: "3-D with shape `[batch, N, 4]` describing the N bounding boxes\nassociated with the image."
type: DT_FLOAT
}
output_arg {
name: "begin"
+ description: "1-D, containing `[offset_height, offset_width, 0]`. Provide as input to\n`tf.slice`."
type_attr: "T"
}
output_arg {
name: "size"
+ description: "1-D, containing `[target_height, target_width, -1]`. Provide as input to\n`tf.slice`."
type_attr: "T"
}
output_arg {
name: "bboxes"
+ description: "3-D with shape `[1, 1, 4]` containing the distorted bounding box.\nProvide as input to `tf.image.draw_bounding_boxes`."
type: DT_FLOAT
}
attr {
@@ -7159,6 +7164,7 @@ op {
default_value {
i: 0
}
+ description: "If either `seed` or `seed2` are set to non-zero, the random number\ngenerator is seeded by the given `seed`. Otherwise, it is seeded by a random\nseed."
}
attr {
name: "seed2"
@@ -7166,6 +7172,7 @@ op {
default_value {
i: 0
}
+ description: "A second seed to avoid seed collision."
}
attr {
name: "min_object_covered"
@@ -7173,6 +7180,7 @@ op {
default_value {
f: 0.1
}
+ description: "The cropped area of the image must contain at least this\nfraction of any bounding box supplied."
}
attr {
name: "aspect_ratio_range"
@@ -7183,6 +7191,7 @@ op {
f: 1.33
}
}
+ description: "The cropped area of the image must have an aspect ratio =\nwidth / height within this range."
}
attr {
name: "area_range"
@@ -7193,6 +7202,7 @@ op {
f: 1
}
}
+ description: "The cropped area of the image must contain a fraction of the\nsupplied image within in this range."
}
attr {
name: "max_attempts"
@@ -7200,6 +7210,7 @@ op {
default_value {
i: 100
}
+ description: "Number of attempts at generating a cropped region of the image\nof the specified constraints. After `max_attempts` failures, return the entire\nimage."
}
attr {
name: "use_image_if_no_bounding_boxes"
@@ -7207,9 +7218,10 @@ op {
default_value {
b: false
}
+ description: "Controls behavior if no bounding boxes supplied.\nIf true, assume an implicit bounding box covering the whole input. If false,\nraise an error."
}
summary: "Generate a single randomly distorted bounding box for an image."
- description: "Bounding box annotations are often supplied in addition to ground-truth labels\nin image recognition or object localization tasks. A common technique for\ntraining such a system is to randomly distort an image while preserving\nits content, i.e. *data augmentation*. This Op outputs a randomly distorted\nlocalization of an object, i.e. bounding box, given an `image_size`,\n`bounding_boxes` and a series of constraints.\n\nThe output of this Op is a single bounding box that may be used to crop the\noriginal image. The output is returned as 3 tensors: `begin`, `size` and\n`bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the\nimage. The latter may be supplied to `tf.image.draw_bounding_box` to visualize\nwhat the bounding box looks like.\n\nBounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The\nbounding box coordinates are floats in `[0.0, 1.0]` relative to the width and\nheight of the underlying image."
+ description: "Bounding box annotations are often supplied in addition to ground-truth labels\nin image recognition or object localization tasks. A common technique for\ntraining such a system is to randomly distort an image while preserving\nits content, i.e. *data augmentation*. This Op outputs a randomly distorted\nlocalization of an object, i.e. bounding box, given an `image_size`,\n`bounding_boxes` and a series of constraints.\n\nThe output of this Op is a single bounding box that may be used to crop the\noriginal image. The output is returned as 3 tensors: `begin`, `size` and\n`bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the\nimage. The latter may be supplied to `tf.image.draw_bounding_box` to visualize\nwhat the bounding box looks like.\n\nBounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The\nbounding box coordinates are floats in `[0.0, 1.0]` relative to the width and\nheight of the underlying image.\n\nFor example,\n\n # Generate a single distorted bounding box.\n begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(\n tf.shape(image),\n bounding_boxes=bounding_boxes)\n\n # Draw the bounding box in an image summary.\n image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),\n bbox_for_draw)\n tf.image_summary(\'images_with_box\', image_with_box)\n\n # Employ the bounding box to distort the image.\n distorted_image = tf.slice(image, begin, size)\n\nNote that if no bounding box information is available, setting\n`use_image_if_no_bounding_boxes = true` will assume there is a single implicit\nbounding box covering the whole image. If `use_image_if_no_bounding_boxes` is\nfalse and no bounding boxes are supplied, an error is raised."
is_stateful: true
}
op {
@@ -7482,7 +7494,7 @@ op {
description: "If True, the assignment will be protected by a lock;\notherwise the behavior is undefined, but may exhibit less contention."
}
summary: "Applies sparse updates to a variable reference."
- description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] = updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] = updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nIf `indices` contains duplicate entries, lexicographically later entries\noverride earlier entries.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterUpdate.png\" alt>\n</div>"
+ description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] = updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] = updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nIf values in `ref` is to be updated more than once, because there are\nduplicate entires in `indices`, the order at which the updates happen\nfor each value is undefined.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterUpdate.png\" alt>\n</div>"
}
op {
name: "SegmentMax"
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index a5a3d14a07..3bec698acc 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -182,8 +182,9 @@ This operation computes
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
-If `indices` contains duplicate entries, lexicographically later entries
-override earlier entries.
+If values in `ref` is to be updated more than once, because there are
+duplicate entires in `indices`, the order at which the updates happen
+for each value is undefined.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 9d069ab4c6..a89a0a4af7 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -3,9 +3,9 @@
load("//google/protobuf:protobuf.bzl", "cc_proto_library")
load("//google/protobuf:protobuf.bzl", "py_proto_library")
-# configure may change the following lines.
-CUDA_VERSION = '7.0'
-CUDNN_VERSION = '6.5'
+# configure may change the following lines to '.X.Y' or similar
+CUDA_VERSION = ''
+CUDNN_VERSION = ''
# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index c6dccc06ff..4b8088fde8 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -75,7 +75,7 @@ filegroup(
cc_library(
name = "cuda",
data = [
- "//third_party/gpus/cuda:lib64/libcudart.so." + tf_get_cuda_version(),
+ "//third_party/gpus/cuda:lib64/libcudart.so" + tf_get_cuda_version(),
],
linkopts = [
"-Wl,-rpath,third_party/gpus/cuda/lib64",
diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h
index d8ba37babc..18395f3292 100644
--- a/tensorflow/core/platform/default/mutex.h
+++ b/tensorflow/core/platform/default/mutex.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <chrono>
#include <condition_variable>
#include <mutex>
+#include "tensorflow/core/platform/default/thread_annotations.h"
namespace tensorflow {
@@ -29,16 +30,24 @@ enum LinkerInitialized { LINKER_INITIALIZED };
// A class that wraps around the std::mutex implementation, only adding an
// additional LinkerInitialized constructor interface.
-class mutex : public std::mutex {
+class LOCKABLE mutex : public std::mutex {
public:
mutex() {}
// The default implementation of std::mutex is safe to use after the linker
// initializations
explicit mutex(LinkerInitialized x) {}
+
+ void lock() ACQUIRE() { std::mutex::lock(); }
+ void unlock() RELEASE() { std::mutex::unlock(); }
+};
+
+class SCOPED_LOCKABLE mutex_lock : public std::unique_lock<std::mutex> {
+ public:
+ mutex_lock(class mutex& m) ACQUIRE(m) : std::unique_lock<std::mutex>(m) {}
+ ~mutex_lock() RELEASE() {}
};
using std::condition_variable;
-typedef std::unique_lock<std::mutex> mutex_lock;
inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
condition_variable* cv, int64 ms) {
diff --git a/tensorflow/core/platform/default/thread_annotations.h b/tensorflow/core/platform/default/thread_annotations.h
index d8a9253926..46143b2ea3 100644
--- a/tensorflow/core/platform/default/thread_annotations.h
+++ b/tensorflow/core/platform/default/thread_annotations.h
@@ -73,6 +73,15 @@ limitations under the License.
#define ACQUIRED_BEFORE(...) \
THREAD_ANNOTATION_ATTRIBUTE__(acquired_before(__VA_ARGS__))
+#define ACQUIRE(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(acquire_capability(__VA_ARGS__))
+
+#define ACQUIRE_SHARED(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(acquire_shared_capability(__VA_ARGS__))
+
+#define RELEASE(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(release_capability(__VA_ARGS__))
+
// Document a function that expects a mutex to be held prior to entry.
// The mutex is expected to be held both on entry to and exit from the
// function.
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 4c0a9b8d70..9d46087814 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,7 +19,7 @@ limitations under the License.
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 0
-#define TF_MINOR_VERSION 6
+#define TF_MINOR_VERSION 7
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 4e1dfb8c4f..c98207b6f0 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -20,6 +20,8 @@ limitations under the License.
#include <algorithm>
+#include "tensorflow/core/platform/types.h"
+
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
@@ -69,6 +71,58 @@ __device__ __host__ inline T ldg(const T* address) {
#endif
}
+// CUDA provides atomic ops, but not for all types. We provide wrappers
+// for some ops and provide implementation for all reasonable types.
+#define CUDA_ATOMIC_WRAPPER(op, T) \
+ __device__ __forceinline__ T CudaAtomic##op(T* address, T val)
+
+#define USE_CUDA_ATOMIC(op, T) \
+ CUDA_ATOMIC_WRAPPER(op, T) { \
+ return atomic##op(address, val); \
+ }
+
+// For atomicAdd.
+USE_CUDA_ATOMIC(Add, int32);
+USE_CUDA_ATOMIC(Add, uint32);
+USE_CUDA_ATOMIC(Add, uint64);
+USE_CUDA_ATOMIC(Add, float);
+
+// Custom implementation of atomicAdd for double.
+// This implementation is copied from CUDA manual.
+CUDA_ATOMIC_WRAPPER(Add, double) {
+ uint64* address_as_ull = (uint64*)address;
+ uint64 old = *address_as_ull, assumed;
+
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed,
+ __double_as_longlong(val + __longlong_as_double(assumed)));
+
+ // Note: uses integer comparison to avoid hang in case of NaN
+ } while (assumed != old);
+
+ return __longlong_as_double(old);
+}
+
+// For atomicSub.
+
+// Custom implementation for sub by just negating the value.
+#define WRAPPED_ATOMIC_SUB(T) \
+ CUDA_ATOMIC_WRAPPER(Sub, T) { \
+ return CudaAtomicAdd(address, -val); \
+ }
+
+WRAPPED_ATOMIC_SUB(uint64);
+WRAPPED_ATOMIC_SUB(int32);
+WRAPPED_ATOMIC_SUB(uint32);
+WRAPPED_ATOMIC_SUB(float);
+WRAPPED_ATOMIC_SUB(double);
+
+#undef WRAPPED_ATOMIC_SUB
+
+#undef USE_CUDA_ATOMIC
+#undef CUDA_ATOMIC_WRAPPER
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc
index cfdbb07cd5..6b2a526834 100644
--- a/tensorflow/core/util/events_writer.cc
+++ b/tensorflow/core/util/events_writer.cc
@@ -115,7 +115,7 @@ bool EventsWriter::Flush() {
// recordio_writer_->Sync() can return true even if the underlying
// file has been deleted. EventWriter.FileDeletionBeforeWriting
// demonstrates this and will fail if the FileHasDisappeared()
- // conditon is removed.
+ // condition is removed.
// Also, we deliberately attempt to Sync() before checking for a
// disappearing file, in case for some file system File::Exists() is
// false after File::Open() but before File::Sync().