From 2701ab910894da95c25bcf6f2e30f0a6c2c20552 Mon Sep 17 00:00:00 2001 From: Florian Courtial Date: Wed, 30 May 2018 23:36:52 +0200 Subject: Add C++ SegmentSum gradient operation. --- tensorflow/cc/gradients/math_grad.cc | 20 ++++++++++++++++++++ tensorflow/cc/gradients/math_grad_test.cc | 10 ++++++++++ 2 files changed, 30 insertions(+) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 52c177212a..62404fff09 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -1006,6 +1006,26 @@ Status ProdGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Prod", ProdGrad); +Status SegmentSumGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // The SegmentSum operation sums segments of the Tensor that have the same + // index in the segment_ids parameter. + // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1] + // will produce [2 + 3 + 4, 5] = [9, 5] + // The gradient that will flow back to the gather operation will look like + // [x1, x2], it will have the same shape as the output of the SegmentSum + // operation. The differentiation step of the SegmentSum operation just + // broadcast the gradient in order to retrieve the z's shape. + // dy/dz = [x1, x1, x1, x2] + grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1))); + + // stop propagation along segment_ids + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index fd7b6fe662..acc100d144 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -41,6 +41,7 @@ using ops::Mul; using ops::Placeholder; using ops::Pow; using ops::Prod; +using ops::SegmentSum; using ops::RealDiv; using ops::SquaredDifference; using ops::Sub; @@ -902,5 +903,14 @@ TEST_F(NaryGradTest, Prod) { RunTest({x}, {x_shape}, {y}, {y_shape}); } +TEST_F(NaryGradTest, SegmentSum) { + TensorShape x_shape({3, 4}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = SegmentSum(scope_, x, {0, 0, 1}); + // the sum is always on the first dimension + TensorShape y_shape({2, 4}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + } // namespace } // namespace tensorflow -- cgit v1.2.3 From bdd8bf316e4ab7d699127d192d30eb614a158462 Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Wed, 11 Jul 2018 20:24:10 +0800 Subject: Remove all references of windows_msvc config_setting --- tensorflow/BUILD | 9 --------- tensorflow/cc/BUILD | 1 - tensorflow/core/BUILD | 6 ------ tensorflow/core/platform/default/build_config.bzl | 3 --- tensorflow/java/BUILD | 1 - tensorflow/tensorflow.bzl | 9 --------- tensorflow/tools/pip_package/BUILD | 1 - tensorflow/tools/proto_text/BUILD | 1 - third_party/curl.BUILD | 14 -------------- third_party/farmhash.BUILD | 8 -------- third_party/flatbuffers/flatbuffers.BUILD | 3 +-- third_party/gif.BUILD | 9 --------- third_party/jpeg/jpeg.BUILD | 8 -------- third_party/lmdb.BUILD | 6 ------ third_party/nasm.BUILD | 9 --------- third_party/snappy.BUILD | 1 - third_party/sqlite.BUILD | 8 +++----- third_party/swig.BUILD | 6 ------ third_party/zlib.BUILD | 1 - 19 files changed, 4 insertions(+), 100 deletions(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 51eea94847..314db599ca 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -115,12 +115,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, - visibility = ["//visibility:public"], -) - config_setting( name = "no_tensorflow_py_deps", define_values = {"no_tensorflow_py_deps": "true"}, @@ -484,7 +478,6 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_framework_version_script.lds)", @@ -526,7 +519,6 @@ tf_cc_shared_object( "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file @@ -551,7 +543,6 @@ tf_cc_shared_object( "$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a98f0b00b2..9533e96f13 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -595,7 +595,6 @@ tf_cc_binary( copts = tf_copts(), linkopts = select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//tensorflow:darwin": [ "-lm", "-lpthread", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 6fa557d53f..d979b55580 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -861,7 +861,6 @@ tf_cuda_library( "util/work_sharder.h", ] + select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "util/memmapped_file_system.h", "util/memmapped_file_system_writer.h", @@ -2025,7 +2024,6 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-ldl", "-lpthread", @@ -2114,7 +2112,6 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": ["-ldl"], }), deps = [ @@ -2139,7 +2136,6 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": ["-ldl"], }), deps = [ @@ -2471,7 +2467,6 @@ tf_cuda_library( ], ) + select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "util/memmapped_file_system.cc", "util/memmapped_file_system_writer.cc", @@ -2482,7 +2477,6 @@ tf_cuda_library( linkopts = select({ "//tensorflow:freebsd": ["-lm"], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": ["-ldl", "-lm"], }), deps = [ diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 28891320c4..fb4ee1c33c 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -467,7 +467,6 @@ def tf_platform_srcs(files): return select({ "//tensorflow:windows" : native.glob(windows_set), - "//tensorflow:windows_msvc" : native.glob(windows_set), "//conditions:default" : native.glob(posix_set), }) @@ -479,7 +478,6 @@ def tf_additional_lib_hdrs(exclude = []): ], exclude = exclude) return select({ "//tensorflow:windows" : windows_hdrs, - "//tensorflow:windows_msvc" : windows_hdrs, "//conditions:default" : native.glob([ "platform/default/*.h", "platform/posix/*.h", @@ -494,7 +492,6 @@ def tf_additional_lib_srcs(exclude = []): ], exclude = exclude) return select({ "//tensorflow:windows" : windows_srcs, - "//tensorflow:windows_msvc" : windows_srcs, "//conditions:default" : native.glob([ "platform/default/*.cc", "platform/posix/*.cc", diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index d1108f251c..b2b7ee3fa5 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -345,7 +345,6 @@ tf_cc_binary( "$(location {})".format(LINKER_EXPORTED_SYMBOLS), ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-s", diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index e4241667ad..8e3323ded5 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -137,14 +137,12 @@ def if_not_mobile(a): def if_not_windows(a): return select({ clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": a, }) def if_windows(a): return select({ clean_dep("//tensorflow:windows"): a, - clean_dep("//tensorflow:windows_msvc"): a, "//conditions:default": [], }) @@ -226,7 +224,6 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False): clean_dep("//tensorflow:android"): android_copts, clean_dep("//tensorflow:darwin"): [], clean_dep("//tensorflow:windows"): get_win_copts(is_external), - clean_dep("//tensorflow:windows_msvc"): get_win_copts(is_external), clean_dep("//tensorflow:ios"): ["-std=c++11"], "//conditions:default": ["-pthread"] })) @@ -286,7 +283,6 @@ def _rpath_linkopts(name): "-Wl,%s" % (_make_search_paths("@loader_path", levels_to_root),), ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root),), ], @@ -656,7 +652,6 @@ def tf_cc_test(name, "-pie", ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], clean_dep("//tensorflow:darwin"): [ "-lm", ], @@ -838,7 +833,6 @@ def tf_cc_test_mkl(srcs, "-pie", ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-lpthread", "-lm" @@ -1351,7 +1345,6 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): "-lm", ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], clean_dep("//tensorflow:darwin"): [], }),) @@ -1461,7 +1454,6 @@ def tf_py_wrap_cc(name, "$(location %s.lds)"%vscriptname, ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-Wl,--version-script", "$(location %s.lds)"%vscriptname, @@ -1472,7 +1464,6 @@ def tf_py_wrap_cc(name, "%s.lds"%vscriptname, ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "%s.lds"%vscriptname, ] diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index c9d53f46c3..91d574b295 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -176,7 +176,6 @@ sh_binary( srcs = ["build_pip_package.sh"], data = select({ "//tensorflow:windows": [":simple_console_for_windows"], - "//tensorflow:windows_msvc": [":simple_console_for_windows"], "//conditions:default": COMMON_PIP_DEPS + [ ":simple_console", "//tensorflow/contrib/lite/python:interpreter_test_data", diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD index 31e8fb9120..fc2c041b6c 100644 --- a/tensorflow/tools/proto_text/BUILD +++ b/tensorflow/tools/proto_text/BUILD @@ -49,7 +49,6 @@ cc_library( copts = if_ios(["-DGOOGLE_LOGGING"]), linkopts = select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//tensorflow:darwin": [ "-lm", "-lpthread", diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index 1638b72161..c93fac6549 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -243,7 +243,6 @@ cc_library( "lib/vtls/darwinssl.c", ], "@org_tensorflow//tensorflow:windows": CURL_WIN_SRCS, - "@org_tensorflow//tensorflow:windows_msvc": CURL_WIN_SRCS, "//conditions:default": [ "lib/vtls/openssl.c", ], @@ -260,7 +259,6 @@ cc_library( ], copts = select({ "@org_tensorflow//tensorflow:windows": CURL_WIN_COPTS, - "@org_tensorflow//tensorflow:windows_msvc": CURL_WIN_COPTS, "//conditions:default": [ "-Iexternal/curl/lib", "-D_GNU_SOURCE", @@ -280,10 +278,6 @@ cc_library( # See curl.h for discussion of write size and Windows "/DCURL_MAX_WRITE_SIZE=16384", ], - "@org_tensorflow//tensorflow:windows_msvc": [ - # See curl.h for discussion of write size and Windows - "/DCURL_MAX_WRITE_SIZE=16384", - ], "//conditions:default": [ "-DCURL_MAX_WRITE_SIZE=65536", ], @@ -307,12 +301,6 @@ cc_library( "-DEFAULTLIB:crypt32.lib", "-DEFAULTLIB:Normaliz.lib", ], - "@org_tensorflow//tensorflow:windows_msvc": [ - "-DEFAULTLIB:ws2_32.lib", - "-DEFAULTLIB:advapi32.lib", - "-DEFAULTLIB:crypt32.lib", - "-DEFAULTLIB:Normaliz.lib", - ], "//conditions:default": [ "-lrt", ], @@ -323,7 +311,6 @@ cc_library( ] + select({ "@org_tensorflow//tensorflow:ios": [], "@org_tensorflow//tensorflow:windows": [], - "@org_tensorflow//tensorflow:windows_msvc": [], "//conditions:default": [ "@boringssl//:ssl", ], @@ -426,7 +413,6 @@ cc_binary( ], copts = select({ "@org_tensorflow//tensorflow:windows": CURL_BIN_WIN_COPTS, - "@org_tensorflow//tensorflow:windows_msvc": CURL_BIN_WIN_COPTS, "//conditions:default": [ "-Iexternal/curl/lib", "-D_GNU_SOURCE", diff --git a/third_party/farmhash.BUILD b/third_party/farmhash.BUILD index a51e1511c1..4b8464684a 100644 --- a/third_party/farmhash.BUILD +++ b/third_party/farmhash.BUILD @@ -2,13 +2,6 @@ licenses(["notice"]) # MIT exports_files(["COPYING"]) -config_setting( - name = "windows_msvc", - values = { - "cpu": "x64_windows_msvc", - }, -) - config_setting( name = "windows", values = { @@ -23,7 +16,6 @@ cc_library( # Disable __builtin_expect support on Windows copts = select({ ":windows": ["/DFARMHASH_OPTIONAL_BUILTIN_EXPECT"], - ":windows_msvc": ["/DFARMHASH_OPTIONAL_BUILTIN_EXPECT"], "//conditions:default": [], }), includes = ["src/."], diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD index 3a19d28667..4a3701e893 100644 --- a/third_party/flatbuffers/flatbuffers.BUILD +++ b/third_party/flatbuffers/flatbuffers.BUILD @@ -18,8 +18,7 @@ config_setting( ) FLATBUFFERS_COPTS = select({ - "@bazel_tools//src:windows": [], - "@bazel_tools//src:windows_msvc": [], + ":windows": [], "//conditions:default": ["-Wno-implicit-fallthrough", "-fexceptions"], }) diff --git a/third_party/gif.BUILD b/third_party/gif.BUILD index 78fbd6c0e0..cbe730fe10 100644 --- a/third_party/gif.BUILD +++ b/third_party/gif.BUILD @@ -21,7 +21,6 @@ cc_library( ], hdrs = ["lib/gif_lib.h"], defines = select({ - #"@org_tensorflow//tensorflow:android": [ ":android": [ "S_IREAD=S_IRUSR", "S_IWRITE=S_IWUSR", @@ -33,7 +32,6 @@ cc_library( visibility = ["//visibility:public"], deps = select({ ":windows": [":windows_polyfill"], - ":windows_msvc": [":windows_polyfill"], "//conditions:default": [], }), ) @@ -50,13 +48,6 @@ genrule( cmd = "touch $@", ) -config_setting( - name = "windows_msvc", - values = { - "cpu": "x64_windows_msvc", - }, -) - config_setting( name = "windows", values = { diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD index 663a218733..99431de62e 100644 --- a/third_party/jpeg/jpeg.BUILD +++ b/third_party/jpeg/jpeg.BUILD @@ -22,7 +22,6 @@ libjpegturbo_copts = select({ "-w", ], ":windows": WIN_COPTS, - ":windows_msvc": WIN_COPTS, "//conditions:default": [ "-O3", "-w", @@ -423,7 +422,6 @@ genrule( outs = ["jconfig.h"], cmd = select({ ":windows": "cp $(location jconfig_win.h) $@", - ":windows_msvc": "cp $(location jconfig_win.h) $@", ":k8": "cp $(location jconfig_nowin_simd.h) $@", ":armeabi-v7a": "cp $(location jconfig_nowin_simd.h) $@", ":arm64-v8a": "cp $(location jconfig_nowin_simd.h) $@", @@ -441,7 +439,6 @@ genrule( outs = ["jconfigint.h"], cmd = select({ ":windows": "cp $(location jconfigint_win.h) $@", - ":windows_msvc": "cp $(location jconfigint_win.h) $@", "//conditions:default": "cp $(location jconfigint_nowin.h) $@", }), ) @@ -541,11 +538,6 @@ config_setting( values = {"cpu": "x64_windows"}, ) -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, -) - config_setting( name = "linux_ppc64le", values = {"cpu": "ppc"}, diff --git a/third_party/lmdb.BUILD b/third_party/lmdb.BUILD index 9b3e1d97c8..f36a698ee3 100644 --- a/third_party/lmdb.BUILD +++ b/third_party/lmdb.BUILD @@ -20,7 +20,6 @@ cc_library( ], linkopts = select({ ":windows": ["-DEFAULTLIB:advapi32.lib"], # InitializeSecurityDescriptor, SetSecurityDescriptorDacl - ":windows_msvc": ["-DEFAULTLIB:advapi32.lib"], "//conditions:default": ["-lpthread"], }), visibility = ["//visibility:public"], @@ -30,8 +29,3 @@ config_setting( name = "windows", values = {"cpu": "x64_windows"}, ) - -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, -) diff --git a/third_party/nasm.BUILD b/third_party/nasm.BUILD index 89330eac54..2b877883b9 100644 --- a/third_party/nasm.BUILD +++ b/third_party/nasm.BUILD @@ -142,7 +142,6 @@ cc_binary( ], copts = select({ ":windows": [], - ":windows_msvc": [], "//conditions:default": [ "-w", "-std=c99", @@ -150,7 +149,6 @@ cc_binary( }), defines = select({ ":windows": [], - ":windows_msvc": [], "//conditions:default": [ "HAVE_SNPRINTF", "HAVE_SYS_TYPES_H", @@ -159,13 +157,6 @@ cc_binary( visibility = ["@jpeg//:__pkg__"], ) -config_setting( - name = "windows_msvc", - values = { - "cpu": "x64_windows_msvc", - }, -) - config_setting( name = "windows", values = { diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD index 58120ccdf2..d93f030769 100644 --- a/third_party/snappy.BUILD +++ b/third_party/snappy.BUILD @@ -20,7 +20,6 @@ cc_library( hdrs = ["snappy.h"], copts = ["-DHAVE_CONFIG_H"] + select({ "@org_tensorflow//tensorflow:windows": [], - "@org_tensorflow//tensorflow:windows_msvc": [], "//conditions:default": [ "-fno-exceptions", "-Wno-sign-compare", diff --git a/third_party/sqlite.BUILD b/third_party/sqlite.BUILD index 2876f305f1..8b876fb56f 100644 --- a/third_party/sqlite.BUILD +++ b/third_party/sqlite.BUILD @@ -4,7 +4,6 @@ licenses(["unencumbered"]) # Public Domain SQLITE_COPTS = [ - "-Os", "-DSQLITE_ENABLE_JSON1", "-DHAVE_DECL_STRERROR_R=1", "-DHAVE_STDINT_H=1", @@ -15,15 +14,14 @@ SQLITE_COPTS = [ "@org_tensorflow//tensorflow:windows": [ "-DSQLITE_MAX_TRIGGER_DEPTH=100", ], - "@org_tensorflow//tensorflow:windows_msvc": [ - "-DSQLITE_MAX_TRIGGER_DEPTH=100", - ], "@org_tensorflow//tensorflow:darwin": [ + "-Os", "-DHAVE_GMTIME_R=1", "-DHAVE_LOCALTIME_R=1", "-DHAVE_USLEEP=1", ], "//conditions:default": [ + "-Os", "-DHAVE_FDATASYNC=1", "-DHAVE_GMTIME_R=1", "-DHAVE_LOCALTIME_R=1", @@ -48,7 +46,7 @@ cc_library( "SQLITE_OMIT_DEPRECATED", ], linkopts = select({ - "@org_tensorflow//tensorflow:windows_msvc": [], + "@org_tensorflow//tensorflow:windows": [], "//conditions:default": [ "-ldl", "-lpthread", diff --git a/third_party/swig.BUILD b/third_party/swig.BUILD index f2f647401b..59a3d9e671 100644 --- a/third_party/swig.BUILD +++ b/third_party/swig.BUILD @@ -71,7 +71,6 @@ cc_binary( ], copts = ["$(STACK_FRAME_UNLIMITED)"] + select({ ":windows": [], - ":windows_msvc": [], "//conditions:default": [ "-Wno-parentheses", "-Wno-unused-variable", @@ -331,11 +330,6 @@ genrule( " $< >$@", ) -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, -) - config_setting( name = "windows", values = {"cpu": "x64_windows"}, diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD index e8048dd98a..33694eaaae 100644 --- a/third_party/zlib.BUILD +++ b/third_party/zlib.BUILD @@ -34,7 +34,6 @@ cc_library( hdrs = ["zlib.h"], copts = select({ "@org_tensorflow//tensorflow:windows": [], - "@org_tensorflow//tensorflow:windows_msvc": [], "//conditions:default": [ "-Wno-shift-negative-value", "-DZ_HAVE_UNISTD_H", -- cgit v1.2.3 From b462815ccc279c79a5592a0a2d4492718da02c5d Mon Sep 17 00:00:00 2001 From: KB Sriram Date: Wed, 18 Jul 2018 08:31:46 -0700 Subject: Add gradient for tensorflow::ops::Fill See https://github.com/tensorflow/tensorflow/issues/20926 --- tensorflow/cc/gradients/array_grad.cc | 18 ++++++++++++++++++ tensorflow/cc/gradients/array_grad_test.cc | 8 ++++++++ 2 files changed, 26 insertions(+) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index b353accddc..e9173227aa 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Split", SplitGrad); +Status FillGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = fill(fill_shape, x) + // No gradient returned for the fill_shape argument. + grad_outputs->push_back(NoGradient()); + // The gradient for x (which must be a scalar) is just the sum of + // all the gradients from the shape it fills. + // We use ReduceSum to implement this, which needs an argument providing + // the indices of all the dimensions of the incoming gradient. + // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))]) + auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]), + Const(scope, 1)); + grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims)); + return scope.status(); +} +REGISTER_GRADIENT_OP("Fill", FillGrad); + Status DiagGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index d09275b648..f41de3dc20 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) { RunTest({x}, {x_shape}, y.output, {y_shape, y_shape}); } +TEST_F(ArrayGradTest, FillGrad) { + TensorShape x_shape({}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + TensorShape y_shape({2, 5, 3}); + auto y = Fill(scope_, {2, 5, 3}, x); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(ArrayGradTest, DiagGrad) { TensorShape x_shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); -- cgit v1.2.3 From e3ef4b19627853e694a51ea0b1465a060caa8952 Mon Sep 17 00:00:00 2001 From: Peng Yu Date: Tue, 31 Jul 2018 16:02:15 -0400 Subject: Add extra log for failing to load variable purpose --- tensorflow/cc/saved_model/loader.cc | 1 + 1 file changed, 1 insertion(+) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 98be66a6ad..ca8cb8a9be 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -169,6 +169,7 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, const string variables_index_path = io::JoinPath( variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { + LOG(INFO) << "Falied to restore variables from " << variables_index_path; LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " "were restored."; return Status::OK(); -- cgit v1.2.3 From 19f86cbeadb7014b9940be1f6921776ef9d2f986 Mon Sep 17 00:00:00 2001 From: Peng Yu Date: Tue, 31 Jul 2018 17:30:06 -0400 Subject: address comments --- tensorflow/cc/saved_model/loader.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index ca8cb8a9be..22ad5e0162 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -169,9 +169,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, const string variables_index_path = io::JoinPath( variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { - LOG(INFO) << "Falied to restore variables from " << variables_index_path; LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " - "were restored."; + "were restored. Failed to restore from " << variables_index_path; return Status::OK(); } const string variables_path = -- cgit v1.2.3 From 234c9a3fd3a430f25c604c42b457b72842375f30 Mon Sep 17 00:00:00 2001 From: Jonathan Hseu Date: Tue, 31 Jul 2018 14:50:55 -0700 Subject: Slightly modify error message --- tensorflow/cc/saved_model/loader.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 22ad5e0162..a5eae97f6e 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -170,7 +170,7 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " - "were restored. Failed to restore from " << variables_index_path; + "were restored. File does not exist: " << variables_index_path; return Status::OK(); } const string variables_path = -- cgit v1.2.3 From 16c2b25e7e23fb1ac373cd2162ce18ca71e9b0a8 Mon Sep 17 00:00:00 2001 From: Peng Yu Date: Tue, 7 Aug 2018 16:04:41 -0400 Subject: make the line shorter --- tensorflow/cc/saved_model/loader.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index a5eae97f6e..3830416159 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -170,7 +170,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " - "were restored. File does not exist: " << variables_index_path; + "were restored. File does not exist: " + << variables_index_path; return Status::OK(); } const string variables_path = -- cgit v1.2.3