diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-02-17 11:42:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-17 12:56:41 -0800 |
commit | fe056f0b5e52db86766761f5e6446a89c1aa3938 (patch) | |
tree | 68bce0e257d181a3fa37f83c97fdff0fdad877fc /tensorflow/core | |
parent | 19d632338f983e02dd0268b931e9cced03b74805 (diff) |
Merge changes from github.
Change: 114882676
Diffstat (limited to 'tensorflow/core')
53 files changed, 458 insertions, 142 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 6506b06096..57eb5999b2 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -535,7 +535,7 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds, pending_feeds.erase(id); } - // Initialize the stack with the fecth nodes. + // Initialize the stack with the fetch nodes. std::vector<const Node*> stack; for (const string& fetch : fetches) { TensorId id(ParseTensorName(fetch)); diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 3b61dc8697..53645a2061 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -304,7 +304,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { ExpandInlineFunctions(lib_, g); EXPECT_EQ(e2, DebugString(g)); - // Get rid of redunant Identity nodes. + // Get rid of redundant Identity nodes. RemoveIdentityNodes(g); const char* e3 = R"P( (n2:float) -> (n42:float) { @@ -683,7 +683,7 @@ TEST(OptimizationTest, RemoveDeadNodes) { {{"a"}, "Square", {"x"}, {{"T", T}}}, // 1 FDH::Const("o", 1), - // A bunch of extra arithmatic that y doesn't depend on + // A bunch of extra arithmetic that y doesn't depend on {{"x1"}, "Add", {"o", "o"}, {{"T", T}}}, {{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}}, {{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}}, @@ -722,7 +722,7 @@ TEST(OptimizationTest, RemoveIdentityNodes_Ref) { // Nodes {// variable {{"v"}, "Variable", {}, {{"dtype", T}, {"shape", TensorShape({})}}}, - // read the variable. Shouln't be removed. + // read the variable. Shouldn't be removed. {{"v_read"}, "Identity", {"v"}, {{"T", T}}}, // returns v + v {{"ret"}, "Add", {"v_read", "v_read"}, {{"T", T}}}}); @@ -761,7 +761,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) { {{"a"}, "Square", {"x"}, {{"T", T}}}, // 1 FDH::Const("o", 1), - // A bunch of extra arithmatic that y doesn't depend on + // A bunch of extra arithmetic that y doesn't depend on {{"x1"}, "Identity", {"a"}, {{"T", T}}}, {{"x2"}, "Identity", {"x1"}, {{"T", T}}}, {{"x3"}, "Identity", {"x2"}, {{"T", T}}}, diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index b8e4a1f783..f904d5be8d 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -105,7 +105,7 @@ class Allocator { // Returns true if this allocator tracks the sizes of allocations. // RequestedSize and AllocatedSize must be overridden if - // TracksAlloctionSizes is overridden to return true. + // TracksAllocationSizes is overridden to return true. virtual bool TracksAllocationSizes() { return false; } // Returns true if this allocator requires tensors with 0 elements diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 503e578fac..8231b91a3f 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -160,10 +160,10 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { // "attr_values", which is a map from a placeholder name to an attr // value. // -// InstatiateFunction calls "get_function" to find signatures of other +// InstantiateFunction calls "get_function" to find signatures of other // functions and primitive ops. -// Placeholders in "fdef" is substitued based on "attr_values" here. +// Placeholders in "fdef" is substituted based on "attr_values" here. typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap; typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>> InstantiateAttrValueSlice; @@ -329,7 +329,7 @@ class FunctionLibraryRuntime { // std::function<Status(const AttrSlice&, FunctionDef*)>. // // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a -// definition of a brain function which computate the gradient for the +// definition of a brain function which compute the gradient for the // <op_name> when the <op_name> is instantiated with the given attrs. // // E.g., diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index a8c2d58b92..dc98044804 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -473,7 +473,7 @@ class OpKernelContext { // will never use eigen_gpu_device. It seems better to have // ensure_eigen_gpu_device fall through and regenerate the // nullptr every time an OpKernelContext is instantiated, than - // to do an unneccessary allocation of a dummy eigen GPU + // to do an unnecessary allocation of a dummy eigen GPU // device for CPU device Ops. eigen_gpu_device = device->MakeGpuDevice(); } @@ -1037,7 +1037,7 @@ typedef ::tensorflow::KernelDefBuilder Name; static ::tensorflow::kernel_factory::OpKernelRegistrar \ registrar__body__##ctr##__object( \ ::tensorflow::register_kernel::kernel_builder.Build(), \ - +[](::tensorflow::OpKernelConstruction* context) \ + [](::tensorflow::OpKernelConstruction* context) \ -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); }) void* GlobalKernelRegistry(); diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h index f86fc6a141..037e6d4684 100644 --- a/tensorflow/core/graph/costmodel.h +++ b/tensorflow/core/graph/costmodel.h @@ -95,7 +95,7 @@ class CostModel { // Check that an estimate is available for every OP node in graph. void CheckInitialized(const Graph& graph) const; - // Helper routines to encapsulate static estimatation heuristics + // Helper routines to encapsulate static estimation heuristics // Compute an estimate of the time to copy "b" bytes over the network, // given a fixed cost of "network_latency_millis" milliseconds and diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index b4b91e115e..4ad2a306b2 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -15,7 +15,7 @@ limitations under the License. // A Graph describes a set of computations that are to be // performed, as well as the dependencies between those -// compuations. The basic model is a DAG (directed acyclic graph) with +// computations. The basic model is a DAG (directed acyclic graph) with // * internal nodes representing computational operations to be performed; // * edges represent dependencies, indicating the target may only be // executed once the source has completed; and diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h index 2a212bbc49..ec28343668 100644 --- a/tensorflow/core/graph/graph_def_builder.h +++ b/tensorflow/core/graph/graph_def_builder.h @@ -37,7 +37,7 @@ namespace tensorflow { // node_builder.Input(input); // return opts.FinalizeBuilder(&node_builder); // } -// } // namspace ops +// } // namespace ops // // // Or, alternatively: // namespace ops { @@ -45,7 +45,7 @@ namespace tensorflow { // static const string kOpName = "Identity"; // return UnaryOp(kOpName, input, opts); // } -// } // namspace ops +// } // namespace ops // // You call it like: // GraphDefBuilder b; diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h index 4ae0133977..5c69af0144 100644 --- a/tensorflow/core/graph/graph_partition.h +++ b/tensorflow/core/graph/graph_partition.h @@ -40,7 +40,7 @@ struct PartitionOptions { // A function that returns the incarnation of a device given the // device's fullname. If not found, GetIncarnationFunc should return - // kIlledgalIncarnation. + // kIllegalIncarnation. static const uint64 kIllegalIncarnation = 0; typedef std::function<uint64(const string&)> GetIncarnationFunc; GetIncarnationFunc get_incarnation = nullptr; diff --git a/tensorflow/core/kernels/avgpooling_op.h b/tensorflow/core/kernels/avgpooling_op.h index a275a32c3b..128ab2b3b9 100644 --- a/tensorflow/core/kernels/avgpooling_op.h +++ b/tensorflow/core/kernels/avgpooling_op.h @@ -41,7 +41,7 @@ struct SpatialAvgPooling { typedef Eigen::GpuDevice GPUDevice; -// Lauch a custom GPU kernels from Yanqing for the avgpooling backward operation +// Launch a custom GPU kernels from Yanqing for the avgpooling backward operation // that works NHWC data formats. // Arguments: // top_diff: backprop to the output of the pooling layer diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index c4f8298e27..5f5d7cf8db 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -85,6 +85,7 @@ void HostConstantOp::Compute(OpKernelContext* ctx) { ctx->set_output(0, tensor_); } +#if GOOGLE_CUDA // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. @@ -93,6 +94,7 @@ REGISTER_KERNEL_BUILDER(Name("Const") .HostMemory("output") .TypeConstraint<int32>("dtype"), HostConstantOp); +#endif typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; @@ -178,10 +180,6 @@ REGISTER_KERNEL(GPU, int16); REGISTER_KERNEL(GPU, int64); // Currently we do not support filling strings and complex64 on GPU -#endif // GOOGLE_CUDA - -#undef REGISTER_KERNEL - // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. @@ -192,6 +190,9 @@ REGISTER_KERNEL_BUILDER(Name("Fill") .HostMemory("value") .HostMemory("output"), FillOp<CPUDevice, int32>); +#endif + +#undef REGISTER_KERNEL template <typename Device, typename T> class ZerosLikeOp : public OpKernel { diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc index ca61a94391..1b976c7210 100644 --- a/tensorflow/core/kernels/cwise_op_abs.cc +++ b/tensorflow/core/kernels/cwise_op_abs.cc @@ -23,7 +23,6 @@ REGISTER_KERNEL_BUILDER(Name("ComplexAbs").Device(DEVICE_CPU), #endif #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Abs", functor::abs, float, double, int64); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -34,5 +33,6 @@ REGISTER_KERNEL_BUILDER(Name("Abs") .HostMemory("y") .TypeConstraint<int32>("T"), UnaryOp<CPUDevice, functor::abs<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_add.cc b/tensorflow/core/kernels/cwise_op_add.cc index 49d9e0c507..8a9c1979e5 100644 --- a/tensorflow/core/kernels/cwise_op_add.cc +++ b/tensorflow/core/kernels/cwise_op_add.cc @@ -20,7 +20,6 @@ REGISTER8(BinaryOp, CPU, "Add", functor::add, float, double, int32, int64, int8, int16, complex64, string); #if GOOGLE_CUDA REGISTER3(BinaryOp, GPU, "Add", functor::add, float, double, int64); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -32,5 +31,6 @@ REGISTER_KERNEL_BUILDER(Name("Add") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::add<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc index 979db4c50c..e97b6b4360 100644 --- a/tensorflow/core/kernels/cwise_op_div.cc +++ b/tensorflow/core/kernels/cwise_op_div.cc @@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "Div", functor::div, float, double, uint8, int16, #if GOOGLE_CUDA REGISTER5(BinaryOp, GPU, "Div", functor::div, float, double, uint8, int16, int64); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("Div") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::div<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_equal_to.cc b/tensorflow/core/kernels/cwise_op_equal_to.cc index 28801a49d6..1b744445ff 100644 --- a/tensorflow/core/kernels/cwise_op_equal_to.cc +++ b/tensorflow/core/kernels/cwise_op_equal_to.cc @@ -21,7 +21,6 @@ REGISTER9(BinaryOp, CPU, "Equal", functor::equal_to, float, double, uint8, int8, #if GOOGLE_CUDA REGISTER6(BinaryOp, GPU, "Equal", functor::equal_to, float, double, uint8, int8, int16, int64); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("Equal") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::equal_to<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc index 7f746745bf..f860ea19ac 100644 --- a/tensorflow/core/kernels/cwise_op_greater.cc +++ b/tensorflow/core/kernels/cwise_op_greater.cc @@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "Greater", functor::greater, float, double, int32, #if GOOGLE_CUDA REGISTER6(BinaryOp, GPU, "Greater", functor::greater, float, double, int64, uint8, int8, int16); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("Greater") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::greater<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc index adf06121a2..465c5a22ae 100644 --- a/tensorflow/core/kernels/cwise_op_greater_equal.cc +++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc @@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float, double, #if GOOGLE_CUDA REGISTER6(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, float, double, int64, uint8, int8, int16); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("GreaterEqual") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::greater_equal<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc index 1710304703..e7acfa091f 100644 --- a/tensorflow/core/kernels/cwise_op_less.cc +++ b/tensorflow/core/kernels/cwise_op_less.cc @@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "Less", functor::less, float, double, int32, int64, #if GOOGLE_CUDA REGISTER6(BinaryOp, GPU, "Less", functor::less, float, double, int64, uint8, int8, int16); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("Less") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::less<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc index 65f79b5799..3175dcae5e 100644 --- a/tensorflow/core/kernels/cwise_op_less_equal.cc +++ b/tensorflow/core/kernels/cwise_op_less_equal.cc @@ -21,7 +21,6 @@ REGISTER7(BinaryOp, CPU, "LessEqual", functor::less_equal, float, double, int32, #if GOOGLE_CUDA REGISTER6(BinaryOp, GPU, "LessEqual", functor::less_equal, float, double, int64, uint8, int8, int16); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -33,5 +32,6 @@ REGISTER_KERNEL_BUILDER(Name("LessEqual") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::less_equal<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_maximum.cc b/tensorflow/core/kernels/cwise_op_maximum.cc index 732e6b48f2..a2fc480674 100644 --- a/tensorflow/core/kernels/cwise_op_maximum.cc +++ b/tensorflow/core/kernels/cwise_op_maximum.cc @@ -20,7 +20,6 @@ REGISTER4(BinaryOp, CPU, "Maximum", functor::maximum, float, double, int32, int64); #if GOOGLE_CUDA REGISTER3(BinaryOp, GPU, "Maximum", functor::maximum, float, double, int64); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -32,5 +31,6 @@ REGISTER_KERNEL_BUILDER(Name("Maximum") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::maximum<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_minimum.cc b/tensorflow/core/kernels/cwise_op_minimum.cc index 017daca4cc..0c6797ec86 100644 --- a/tensorflow/core/kernels/cwise_op_minimum.cc +++ b/tensorflow/core/kernels/cwise_op_minimum.cc @@ -20,7 +20,6 @@ REGISTER4(BinaryOp, CPU, "Minimum", functor::minimum, float, double, int32, int64); #if GOOGLE_CUDA REGISTER3(BinaryOp, GPU, "Minimum", functor::minimum, float, double, int64); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -32,5 +31,6 @@ REGISTER_KERNEL_BUILDER(Name("Minimum") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::minimum<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc index 94f7ddd3b2..5b5ad90207 100644 --- a/tensorflow/core/kernels/cwise_op_sign.cc +++ b/tensorflow/core/kernels/cwise_op_sign.cc @@ -19,7 +19,6 @@ namespace tensorflow { REGISTER4(UnaryOp, CPU, "Sign", functor::sign, float, double, int32, int64); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Sign", functor::sign, float, double, int64); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -30,5 +29,6 @@ REGISTER_KERNEL_BUILDER(Name("Sign") .HostMemory("y") .TypeConstraint<int32>("T"), UnaryOp<CPUDevice, functor::sign<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc index 8858db9e5f..b3727ec361 100644 --- a/tensorflow/core/kernels/cwise_op_sub.cc +++ b/tensorflow/core/kernels/cwise_op_sub.cc @@ -20,7 +20,6 @@ REGISTER5(BinaryOp, CPU, "Sub", functor::sub, float, double, int32, int64, complex64); #if GOOGLE_CUDA REGISTER3(BinaryOp, GPU, "Sub", functor::sub, float, double, int64); -#endif // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -32,5 +31,6 @@ REGISTER_KERNEL_BUILDER(Name("Sub") .HostMemory("z") .TypeConstraint<int32>("T"), BinaryOp<CPUDevice, functor::sub<int32>>); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc index 22a26e8310..38a46583bd 100644 --- a/tensorflow/core/kernels/identity_op.cc +++ b/tensorflow/core/kernels/identity_op.cc @@ -47,6 +47,7 @@ REGISTER_GPU_KERNEL(bfloat16); #undef REGISTER_GPU_KERNEL +#if GOOGLE_CUDA // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. @@ -56,5 +57,6 @@ REGISTER_KERNEL_BUILDER(Name("Identity") .HostMemory("output") .TypeConstraint<int32>("T"), IdentityOp); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc index 286b74ca64..f3ad98ab0a 100644 --- a/tensorflow/core/kernels/pad_op.cc +++ b/tensorflow/core/kernels/pad_op.cc @@ -170,7 +170,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); PadOp<GPUDevice, T>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); -#endif // GOOGLE_CUDA // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -182,5 +181,6 @@ REGISTER_KERNEL_BUILDER(Name("Pad") .HostMemory("paddings") .HostMemory("output"), PadOp<CPUDevice, int32>); +#endif } // end namespace tensorflow diff --git a/tensorflow/core/kernels/range_sampler.cc b/tensorflow/core/kernels/range_sampler.cc index 58bd103c80..40be4bdb20 100644 --- a/tensorflow/core/kernels/range_sampler.cc +++ b/tensorflow/core/kernels/range_sampler.cc @@ -60,7 +60,7 @@ namespace { // We use batch_size and num_tries, where num_tries is the observed number of // tries it took to get batch_size unique values. // -// Assuming (falsely) that the nubmer of tries to get a batch of batch_size +// Assuming (falsely) that the number of tries to get a batch of batch_size // distinct values is _always_ num_tries, the probability that the value // is in a batch is (1 - (1-p)^num_tries) static float ExpectedCountHelper(float p, int batch_size, int num_tries) { diff --git a/tensorflow/core/kernels/range_sampler.h b/tensorflow/core/kernels/range_sampler.h index b2a1d44da1..2513f50b37 100644 --- a/tensorflow/core/kernels/range_sampler.h +++ b/tensorflow/core/kernels/range_sampler.h @@ -65,7 +65,7 @@ class RangeSampler { // Expected counts for the elements of the returned "batch" are reported // in the aligned array "batch_expected_count". // - // The user can optionally provide "extras", containg values in the range. + // The user can optionally provide "extras", containing values in the range. // The expected counts for the extras are reported in the aligned array // "extras_expected_count". // diff --git a/tensorflow/core/kernels/reshape_op.cc b/tensorflow/core/kernels/reshape_op.cc index 2413623684..1ee959c026 100644 --- a/tensorflow/core/kernels/reshape_op.cc +++ b/tensorflow/core/kernels/reshape_op.cc @@ -30,6 +30,7 @@ REGISTER_KERNEL_BUILDER(Name("Reshape").Device(DEVICE_CPU).HostMemory("shape"), TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL +#if GOOGLE_CUDA // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. @@ -40,5 +41,6 @@ REGISTER_KERNEL_BUILDER(Name("Reshape") .HostMemory("output") .TypeConstraint<int32>("T"), ReshapeOp); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index afc03efbca..30fd105b5f 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -15,6 +15,8 @@ limitations under the License. // See docs in ../ops/state_ops.cc. +#include "tensorflow/core/kernels/scatter_op.h" + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -23,33 +25,38 @@ limitations under the License. namespace tensorflow { -enum class UpdateOp { ASSIGN, ADD, SUB }; +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace { -template <UpdateOp Op> +template <scatter_op::UpdateOp Op> struct Assign {}; template <> -struct Assign<UpdateOp::ASSIGN> { +struct Assign<scatter_op::UpdateOp::ASSIGN> { template <typename Params, typename Update> static void Run(Params p, Update u) { p = u; } }; template <> -struct Assign<UpdateOp::ADD> { +struct Assign<scatter_op::UpdateOp::ADD> { template <typename Params, typename Update> static void Run(Params p, Update u) { p += u; } }; template <> -struct Assign<UpdateOp::SUB> { +struct Assign<scatter_op::UpdateOp::SUB> { template <typename Params, typename Update> static void Run(Params p, Update u) { p -= u; } }; -template <class T, typename Index, UpdateOp op> +} // namespace + +template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> class ScatterUpdateOp : public OpKernel { public: // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, @@ -108,85 +115,136 @@ class ScatterUpdateOp : public OpKernel { "updates.shape ", Tupdates.shape().DebugString(), ", indices.shape ", Tindices.shape().DebugString(), ", params.shape ", Tparams.shape().DebugString())); - const Index N = Tindices.NumElements(); // We always return the input ref. c->forward_ref_input_to_ref_output(0, 0); + const Index N = Tindices.NumElements(); if (N > 0) { - const Index first_dim_size = Tparams.dim_size(0); - // Validate all the indices are in range - auto Tindices_vec = Tindices.flat<Index>(); - for (Index i = 0; i < N; i++) { - const Index index = Tindices_vec(i); - OP_REQUIRES(c, index >= 0 && index < first_dim_size, - errors::InvalidArgument( - strings::StrCat("Index ", index, " at offset ", i, - " in indices is out of range"))); - } + auto Tindices_flat = Tindices.flat<Index>(); auto Tparams_flat = Tparams.flat_outer_dims<T>(); auto Tupdates_flat = Tupdates.shaped<T, 2>({N, Tupdates.NumElements() / N}); - for (Index i = 0; i < N; i++) { - // Copy last Ndim-1 dimensions of Tupdates[i] to - // Tparams[Tindices[i]] - Assign<op>::Run(Tparams_flat.template chip<0>(Tindices_vec(i)), - Tupdates_flat.template chip<0>(i)); - } + + functor::ScatterFunctor<Device, T, Index, op> functor; + functor(c, c->template eigen_device<Device>(), + Tparams_flat, Tupdates_flat, Tindices_flat); } } }; -#define REGISTER_SCATTER_UPDATE(type, index_type) \ - REGISTER_KERNEL_BUILDER( \ - Name("ScatterUpdate") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - ScatterUpdateOp<type, index_type, UpdateOp::ASSIGN>); +namespace functor { +// Implementation of update functor for CPU. +template <typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterFunctor<CPUDevice, T, Index, op> { + void operator()(OpKernelContext* c, const CPUDevice& d, + typename TTypes<T>::Matrix params, + typename TTypes<T>::ConstMatrix updates, + typename TTypes<Index>::ConstFlat indices) { + Index N = indices.size(); + // Validate all the indices are in range + Index first_dim_size = params.dimension(0); + for (Index i = 0; i < N; i++) { + const Index index = indices(i); + OP_REQUIRES(c, index >= 0 && index < first_dim_size, + errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range"))); + } + for (Index i = 0; i < N; i++) { + // Copy last Ndim-1 dimensions of Tupdates[i] to + // Tparams[Tindices[i]] + Assign<op>::Run(params.template chip<0>(indices(i)), + updates.template chip<0>(i)); + } + } +}; +} // namespace functor -#define REGISTER_SCATTER_UPDATE_INT32(type) REGISTER_SCATTER_UPDATE(type, int32) -#define REGISTER_SCATTER_UPDATE_INT64(type) REGISTER_SCATTER_UPDATE(type, int64) +#define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_##dev) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + ScatterUpdateOp<dev##Device, type, index_type, op>) -TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT32); -TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT64); +#define REGISTER_SCATTER_KERNEL(type, dev, name, op) \ + REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ + REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); -#undef REGISTER_SCATTER_UPDATE_INT64 -#undef REGISTER_SCATTER_UPDATE_INT32 -#undef REGISTER_SCATTER_UPDATE +#define REGISTER_SCATTER_ADD_SUB(type, dev) \ + REGISTER_SCATTER_KERNEL( \ + type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \ + REGISTER_SCATTER_KERNEL( \ + type, dev, "ScatterSub", scatter_op::UpdateOp::SUB); -#define REGISTER_SCATTER_ADD(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("ScatterAdd") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - ScatterUpdateOp<type, index_type, UpdateOp::ADD>); +#define REGISTER_SCATTER_UPDATE(type, dev) \ + REGISTER_SCATTER_KERNEL( \ + type, dev, "ScatterUpdate", scatter_op::UpdateOp::ASSIGN); -#define REGISTER_SCATTER_ADD_INT32(type) REGISTER_SCATTER_ADD(type, int32) -#define REGISTER_SCATTER_ADD_INT64(type) REGISTER_SCATTER_ADD(type, int64) +// Registers CPU kernels. +#define REGISTER_SCATTER_ADD_SUB_CPU(type) \ + REGISTER_SCATTER_ADD_SUB(type, CPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_INT32); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_INT64); +#define REGISTER_SCATTER_UPDATE_CPU(type) \ + REGISTER_SCATTER_UPDATE(type, CPU); -#undef REGISTER_SCATTER_ADD_INT32 -#undef REGISTER_SCATTER_ADD_INT64 -#undef REGISTER_SCATTER_ADD +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_CPU); +TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU); + +// Registers GPU kernels. +#if GOOGLE_CUDA +#define REGISTER_SCATTER_ADD_SUB_GPU(type) \ + REGISTER_SCATTER_ADD_SUB(type, GPU); -#define REGISTER_SCATTER_SUB(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("ScatterSub") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - ScatterUpdateOp<type, index_type, UpdateOp::SUB>); +#define REGISTER_SCATTER_UPDATE_GPU(type) \ + REGISTER_SCATTER_UPDATE(type, GPU); -#define REGISTER_SCATTER_SUB_INT32(type) REGISTER_SCATTER_SUB(type, int32) -#define REGISTER_SCATTER_SUB_INT64(type) REGISTER_SCATTER_SUB(type, int64) +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_GPU); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_SUB_INT32); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_SUB_INT64); +#endif // GOOGLE_CUDA -#undef REGISTER_SCATTER_SUB_INT64 -#undef REGISTER_SCATTER_SUB_INT32 -#undef REGISTER_SCATTER_SUB +#undef REGISTER_SCATTER_ADD +#undef REGISTER_SCATTER_ADD_SUB +#undef REGISTER_SCATTER_ADD_SUB_CPU +#undef REGISTER_SCATTER_ADD_SUB_GPU +#undef REGISTER_SCATTER_UPDATE +#undef REGISTER_SCATTER_UPDATE_CPU +#undef REGISTER_SCATTER_UPDATE_GPU +#undef REGISTER_SCATTER_KERNEL +#undef REGISTER_SCATTER_KERNEL_INDEX + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { + +#define DECLARE_GPU_SPECS_OP(T, Index, op) \ + template <> \ + void ScatterFunctor<GPUDevice, T, Index, op>::operator()( \ + OpKernelContext* c, const GPUDevice& d, \ + typename TTypes<T>::Matrix params, \ + typename TTypes<T>::ConstMatrix updates, \ + typename TTypes<Index>::ConstFlat indices); \ + extern template struct ScatterFunctor<GPUDevice, T, Index, op>; + +#define DECLARE_GPU_SPECS_INDEX(T, Index) \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPECS_INDEX(T, int32); \ + DECLARE_GPU_SPECS_INDEX(T, int64); + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); + +#undef DECLARE_GPU_SPECS +#undef DECLARE_GPU_SPECS_INDEX +#undef DECLARE_GPU_SPECS_OP + +} // namespace functor +#endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_op.h b/tensorflow/core/kernels/scatter_op.h new file mode 100644 index 0000000000..b7c7df97a7 --- /dev/null +++ b/tensorflow/core/kernels/scatter_op.h @@ -0,0 +1,48 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_SCATTER_OP_H_ +#define TENSORFLOW_KERNELS_SCATTER_OP_H_ + +// Functor definitions for Scatter ops, must be compilable by nvcc. + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace scatter_op { + +enum class UpdateOp { ASSIGN, ADD, SUB }; + +} // namespace scatter_op + +namespace functor { + +// Functor used by ScatterOp to do the computations. +template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterFunctor { + void operator()(OpKernelContext* c, const Device& d, + typename TTypes<T>::Matrix params, + typename TTypes<T>::ConstMatrix updates, + typename TTypes<Index>::ConstFlat indices); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_SCATTER_OP_H_ diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc new file mode 100644 index 0000000000..6ef23419ab --- /dev/null +++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc @@ -0,0 +1,108 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/scatter_op.h" + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +template <typename T, typename Index, scatter_op::UpdateOp op> +__global__ void ScatterOpCustomKernel( + T* params, const T* updates, const Index* indices, + Index first_dim_size, Index updates_size, Index indices_size) { + Index update_block = updates_size / indices_size; + CUDA_1D_KERNEL_LOOP(i, updates_size) { + int indices_i = i / update_block; + int updates_i = i; + int param_first_index = indices[indices_i]; + if (!(param_first_index >= 0 && param_first_index < first_dim_size)) { + // Ignore indices that are out of range. + continue; + } + int params_i = param_first_index * update_block + (i % update_block); + switch (op) { + case scatter_op::UpdateOp::ASSIGN: { + params[params_i] = ldg(updates + updates_i); + break; + } + case scatter_op::UpdateOp::ADD: { + CudaAtomicAdd(params + params_i, ldg(updates + updates_i)); + break; + } + case scatter_op::UpdateOp::SUB: { + CudaAtomicSub(params + params_i, ldg(updates + updates_i)); + break; + } + } + } +} + +namespace functor { +// Specialization for a GPU device. +template <typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterFunctor<GPUDevice, T, Index, op> { + void operator()(OpKernelContext* c, const GPUDevice& d, + typename TTypes<T>::Matrix params, + typename TTypes<T>::ConstMatrix updates, + typename TTypes<Index>::ConstFlat indices) { + // TODO: Implement indices range check. The hardest part is with returning + // a value after the range check, as we do not want to do device to host + // memcpy during a stream. + const Index first_dim_size = params.dimension(0); + const Index indices_size = indices.size(); + const Index updates_size = updates.size(); + CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d); + ScatterOpCustomKernel<T,Index,op> + <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + params.data(), updates.data(), indices.data(), + first_dim_size, updates_size, indices_size); + } +}; + +} // namespace functor + +#define DEFINE_GPU_SPECS_OP(T, Index, op) \ + template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; + +#define DEFINE_GPU_SPECS_INDEX(T, Index) \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); + +#define DEFINE_GPU_SPECS(T) \ + DEFINE_GPU_SPECS_INDEX(T, int32); \ + DEFINE_GPU_SPECS_INDEX(T, int64); + +DEFINE_GPU_SPECS(float); +DEFINE_GPU_SPECS(double); +// TODO: The following fails to compile. +// TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); + +#undef DEFINE_GPU_SPECS +#undef DEFINE_GPU_SPECS_INDEX +#undef DEFINE_GPU_SPECS_OP + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc index cb6abbe75f..1d6285882c 100644 --- a/tensorflow/core/kernels/scatter_op_test.cc +++ b/tensorflow/core/kernels/scatter_op_test.cc @@ -243,6 +243,7 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op) { testing::StopTiming(); const int kRows = 10000000 / embedding_size; std::vector<float> values; + values.reserve(kRows); for (int i = 0; i < kRows * embedding_size; i++) { values.push_back(i); } @@ -270,6 +271,7 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op) { while (iters-- > 0) { Status s = bm.RunOpKernel(); } + testing::StopTiming(); } static void BM_ScatterUpdateInt32(int iters, int embedding_size) { diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc index c2ba5ac275..b2c8be925a 100644 --- a/tensorflow/core/kernels/shape_ops.cc +++ b/tensorflow/core/kernels/shape_ops.cc @@ -43,6 +43,7 @@ class ShapeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("Shape").Device(DEVICE_CPU).HostMemory("output"), ShapeOp); +#if GOOGLE_CUDA #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("Shape") \ .Device(DEVICE_GPU) \ @@ -61,6 +62,7 @@ REGISTER_KERNEL_BUILDER(Name("Shape") .HostMemory("output") .TypeConstraint<int32>("T"), ShapeOp); +#endif class ShapeNOp : public OpKernel { public: @@ -82,6 +84,7 @@ class ShapeNOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("ShapeN").Device(DEVICE_CPU).HostMemory("output"), ShapeNOp); +#if GOOGLE_CUDA #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("ShapeN") \ .Device(DEVICE_GPU) \ @@ -100,6 +103,7 @@ REGISTER_KERNEL_BUILDER(Name("ShapeN") .HostMemory("output") .TypeConstraint<int32>("T"), ShapeNOp); +#endif class RankOp : public OpKernel { public: @@ -118,6 +122,7 @@ class RankOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("Rank").Device(DEVICE_CPU).HostMemory("output"), RankOp); +#if GOOGLE_CUDA #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("Rank") \ .Device(DEVICE_GPU) \ @@ -143,6 +148,7 @@ REGISTER_KERNEL_BUILDER(Name("Rank") .HostMemory("input") .HostMemory("output"), RankOp); +#endif class SizeOp : public OpKernel { public: @@ -162,6 +168,7 @@ class SizeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("Size").Device(DEVICE_CPU).HostMemory("output"), SizeOp); +#if GOOGLE_CUDA #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("Size") \ .Device(DEVICE_GPU) \ @@ -180,6 +187,7 @@ REGISTER_KERNEL_BUILDER(Name("Size") .HostMemory("input") .HostMemory("output"), SizeOp); +#endif class ExpandDimsOp : public OpKernel { public: @@ -225,6 +233,7 @@ class ExpandDimsOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("ExpandDims").Device(DEVICE_CPU).HostMemory("dim"), ExpandDimsOp); +#if GOOGLE_CUDA #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ .Device(DEVICE_GPU) \ @@ -241,6 +250,7 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .HostMemory("dim") .HostMemory("output"), ExpandDimsOp); +#endif class SqueezeOp : public OpKernel { public: @@ -313,6 +323,7 @@ class SqueezeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("Squeeze").Device(DEVICE_CPU), SqueezeOp); +#if GOOGLE_CUDA #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ Name("Squeeze").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ @@ -329,5 +340,6 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze") .HostMemory("input") .HostMemory("output"), SqueezeOp); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 1aa5ae0ab5..69b53500a1 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -67,7 +67,7 @@ static const int N = 128; // Note that all the data/indices of all the blocks are stored in the same // vectors respectively. To identify block boundaries, we store the block // offsets using index3_offset/index_offset. If there are n blocks in the slice, -// index3_offset and index_offset have n entires. The indices for the ith block +// index3_offset and index_offset have n entries. The indices for the ith block // are the values in the following range: // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for // index_offset. @@ -475,7 +475,7 @@ class SparseMatMulOp : public OpKernel { if (!a_is_sparse_ && !b_is_sparse_) { // Fallback to Eigen contract. // Note that we currently don't optimize the case where only right is - // sparse. That can generally be handled by tranposing the order of the + // sparse. That can generally be handled by transposing the order of the // matmul. Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; dim_pair[0].first = transpose_a_ ? 0 : 1; @@ -540,7 +540,7 @@ class SparseMatMulOp : public OpKernel { // Encodes "mat" using a sparse representation and stores that in // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and // "slice_num_cols", each grid element is converted into a SparseSlice and - // stored in mat_slices. "slice_block_size" is used to perform futher column + // stored in mat_slices. "slice_block_size" is used to perform further column // blocking of each slice. static inline BlockingCounter* CreateSparseSlices( const ConstMatrixMap& mat, bool transpose, int slice_num_rows, @@ -776,7 +776,7 @@ inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left, *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256); *NR = right.dimension(1); if (*KR * *NR > mem) { - // 4096 may be enough to ammortize the cost of writes. + // 4096 may be enough to amortize the cost of writes. *KR = std::min<int>(*KR, 4096); } // Use sizes that are multiples of K and 256. diff --git a/tensorflow/core/lib/core/command_line_flags.cc b/tensorflow/core/lib/core/command_line_flags.cc index 26e495a520..1cfc01b6e2 100644 --- a/tensorflow/core/lib/core/command_line_flags.cc +++ b/tensorflow/core/lib/core/command_line_flags.cc @@ -43,7 +43,7 @@ bool StringToValue<string>(const string& content, string* value) { // Return OK if the argument is used. It store the extracted value into the // matching flag. // Return NOT_FOUND if the argument is not recognized. -// Retrun INVALID_ARGUMENT if the command is recognized, but fails to extract +// Return INVALID_ARGUMENT if the command is recognized, but fails to extract // its value. template <typename T> Status ParseArgument(const string& argument) { diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h index 6b2cba809c..1f35c0cab8 100644 --- a/tensorflow/core/lib/gtl/inlined_vector.h +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -374,7 +374,7 @@ class InlinedVector { // Moves srcs[0,n-1] contents to dst[0,n-1]. static void Move(const T* src, size_t n, T* dst) { - for (int i = 0; i < n; i++) { + for (size_t i = 0; i < n; i++) { new (dst + i) T(std::move(*(src + i))); } } diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.h b/tensorflow/core/lib/jpeg/jpeg_handle.h index 04b092f053..4beca30c3f 100644 --- a/tensorflow/core/lib/jpeg/jpeg_handle.h +++ b/tensorflow/core/lib/jpeg/jpeg_handle.h @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // This file declares the functions and structures for memory I/O with libjpeg -// These functions are not meant to be used directly, see jpeg_mem.h isntead. +// These functions are not meant to be used directly, see jpeg_mem.h instead. #ifndef TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_ #define TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_ diff --git a/tensorflow/core/lib/random/distribution_sampler.cc b/tensorflow/core/lib/random/distribution_sampler.cc index 1fc1397499..3daaa1447c 100644 --- a/tensorflow/core/lib/random/distribution_sampler.cc +++ b/tensorflow/core/lib/random/distribution_sampler.cc @@ -42,7 +42,7 @@ DistributionSampler::DistributionSampler( std::vector<int> low; low.reserve(n); - // compute propotional weights + // compute proportional weights for (int i = 0; i < n; i++) { double p = (weights[i] * n) / sum; pr[i] = p; diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h index 9f510342d7..a2ee5c96aa 100644 --- a/tensorflow/core/lib/random/random_distributions.h +++ b/tensorflow/core/lib/random/random_distributions.h @@ -192,7 +192,7 @@ class SingleSampleAdapter { // each invocation. It needs to define kResultElementCount for the // sample count for each invocation, and ResultType for actual // returned sample type. -// RealType: the data type of the real numberes that will be returned by the +// RealType: the data type of the real numbers that will be returned by the // distribution. This could be either float or double for now. // This class is meant to be implemented through specialization. The default // is not defined by design. @@ -259,7 +259,7 @@ class NormalDistribution<Generator, double> { // each invocation. It needs to define kResultElementCount for the // sample count for each invocation, and ResultType for actual // returned sample type. -// RealType: the data type of the real numberes that will be returned by the +// RealType: the data type of the real numbers that will be returned by the // distribution. This could be either float or double for now. // This class is meant to be implemented through specialization. The default // is not defined by design. diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc index 4d3e4a5cdc..13398b838f 100644 --- a/tensorflow/core/lib/random/random_distributions_test.cc +++ b/tensorflow/core/lib/random/random_distributions_test.cc @@ -36,7 +36,7 @@ namespace { static constexpr float kZLimit = 6.0; // A utility function to fill the given array with samples from the given -// distribution, using the single adatper of the underlying generator +// distribution, using the single adapter of the underlying generator template <class Distribution> void FillRandomsWithSingles(PhiloxRandom gen, typename Distribution::ResultElementType* p, @@ -87,7 +87,7 @@ bool CheckSamplesMoments(const std::vector<T>& samples, break; } // moments[i] store the i-th order measured moments. - // bypass std::vector::opeartor[] because they are too slow in the debug + // bypass std::vector::operator[] because they are too slow in the debug // mode, given the large number of samples. moments_data[i] += moment; ++moments_sample_count_data[i]; diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index 859a654e36..778545b44b 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -33,7 +33,7 @@ char* FastInt32ToBufferLeft(int32 i, char* buffer) { if (i < 0) { *buffer++ = '-'; // We need to do the negation in modular (i.e., "unsigned") - // arithmetic; MSVC++ apprently warns for plain "-u", so + // arithmetic; MSVC++ apparently warns for plain "-u", so // we write the equivalent expression "0 - u" instead. u = 0 - u; } diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h index f1924ccf93..4dd0bcdec4 100644 --- a/tensorflow/core/lib/strings/numbers.h +++ b/tensorflow/core/lib/strings/numbers.h @@ -77,7 +77,7 @@ char* FloatToBuffer(float i, char* buffer); string FpToString(Fprint fp); // Attempt to parse a fingerprint in the form encoded by FpToString. If -// successsful, stores the fingerprint in *fp and returns true. Otherwise, +// successful, stores the fingerprint in *fp and returns true. Otherwise, // returns false. bool StringToFp(const string& s, Fprint* fp); diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h index 02632a6b2b..33b6028153 100644 --- a/tensorflow/core/lib/strings/strcat.h +++ b/tensorflow/core/lib/strings/strcat.h @@ -30,7 +30,7 @@ limitations under the License. // The AlphaNum type was designed to be used as the parameter type for StrCat(). // Any routine accepting either a string or a number may accept it. // The basic idea is that by accepting a "const AlphaNum &" as an argument -// to your function, your callers will automagically convert bools, integers, +// to your function, your callers will automatically convert bools, integers, // and floating point values to strings for you. // // NOTE: Use of AlphaNum outside of the //strings package is unsupported except diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index d9c3d8d2c9..99c90a811b 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -446,22 +446,20 @@ Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and height of the underlying image. -Example: - -``` -# Generate a single distorted bounding box. -begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( - tf.shape(image), - bounding_boxes=bounding_boxes) - -# Draw the bounding box in an image summary. -image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), - bbox_for_draw) -tf.image_summary('images_with_box', image_with_box) - -# Employ the bounding box to distort the image. -distorted_image = tf.slice(image, begin, size) -``` +For example, + + # Generate a single distorted bounding box. + begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( + tf.shape(image), + bounding_boxes=bounding_boxes) + + # Draw the bounding box in an image summary. + image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), + bbox_for_draw) + tf.image_summary('images_with_box', image_with_box) + + # Employ the bounding box to distort the image. + distorted_image = tf.slice(image, begin, size) Note that if no bounding box information is available, setting `use_image_if_no_bounding_boxes = true` will assume there is a single implicit diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 644bb546b3..1a1e019230 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -7122,22 +7122,27 @@ op { name: "SampleDistortedBoundingBox" input_arg { name: "image_size" + description: "1-D, containing `[height, width, channels]`." type_attr: "T" } input_arg { name: "bounding_boxes" + description: "3-D with shape `[batch, N, 4]` describing the N bounding boxes\nassociated with the image." type: DT_FLOAT } output_arg { name: "begin" + description: "1-D, containing `[offset_height, offset_width, 0]`. Provide as input to\n`tf.slice`." type_attr: "T" } output_arg { name: "size" + description: "1-D, containing `[target_height, target_width, -1]`. Provide as input to\n`tf.slice`." type_attr: "T" } output_arg { name: "bboxes" + description: "3-D with shape `[1, 1, 4]` containing the distorted bounding box.\nProvide as input to `tf.image.draw_bounding_boxes`." type: DT_FLOAT } attr { @@ -7159,6 +7164,7 @@ op { default_value { i: 0 } + description: "If either `seed` or `seed2` are set to non-zero, the random number\ngenerator is seeded by the given `seed`. Otherwise, it is seeded by a random\nseed." } attr { name: "seed2" @@ -7166,6 +7172,7 @@ op { default_value { i: 0 } + description: "A second seed to avoid seed collision." } attr { name: "min_object_covered" @@ -7173,6 +7180,7 @@ op { default_value { f: 0.1 } + description: "The cropped area of the image must contain at least this\nfraction of any bounding box supplied." } attr { name: "aspect_ratio_range" @@ -7183,6 +7191,7 @@ op { f: 1.33 } } + description: "The cropped area of the image must have an aspect ratio =\nwidth / height within this range." } attr { name: "area_range" @@ -7193,6 +7202,7 @@ op { f: 1 } } + description: "The cropped area of the image must contain a fraction of the\nsupplied image within in this range." } attr { name: "max_attempts" @@ -7200,6 +7210,7 @@ op { default_value { i: 100 } + description: "Number of attempts at generating a cropped region of the image\nof the specified constraints. After `max_attempts` failures, return the entire\nimage." } attr { name: "use_image_if_no_bounding_boxes" @@ -7207,9 +7218,10 @@ op { default_value { b: false } + description: "Controls behavior if no bounding boxes supplied.\nIf true, assume an implicit bounding box covering the whole input. If false,\nraise an error." } summary: "Generate a single randomly distorted bounding box for an image." - description: "Bounding box annotations are often supplied in addition to ground-truth labels\nin image recognition or object localization tasks. A common technique for\ntraining such a system is to randomly distort an image while preserving\nits content, i.e. *data augmentation*. This Op outputs a randomly distorted\nlocalization of an object, i.e. bounding box, given an `image_size`,\n`bounding_boxes` and a series of constraints.\n\nThe output of this Op is a single bounding box that may be used to crop the\noriginal image. The output is returned as 3 tensors: `begin`, `size` and\n`bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the\nimage. The latter may be supplied to `tf.image.draw_bounding_box` to visualize\nwhat the bounding box looks like.\n\nBounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The\nbounding box coordinates are floats in `[0.0, 1.0]` relative to the width and\nheight of the underlying image." + description: "Bounding box annotations are often supplied in addition to ground-truth labels\nin image recognition or object localization tasks. A common technique for\ntraining such a system is to randomly distort an image while preserving\nits content, i.e. *data augmentation*. This Op outputs a randomly distorted\nlocalization of an object, i.e. bounding box, given an `image_size`,\n`bounding_boxes` and a series of constraints.\n\nThe output of this Op is a single bounding box that may be used to crop the\noriginal image. The output is returned as 3 tensors: `begin`, `size` and\n`bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the\nimage. The latter may be supplied to `tf.image.draw_bounding_box` to visualize\nwhat the bounding box looks like.\n\nBounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The\nbounding box coordinates are floats in `[0.0, 1.0]` relative to the width and\nheight of the underlying image.\n\nFor example,\n\n # Generate a single distorted bounding box.\n begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(\n tf.shape(image),\n bounding_boxes=bounding_boxes)\n\n # Draw the bounding box in an image summary.\n image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),\n bbox_for_draw)\n tf.image_summary(\'images_with_box\', image_with_box)\n\n # Employ the bounding box to distort the image.\n distorted_image = tf.slice(image, begin, size)\n\nNote that if no bounding box information is available, setting\n`use_image_if_no_bounding_boxes = true` will assume there is a single implicit\nbounding box covering the whole image. If `use_image_if_no_bounding_boxes` is\nfalse and no bounding boxes are supplied, an error is raised." is_stateful: true } op { @@ -7482,7 +7494,7 @@ op { description: "If True, the assignment will be protected by a lock;\notherwise the behavior is undefined, but may exhibit less contention." } summary: "Applies sparse updates to a variable reference." - description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] = updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] = updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nIf `indices` contains duplicate entries, lexicographically later entries\noverride earlier entries.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterUpdate.png\" alt>\n</div>" + description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] = updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] = updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nIf values in `ref` is to be updated more than once, because there are\nduplicate entires in `indices`, the order at which the updates happen\nfor each value is undefined.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterUpdate.png\" alt>\n</div>" } op { name: "SegmentMax" diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index a5a3d14a07..3bec698acc 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -182,8 +182,9 @@ This operation computes This operation outputs `ref` after the update is done. This makes it easier to chain operations that need to use the reset value. -If `indices` contains duplicate entries, lexicographically later entries -override earlier entries. +If values in `ref` is to be updated more than once, because there are +duplicate entires in `indices`, the order at which the updates happen +for each value is undefined. Requires `updates.shape = indices.shape + ref.shape[1:]`. diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 9d069ab4c6..a89a0a4af7 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -3,9 +3,9 @@ load("//google/protobuf:protobuf.bzl", "cc_proto_library") load("//google/protobuf:protobuf.bzl", "py_proto_library") -# configure may change the following lines. -CUDA_VERSION = '7.0' -CUDNN_VERSION = '6.5' +# configure may change the following lines to '.X.Y' or similar +CUDA_VERSION = '' +CUDNN_VERSION = '' # Appends a suffix to a list of deps. def tf_deps(deps, suffix): diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD index c6dccc06ff..4b8088fde8 100644 --- a/tensorflow/core/platform/default/build_config/BUILD +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -75,7 +75,7 @@ filegroup( cc_library( name = "cuda", data = [ - "//third_party/gpus/cuda:lib64/libcudart.so." + tf_get_cuda_version(), + "//third_party/gpus/cuda:lib64/libcudart.so" + tf_get_cuda_version(), ], linkopts = [ "-Wl,-rpath,third_party/gpus/cuda/lib64", diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h index d8ba37babc..18395f3292 100644 --- a/tensorflow/core/platform/default/mutex.h +++ b/tensorflow/core/platform/default/mutex.h @@ -22,6 +22,7 @@ limitations under the License. #include <chrono> #include <condition_variable> #include <mutex> +#include "tensorflow/core/platform/default/thread_annotations.h" namespace tensorflow { @@ -29,16 +30,24 @@ enum LinkerInitialized { LINKER_INITIALIZED }; // A class that wraps around the std::mutex implementation, only adding an // additional LinkerInitialized constructor interface. -class mutex : public std::mutex { +class LOCKABLE mutex : public std::mutex { public: mutex() {} // The default implementation of std::mutex is safe to use after the linker // initializations explicit mutex(LinkerInitialized x) {} + + void lock() ACQUIRE() { std::mutex::lock(); } + void unlock() RELEASE() { std::mutex::unlock(); } +}; + +class SCOPED_LOCKABLE mutex_lock : public std::unique_lock<std::mutex> { + public: + mutex_lock(class mutex& m) ACQUIRE(m) : std::unique_lock<std::mutex>(m) {} + ~mutex_lock() RELEASE() {} }; using std::condition_variable; -typedef std::unique_lock<std::mutex> mutex_lock; inline ConditionResult WaitForMilliseconds(mutex_lock* mu, condition_variable* cv, int64 ms) { diff --git a/tensorflow/core/platform/default/thread_annotations.h b/tensorflow/core/platform/default/thread_annotations.h index d8a9253926..46143b2ea3 100644 --- a/tensorflow/core/platform/default/thread_annotations.h +++ b/tensorflow/core/platform/default/thread_annotations.h @@ -73,6 +73,15 @@ limitations under the License. #define ACQUIRED_BEFORE(...) \ THREAD_ANNOTATION_ATTRIBUTE__(acquired_before(__VA_ARGS__)) +#define ACQUIRE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquire_capability(__VA_ARGS__)) + +#define ACQUIRE_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquire_shared_capability(__VA_ARGS__)) + +#define RELEASE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(release_capability(__VA_ARGS__)) + // Document a function that expects a mutex to be held prior to entry. // The mutex is expected to be held both on entry to and exit from the // function. diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 4c0a9b8d70..9d46087814 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -19,7 +19,7 @@ limitations under the License. // TensorFlow uses semantic versioning, see http://semver.org/. #define TF_MAJOR_VERSION 0 -#define TF_MINOR_VERSION 6 +#define TF_MINOR_VERSION 7 #define TF_PATCH_VERSION 0 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 4e1dfb8c4f..c98207b6f0 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -20,6 +20,8 @@ limitations under the License. #include <algorithm> +#include "tensorflow/core/platform/types.h" + #define CUDA_1D_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ i += blockDim.x * gridDim.x) @@ -69,6 +71,58 @@ __device__ __host__ inline T ldg(const T* address) { #endif } +// CUDA provides atomic ops, but not for all types. We provide wrappers +// for some ops and provide implementation for all reasonable types. +#define CUDA_ATOMIC_WRAPPER(op, T) \ + __device__ __forceinline__ T CudaAtomic##op(T* address, T val) + +#define USE_CUDA_ATOMIC(op, T) \ + CUDA_ATOMIC_WRAPPER(op, T) { \ + return atomic##op(address, val); \ + } + +// For atomicAdd. +USE_CUDA_ATOMIC(Add, int32); +USE_CUDA_ATOMIC(Add, uint32); +USE_CUDA_ATOMIC(Add, uint64); +USE_CUDA_ATOMIC(Add, float); + +// Custom implementation of atomicAdd for double. +// This implementation is copied from CUDA manual. +CUDA_ATOMIC_WRAPPER(Add, double) { + uint64* address_as_ull = (uint64*)address; + uint64 old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN + } while (assumed != old); + + return __longlong_as_double(old); +} + +// For atomicSub. + +// Custom implementation for sub by just negating the value. +#define WRAPPED_ATOMIC_SUB(T) \ + CUDA_ATOMIC_WRAPPER(Sub, T) { \ + return CudaAtomicAdd(address, -val); \ + } + +WRAPPED_ATOMIC_SUB(uint64); +WRAPPED_ATOMIC_SUB(int32); +WRAPPED_ATOMIC_SUB(uint32); +WRAPPED_ATOMIC_SUB(float); +WRAPPED_ATOMIC_SUB(double); + +#undef WRAPPED_ATOMIC_SUB + +#undef USE_CUDA_ATOMIC +#undef CUDA_ATOMIC_WRAPPER + } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc index cfdbb07cd5..6b2a526834 100644 --- a/tensorflow/core/util/events_writer.cc +++ b/tensorflow/core/util/events_writer.cc @@ -115,7 +115,7 @@ bool EventsWriter::Flush() { // recordio_writer_->Sync() can return true even if the underlying // file has been deleted. EventWriter.FileDeletionBeforeWriting // demonstrates this and will fail if the FileHasDisappeared() - // conditon is removed. + // condition is removed. // Also, we deliberately attempt to Sync() before checking for a // disappearing file, in case for some file system File::Exists() is // false after File::Open() but before File::Sync(). |