aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xconfigure6
-rw-r--r--tensorflow/BUILD15
-rw-r--r--tensorflow/c/BUILD8
-rw-r--r--tensorflow/c/exported_symbols.lds1
-rw-r--r--tensorflow/c/version_script.lds9
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc30
-rw-r--r--tensorflow/compiler/tests/BUILD14
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc179
-rw-r--r--tensorflow/compiler/tests/spacetobatch_op_test.py266
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py5
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc186
-rw-r--r--tensorflow/compiler/tf2xla/kernels/elu_op.cc65
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc190
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc25
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h3
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc163
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h10
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc161
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h3
-rw-r--r--tensorflow/compiler/xla/service/liveness_util.cc11
-rw-r--r--tensorflow/compiler/xla/service/liveness_util_test.cc68
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc31
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc5
-rw-r--r--tensorflow/compiler/xla/tools/BUILD46
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc139
-rw-r--r--tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc204
-rw-r--r--tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h59
-rw-r--r--tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc154
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py42
-rw-r--r--tensorflow/contrib/distributions/__init__.py7
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py20
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py3
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py18
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py11
-rw-r--r--tensorflow/contrib/keras/python/keras/backend.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py33
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py24
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/mnist.py3
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py11
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/model_fn.py18
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py56
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/run_config.py11
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py26
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py23
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/generator_io.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/graph_io.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py14
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/gc_test.py4
-rw-r--r--tensorflow/contrib/linalg/BUILD12
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py23
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_util.py47
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py236
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py30
-rw-r--r--tensorflow/contrib/session_bundle/gc_test.py4
-rw-r--r--tensorflow/core/BUILD7
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc12
-rw-r--r--tensorflow/core/common_runtime/direct_session.h2
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.h6
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc10
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h3
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc7
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
-rw-r--r--tensorflow/core/framework/op_kernel.cc17
-rw-r--r--tensorflow/core/framework/op_kernel.h4
-rw-r--r--tensorflow/core/framework/shape_inference.cc22
-rw-r--r--tensorflow/core/framework/shape_inference.h14
-rw-r--r--tensorflow/core/graph/graph_constructor.cc35
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc171
-rw-r--r--tensorflow/core/grappler/BUILD11
-rw-r--r--tensorflow/core/grappler/devices.cc16
-rw-r--r--tensorflow/core/grappler/devices.h4
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc3
-rw-r--r--tensorflow/core/grappler/op_types.cc27
-rw-r--r--tensorflow/core/grappler/op_types.h29
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD35
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel.cc260
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel.h63
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel_test.cc125
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc12
-rw-r--r--tensorflow/core/kernels/batchtospace_op.cc4
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.cc49
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc34
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_test.cc51
-rw-r--r--tensorflow/core/kernels/gather_functor.cc1
-rw-r--r--tensorflow/core/kernels/gather_functor_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/gather_op.cc1
-rw-r--r--tensorflow/core/kernels/gather_op_test.cc34
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc113
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc1
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc4
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.cu.cc9
-rw-r--r--tensorflow/core/kernels/quantize_op.cc1
-rw-r--r--tensorflow/core/kernels/quantize_op_test.cc44
-rw-r--r--tensorflow/core/kernels/random_op.cc30
-rw-r--r--tensorflow/core/kernels/random_poisson_op.cc22
-rw-r--r--tensorflow/core/kernels/save_restore_v2_ops.cc6
-rw-r--r--tensorflow/core/kernels/spacetobatch_op.cc4
-rw-r--r--tensorflow/core/lib/random/philox_random.h14
-rw-r--r--tensorflow/core/ops/array_ops.cc21
-rw-r--r--tensorflow/core/ops/array_ops_test.cc12
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc2
-rw-r--r--tensorflow/core/ops/random_ops.cc21
-rw-r--r--tensorflow/core/ops/set_ops.cc4
-rw-r--r--tensorflow/core/platform/cpu_info.cc2
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto9
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h22
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc18
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.h4
-rw-r--r--tensorflow/docs_src/get_started/get_started.md4
-rw-r--r--tensorflow/docs_src/programmers_guide/debugger.md2
-rw-r--r--tensorflow/docs_src/tutorials/recurrent.md19
-rw-r--r--tensorflow/python/client/session.py3
-rw-r--r--tensorflow/python/client/tf_session_helper.cc2
-rw-r--r--tensorflow/python/debug/BUILD2
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli.py166
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py91
-rw-r--r--tensorflow/python/debug/cli/cli_shared.py12
-rw-r--r--tensorflow/python/debug/cli/curses_ui.py84
-rw-r--r--tensorflow/python/debug/cli/stepper_cli.py20
-rw-r--r--tensorflow/python/debug/lib/source_utils.py145
-rw-r--r--tensorflow/python/debug/lib/source_utils_test.py158
-rw-r--r--tensorflow/python/estimator/BUILD16
-rw-r--r--tensorflow/python/estimator/estimator.py20
-rw-r--r--tensorflow/python/estimator/estimator_test.py38
-rw-r--r--tensorflow/python/estimator/run_config.py4
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py103
-rw-r--r--tensorflow/python/kernel_tests/linalg_ops_test.py2
-rw-r--r--tensorflow/python/layers/normalization.py273
-rw-r--r--tensorflow/python/layers/normalization_test.py59
-rw-r--r--tensorflow/python/ops/array_ops.py13
-rw-r--r--tensorflow/python/ops/nn_ops.py6
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py16
-rw-r--r--tensorflow/python/ops/rnn.py68
-rw-r--r--tensorflow/python/ops/session_ops.py14
-rw-r--r--tensorflow/python/ops/state_ops.py9
-rw-r--r--tensorflow/python/ops/variable_scope.py2
-rw-r--r--tensorflow/python/training/adam.py33
-rw-r--r--tensorflow/python/training/adam_test.py16
-rw-r--r--tensorflow/python/training/device_setter.py2
-rw-r--r--tensorflow/python/training/device_setter_test.py7
-rw-r--r--tensorflow/python/training/monitored_session.py2
-rw-r--r--tensorflow/tensorboard/backend/application.py48
-rw-r--r--tensorflow/tensorboard/backend/application_test.py65
-rw-r--r--tensorflow/tensorboard/backend/event_processing/event_accumulator.py8
-rw-r--r--tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py10
-rw-r--r--tensorflow/tensorboard/backend/event_processing/event_multiplexer.py15
-rw-r--r--tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py47
-rw-r--r--tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html9
-rw-r--r--tensorflow/tensorboard/http_api.md7
-rw-r--r--tensorflow/tensorboard/package.json2
-rw-r--r--tensorflow/tensorboard/plugins/base_plugin.py12
-rw-r--r--tensorflow/tensorboard/plugins/debugger/debugger_plugin.py15
-rw-r--r--tensorflow/tensorboard/plugins/debugger/debugger_plugin_test.py13
-rw-r--r--tensorflow/tensorboard/plugins/projector/projector_plugin.py107
-rw-r--r--tensorflow/tensorboard/plugins/projector/projector_plugin_test.py332
-rw-r--r--tensorflow/tensorboard/plugins/text/text_plugin.py10
-rw-r--r--tensorflow/tensorboard/plugins/text/text_plugin_test.py14
-rw-r--r--tensorflow/tensorflow.bzl142
-rw-r--r--tensorflow/tools/dist_test/server/Dockerfile.test8
-rw-r--r--tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb2
-rw-r--r--tensorflow/tools/graph_transforms/quantize_weights.cc4
-rw-r--r--tensorflow/tools/graph_transforms/quantize_weights_test.cc120
-rw-r--r--tensorflow/workspace.bzl37
-rw-r--r--third_party/fft2d/BUILD30
-rw-r--r--third_party/fft2d/LICENSE3
-rw-r--r--third_party/fft2d/fft.h36
-rw-r--r--third_party/fft2d/fft2d.BUILD36
176 files changed, 5879 insertions, 879 deletions
diff --git a/configure b/configure
index 6360641be2..48a4594da6 100755
--- a/configure
+++ b/configure
@@ -56,7 +56,7 @@ rm -f .tf_configure.bazelrc
touch .tf_configure.bazelrc
touch .bazelrc
sed_hyphen_i "/tf_configure/d" .bazelrc
-echo "import .tf_configure.bazelrc" >> .bazelrc
+echo "import %workspace%/.tf_configure.bazelrc" >> .bazelrc
# Delete any leftover BUILD files from the Makefile build, which would interfere
# with Bazel parsing.
@@ -284,6 +284,7 @@ export TF_NEED_CUDA
write_action_env_to_bazelrc "TF_NEED_CUDA" "$TF_NEED_CUDA"
export TF_NEED_OPENCL
+write_action_env_to_bazelrc "TF_NEED_OPENCL" "$TF_NEED_OPENCL"
if [ "$TF_NEED_CUDA" == "1" ]; then
while [[ "$TF_CUDA_CLANG" == "" ]]; do
@@ -547,6 +548,7 @@ while true; do
fi
if [ -e "$HOST_CXX_COMPILER" ]; then
export HOST_CXX_COMPILER
+ write_action_env_to_bazelrc "HOST_CXX_COMPILER" "$HOST_CXX_COMPILER"
break
fi
echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2
@@ -570,6 +572,7 @@ while true; do
fi
if [ -e "$HOST_C_COMPILER" ]; then
export HOST_C_COMPILER
+ write_action_env_to_bazelrc "HOST_C_COMPILER" "$HOST_C_COMPILER"
break
fi
echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2
@@ -600,6 +603,7 @@ while true; do
if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then
export COMPUTECPP_TOOLKIT_PATH
+ write_action_env_to_bazelrc "COMPUTECPP_TOOLKIT_PATH" "$COMPUTECPP_TOOLKIT_PATH"
break
fi
echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found"
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index e437987112..b98be57ec0 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -351,9 +351,24 @@ filegroup(
# -------------------------------------------
cc_binary(
name = "libtensorflow.so",
+ linkopts = select({
+ "//tensorflow:darwin": [
+ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file
+ "//tensorflow/c:exported_symbols.lds",
+ ],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "-z defs",
+ "-s",
+ "-Wl,--version-script", # This line must be directly followed by the version_script.lds file
+ "//tensorflow/c:version_script.lds",
+ ],
+ }),
linkshared = 1,
deps = [
"//tensorflow/c:c_api",
+ "//tensorflow/c:exported_symbols.lds",
+ "//tensorflow/c:version_script.lds",
"//tensorflow/core:tensorflow",
],
)
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 0019dfeeb1..6e39deee63 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -45,6 +45,14 @@ tf_cuda_library(
}),
)
+exports_files(
+ [
+ "version_script.lds",
+ "exported_symbols.lds",
+ ],
+ visibility = ["//visibility:public"],
+)
+
tf_cuda_library(
name = "tf_status_helper",
srcs = ["tf_status_helper.cc"],
diff --git a/tensorflow/c/exported_symbols.lds b/tensorflow/c/exported_symbols.lds
new file mode 100644
index 0000000000..a14bdaa48b
--- /dev/null
+++ b/tensorflow/c/exported_symbols.lds
@@ -0,0 +1 @@
+_TF_*
diff --git a/tensorflow/c/version_script.lds b/tensorflow/c/version_script.lds
new file mode 100644
index 0000000000..455bd7362b
--- /dev/null
+++ b/tensorflow/c/version_script.lds
@@ -0,0 +1,9 @@
+VERS_1.0 {
+ # Export symbols in c_api.h.
+ global:
+ TF_*;
+
+ # Hide everything else.
+ local:
+ *;
+};
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index 85ef9560bb..59a45538a7 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -52,7 +52,8 @@ const char kUsageHeader[] =
"header file that gives access to the functionality in the object file.\n"
"A typical invocation looks like this:\n"
"\n"
- " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n"
+ " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt "
+ "--cpp_class=\"mynamespace::MyComputation\"\n"
"\n";
Status ReadProtoFile(const string& kind, const string& fname,
@@ -73,6 +74,9 @@ void ParseTensorId(const string& name, TensorId* id) {
Status Main(const MainFlags& flags) {
// Process config.
Config config;
+ if (flags.config.empty()) {
+ return errors::InvalidArgument("Must specify --config");
+ }
TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
@@ -85,6 +89,9 @@ Status Main(const MainFlags& flags) {
}
// Read and initialize the graph.
+ if (flags.graph.empty()) {
+ return errors::InvalidArgument("Must specify --graph");
+ }
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
std::unique_ptr<Graph> graph;
@@ -101,6 +108,9 @@ Status Main(const MainFlags& flags) {
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object,
StringPiece(obj.data(), obj.size())));
HeaderOpts header_opts;
+ if (flags.cpp_class.empty()) {
+ return errors::InvalidArgument("Must specify --cpp_class");
+ }
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name,
&header_opts.namespaces));
string header;
@@ -131,12 +141,16 @@ int main(int argc, char** argv) {
QCHECK(parsed_flags_ok) << "\n" << usage;
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
- QCHECK(argc == 1 && !flags.config.empty() &&
- (flags.dump_fetch_nodes ||
- (!flags.graph.empty() && !flags.entry_point.empty())))
- << "\n"
- << usage;
-
- TF_QCHECK_OK(tensorflow::tfcompile::Main(flags));
+ QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
+ "other than flags\n\n"
+ << usage;
+ tensorflow::Status status = tensorflow::tfcompile::Main(flags);
+ if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
+ std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
+ << usage;
+ return 1;
+ } else {
+ TF_QCHECK_OK(status);
+ }
return 0;
}
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 03e255e6b8..0592e3d4b1 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -306,6 +306,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "spacetobatch_op_test",
+ size = "medium",
+ srcs = ["spacetobatch_op_test.py"],
+ shard_count = 3,
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "ternary_ops_test",
size = "small",
srcs = ["ternary_ops_test.py"],
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 9efdaee7ab..7221a0a3c7 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -108,6 +108,12 @@ class BinaryOpsTest(XLATestCase):
expected=np.array([-75, -48, -21, 0], dtype=dtype))
self._testBinary(
+ gen_nn_ops._elu_grad,
+ np.array([1, 2, 3, 4, 5, 6], dtype=dtype),
+ np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype),
+ expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype))
+
+ self._testBinary(
gen_nn_ops._relu_grad,
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype),
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index a0cd905f17..7d91594db0 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -218,12 +218,11 @@ class OpTest : public ::testing::Test {
static constexpr int kDefaultMaxRank = 5;
static constexpr int64 kDefaultMaxDimensionSize = 20LL;
- // Returns a random dimension size.
+ // Returns a random dimension size, in the range [min, max).
int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize);
// Returns a random shape. The tensor has rank in the range [min_rank,
- // max_rank).
- // Each dimension has size [0, kDefaultMaxDimensionSize].
+ // max_rank). Each dimension has size [min_size, max_size).
std::vector<int64> RandomDims(int min_rank = 0,
int max_rank = kDefaultMaxRank,
int64 min_size = 0,
@@ -668,6 +667,9 @@ void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test";
return;
}
+ for (const Tensor& expected : expected_outputs) {
+ VLOG(1) << "Expected: " << expected.DebugString();
+ }
VLOG(1) << "Running test graph";
TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs));
@@ -877,6 +879,79 @@ TEST_F(OpTest, BatchMatMul) {
});
}
+TEST_F(OpTest, BatchToSpace) {
+ Repeatedly([this]() {
+ const int num_block_dims = 2;
+ std::vector<int64> block_dims =
+ RandomDims(num_block_dims, num_block_dims, 0, 5);
+ int64 block_size = RandomDim(0, 4);
+
+ std::vector<int64> input_dims(1 + num_block_dims + 1);
+ input_dims[0] = RandomDim();
+ for (int i = 0; i < num_block_dims; ++i) {
+ input_dims[0] *= block_size;
+ input_dims[1 + i] = block_dims[i];
+ }
+ input_dims[1 + num_block_dims] = RandomDim();
+
+ std::vector<int64> crop_vals;
+ std::uniform_int_distribution<int> distribution(0, 4);
+ for (int i = 0; i < num_block_dims; ++i) {
+ // Chooses crop values; does not always choose legal values.
+ crop_vals.push_back(distribution(generator()));
+ crop_vals.push_back(distribution(generator()));
+ }
+ Tensor crops;
+ CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
+ TensorShape({num_block_dims, 2})));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
+ .Input(RandomTensor(DT_FLOAT, input_dims))
+ .Input(crops)
+ .Attr("T", DT_FLOAT)
+ .Attr("block_size", block_size));
+ });
+}
+
+TEST_F(OpTest, BatchToSpaceND) {
+ Repeatedly([this]() {
+ std::vector<int64> block_dims = RandomDims(1, 3, 0, 5);
+ int num_block_dims = block_dims.size();
+ std::vector<int64> remaining_dims = RandomDims(0, 3);
+ std::vector<int64> block_multipliers =
+ RandomDims(block_dims.size(), block_dims.size(), 0, 4);
+
+ std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size());
+ input_dims[0] = RandomDim();
+ for (int i = 0; i < num_block_dims; ++i) {
+ input_dims[0] *= block_dims[i];
+ }
+ std::copy(block_multipliers.begin(), block_multipliers.end(),
+ input_dims.begin() + 1);
+ std::copy(remaining_dims.begin(), remaining_dims.end(),
+ input_dims.begin() + 1 + num_block_dims);
+
+ std::vector<int64> crop_vals;
+ std::uniform_int_distribution<int> distribution(0, 3);
+ for (int i = 0; i < num_block_dims; ++i) {
+ // Chooses crop values; does not always choose legal values.
+ crop_vals.push_back(distribution(generator()));
+ crop_vals.push_back(distribution(generator()));
+ }
+ Tensor crops;
+ CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
+ TensorShape({num_block_dims, 2})));
+
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("BatchToSpaceND")
+ .Input(RandomTensor(DT_FLOAT, input_dims))
+ .Input(test::AsTensor<int32>(
+ std::vector<int32>(block_dims.begin(), block_dims.end())))
+ .Input(crops)
+ .Attr("T", DT_FLOAT));
+ });
+}
+
TEST_F(OpTest, BiasAdd) {
Repeatedly([this]() {
auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank));
@@ -1214,6 +1289,23 @@ TEST_F(OpTest, DynamicStitch) {
});
}
+TEST_F(OpTest, Elu) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Elu").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, EluGrad) {
+ Repeatedly([this]() {
+ auto dims = RandomDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad")
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
TEST_F(OpTest, Equal) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
@@ -2019,6 +2111,87 @@ TEST_F(OpTest, SoftplusGrad) {
});
}
+TEST_F(OpTest, SpaceToBatch) {
+ Repeatedly([this]() {
+ std::vector<int64> block_dims = RandomDims(4, 4, 0, 5);
+ const int num_block_dims = 2;
+ int64 block_size = RandomDim(0, 4);
+
+ std::vector<int64> input_dims(1 + num_block_dims + 1);
+ input_dims[0] = RandomDim();
+ for (int i = 0; i < num_block_dims; ++i) {
+ input_dims[1 + i] = block_dims[i] * block_size;
+ }
+ input_dims[1 + num_block_dims] = RandomDim();
+
+ std::vector<int64> padding_vals;
+ std::uniform_int_distribution<int> distribution(0, 7);
+ for (int i = 0; i < num_block_dims; ++i) {
+ int64 pad_before;
+ int64 pad_after;
+ do {
+ pad_before = distribution(generator());
+ pad_after = distribution(generator());
+ } while (pad_before + pad_after > input_dims[1 + i]);
+ input_dims[1 + i] -= pad_before + pad_after;
+ padding_vals.push_back(pad_before);
+ padding_vals.push_back(pad_after);
+ }
+ Tensor paddings;
+ CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
+ TensorShape({num_block_dims, 2})));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
+ .Input(RandomTensor(DT_FLOAT, input_dims))
+ .Input(paddings)
+ .Attr("T", DT_FLOAT)
+ .Attr("block_size", block_size));
+ });
+}
+
+TEST_F(OpTest, SpaceToBatchND) {
+ Repeatedly([this]() {
+ std::vector<int64> block_dims = RandomDims(1, 3, 0, 5);
+ int num_block_dims = block_dims.size();
+ std::vector<int64> remaining_dims = RandomDims(0, 3);
+ std::vector<int64> block_multipliers =
+ RandomDims(block_dims.size(), block_dims.size(), 0, 4);
+
+ std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size());
+ input_dims[0] = RandomDim();
+ for (int i = 0; i < num_block_dims; ++i) {
+ input_dims[1 + i] = block_dims[i] * block_multipliers[i];
+ }
+ std::copy(remaining_dims.begin(), remaining_dims.end(),
+ input_dims.begin() + 1 + num_block_dims);
+
+ std::vector<int64> padding_vals;
+ std::uniform_int_distribution<int> distribution(0, 7);
+ for (int i = 0; i < num_block_dims; ++i) {
+ int64 pad_before;
+ int64 pad_after;
+ do {
+ pad_before = distribution(generator());
+ pad_after = distribution(generator());
+ } while (pad_before + pad_after > input_dims[1 + i]);
+ input_dims[1 + i] -= pad_before + pad_after;
+ padding_vals.push_back(pad_before);
+ padding_vals.push_back(pad_after);
+ }
+ Tensor paddings;
+ CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
+ TensorShape({num_block_dims, 2})));
+
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("SpaceToBatchND")
+ .Input(RandomTensor(DT_FLOAT, input_dims))
+ .Input(test::AsTensor<int32>(
+ std::vector<int32>(block_dims.begin(), block_dims.end())))
+ .Input(paddings)
+ .Attr("T", DT_FLOAT));
+ });
+}
+
TEST_F(OpTest, SparseMatMul) {
Repeatedly([this]() {
int64 x = RandomDim();
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
new file mode 100644
index 0000000000..9c3b86c84b
--- /dev/null
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -0,0 +1,266 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for SpaceToBatch and BatchToSpace ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.platform import test
+
+
+def space_to_batch_direct(input_array, block_shape, paddings):
+ """Direct Python implementation of space-to-batch conversion.
+
+ This is used for tests only.
+
+ Args:
+ input_array: N-D array
+ block_shape: 1-D array of shape [num_block_dims].
+ paddings: 2-D array of shape [num_block_dims, 2].
+
+ Returns:
+ Converted tensor.
+ """
+ input_array = np.array(input_array)
+ block_shape = np.array(block_shape)
+ num_block_dims = len(block_shape)
+ paddings = np.array(paddings).reshape((len(block_shape), 2))
+
+ padded = np.pad(input_array,
+ pad_width=([[0, 0]] + list(paddings) + [[0, 0]] *
+ (input_array.ndim - 1 - num_block_dims)),
+ mode="constant")
+ reshaped_padded_shape = [input_array.shape[0]]
+ output_shape = [input_array.shape[0] * np.prod(block_shape)]
+ for block_dim, block_shape_value in enumerate(block_shape):
+ reduced_size = padded.shape[block_dim + 1] // block_shape_value
+ reshaped_padded_shape.append(reduced_size)
+ output_shape.append(reduced_size)
+ reshaped_padded_shape.append(block_shape_value)
+ reshaped_padded_shape.extend(input_array.shape[num_block_dims + 1:])
+ output_shape.extend(input_array.shape[num_block_dims + 1:])
+
+ reshaped_padded = padded.reshape(reshaped_padded_shape)
+ permuted_reshaped_padded = np.transpose(reshaped_padded, (
+ list(np.arange(num_block_dims) * 2 + 2) + [0] +
+ list(np.arange(num_block_dims) * 2 + 1) + list(
+ np.arange(input_array.ndim - num_block_dims - 1) + 1 + num_block_dims
+ * 2)))
+ return permuted_reshaped_padded.reshape(output_shape)
+
+
+class SpaceToBatchTest(XLATestCase):
+ """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops."""
+
+ def _testPad(self, inputs, paddings, block_size, outputs):
+ with self.test_session() as sess, self.test_scope():
+ for dtype in self.float_types:
+ # outputs = space_to_batch(inputs)
+ placeholder = array_ops.placeholder(dtype)
+ x_tf = gen_array_ops._space_to_batch(
+ placeholder, paddings, block_size=block_size)
+ self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs)
+ # inputs = batch_to_space(outputs)
+ x_tf = gen_array_ops._batch_to_space(
+ placeholder, paddings, block_size=block_size)
+ self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs)
+
+ def _testOne(self, inputs, block_size, outputs):
+ paddings = np.zeros((2, 2), dtype=np.int32)
+ self._testPad(inputs, paddings, block_size, outputs)
+
+ # [1, 2, 2, 1] <-> [4, 1, 1, 1]
+ def testSmallInput2x2(self):
+ x_np = [[[[1], [2]], [[3], [4]]]]
+ block_size = 2
+ x_out = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ # [1, 2, 2, 1] <-> [1, 3, 3, 1] (padding) <-> [9, 1, 1, 1]
+ def testSmallInput2x2Pad1x0(self):
+ x_np = [[[[1], [2]], [[3], [4]]]]
+ paddings = np.array([[1, 0], [1, 0]], dtype=np.int32)
+ block_size = 3
+ x_out = [[[[0]]], [[[0]]], [[[0]]], [[[0]]], [[[1]]], [[[2]]], [[[0]]],
+ [[[3]]], [[[4]]]]
+ self._testPad(x_np, paddings, block_size, x_out)
+
+ # Test with depth larger than 1.
+ # [1, 2, 2, 3] <-> [4, 1, 1, 3]
+ def testDepthInput2x2(self):
+ x_np = [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]
+ block_size = 2
+ x_out = [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ # Test for larger input dimensions.
+ # [1, 4, 4, 1] <-> [4, 2, 2, 1]
+ def testLargerInput2x2(self):
+ x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]],
+ [[9], [10], [11], [12]], [[13], [14], [15], [16]]]]
+ block_size = 2
+ x_out = [[[[1], [3]], [[9], [11]]], [[[2], [4]], [[10], [12]]],
+ [[[5], [7]], [[13], [15]]], [[[6], [8]], [[14], [16]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ # Test with batch larger than 1.
+ # [2, 2, 4, 1] <-> [8, 1, 2, 1]
+ def testBatchInput2x2(self):
+ x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]]],
+ [[[9], [10], [11], [12]], [[13], [14], [15], [16]]]]
+ block_size = 2
+ x_out = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]],
+ [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ # Tests for larger input spatial dimensions AND batch larger than 1, to ensure
+ # that elements are correctly laid out spatially and properly interleaved
+ # along the batch dimension.
+ # [2, 4, 4, 1] <-> [8, 2, 2, 1]
+ def testLargerInputBatch2x2(self):
+ x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]],
+ [[9], [10], [11], [12]], [[13], [14], [15], [16]]],
+ [[[17], [18], [19], [20]], [[21], [22], [23], [24]],
+ [[25], [26], [27], [28]], [[29], [30], [31], [32]]]]
+ x_out = [[[[1], [3]], [[9], [11]]], [[[17], [19]], [[25], [27]]],
+ [[[2], [4]], [[10], [12]]], [[[18], [20]], [[26], [28]]],
+ [[[5], [7]], [[13], [15]]], [[[21], [23]], [[29], [31]]],
+ [[[6], [8]], [[14], [16]]], [[[22], [24]], [[30], [32]]]]
+ block_size = 2
+ self._testOne(x_np, block_size, x_out)
+
+
+class SpaceToBatchNDTest(XLATestCase):
+ """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops."""
+
+ def _testPad(self, inputs, block_shape, paddings, outputs):
+ block_shape = np.array(block_shape)
+ paddings = np.array(paddings).reshape((len(block_shape), 2))
+ with self.test_session() as sess, self.test_scope():
+ for dtype in self.float_types:
+ placeholder = array_ops.placeholder(dtype)
+ # outputs = space_to_batch(inputs)
+ x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings)
+ self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs)
+ # inputs = batch_to_space(outputs)
+ placeholder = array_ops.placeholder(dtype)
+ x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, paddings)
+ self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs)
+
+ def _testDirect(self, input_shape, block_shape, paddings):
+ inputs = np.arange(np.prod(input_shape), dtype=np.float32)
+ inputs = inputs.reshape(input_shape)
+ self._testPad(inputs, block_shape, paddings,
+ space_to_batch_direct(inputs, block_shape, paddings))
+
+ def testZeroBlockDimsZeroRemainingDims(self):
+ self._testPad(
+ inputs=[1, 2],
+ block_shape=[],
+ paddings=[],
+ outputs=[1, 2],)
+
+ def testZeroBlockDimsOneRemainingDim(self):
+ self._testPad(
+ inputs=[[1, 2], [3, 4]],
+ block_shape=[],
+ paddings=[],
+ outputs=[[1, 2], [3, 4]])
+
+ # Same thing, but with a no-op block dim.
+ self._testPad(
+ inputs=[[1, 2], [3, 4]],
+ block_shape=[1],
+ paddings=[[0, 0]],
+ outputs=[[1, 2], [3, 4]])
+
+ def testZeroBlockDimsTwoRemainingDims(self):
+ self._testPad(
+ inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
+ block_shape=[],
+ paddings=[],
+ outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+
+ # Same thing, but with a no-op block dim.
+ self._testPad(
+ inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
+ block_shape=[1],
+ paddings=[[0, 0]],
+ outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+
+ # Same thing, but with two no-op block dims.
+ self._testPad(
+ inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
+ block_shape=[1, 1],
+ paddings=[[0, 0], [0, 0]],
+ outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+
+ def testOneBlockDimZeroRemainingDims(self):
+ self._testPad(
+ inputs=[[1, 2, 3], [4, 5, 6]],
+ block_shape=[2],
+ paddings=[1, 0],
+ outputs=[[0, 2], [0, 5], [1, 3], [4, 6]])
+
+ def testOneBlockDimOneRemainingDim(self):
+ self._testPad(
+ inputs=[[[1, 11], [2, 21], [3, 31]], [[4, 41], [5, 51], [6, 61]]],
+ block_shape=[2],
+ paddings=[1, 0],
+ outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]],
+ [[4, 41], [6, 61]]])
+
+ def testDirect(self):
+ # Test with zero-size remaining dimension.
+ self._testDirect(
+ input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]])
+
+ # Test with zero-size blocked dimension.
+ self._testDirect(
+ input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]])
+
+ # Test with padding up from zero size.
+ self._testDirect(
+ input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]])
+
+ self._testDirect(
+ input_shape=[3, 3, 4, 5, 2],
+ block_shape=[3, 4, 2],
+ paddings=[[1, 2], [0, 0], [3, 0]])
+
+ self._testDirect(
+ input_shape=[3, 3, 4, 5, 2],
+ block_shape=[3, 4, 2, 2],
+ paddings=[[1, 2], [0, 0], [3, 0], [0, 0]])
+
+ self._testDirect(
+ input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
+ block_shape=[1, 1, 3, 4, 2, 2],
+ paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]])
+
+ self._testDirect(
+ input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
+ block_shape=[1, 1, 3, 4, 2, 2, 1],
+ paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0], [0, 0]])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 1e85d3a2c8..3f324d1071 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -210,6 +210,11 @@ class UnaryOpsTest(XLATestCase):
dtype=dtype))
self._assertOpOutputMatchesExpected(
+ nn_ops.elu,
+ np.array([[-1, 0, 1]], dtype=dtype),
+ expected=np.array([[-0.63212056, 0, 1]], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
nn_ops.relu,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[0, 1]], dtype=dtype))
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index 53aa749a0a..44ff13ca34 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -35,6 +35,9 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Any", "reduction_indices"},
{"ArgMax", "dimension"},
{"AvgPoolGrad", "orig_input_shape"},
+ {"BatchToSpace", "crops"},
+ {"BatchToSpaceND", "block_shape"},
+ {"BatchToSpaceND", "crops"},
{"BroadcastGradientArgs", "s0"},
{"BroadcastGradientArgs", "s1"},
{"Concat", "concat_dim"},
@@ -69,6 +72,9 @@ Status BackwardsConstAnalysis(const Graph& g,
{"ReverseV2", "axis"},
{"Slice", "begin"},
{"Slice", "size"},
+ {"SpaceToBatch", "paddings"},
+ {"SpaceToBatchND", "block_shape"},
+ {"SpaceToBatchND", "paddings"},
{"Split", "split_dim"},
{"SplitV", "split_dim"},
{"SplitV", "size_splits"},
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 2ee80a41e8..14d2a72f7c 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -15,6 +15,7 @@ tf_kernel_library(
srcs = [
"aggregate_ops.cc",
"batch_matmul_op.cc",
+ "batchtospace_op.cc",
"bcast_ops.cc",
"bias_ops.cc",
"binary_ops.cc",
@@ -26,6 +27,7 @@ tf_kernel_library(
"depthwise_conv_ops.cc",
"diag_op.cc",
"dynamic_stitch_op.cc",
+ "elu_op.cc",
"fill_op.cc",
"function_ops.cc",
"identity_op.cc",
@@ -49,6 +51,7 @@ tf_kernel_library(
"shape_op.cc",
"slice_op.cc",
"softmax_op.cc",
+ "spacetobatch_op.cc",
"split_op.cc",
"strided_slice_op.cc",
"tile_ops.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
new file mode 100644
index 0000000000..eb4bd47ee5
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -0,0 +1,186 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+void BatchToSpace(XlaOpKernelContext* ctx,
+ const xla::ComputationDataHandle& input, DataType input_dtype,
+ const TensorShape& input_tensor_shape,
+ gtl::ArraySlice<int64> block_shape,
+ const xla::Literal& crops) {
+ const int input_rank = input_tensor_shape.dims();
+ const gtl::InlinedVector<int64, 4> input_shape =
+ input_tensor_shape.dim_sizes();
+ const int block_rank = block_shape.size();
+
+ OP_REQUIRES(
+ ctx, input_rank >= 1 + block_rank,
+ errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
+ " instead of ", input_rank));
+ gtl::ArraySlice<int64> remainder_shape(input_shape);
+ remainder_shape.remove_prefix(1 + block_rank);
+
+ OP_REQUIRES(
+ ctx,
+ xla::ShapeUtil::Rank(crops.shape()) == 2 &&
+ block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) &&
+ 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1),
+ errors::InvalidArgument("crops should have shape [", block_rank,
+ ", 2] instead of ",
+ xla::ShapeUtil::HumanString(crops.shape())));
+
+ xla::ComputationBuilder* b = ctx->builder();
+ const int64 batch_size = input_shape[0];
+
+ // Compute the product of the block_shape values.
+ int64 block_num_elems = 1;
+ for (int i = 0; i < block_rank; ++i) {
+ block_num_elems *= block_shape[i];
+ }
+ OP_REQUIRES(ctx, block_num_elems > 0,
+ errors::InvalidArgument(
+ "The product of the block dimensions must be positive"));
+
+ // 1. Reshape `input` to `reshaped` of shape:
+ // [block_shape[0], ..., block_shape[M-1],
+ // batch / prod(block_shape),
+ // input_shape[1], ..., input_shape[N-1]]
+
+ OP_REQUIRES(
+ ctx, batch_size % block_num_elems == 0,
+ errors::InvalidArgument("Input batch dimension (", batch_size,
+ ") is not divisible by product of block sizes (",
+ block_num_elems, ")"));
+ std::vector<int64> reshaped_shape(input_rank + block_rank);
+ std::copy(block_shape.begin(), block_shape.end(), reshaped_shape.begin());
+ reshaped_shape[block_rank] = batch_size / block_num_elems;
+ std::copy(input_shape.begin() + 1, input_shape.end(),
+ reshaped_shape.begin() + block_rank + 1);
+ xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
+
+ // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
+ // [batch / prod(block_shape),
+ //
+ // input_shape[1], block_shape[0],
+ // ...,
+ // input_shape[M], block_shape[M-1],
+ //
+ // input_shape[M+1], ..., input_shape[N-1]]
+ std::vector<int64> permutation(reshaped_shape.size());
+ permutation[0] = block_rank;
+ for (int i = 0; i < block_rank; ++i) {
+ permutation[1 + 2 * i] = block_rank + 1 + i;
+ permutation[1 + 2 * i + 1] = i;
+ }
+ std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
+ 1 + block_rank * 2);
+ xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation);
+
+ // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
+ // [batch / prod(block_shape),
+ //
+ // input_shape[1] * block_shape[0],
+ // ...,
+ // input_shape[M] * block_shape[M-1],
+ //
+ // input_shape[M+1],
+ // ...,
+ // input_shape[N-1]]
+ std::vector<int64> reshaped_permuted_shape(input_rank);
+ reshaped_permuted_shape[0] = batch_size / block_num_elems;
+ for (int i = 0; i < block_rank; ++i) {
+ reshaped_permuted_shape[1 + i] = block_shape[i] * input_shape[1 + i];
+ }
+ std::copy(remainder_shape.begin(), remainder_shape.end(),
+ reshaped_permuted_shape.begin() + 1 + block_rank);
+
+ xla::ComputationDataHandle reshaped_permuted =
+ b->Reshape(permuted, reshaped_permuted_shape);
+
+ // 4. Crop the start and end of dimensions `[1, ..., M]` of
+ // `reshaped_permuted` according to `crops` to produce the output of shape:
+ // [batch / prod(block_shape),
+ //
+ // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
+ // ...,
+ // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
+ //
+ // input_shape[M+1], ..., input_shape[N-1]]
+ std::vector<int64> start_indices(input_rank, 0);
+ std::vector<int64> end_indices = reshaped_permuted_shape;
+ for (int i = 0; i < block_rank; ++i) {
+ int64 crop_start = xla::LiteralUtil::Get<int64>(crops, {i, 0});
+ int64 crop_end = xla::LiteralUtil::Get<int64>(crops, {i, 1});
+ OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0,
+ errors::InvalidArgument("Crops must be non-negative"));
+ start_indices[1 + i] = crop_start;
+ end_indices[1 + i] -= crop_end;
+ OP_REQUIRES(
+ ctx, start_indices[1 + i] <= end_indices[1 + i],
+ errors::InvalidArgument(
+ "Cropped size must be non-negative: start: ", crop_start,
+ " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
+ }
+ xla::ComputationDataHandle output =
+ b->Slice(reshaped_permuted, start_indices, end_indices);
+ ctx->SetOutput(0, output);
+}
+
+class BatchToSpaceNDOp : public XlaOpKernel {
+ public:
+ explicit BatchToSpaceNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<int64> block_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
+
+ xla::Literal crops;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &crops));
+
+ BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
+ block_shape, crops);
+ }
+};
+REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp);
+
+class BatchToSpaceOp : public XlaOpKernel {
+ public:
+ explicit BatchToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
+ OP_REQUIRES(
+ ctx, block_size_ > 1,
+ errors::InvalidArgument("Block size should be > 1: ", block_size_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::Literal crops;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &crops));
+
+ BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
+ {block_size_, block_size_}, crops);
+ }
+
+ private:
+ int block_size_;
+};
+REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
new file mode 100644
index 0000000000..62a5e1bd42
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Native XLA implementations of XLA Elu Ops
+
+#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/no_op.h"
+
+namespace tensorflow {
+namespace {
+
+class EluOp : public XlaOpKernel {
+ public:
+ explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Computes the max of the scalar input x and 0.
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+ const auto zero = XlaHelpers::Zero(b, input_type(0));
+ const auto one = XlaHelpers::One(b, input_type(0));
+ const auto pred = b->Gt(ctx->Input(0), zero);
+ const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one);
+ ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1));
+ }
+};
+
+class EluGradOp : public XlaOpKernel {
+ public:
+ explicit EluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Return the lhs (incoming gradient) if the rhs (input feature) > 0,
+ // otherwise return lhs * (1 + rhs).
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+ const auto zero = XlaHelpers::Zero(b, input_type(0));
+ const auto one = XlaHelpers::One(b, input_type(0));
+ const auto grad = ctx->Input(0);
+ const auto activation = ctx->Input(1);
+ const auto exp_grad = b->Mul(grad, b->Add(activation, one));
+ const auto pred = b->Gt(activation, zero);
+ ctx->SetOutput(0, b->Select(pred, grad, exp_grad));
+ }
+};
+
+REGISTER_XLA_OP(Name("Elu"), EluOp);
+REGISTER_XLA_OP(Name("EluGrad"), EluGradOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
new file mode 100644
index 0000000000..f15b354cb2
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -0,0 +1,190 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+void SpaceToBatch(XlaOpKernelContext* ctx,
+ const xla::ComputationDataHandle& input, DataType input_dtype,
+ const TensorShape& input_tensor_shape,
+ gtl::ArraySlice<int64> block_shape,
+ const xla::Literal& paddings) {
+ const int input_rank = input_tensor_shape.dims();
+ const gtl::InlinedVector<int64, 4> input_shape =
+ input_tensor_shape.dim_sizes();
+ const int block_rank = block_shape.size();
+
+ OP_REQUIRES(
+ ctx, input_rank >= 1 + block_rank,
+ errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
+ " instead of ", input_rank));
+ gtl::ArraySlice<int64> remainder_shape(input_shape);
+ remainder_shape.remove_prefix(1 + block_rank);
+
+ OP_REQUIRES(
+ ctx,
+ xla::ShapeUtil::Rank(paddings.shape()) == 2 &&
+ block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
+ 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
+ errors::InvalidArgument("paddings should have shape [", block_rank,
+ ", 2] instead of ",
+ xla::ShapeUtil::HumanString(paddings.shape())));
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
+ // input according to `paddings` to produce `padded` of shape `padded_shape`.
+ xla::PaddingConfig padding_config;
+ std::vector<int64> padded_shape(input_shape.begin(), input_shape.end());
+ int64 block_num_elems = 1LL;
+ padding_config.add_dimensions(); // Don't pad the batch dimension.
+ for (int i = 0; i < block_rank; ++i) {
+ auto* dim = padding_config.add_dimensions();
+ int64 pad_start = xla::LiteralUtil::Get<int64>(paddings, {i, 0});
+ int64 pad_end = xla::LiteralUtil::Get<int64>(paddings, {i, 1});
+ OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
+ errors::InvalidArgument("Paddings must be non-negative"));
+ dim->set_edge_padding_low(pad_start);
+ dim->set_edge_padding_high(pad_end);
+ padded_shape[1 + i] += pad_start + pad_end;
+ block_num_elems *= block_shape[i];
+ }
+ // Don't pad the remainder dimensions.
+ for (int i = 0; i < remainder_shape.size(); ++i) {
+ padding_config.add_dimensions();
+ }
+ OP_REQUIRES(ctx, block_num_elems > 0,
+ errors::InvalidArgument(
+ "The product of the block dimensions must be positive"));
+
+ xla::ComputationDataHandle padded =
+ b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
+
+ // 2. Reshape `padded` to `reshaped_padded` of shape:
+ //
+ // [batch] +
+ // [padded_shape[1] / block_shape[0],
+ // block_shape[0],
+ // ...,
+ // padded_shape[M] / block_shape[M-1],
+ // block_shape[M-1]] +
+ // remaining_shape
+ const int64 batch_size = input_shape[0];
+ std::vector<int64> reshaped_padded_shape(input_rank + block_rank);
+ reshaped_padded_shape[0] = batch_size;
+ for (int i = 0; i < block_rank; ++i) {
+ OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0,
+ errors::InvalidArgument("padded_shape[", 1 + i,
+ "]=", padded_shape[1 + i],
+ " is not divisible by block_shape[", i,
+ "]=", block_shape[i]));
+
+ reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i];
+ reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i];
+ }
+ std::copy(remainder_shape.begin(), remainder_shape.end(),
+ reshaped_padded_shape.begin() + 1 + 2 * block_rank);
+
+ xla::ComputationDataHandle reshaped_padded =
+ b->Reshape(padded, reshaped_padded_shape);
+
+ // 3. Permute dimensions of `reshaped_padded` to produce
+ // `permuted_reshaped_padded` of shape:
+ //
+ // block_shape +
+ // [batch] +
+ // [padded_shape[1] / block_shape[0],
+ // ...,
+ // padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+ std::vector<int64> permutation(reshaped_padded_shape.size());
+ for (int i = 0; i < block_rank; ++i) {
+ permutation[i] = 1 + 2 * i + 1;
+ permutation[block_rank + 1 + i] = 1 + 2 * i;
+ }
+ permutation[block_rank] = 0;
+ std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
+ 1 + block_rank * 2);
+ xla::ComputationDataHandle permuted_reshaped_padded =
+ b->Transpose(reshaped_padded, permutation);
+
+ // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
+ // batch dimension, producing an output tensor of shape:
+ //
+ // [batch * prod(block_shape)] +
+ // [padded_shape[1] / block_shape[0],
+ // ...,
+ // padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+ // Determine the length of the prefix of block dims that can be combined
+ // into the batch dimension due to having no padding and block_shape=1.
+ std::vector<int64> output_shape(input_rank);
+ output_shape[0] = batch_size * block_num_elems;
+ for (int i = 0; i < block_rank; ++i) {
+ output_shape[1 + i] = padded_shape[1 + i] / block_shape[i];
+ }
+ std::copy(remainder_shape.begin(), remainder_shape.end(),
+ output_shape.begin() + 1 + block_rank);
+
+ xla::ComputationDataHandle output =
+ b->Reshape(permuted_reshaped_padded, output_shape);
+ ctx->SetOutput(0, output);
+}
+
+class SpaceToBatchNDOp : public XlaOpKernel {
+ public:
+ explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<int64> block_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
+
+ xla::Literal paddings;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings));
+
+ SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
+ block_shape, paddings);
+ }
+};
+REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp);
+
+class SpaceToBatchOp : public XlaOpKernel {
+ public:
+ explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
+ OP_REQUIRES(
+ ctx, block_size_ > 1,
+ errors::InvalidArgument("Block size should be > 1: ", block_size_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::Literal paddings;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings));
+
+ SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
+ {block_size_, block_size_}, paddings);
+ }
+
+ private:
+ int block_size_;
+};
+REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 53dcdec7a2..a022de36a2 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -186,6 +186,31 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
return LiteralToInt64Vector(literal, out);
}
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
+ xla::Literal* out) {
+ xla::Literal literal;
+ TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+ switch (literal.shape().element_type()) {
+ case xla::S32:
+ out->Clear();
+ *out->mutable_shape() = literal.shape();
+ out->mutable_shape()->set_element_type(xla::S64);
+ for (int32 x : literal.s32s()) {
+ out->add_s64s(x);
+ }
+ return Status::OK();
+
+ case xla::S64:
+ out->Swap(&literal);
+ return Status::OK();
+
+ default:
+ return errors::InvalidArgument(
+ "Invalid argument to ConstantInputAsInt64Literal: ",
+ xla::ShapeUtil::HumanString(literal.shape()));
+ }
+}
+
// TODO(phawkins): validate that the dimensions form a valid shape, fail
// gracefully if they do not.
Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 60e3b59d32..f97e07bea5 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -110,6 +110,9 @@ class XlaOpKernelContext {
// Converts a constant 1D int32 or int64 tensor into a vector of int64s.
Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
+ // Converts a constant int32 or int64 Tensor into an xla int64 Literal.
+ Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
+
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
Status ConstantInputAsShape(int index, TensorShape* shape);
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index b9118fab25..695e4e7f07 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -594,8 +594,10 @@ cc_test(
deps = [
":buffer_assignment",
":computation_tracker",
+ ":copy_insertion",
":cpu_plugin",
":hlo",
+ ":hlo_ordering",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 27a1c0fec8..0969cff39a 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -868,7 +868,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
computation->root_instruction()->dimensions();
EXPECT_EQ(1, broadcast_dims.size());
EXPECT_TRUE(broadcast_dims[0] == 1 || broadcast_dims[0] == 2 ||
- broadcast_dims[3] == 3);
+ broadcast_dims[0] == 3);
}
TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index e2b550fc02..931f589800 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -41,6 +41,8 @@ limitations under the License.
namespace xla {
+using ::tensorflow::gtl::FlatMap;
+using ::tensorflow::gtl::FlatSet;
using ::tensorflow::strings::Appendf;
using ::tensorflow::strings::HumanReadableNumBytes;
@@ -394,8 +396,8 @@ Status GatherComputationsByAllocationType(
// Sets for quickly checking membership. Computations are returned in vectors
// for stable iteration.
- tensorflow::gtl::FlatSet<HloComputation*> thread_local_set;
- tensorflow::gtl::FlatSet<HloComputation*> global_set;
+ FlatSet<HloComputation*> thread_local_set;
+ FlatSet<HloComputation*> global_set;
while (!worklist.empty()) {
auto worklist_front = worklist.front();
@@ -554,10 +556,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
Status BufferAssigner::AssignBuffersForComputation(
const HloComputation* computation, bool is_thread_local,
- const tensorflow::gtl::FlatSet<const HloInstruction*>* hlos_to_allocate,
- const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
- const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
- colocated_allocations,
+ const FlatSet<const HloInstruction*>* hlos_to_allocate,
+ const FlatSet<const LogicalBuffer*>& colocated_buffers,
+ const FlatSet<BufferAllocation::Index>& colocated_allocations,
BufferAssignment* assignment) {
// Buffers are sorted and assigned to BufferAllocations in decreasing order of
// size.
@@ -578,7 +579,7 @@ Status BufferAssigner::AssignBuffersForComputation(
// Generate a post order sort of instructions for sorting of the
// LogicalBuffers.
- tensorflow::gtl::FlatMap<const HloInstruction*, int> post_order_position;
+ FlatMap<const HloInstruction*, int> post_order_position;
int position = 0;
for (auto* instruction : computation->MakeInstructionPostOrder()) {
post_order_position.emplace(instruction, position);
@@ -590,7 +591,7 @@ Status BufferAssigner::AssignBuffersForComputation(
const BufferLiveness& liveness = assignment->liveness();
const std::vector<const HloInstruction*>* sequential_order =
liveness.hlo_ordering().SequentialOrder(*computation);
- tensorflow::gtl::FlatSet<const LogicalBuffer*> unassigned_temp_buffers;
+ FlatSet<const LogicalBuffer*> unassigned_temp_buffers;
// Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
// first for simplicity. This means any previously created BufferAllocation is
@@ -791,7 +792,7 @@ Status BufferAssigner::AssignBuffersForComputation(
Status BufferAssigner::AssignBuffersWithSequentialOrdering(
const std::vector<const HloInstruction*>& sequence,
- const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers_to_assign,
+ const FlatSet<const LogicalBuffer*>& buffers_to_assign,
const HloComputation& computation, BufferAssignment* assignment) {
// Run the sequence of instructions through the heap simulator. The heuristic
// that seems to give the best results is lazy-best-fit, with all runs of
@@ -881,40 +882,137 @@ void BufferAssigner::AddSetToColocatedBufferSets(
}
}
+// Conceptually the same as AddSetToColocatedBufferSets, but specific to the
+// colocated buffers for while instructions. 'colocated_set' contains the
+// buffers for a single while instruction that must be colocated. The idea here
+// is to apply a memory-saving heuristic for separate while instructions whose
+// buffers are disjoint in liveness, by using the colocation mechanism to force
+// buffer sharing. This often reduces memory for multi-layer RNNs.
+//
+// TODO(b/32491382): We should be able to remove this heuristic after we
+// implement module-level liveness analysis, which would let us directly detect
+// buffer sharing opportunities between the while instruction buffer and the
+// buffers from the predicate and body computation, as well as sharing across
+// different while instructions.
+void BufferAssigner::AddWhileSetToColocatedBufferSets(
+ const std::vector<const LogicalBuffer*>& colocated_set,
+ const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo,
+ const HloComputation& computation, const BufferLiveness& buffer_liveness,
+ std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
+ CHECK(!colocated_set.empty());
+
+ // Parallel while loops cannot safely share colocated buffer sets.
+ if (buffer_liveness.hlo_ordering().SequentialOrder(computation) == nullptr) {
+ AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
+ return;
+ }
+
+ // Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets
+ // are added in postorder over computations and instructions.
+ const int64 init_buffer_size = buffer_size_(*while_init_buffer);
+ for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) {
+ const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i];
+
+ // Skip predecessor sets not associated with while loops.
+ if (std::all_of(predecessor_set.begin(), predecessor_set.end(),
+ [](const LogicalBuffer* buffer) {
+ return buffer->instruction()->opcode() !=
+ HloOpcode::kWhile;
+ })) {
+ continue;
+ }
+
+ // Skip predecessor sets already associated with 'while_hlo'.
+ if (std::any_of(predecessor_set.begin(), predecessor_set.end(),
+ [&while_hlo](const LogicalBuffer* buffer) {
+ return buffer->instruction() == while_hlo;
+ })) {
+ continue;
+ }
+
+ // Build vector of predecessor while result buffers.
+ std::vector<const LogicalBuffer*> predecessor_while_buffers;
+ for (const LogicalBuffer* buffer : predecessor_set) {
+ if (buffer->instruction()->opcode() == HloOpcode::kWhile &&
+ buffer_size_(*buffer) == init_buffer_size &&
+ buffer->instruction()->parent() == &computation) {
+ predecessor_while_buffers.push_back(buffer);
+ }
+ }
+ if (predecessor_while_buffers.empty()) {
+ continue;
+ }
+
+ // Skip predecessor set if the live range of any predecessor buffers
+ // overlaps with 'while_init_buffer'. Note that tuple element buffer
+ // forwarding can cause the same buffer to appear on both sides of the
+ // interference comparison below.
+ if (std::any_of(
+ predecessor_while_buffers.begin(), predecessor_while_buffers.end(),
+ [while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) {
+ return while_init_buffer->id() != buffer->id() &&
+ buffer_liveness.MayInterfere(*while_init_buffer, *buffer);
+ })) {
+ continue;
+ }
+
+ // All our checks have passed; merge 'predecessor_set' with 'colocated_set',
+ // and add the merged set to 'colocated_buffer_sets'. This forces the
+ // colocation of buffers across different while instructions.
+ FlatSet<const LogicalBuffer*> unique;
+ unique.insert(predecessor_set.begin(), predecessor_set.end());
+ unique.insert(colocated_set.begin(), colocated_set.end());
+ std::vector<const LogicalBuffer*> merged_set(unique.begin(), unique.end());
+ AddSetToColocatedBufferSets(merged_set, colocated_buffer_sets);
+ return;
+ }
+
+ // Failed to merge into predecessor set; add 'colocated_set' as-is.
+ AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
+}
+
namespace {
+
// Checks that points-to set of 'instruction' is unambiguous and distinct
// (ensured by CopyInsertion), then adds the buffer from the points-to set at
// 'index' to 'colocated_set'.
-void AddBufferToColocatedSet(const HloInstruction* instruction,
- const ShapeIndex& index,
- const TuplePointsToAnalysis& points_to_analysis,
- std::vector<const LogicalBuffer*>* colocated_set) {
+const LogicalBuffer* AddBufferToColocatedSet(
+ const HloInstruction* instruction, const ShapeIndex& index,
+ const TuplePointsToAnalysis& points_to_analysis,
+ std::vector<const LogicalBuffer*>* colocated_set) {
// CopyInsertion ensures root points-to set is unambiguous and distinct.
const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
CHECK(!points_to.IsAmbiguous());
CHECK(points_to.IsDistinct());
colocated_set->push_back(points_to.element(index)[0]);
+ return colocated_set->back();
}
+
} // namespace
// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
// in the same allocation (currently just supports kWhile and kCall).
void BufferAssigner::BuildColocatedBufferSets(
- const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
+ const HloModule* module, const BufferLiveness& buffer_liveness,
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
- for (auto& computation : module->computations()) {
- for (auto& instruction : computation->instructions()) {
+ const TuplePointsToAnalysis& points_to_analysis =
+ buffer_liveness.points_to_analysis();
+ for (const HloComputation* computation : module->MakeComputationPostOrder()) {
+ for (const HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
const HloOpcode opcode = instruction->opcode();
if (opcode == HloOpcode::kWhile) {
- HloInstruction* while_hlo = instruction.get();
+ const HloInstruction* while_hlo = instruction;
TF_CHECK_OK(ShapeUtil::ForEachSubshape(
while_hlo->shape(),
- [this, while_hlo, &points_to_analysis, colocated_buffer_sets](
- const Shape& /*subshape*/, const ShapeIndex& index) {
+ [this, while_hlo, &points_to_analysis, &buffer_liveness,
+ computation, colocated_buffer_sets](const Shape& /*subshape*/,
+ const ShapeIndex& index) {
std::vector<const LogicalBuffer*> colocated_set;
// Add while.init.
- AddBufferToColocatedSet(while_hlo->operand(0), index,
- points_to_analysis, &colocated_set);
+ auto* init_buffer =
+ AddBufferToColocatedSet(while_hlo->operand(0), index,
+ points_to_analysis, &colocated_set);
// Add while.result.
AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
&colocated_set);
@@ -930,12 +1028,15 @@ void BufferAssigner::BuildColocatedBufferSets(
AddBufferToColocatedSet(
while_hlo->while_body()->root_instruction(), index,
points_to_analysis, &colocated_set);
- AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
+ AddWhileSetToColocatedBufferSets(
+ colocated_set, init_buffer, while_hlo, *computation,
+ buffer_liveness, colocated_buffer_sets);
return Status::OK();
}));
} else if (opcode == HloOpcode::kCall) {
- HloInstruction* call_hlo = instruction.get();
- HloInstruction* root_hlo = call_hlo->to_apply()->root_instruction();
+ const HloInstruction* call_hlo = instruction;
+ const HloInstruction* root_hlo =
+ call_hlo->to_apply()->root_instruction();
TF_CHECK_OK(ShapeUtil::ForEachSubshape(
call_hlo->shape(),
[this, call_hlo, root_hlo, &points_to_analysis,
@@ -961,8 +1062,8 @@ void BufferAssigner::BuildColocatedBufferSets(
void BufferAssigner::AssignColocatedBufferSets(
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
BufferAssignment* assignment,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
- tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations) {
+ FlatSet<const LogicalBuffer*>* colocated_buffers,
+ FlatSet<BufferAllocation::Index>* colocated_allocations) {
for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
BufferAllocation* allocation = nullptr;
for (const LogicalBuffer* buffer : colocated_buffer_set) {
@@ -1008,9 +1109,9 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to
// AssignBuffersForComputation for fast membership testing.
- std::unique_ptr<tensorflow::gtl::FlatSet<const HloInstruction*>> hlo_set;
+ std::unique_ptr<FlatSet<const HloInstruction*>> hlo_set;
if (hlos_to_allocate != nullptr) {
- hlo_set = MakeUnique<tensorflow::gtl::FlatSet<const HloInstruction*>>(
+ hlo_set = MakeUnique<FlatSet<const HloInstruction*>>(
hlos_to_allocate->begin(), hlos_to_allocate->end());
}
@@ -1022,11 +1123,11 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// Once b/32491382 enables module-level liveness analysis, we may be able
// to assign colocated buffers (or at least reuse their allocation for
// buffers outside of the set) in AssignBuffersForComputation.
- tensorflow::gtl::FlatSet<const LogicalBuffer*> colocated_buffers;
- tensorflow::gtl::FlatSet<BufferAllocation::Index> colocated_allocations;
+ FlatSet<const LogicalBuffer*> colocated_buffers;
+ FlatSet<BufferAllocation::Index> colocated_allocations;
if (colocate_related_buffers_) {
std::vector<ColocatedBufferSet> colocated_buffer_sets;
- BuildColocatedBufferSets(module, assignment->points_to_analysis(),
+ BuildColocatedBufferSets(module, assignment->liveness(),
&colocated_buffer_sets);
AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(),
&colocated_buffers, &colocated_allocations);
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index b82acb19b3..ec1375e24d 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -465,7 +465,7 @@ class BufferAssigner {
// ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
// which should be colocated in the same buffer allocation.
void BuildColocatedBufferSets(
- const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
+ const HloModule* module, const BufferLiveness& buffer_liveness,
std::vector<ColocatedBufferSet>* colocated_buffer_sets);
// For each buffer set in 'colocated_buffer_sets', assigns all buffers in the
@@ -482,6 +482,14 @@ class BufferAssigner {
const std::vector<const LogicalBuffer*>& colocated_set,
std::vector<ColocatedBufferSet>* colocated_buffer_sets);
+ // Conceptually the same as AddSetToColocatedBufferSets, but specific to the
+ // colocated buffers for while instructions.
+ void AddWhileSetToColocatedBufferSets(
+ const std::vector<const LogicalBuffer*>& colocated_set,
+ const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo,
+ const HloComputation& computation, const BufferLiveness& buffer_liveness,
+ std::vector<ColocatedBufferSet>* colocated_buffer_sets);
+
const HloModule* module_;
// Function which returns the buffer size for a given logical buffer (shape).
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index bb7342d508..f6637d6098 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -23,10 +23,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
@@ -1245,6 +1247,163 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
}
}
-} // namespace
+class WhileBufferAssignmentTest : public HloTestBase {
+ protected:
+ std::unique_ptr<HloComputation> BuildWhileConditionComputation(
+ const string& name) {
+ auto builder = HloComputation::Builder(name);
+ builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
+ auto zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
+ auto ten = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
+ return builder.Build();
+ }
+
+ std::unique_ptr<HloComputation> BuildWhileBodyComputation(
+ const string& name) {
+ auto builder = HloComputation::Builder(name);
+ auto loop_state = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
+ auto weights = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
+ auto output = builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape_, HloOpcode::kMultiply, input, weights));
+ builder.AddInstruction(
+ HloInstruction::CreateTuple({input, weights, output}));
+ return builder.Build();
+ }
+
+ void RunCopyInsertion(HloModule* module) {
+ CopyInsertion copy_insertion;
+ EXPECT_IS_OK(copy_insertion.Run(module).status());
+ }
+
+ std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
+ int64 alignment = 1) {
+ auto sequence =
+ CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
+ return BufferAssigner::Run(
+ module, MakeUnique<SequentialHloOrdering>(module, sequence),
+ ByteSizeOf, alignment)
+ .ConsumeValueOrDie();
+ }
+
+ static int64 ByteSizeOf(const LogicalBuffer& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
+ }
+
+ Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
+ Shape loop_state_shape_ =
+ ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
+};
+
+TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
+ auto module = MakeUnique<HloModule>(TestName());
+ auto builder = HloComputation::Builder("entry");
+
+ auto input0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape_, "input0"));
+ auto weights0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, data_shape_, "weights0"));
+ auto weights1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, data_shape_, "weights1"));
+
+ auto zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
+ auto output0 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ auto output1 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+
+ auto cond0 =
+ module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
+ auto body0 =
+ module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
+
+ auto tuple0 = builder.AddInstruction(
+ HloInstruction::CreateTuple({input0, weights0, output0}));
+ auto while0 = builder.AddInstruction(
+ HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
+
+ auto cond1 =
+ module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
+ auto body1 =
+ module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
+ auto input1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
+ auto tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({input1, weights1, output1}));
+ auto while1 = builder.AddInstruction(
+ HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
+
+ module->AddEntryComputation(builder.Build());
+ RunCopyInsertion(module.get());
+ auto assignment = RunBufferAssignment(module.get());
+ // While instruction 'while0' has no predecessor while instructions with
+ // which to share allocations.
+
+ // While instruction 'while1' can share allocations with the following
+ // buffers:
+ // *) while0[2], while1[0]
+ // *) while0[1], while1[1]
+ EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
+ assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
+ EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
+ assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
+}
+
+TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
+ auto module = MakeUnique<HloModule>(TestName());
+ auto builder = HloComputation::Builder("entry");
+
+ auto input0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape_, "input0"));
+ auto weights0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, data_shape_, "weights0"));
+
+ auto zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
+ auto output0 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ auto output1 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+
+ auto cond0 =
+ module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
+ auto body0 =
+ module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
+
+ auto tuple0 = builder.AddInstruction(
+ HloInstruction::CreateTuple({input0, weights0, output0}));
+ auto while0 = builder.AddInstruction(
+ HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
+
+ auto cond1 =
+ module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
+ auto body1 =
+ module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
+
+ auto tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({input0, weights0, output1}));
+ auto while1 = builder.AddInstruction(
+ HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
+
+ module->AddEntryComputation(builder.Build());
+ RunCopyInsertion(module.get());
+ auto assignment = RunBufferAssignment(module.get());
+
+ EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
+ assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
+ EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
+ assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
+}
+
+} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 616b239a93..ceb0cdaa31 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -165,4 +165,17 @@ bool HloOpcodeIsComparison(HloOpcode opcode) {
}
}
+bool HloOpcodeIsVariadic(HloOpcode opcode) {
+ switch (opcode) {
+ case HloOpcode::kCall:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kFusion:
+ case HloOpcode::kMap:
+ case HloOpcode::kTuple:
+ return true;
+ default:
+ return false;
+ }
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 978ed5e79b..e2cdbfdfa7 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -104,6 +104,9 @@ inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
// Returns true iff the given opcode is a comparison operation.
bool HloOpcodeIsComparison(HloOpcode opcode);
+// Returns true iff the given opcode has variadic operands.
+bool HloOpcodeIsVariadic(HloOpcode opcode);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 6e3c983071..eb7fe467b3 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -40,6 +40,8 @@ void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module,
} // namespace
StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
+ run_called_ = true;
+
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
std::vector<string> tmp =
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index a8c2d51873..682c4b952d 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -47,6 +47,7 @@ class HloPassPipeline : public HloPassInterface {
// Returns a reference to the added pass.
template <typename T, typename... Args>
T& AddPass(Args&&... args) {
+ CHECK(!run_called_) << "AddPass cannot be called after Run";
auto pass = new T(std::forward<Args>(args)...);
passes_.push_back(std::unique_ptr<T>(pass));
return *pass;
@@ -57,6 +58,7 @@ class HloPassPipeline : public HloPassInterface {
// (it is required to always return "false" from its Run() method).
template <typename T, typename... Args>
T& AddInvariantChecker(Args&&... args) {
+ CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run";
auto pass = new T(std::forward<Args>(args)...);
invariant_checkers_.push_back(std::unique_ptr<T>(pass));
return *pass;
@@ -70,6 +72,7 @@ class HloPassPipeline : public HloPassInterface {
Compiler::HloDumper dumper_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
+ bool run_called_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
};
diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc
index caaf56a551..1f625ae0e2 100644
--- a/tensorflow/compiler/xla/service/liveness_util.cc
+++ b/tensorflow/compiler/xla/service/liveness_util.cc
@@ -101,12 +101,12 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
} // namespace
// User and operand can share buffers iff both instructions emit the same shape
-// and layout, and 'user' meets one of the following two qualifications:
-// *) Is element-wise.
+// and layout, and 'user' meets one of the following qualifications:
+// *) Is element-wise. Or...
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
-// at operand 0.
-// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
+// at operand 0. Or...
+// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0.
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
@@ -144,7 +144,8 @@ bool CanShareOperandBufferWithUser(
break;
}
return false;
- } else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) {
+ } else if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
+ user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
// so here we just need to check that the use is at operand index 0.
std::vector<int64> operand_indices = user->OperandIndices(operand);
diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc
index 2ff71d6f3c..079b59265b 100644
--- a/tensorflow/compiler/xla/service/liveness_util_test.cc
+++ b/tensorflow/compiler/xla/service/liveness_util_test.cc
@@ -185,5 +185,73 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
*points_to_analysis_));
}
+TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape update_shape = ShapeUtil::MakeShape(F32, {4});
+ Shape starts_shape = ShapeUtil::MakeShape(S32, {1});
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "data"));
+ auto update = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, update_shape, "update"));
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, starts_shape, "starts"));
+ auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape, data, update, starts));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ // The DynamicUpdateSlice instruction can share with the data operand, but not
+ // with update or starts.
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_));
+ EXPECT_FALSE(
+ CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_));
+ EXPECT_FALSE(
+ CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+
+ auto make_cond = [this, &data_shape]() {
+ auto builder = HloComputation::Builder(TestName() + ".Cond");
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "data"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
+ return builder.Build();
+ };
+
+ auto make_body = [this, &data_shape]() {
+ auto builder = HloComputation::Builder(TestName() + ".Body");
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "data"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
+ return builder.Build();
+ };
+
+ module_ = MakeUnique<HloModule>(TestName());
+ HloComputation* cond_computation =
+ module_->AddEmbeddedComputation(make_cond());
+ HloComputation* body_computation =
+ module_->AddEmbeddedComputation(make_body());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "data"));
+ auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
+ data_shape, cond_computation, body_computation, data));
+ computation_ = module_->AddEntryComputation(builder.Build());
+
+ RunAnalysis();
+
+ // The While instruction can share with the data operand.
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index 1796a732e5..16d4282466 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -265,6 +265,37 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
*LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
}
+TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
+ auto builder = HloComputation::Builder(TestName());
+ Array3D<float> input_vals(2, 3, 4);
+ input_vals.FillRandom(1.0);
+
+ Array4D<float> expected(2, 3, 4, 5);
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 4; ++k) {
+ for (int m = 0; m < 5; ++m) {
+ expected(i, j, k, m) = input_vals(i, j, k);
+ }
+ }
+ }
+ }
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR3FromArray3D<float>(input_vals)));
+
+ // Broadcast vector in dimensions 2 and 3.
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 34fce21758..d00a317534 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -211,9 +211,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); }
XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); }
XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); }
XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); }
-XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); }
XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); }
+XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); }
XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); }
XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); }
@@ -221,6 +221,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); }
XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) {
RunR1ToR0Test(16 * 1024 + 1);
}
+XLA_TEST_F(ReduceTest, ReduceR1_64K_F32_To_R0) { RunR1ToR0Test(64 * 1024); }
+XLA_TEST_F(ReduceTest, ReduceR1_1M_F32_To_R0) { RunR1ToR0Test(1024 * 1024); }
+XLA_TEST_F(ReduceTest, ReduceR1_16M_F32_To_R0) { RunR1ToR0Test(4096 * 4096); }
XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); }
XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); }
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 46eab7f02b..ab598b8edd 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -176,6 +176,52 @@ cc_binary(
],
)
+cc_library(
+ name = "hlo_tfgraph_builder",
+ srcs = ["hlo_tfgraph_builder.cc"],
+ hdrs = ["hlo_tfgraph_builder.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "hlo_tfgraph_builder_test",
+ srcs = ["hlo_tfgraph_builder_test.cc"],
+ deps = [
+ ":hlo_tfgraph_builder",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_binary(
+ name = "dumped_computation_to_tf_graphdef",
+ srcs = ["dumped_computation_to_tf_graphdef.cc"],
+ deps = [
+ ":hlo_tfgraph_builder",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:hlo_graph_dumper",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
# -----------------------------------------------------------------------------
filegroup(
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
new file mode 100644
index 0000000000..1aa769ee5a
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
@@ -0,0 +1,139 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Usage: dumped_computation_to_tf_graph \
+// --output_dir=/tmp/graphs/ some_binary_snapshot_proto*
+//
+// Dumps a tensorflow GraphDef in text format for a snapshot computation. The
+// dumped graph is an HLO computation with HLO instructions as nodes and can be
+// visualized on Tensorboard. Upload the dumped files on Tensorboard.
+//
+// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// ServiceInterface::SnapshotComputation to disk.
+
+#include <stdio.h>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/service.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::Env;
+using tensorflow::io::JoinPath;
+using tensorflow::strings::StrAppend;
+
+namespace xla {
+namespace tools {
+namespace {
+
+// Dumps all computations in the module to the given directory.
+void DumpTfGraph(const HloModule& module, const string& directory_path) {
+ Env* env = Env::Default();
+ TF_CHECK_OK(env->RecursivelyCreateDir(directory_path));
+ string fname = module.name();
+ std::replace(fname.begin(), fname.end(), '/', '_');
+ // Since the file name will be used as the top-level scope name, clean it up
+ // to make it a valid scope name.
+ CleanNodeName(&fname);
+ StrAppend(&fname, ".pbtxt");
+ string path = JoinPath(directory_path, fname);
+ HloTfGraphBuilder builder;
+ TF_CHECK_OK(builder.AddComputation(*module.entry_computation()));
+ std::cout << "Dumping " << module.name() << " to " << path << std::endl;
+ TF_CHECK_OK(WriteTextProto(env, path, builder.GetGraphDef()));
+}
+
+} // namespace
+
+void RealMain(tensorflow::gtl::ArraySlice<char*> args,
+ const string& output_dir) {
+ LocalClient* client = ClientLibrary::LocalClientOrDie();
+ // To avoid adding a new flag, use local service and lower the computations
+ // locally.
+ LocalService* local_service =
+ ClientLibrary::GetXlaService(client->platform());
+ // Build HloModule for each Computation and dump to file.
+ for (char* arg : args) {
+ SessionModule session_module;
+ TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
+ &session_module));
+ auto computation_status = client->LoadSnapshot(session_module);
+ if (!computation_status.ok()) {
+ fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
+ computation_status.status().ToString().c_str());
+ continue;
+ }
+ Computation computation = computation_status.ConsumeValueOrDie();
+
+ StatusOr<UserComputation*> user_computation_status =
+ local_service->computation_tracker().Resolve(computation.handle());
+ if (!user_computation_status.ok()) {
+ fprintf(stderr,
+ "failed to resolve computation to UserComputation %s: %s\n", arg,
+ user_computation_status.status().ToString().c_str());
+ continue;
+ }
+
+ auto* user_computation = user_computation_status.ValueOrDie();
+ StatusOr<std::unique_ptr<HloModule>> module_status =
+ local_service->computation_tracker().BuildHloModule(
+ user_computation->GetVersionedHandle());
+
+ if (!module_status.ok()) {
+ fprintf(stderr, "failed to build HloModule %s: %s\n", arg,
+ module_status.status().ToString().c_str());
+ continue;
+ }
+
+ DumpTfGraph(*module_status.ValueOrDie(), output_dir);
+ }
+}
+
+} // namespace tools
+} // namespace xla
+
+int main(int argc, char** argv) {
+ string output_dir = "";
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("output_dir", &output_dir,
+ "Directory to write GraphDef data to."),
+ };
+
+ string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_ok || output_dir.empty()) {
+ LOG(QFATAL) << usage;
+ }
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ args.pop_front(); // Pop off the binary name, argv[0]
+ xla::tools::RealMain(args, output_dir);
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc
new file mode 100644
index 0000000000..fe835a20c4
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+LIcensed under the Apache License, Version 2.0 (the "License");
+You may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+using ::tensorflow::GraphDef;
+using ::tensorflow::NodeDef;
+using ::tensorflow::TensorShapeProto;
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+using ::tensorflow::str_util::Join;
+
+namespace xla {
+namespace tools {
+namespace {
+
+string GetOpDefName(const HloInstruction* instruction) {
+ string name = StrCat("hlo-", HloOpcodeString(instruction->opcode()));
+ tensorflow::str_util::TitlecaseString(&name, "-");
+ name.erase(std::remove(name.begin(), name.end(), '-'), name.end());
+
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ string fusion_name = ToString(instruction->fusion_kind());
+ StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1));
+ }
+ return name;
+}
+
+TensorShapeProto GetTensorShape(const HloInstruction* instruction) {
+ TensorShapeProto tensor_shape;
+ const Shape& shape = instruction->shape();
+ for (auto dim : shape.dimensions()) {
+ tensor_shape.add_dim()->set_size(dim);
+ }
+ return tensor_shape;
+}
+
+} // namespace
+
+void CleanNodeName(string* name) {
+ name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
+ const string chars_to_replace = "<>[]";
+ auto pred = [&](char c) {
+ return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) !=
+ chars_to_replace.end();
+ };
+ std::replace_if(name->begin(), name->end(), pred, '_');
+}
+
+Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
+ LOG(INFO) << "Adding computation " << computation.name();
+ for (auto embedded : computation.MakeEmbeddedComputationsList()) {
+ LOG(INFO) << "Adding embedded computation " << embedded->name();
+ for (auto& instruction : embedded->instructions()) {
+ TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
+ }
+ }
+ for (auto& instruction : computation.instructions()) {
+ TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
+ }
+ return Status::OK();
+}
+
+const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; }
+
+const string& HloTfGraphBuilder::GetNodeNameForInstruction(
+ const HloInstruction* instruction) {
+ if (ContainsKey(instruction_to_node_name_, instruction)) {
+ return instruction_to_node_name_[instruction];
+ }
+ // If an instruction is fused, put it in the subgraph of the fusion;
+ // otherwise, put it in the computation subgraph.
+ string node_name =
+ instruction->IsFused()
+ ? GetNodeNameForInstruction(instruction->fusion_instruction())
+ : instruction->parent()->name();
+ string instruction_name = instruction->name();
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ StrAppend(&instruction_name, ".", instruction->parameter_number());
+ }
+ StrAppend(&node_name, "/", instruction_name);
+ CleanNodeName(&node_name);
+ auto ret =
+ instruction_to_node_name_.insert(std::make_pair(instruction, node_name));
+ CHECK(ret.second);
+ return ret.first->second;
+}
+
+void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
+ NodeDef* node_def) const {
+ auto& attrs = *node_def->mutable_attr();
+
+ // Set the number of arguments for instructions that have variadic operands.
+ if (HloOpcodeIsVariadic(instruction->opcode())) {
+ tensorflow::AttrValue attr_value;
+ attr_value.set_i(instruction->operands().size());
+ attrs["arg_num"] = attr_value;
+ }
+
+ // Set the node type.
+ attrs["type"].set_s(
+ xla::PrimitiveType_Name(instruction->shape().element_type()));
+
+ // Set the shape of the output tensor. "_output_shapes" is a special attribute
+ // name used by Tensorboard for shapes of output tensors.
+ tensorflow::AttrValue shapes;
+ *shapes.mutable_list()->add_shape() = GetTensorShape(instruction);
+ attrs["_output_shapes"] = shapes;
+
+ // Set the layout.
+ if (LayoutUtil::HasLayout(instruction->shape())) {
+ string layout_string;
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ // For tuples, emit the full shape because the layout of a tuple is not
+ // represented in a single Layout field.
+ layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape());
+ } else {
+ layout_string = StrCat(
+ "{", Join(instruction->shape().layout().minor_to_major(), ","), "}");
+ }
+ attrs["layout"].set_s(layout_string);
+ }
+
+ // Set op-specific attributes.
+ switch (instruction->opcode()) {
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReverse:
+ case HloOpcode::kTranspose:
+ for (auto dim : instruction->dimensions()) {
+ attrs["dims"].mutable_list()->add_i(dim);
+ }
+ break;
+ case HloOpcode::kGetTupleElement:
+ attrs["index"].set_i(instruction->tuple_index());
+ break;
+ case HloOpcode::kRng:
+ attrs["dist"].set_s(
+ RandomDistribution_Name(instruction->random_distribution()));
+ break;
+ case HloOpcode::kConstant:
+ if (ShapeUtil::IsScalar(instruction->shape())) {
+ attrs["value"].set_s(
+ LiteralUtil::GetAsString(instruction->literal(), {}));
+ }
+ break;
+ case HloOpcode::kCustomCall:
+ attrs["custom_call_target"].set_s(instruction->custom_call_target());
+ break;
+ default:
+ break;
+ }
+}
+
+Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
+ if (!visited_instructions_.insert(instruction).second) {
+ // Skip instructions that have already been added.
+ return Status::OK();
+ }
+
+ NodeDef* node_def = graph_def_.add_node();
+ node_def->set_name(GetNodeNameForInstruction(instruction));
+ node_def->set_op(GetOpDefName(instruction));
+ SetNodeAttrs(instruction, node_def);
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ for (auto& fused_instruction : instruction->fused_instructions()) {
+ TF_RETURN_IF_ERROR(AddInstruction(fused_instruction.get()));
+ }
+ }
+ // Add all edges including control edges.
+ for (unsigned i = 0; i < instruction->operands().size(); ++i) {
+ *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i));
+ }
+ // Called computations are control dependencies.
+ for (const auto* called_computation : instruction->called_computations()) {
+ *node_def->add_input() = StrCat(
+ "^", GetNodeNameForInstruction(called_computation->root_instruction()));
+ }
+ return Status::OK();
+}
+
+} // namespace tools
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h
new file mode 100644
index 0000000000..3052eae113
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace xla {
+namespace tools {
+
+// This constructs a tensorflow graph for HLO computations.
+class HloTfGraphBuilder {
+ public:
+ // Adds a computation to the graph.
+ Status AddComputation(const HloComputation& computation);
+
+ const tensorflow::GraphDef& GetGraphDef() const;
+
+ private:
+ // Gets the node name of an instruction. The node name is hierarchical. For
+ // example, if an instruction is fused, it will be put in a subgraph of the
+ // fusion instruction.
+ const string& GetNodeNameForInstruction(const HloInstruction* instruction);
+
+ void SetNodeAttrs(const HloInstruction* instruction,
+ tensorflow::NodeDef* node_def) const;
+
+ Status AddInstruction(const HloInstruction* instruction);
+
+ tensorflow::GraphDef graph_def_;
+ // This records instructions that have been visited.
+ std::unordered_set<const HloInstruction*> visited_instructions_;
+ // A cache that maps instruction to the node name.
+ std::unordered_map<const HloInstruction*, string> instruction_to_node_name_;
+};
+
+// Cleans the node name to make it a valid name in a tensorflow graph.
+void CleanNodeName(string* name);
+
+} // namespace tools
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc
new file mode 100644
index 0000000000..626bcc6d85
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc
@@ -0,0 +1,154 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace tools {
+namespace {
+
+using ::tensorflow::GraphDef;
+
+class HloTfGraphBuilderTest : public HloTestBase {
+ protected:
+ HloTfGraphBuilderTest() {}
+ HloTfGraphBuilder generator_;
+
+ // Create a computation which takes a scalar and returns its negation.
+ std::unique_ptr<HloComputation> CreateNegateComputation() {
+ auto builder = HloComputation::Builder("Negate");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
+ return builder.Build();
+ }
+
+ // Creates a computation which calls map with the given computation.
+ std::unique_ptr<HloComputation> CreateMapComputation(
+ HloComputation* map_computation) {
+ auto builder = HloComputation::Builder("Map");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(r0f32_, {param}, map_computation));
+ return builder.Build();
+ }
+ Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+};
+
+TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
+ auto builder = HloComputation::Builder("Concatenate");
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
+ auto param_1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param0"));
+ auto param_2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "param1"));
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ ShapeUtil::MakeShape(F32, {2, 4}), {param_1, param_2}, 1));
+ TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
+ GraphDef graph_def = generator_.GetGraphDef();
+ EXPECT_EQ(graph_def.node_size(), 3);
+ const auto &node = graph_def.node(2);
+ EXPECT_EQ(node.name(), "Concatenate/concatenate");
+
+ // Check dimensions.
+ auto dims_value = node.attr().find("dims");
+ CHECK(dims_value != node.attr().end());
+ EXPECT_EQ(dims_value->second.list().i_size(), 1);
+ EXPECT_EQ(dims_value->second.list().i(0), 1);
+
+ // Check shapes.
+ auto shape_value = node.attr().find("_output_shapes");
+ CHECK(shape_value != node.attr().end());
+ EXPECT_EQ(shape_value->second.list().shape_size(), 1);
+ EXPECT_EQ(shape_value->second.list().shape(0).dim_size(), 2);
+ EXPECT_EQ(shape_value->second.list().shape(0).dim(0).size(), 2);
+ EXPECT_EQ(shape_value->second.list().shape(0).dim(1).size(), 4);
+}
+
+TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
+ auto builder = HloComputation::Builder("Const");
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(123)));
+ TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
+ GraphDef graph_def = generator_.GetGraphDef();
+ EXPECT_EQ(graph_def.node_size(), 1);
+ const auto &node = graph_def.node(0);
+ auto value = node.attr().find("value");
+ CHECK(value != node.attr().end());
+ EXPECT_EQ(value->second.s(), "123");
+
+ auto type = node.attr().find("type");
+ CHECK(type != node.attr().end());
+ EXPECT_EQ(type->second.s(), "S32");
+}
+
+TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {
+ auto negate_computation = CreateNegateComputation();
+ TF_CHECK_OK(generator_.AddComputation(*negate_computation));
+ GraphDef graph_def = generator_.GetGraphDef();
+ EXPECT_EQ(graph_def.node_size(), 2);
+ EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0");
+ EXPECT_EQ(graph_def.node(0).op(), "HloParameter");
+ EXPECT_EQ(graph_def.node(1).name(), "Negate/negate");
+ EXPECT_EQ(graph_def.node(1).op(), "HloNegate");
+ EXPECT_EQ(graph_def.node(1).input_size(), 1);
+ EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0");
+}
+
+TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) {
+ auto builder = HloComputation::Builder("GE");
+ auto param_1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ auto param_2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32_, "param1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2));
+ TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
+ GraphDef graph_def = generator_.GetGraphDef();
+ EXPECT_EQ(graph_def.node_size(), 3);
+ EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0");
+ EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1");
+ EXPECT_EQ(graph_def.node(2).input_size(), 2);
+ EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to");
+ EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo");
+}
+
+TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) {
+ // Create computations with a diamond-shaped callgraph.
+ auto negate_computation = CreateNegateComputation();
+ auto map1_computation = CreateMapComputation(negate_computation.get());
+ auto map2_computation = CreateMapComputation(negate_computation.get());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ auto map1 = builder.AddInstruction(
+ HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get()));
+ auto map2 = builder.AddInstruction(
+ HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get()));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
+ auto computation = builder.Build();
+ TF_CHECK_OK(generator_.AddComputation(*computation));
+ EXPECT_GT(generator_.GetGraphDef().node_size(), 0);
+}
+
+} // namespace
+} // namespace tools
+} // namespace xla
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
index 81e40dbe5e..c7f185aab8 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
@@ -42,12 +42,10 @@ class StochasticTensorTest(test.TestCase):
sigma2 = constant_op.constant([0.1, 0.2, 0.3])
prior_default = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma))
+ distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior_default.value_type, st.SampleValue))
prior_0 = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma),
+ distributions.Normal(loc=mu, scale=sigma),
dist_value_type=st.SampleValue())
self.assertTrue(isinstance(prior_0.value_type, st.SampleValue))
@@ -55,8 +53,7 @@ class StochasticTensorTest(test.TestCase):
prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior.value_type, st.SampleValue))
likelihood = st.StochasticTensor(
- distributions.Normal(
- loc=prior, scale=sigma2))
+ distributions.Normal(loc=prior, scale=sigma2))
self.assertTrue(isinstance(likelihood.value_type, st.SampleValue))
coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
@@ -102,8 +99,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue()):
prior_single = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma))
+ distributions.Normal(loc=mu, scale=sigma))
prior_single_value = prior_single.value()
self.assertEqual(prior_single_value.get_shape(), (2, 3))
@@ -113,8 +109,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue(1)):
prior_single = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma))
+ distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
prior_single_value = prior_single.value()
@@ -125,8 +120,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue(2)):
prior_double = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma))
+ distributions.Normal(loc=mu, scale=sigma))
prior_double_value = prior_double.value()
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
@@ -163,8 +157,7 @@ class StochasticTensorTest(test.TestCase):
# With passed-in loss_fn.
dt = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma),
+ distributions.Normal(loc=mu, scale=sigma),
dist_value_type=st.MeanValue(stop_gradient=True),
loss_fn=sge.get_score_function_with_constant_baseline(
baseline=constant_op.constant(8.0)))
@@ -199,8 +192,7 @@ class ObservedStochasticTensorTest(test.TestCase):
sigma = constant_op.constant([1.1, 1.2, 1.3])
obs = array_ops.zeros((2, 3))
z = st.ObservedStochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma), value=obs)
+ distributions.Normal(loc=mu, scale=sigma), value=obs)
[obs_val, z_val] = sess.run([obs, z.value()])
self.assertAllEqual(obs_val, z_val)
@@ -212,15 +204,13 @@ class ObservedStochasticTensorTest(test.TestCase):
sigma = array_ops.placeholder(dtypes.float32)
obs = array_ops.placeholder(dtypes.float32)
z = st.ObservedStochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma), value=obs)
+ distributions.Normal(loc=mu, scale=sigma), value=obs)
mu2 = array_ops.placeholder(dtypes.float32, shape=[None])
sigma2 = array_ops.placeholder(dtypes.float32, shape=[None])
obs2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
z2 = st.ObservedStochasticTensor(
- distributions.Normal(
- loc=mu2, scale=sigma2), value=obs2)
+ distributions.Normal(loc=mu2, scale=sigma2), value=obs2)
coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
self.assertEqual(coll, [z, z2])
@@ -231,22 +221,18 @@ class ObservedStochasticTensorTest(test.TestCase):
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
- distributions.Normal(
- loc=mu, scale=sigma),
+ distributions.Normal(loc=mu, scale=sigma),
value=array_ops.zeros((3,)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
- distributions.Normal(
- loc=mu, scale=sigma),
+ distributions.Normal(loc=mu, scale=sigma),
value=array_ops.zeros((3, 1)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
- distributions.Normal(
- loc=mu, scale=sigma),
- value=array_ops.zeros(
- (1, 2), dtype=dtypes.int32))
+ distributions.Normal(loc=mu, scale=sigma),
+ value=array_ops.zeros((1, 2), dtype=dtypes.int32))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 470b9edb79..e17197080a 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -135,8 +135,9 @@ from tensorflow.contrib.distributions.python.ops.wishart import *
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ['ConditionalDistribution',
- 'ConditionalTransformedDistribution',
- 'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED']
+_allowed_symbols = [
+ 'ConditionalDistribution', 'ConditionalTransformedDistribution',
+ 'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED'
+]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
index 71460a1769..5d6e4d9197 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
@@ -488,9 +488,7 @@ class AffineBijectorTest(test.TestCase):
shift=mu,
scale_identity_multiplier=2.,
scale_perturb_diag=[2., 1],
- scale_perturb_factor=[[2., 0],
- [0., 0],
- [0, 1]])
+ scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
bijector_ref = affine_lib.Affine(shift=mu, scale_diag=[10., 2, 3])
self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
@@ -526,9 +524,7 @@ class AffineBijectorTest(test.TestCase):
shift=mu,
scale_diag=[2., 3, 4],
scale_perturb_diag=[2., 1],
- scale_perturb_factor=[[2., 0],
- [0., 0],
- [0, 1]])
+ scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
bijector_ref = affine_lib.Affine(shift=mu, scale_diag=[10., 3, 5])
self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
@@ -561,17 +557,11 @@ class AffineBijectorTest(test.TestCase):
# Corresponds to scale = [[10, 0, 0], [1, 3, 0], [2, 3, 5]]
bijector = affine_lib.Affine(
shift=mu,
- scale_tril=[[2., 0, 0],
- [1, 3, 0],
- [2, 3, 4]],
+ scale_tril=[[2., 0, 0], [1, 3, 0], [2, 3, 4]],
scale_perturb_diag=[2., 1],
- scale_perturb_factor=[[2., 0],
- [0., 0],
- [0, 1]])
+ scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
bijector_ref = affine_lib.Affine(
- shift=mu, scale_tril=[[10., 0, 0],
- [1, 3, 0],
- [2, 3, 5]])
+ shift=mu, scale_tril=[[10., 0, 0], [1, 3, 0], [2, 3, 5]])
self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
index ecf068bf6b..cb514e625b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
@@ -70,7 +70,8 @@ class ChainBijectorTest(test.TestCase):
softmax_centered_lib.SoftmaxCentered(
event_ndims=1, validate_args=True),
softmax_centered_lib.SoftmaxCentered(
- event_ndims=0, validate_args=True)])
+ event_ndims=0, validate_args=True)
+ ])
x = tensor_shape.TensorShape([])
y = tensor_shape.TensorShape([2 + 1])
self.assertAllEqual(y, bijector.forward_event_shape(x))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
index e16f9dff22..40018de63f 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
@@ -36,17 +36,19 @@ class SigmoidBijectorTest(test.TestCase):
y = special.expit(x)
ildj = -np.log(y) - np.log1p(-y)
self.assertAllClose(
- y, sigmoid.Sigmoid().forward(x).eval(),
- atol=0., rtol=1e-2)
+ y, sigmoid.Sigmoid().forward(x).eval(), atol=0., rtol=1e-2)
self.assertAllClose(
- x, sigmoid.Sigmoid().inverse(y).eval(),
- atol=0., rtol=1e-4)
+ x, sigmoid.Sigmoid().inverse(y).eval(), atol=0., rtol=1e-4)
self.assertAllClose(
- ildj, sigmoid.Sigmoid().inverse_log_det_jacobian(y).eval(),
- atol=0., rtol=1e-6)
+ ildj,
+ sigmoid.Sigmoid().inverse_log_det_jacobian(y).eval(),
+ atol=0.,
+ rtol=1e-6)
self.assertAllClose(
- -ildj, sigmoid.Sigmoid().forward_log_det_jacobian(x).eval(),
- atol=0., rtol=1e-4)
+ -ildj,
+ sigmoid.Sigmoid().forward_log_det_jacobian(x).eval(),
+ atol=0.,
+ rtol=1e-4)
def testScalarCongruency(self):
with self.test_session():
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
index 7fee2e1f3a..e3f6ddd8c0 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
@@ -171,11 +171,12 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution):
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits, probs=probs, validate_args=validate_args)
super(RelaxedBernoulli, self).__init__(
- distribution=logistic.Logistic(self._logits / self._temperature,
- 1. / self._temperature,
- validate_args=validate_args,
- allow_nan_stats=allow_nan_stats,
- name=name + "/Logistic"),
+ distribution=logistic.Logistic(
+ self._logits / self._temperature,
+ 1. / self._temperature,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ name=name + "/Logistic"),
bijector=sigmoid_lib.Sigmoid(validate_args=validate_args),
validate_args=validate_args,
name=name)
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py
index d7c646c19a..d149138796 100644
--- a/tensorflow/contrib/keras/python/keras/backend.py
+++ b/tensorflow/contrib/keras/python/keras/backend.py
@@ -3614,7 +3614,7 @@ _config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
try:
_config = json.load(open(_config_path))
- except json.decoder.JSONDecodeError:
+ except ValueError:
_config = {}
_floatx = _config.get('floatx', floatx())
assert _floatx in {'float16', 'float32', 'float64'}
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 0140f6d0d3..13cabe6e04 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -379,7 +379,10 @@ def batch_norm(inputs,
fused=False,
data_format=DATA_FORMAT_NHWC,
zero_debias_moving_mean=False,
- scope=None):
+ scope=None,
+ renorm=False,
+ renorm_clipping=None,
+ renorm_decay=0.99):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
"Batch Normalization: Accelerating Deep Network Training by Reducing
@@ -446,6 +449,19 @@ def batch_norm(inputs,
zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
scope: Optional scope for `variable_scope`.
+ renorm: Whether to use Batch Renormalization
+ (https://arxiv.org/abs/1702.03275). This adds extra variables during
+ training. The inference is the same for either value of this parameter.
+ renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
+ scalar `Tensors` used to clip the renorm correction. The correction
+ `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
+ `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
+ dmax are set to inf, 0, inf, respectively.
+ renorm_decay: Momentum used to update the moving means and standard
+ deviations with renorm. Unlike `momentum`, this affects training
+ and should be neither too small (which would add noise) nor too large
+ (which would give stale estimates). Note that `decay` is still applied
+ to get the means and variances for inference.
Returns:
A `Tensor` representing the output of the operation.
@@ -464,6 +480,8 @@ def batch_norm(inputs,
if param_regularizers is not None:
raise ValueError('Regularizers are not currently '
'supported for fused batch norm.')
+ if renorm:
+ raise ValueError('Renorm is not supported for fused batch norm.')
return _fused_batch_norm(
inputs,
decay=decay,
@@ -524,6 +542,9 @@ def batch_norm(inputs,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
trainable=trainable,
+ renorm=renorm,
+ renorm_clipping=renorm_clipping,
+ renorm_momentum=renorm_decay,
name=sc.name,
_scope=sc,
_reuse=reuse)
@@ -551,6 +572,9 @@ def batch_norm(inputs,
# Custom updates collections are not supported because the update logic
# is different in this case, in particular w.r.t. "forced updates" and
# update op reuse.
+ if renorm:
+ raise ValueError('renorm is not supported with batch_weights, '
+ 'updates_collections or zero_debias_moving_mean')
inputs_shape = inputs.get_shape()
inputs_rank = inputs_shape.ndims
if inputs_rank is None:
@@ -1241,6 +1265,13 @@ def flatten(inputs,
def _sparse_inner_flatten(inputs, new_rank):
"""Helper function for `inner_flatten`."""
+ inputs_rank = inputs.dense_shape.get_shape().as_list()[0]
+ if inputs_rank < new_rank:
+ raise ValueError(
+ 'Inputs has rank less than new_rank. {} must have rank at least'
+ ' {}. Received rank {}, shape {}'.format(inputs, new_rank, inputs_rank,
+ inputs.get_shape()))
+
outer_dimensions = inputs.dense_shape[:new_rank - 1]
inner_dimensions = inputs.dense_shape[new_rank - 1:]
new_shape = array_ops.concat((outer_dimensions,
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 2b170e92ba..ee4ebf2c43 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1465,6 +1465,30 @@ class PartialFlattenTest(test.TestCase):
flattened5 = _layers._inner_flatten(inputs, 5)
self.assertEqual([2, None, 4, None, 30], flattened5.get_shape().as_list())
+ def testDenseFlattenRankAssertion(self):
+ """Test `_inner_flatten` rank assertion for dense tensors."""
+ shape = [2, 3]
+ new_rank = 3
+ inputs = array_ops.placeholder(dtypes.int32)
+ inputs.set_shape(shape)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'inputs has rank less than new_rank'):
+ _layers._inner_flatten(inputs, new_rank)
+
+ def testSparseFlattenRankAssertion(self):
+ """Test `_inner_flatten` rank assertion for sparse tensors."""
+ shape = [2, 3]
+ new_rank = 3
+ np.random.seed(10301)
+ random_ = np.random.rand(*shape)
+ indices, values, _ = _sparsify(random_)
+ inputs = sparse_tensor.SparseTensor(indices, values, shape)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'Inputs has rank less than new_rank'):
+ _layers._inner_flatten(inputs, new_rank)
+
class FCTest(test.TestCase):
diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py
index 01262ff5f8..fd50070dac 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py
@@ -27,7 +27,8 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.python.framework import dtypes
-SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
+# CVDF mirror of http://yann.lecun.com/exdb/mnist/
+SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
def _read32(bytestream):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 107454dca1..29ea692f8f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -362,6 +362,11 @@ class BaseEstimator(
self._config = config
logging.info('Using config: %s', str(vars(self._config)))
+ if self._config.session_config is None:
+ self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ else:
+ self._session_config = self._config.session_config
+
# Model directory.
if (model_dir is not None) and (self._config.model_dir is not None):
if model_dir != self._config.model_dir:
@@ -829,7 +834,7 @@ class BaseEstimator(
eval_ops=update_op,
final_ops=eval_dict,
hooks=hooks,
- config=config_pb2.ConfigProto(allow_soft_placement=True))
+ config=self._session_config)
current_global_step = eval_results[global_step_key]
_write_dict_to_summary(eval_dir, eval_results, current_global_step)
@@ -864,7 +869,7 @@ class BaseEstimator(
session_creator=monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
scaffold=infer_ops.scaffold,
- config=config_pb2.ConfigProto(allow_soft_placement=True)))
+ config=self._session_config))
if not as_iterable:
with mon_sess:
if not mon_sess.should_stop():
@@ -976,7 +981,7 @@ class BaseEstimator(
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
- config=config_pb2.ConfigProto(allow_soft_placement=True)
+ config=self._session_config
) as mon_sess:
loss = None
while not mon_sess.should_stop():
diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
index c0a3918549..c56741a4d1 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
@@ -53,12 +53,17 @@ class ModeKeys(object):
EVAL = 'eval'
INFER = 'infer'
+ @classmethod
+ def validate(cls, key):
+ if key not in (cls.TRAIN, cls.EVAL, cls.INFER):
+ raise ValueError('Invalid mode %s.' % key)
+
class ModelFnOps(
collections.namedtuple('ModelFnOps', [
'predictions', 'loss', 'train_op', 'eval_metric_ops',
'output_alternatives', 'training_chief_hooks', 'training_hooks',
- 'scaffold'
+ 'scaffold', 'mode'
])):
"""Ops returned from a model_fn."""
@@ -119,6 +124,8 @@ class ModelFnOps(
Raises:
ValueError: If validation fails.
"""
+ ModeKeys.validate(mode)
+
# Assert all ops are from the same graph.
get_graph_from_inputs((predictions, loss, train_op))
@@ -183,14 +190,13 @@ class ModelFnOps(
output_alternatives=output_alternatives,
training_chief_hooks=training_chief_hooks,
training_hooks=training_hooks,
- scaffold=scaffold)
+ scaffold=scaffold,
+ mode=mode)
- def estimator_spec(self, mode, default_serving_output_alternative_key=None):
+ def estimator_spec(self, default_serving_output_alternative_key=None):
"""Creates an equivalent `EstimatorSpec`.
Args:
- mode: One of `ModeKeys`. Specifies if this training, evaluation or
- prediction.
default_serving_output_alternative_key: Required for multiple heads. If
you have multiple entries in `output_alternatives` dict (comparable to
multiple heads), `EstimatorSpec` requires a default head that will be
@@ -265,7 +271,7 @@ class ModelFnOps(
return result
return core_model_fn_lib.EstimatorSpec(
- mode=mode,
+ mode=self.mode,
predictions=self.predictions,
loss=self.loss,
train_op=self.train_op,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py
index 51b32359a3..4f76013a2a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py
@@ -80,18 +80,20 @@ class ModelFnopsTest(test.TestCase):
def testEstimatorSpec_except_export(self):
predictions = self.create_predictions()
- model_fn_ops = self.create_model_fn_ops(predictions, None)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, None, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
def testEstimatorSpec_export_regression_with_scores(self):
predictions = self.create_predictions()
output_alternatives = {"regression_head": (
constants.ProblemType.LINEAR_REGRESSION, predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@@ -108,9 +110,10 @@ class ModelFnopsTest(test.TestCase):
output_alternatives = {"regression_head": (
constants.ProblemType.LINEAR_REGRESSION,
output_alternatives_predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@@ -124,9 +127,10 @@ class ModelFnopsTest(test.TestCase):
predictions = self.create_predictions()
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@@ -145,9 +149,10 @@ class ModelFnopsTest(test.TestCase):
del output_alternatives_predictions["scores"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@@ -167,9 +172,10 @@ class ModelFnopsTest(test.TestCase):
del output_alternatives_predictions["probabilities"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@@ -187,9 +193,10 @@ class ModelFnopsTest(test.TestCase):
del output_alternatives_predictions["classes"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@@ -208,9 +215,10 @@ class ModelFnopsTest(test.TestCase):
[1, 2, 3])
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@@ -226,9 +234,10 @@ class ModelFnopsTest(test.TestCase):
predictions = self.create_predictions()
output_alternatives = {"logistic_head": (
constants.ProblemType.LOGISTIC_REGRESSION, predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@@ -245,9 +254,10 @@ class ModelFnopsTest(test.TestCase):
output_alternatives = {"unspecified_head": (
constants.ProblemType.UNSPECIFIED, predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
+ estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@@ -263,10 +273,10 @@ class ModelFnopsTest(test.TestCase):
constants.ProblemType.LINEAR_REGRESSION, predictions),
"classification_head": (
constants.ProblemType.CLASSIFICATION, predictions)}
- model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
+ model_fn_ops = self.create_model_fn_ops(
+ predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
- estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER,
- "regression_head")
+ estimator_spec = model_fn_ops.estimator_spec("regression_head")
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
index bc7465bbc2..37ee814b62 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
@@ -214,7 +214,8 @@ class RunConfig(ClusterConfig):
keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000,
evaluation_master='',
- model_dir=None):
+ model_dir=None,
+ session_config=None):
"""Constructor.
Note that the superclass `ClusterConfig` may set properties like
@@ -246,6 +247,9 @@ class RunConfig(ClusterConfig):
evaluation_master: the master on which to perform evaluation.
model_dir: directory where model parameters, graph etc are saved. If
`None`, see `Estimator` about where the model will be saved.
+ session_config: a ConfigProto used to set session parameters, or None.
+ Note - using this argument, it is easy to provide settings which break
+ otherwise perfectly good models. Use with care.
"""
super(RunConfig, self).__init__(
master=master, evaluation_master=evaluation_master)
@@ -261,6 +265,7 @@ class RunConfig(ClusterConfig):
self._tf_random_seed = tf_random_seed
self._save_summary_steps = save_summary_steps
self._save_checkpoints_secs = save_checkpoints_secs
+ self._session_config = session_config
if save_checkpoints_secs == RunConfig._USE_DEFAULT:
if save_checkpoints_steps is None:
self._save_checkpoints_secs = 600
@@ -346,6 +351,10 @@ class RunConfig(ClusterConfig):
return self._save_checkpoints_steps
@property
+ def session_config(self):
+ return self._session_config
+
+ @property
def keep_checkpoint_max(self):
return self._keep_checkpoint_max
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index cecc24c17d..4f7c72c9dd 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -118,7 +118,8 @@ class Experiment(object):
occur if no new snapshot is available, hence, this is the minimum.
delay_workers_by_global_step: if `True` delays training workers
based on global step instead of time.
- export_strategies: A list of `ExportStrategy`s, or a single one, or None.
+ export_strategies: Iterable of `ExportStrategy`s, or a single one, or
+ `None`.
train_steps_per_iteration: (applies only to continuous_train_and_eval).
Perform this many (integer) number of train steps for each
training-evaluation iteration. With a small value, the model will be
@@ -184,16 +185,19 @@ class Experiment(object):
def eval_steps(self):
return self._eval_steps
- def _set_export_strategies(self, value):
- if value is None:
- self._export_strategies = []
- elif isinstance(value, list):
- self._export_strategies = value[:]
- elif isinstance(value, export_strategy.ExportStrategy):
- self._export_strategies = [value]
- else:
- raise ValueError("`export_strategies` must be an ExportStrategy, "
- "a list of ExportStrategies, or None.")
+ def _set_export_strategies(self, values): # pylint: disable=missing-docstring
+ export_strategies = []
+ if values:
+ if isinstance(values, export_strategy.ExportStrategy):
+ export_strategies.append(values)
+ else:
+ for value in values:
+ if not isinstance(value, export_strategy.ExportStrategy):
+ raise ValueError("`export_strategies` must be an ExportStrategy,"
+ " an iterable of ExportStrategy, or `None`,"
+ " found %s." % value)
+ export_strategies.append(value)
+ self._export_strategies = tuple(export_strategies)
def extend_train_hooks(self, additional_hooks):
"""Extends the hooks for training."""
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index 00ed062b0a..f9f95f8e67 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -484,6 +484,25 @@ class ExperimentTest(test.TestCase):
self.assertAllEqual([noop_hook, another_noop_hook], ex._train_monitors)
self.assertAllEqual([noop_hook], input_hooks)
+ def test_invalid_export_strategies(self):
+ for est in self._estimators_for_tests():
+ with self.assertRaisesRegexp(ValueError, 'ExportStrategy'):
+ experiment.Experiment(
+ est,
+ train_input_fn='train_input',
+ eval_input_fn='eval_input',
+ train_steps=100,
+ eval_steps=100,
+ export_strategies='not_an_export_strategy')
+ with self.assertRaisesRegexp(ValueError, 'ExportStrategy'):
+ experiment.Experiment(
+ est,
+ train_input_fn='train_input',
+ eval_input_fn='eval_input',
+ train_steps=100,
+ eval_steps=100,
+ export_strategies=['not_an_export_srategy'])
+
def test_export_strategies_reset(self):
for est in self._estimators_for_tests():
eval_metrics = 'eval_metrics' if not isinstance(
@@ -498,7 +517,7 @@ class ExperimentTest(test.TestCase):
eval_metrics=eval_metrics,
train_steps=100,
eval_steps=100,
- export_strategies=[export_strategy_1])
+ export_strategies=(export_strategy_1,))
ex.train_and_evaluate()
self.assertEqual(1, est.export_count)
@@ -728,7 +747,7 @@ class ExperimentTest(test.TestCase):
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
- export_strategies=[exp_strategy])
+ export_strategies=(exp_strategy,))
ex.test()
self.assertEqual(1, est.fit_count)
self.assertEqual(1, est.eval_count)
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py
index 4a70f00407..c302c7725a 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py
@@ -131,4 +131,5 @@ def generator_input_fn(x,
target = features.pop(target_key[0])
return features, target
return features
+
return _generator_input_fn
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
index ae68e35c21..bc767ec18b 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
@@ -35,7 +35,7 @@ from tensorflow.python.training import queue_runner_impl
class GeneratorIoTest(test.TestCase):
-
+
def testGeneratorInputFn(self):
def generator():
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
index 0f317b7bb0..9bdd3206b2 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
@@ -359,8 +359,9 @@ def _read_keyed_batch_examples_helper(file_pattern,
# Check input parameters are given and reasonable.
if (not queue_capacity) or (queue_capacity <= 0):
raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
- if (batch_size is None) or ((not isinstance(batch_size, ops.Tensor)) and
- (batch_size <= 0 or batch_size > queue_capacity)):
+ if (batch_size is None) or (
+ (not isinstance(batch_size, ops.Tensor)) and
+ (batch_size <= 0 or batch_size >= queue_capacity)):
raise ValueError('Invalid batch_size %s, with queue_capacity %s.' %
(batch_size, queue_capacity))
if (read_batch_size is None) or (
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
index 83643689e1..542aaabc95 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
@@ -114,6 +114,18 @@ class GraphIOTest(test.TestCase):
name=name)
self.assertRaisesRegexp(
ValueError,
+ "Invalid batch_size",
+ graph_io.read_batch_examples,
+ _VALID_FILE_PATTERN,
+ default_batch_size,
+ io_ops.TFRecordReader,
+ False,
+ num_epochs=None,
+ queue_capacity=default_batch_size,
+ num_threads=num_threads,
+ name=name)
+ self.assertRaisesRegexp(
+ ValueError,
"Invalid queue_capacity",
graph_io.read_batch_examples,
_VALID_FILE_PATTERN,
@@ -356,7 +368,7 @@ class GraphIOTest(test.TestCase):
]
filename = self._create_temp_file("".join(json_lines))
batch_size = 10000
- queue_capacity = 10000
+ queue_capacity = 100000
name = "my_large_batch"
features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}
diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
index d3270dcc16..9c63096d0e 100644
--- a/tensorflow/contrib/learn/python/learn/utils/gc_test.py
+++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
@@ -29,10 +29,6 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
-def tearDownModule():
- gfile.DeleteRecursively(test.get_temp_dir())
-
-
class GcTest(test_util.TensorFlowTestCase):
def testLargestExportVersions(self):
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
index 9b196e2cf5..34a293f80b 100644
--- a/tensorflow/contrib/linalg/BUILD
+++ b/tensorflow/contrib/linalg/BUILD
@@ -30,7 +30,7 @@ cuda_py_tests(
cuda_py_tests(
name = "linear_operator_addition_test",
- size = "medium",
+ size = "small",
srcs = ["python/kernel_tests/linear_operator_addition_test.py"],
additional_deps = [
":linalg_py",
@@ -43,7 +43,6 @@ cuda_py_tests(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
- shard_count = 5,
)
cuda_py_tests(
@@ -61,7 +60,6 @@ cuda_py_tests(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
- shard_count = 5,
)
cuda_py_tests(
@@ -79,7 +77,6 @@ cuda_py_tests(
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
- shard_count = 5,
)
cuda_py_tests(
@@ -96,7 +93,6 @@ cuda_py_tests(
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
- shard_count = 5,
)
cuda_py_tests(
@@ -112,7 +108,6 @@ cuda_py_tests(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
- shard_count = 5,
)
cuda_py_tests(
@@ -128,7 +123,6 @@ cuda_py_tests(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
- shard_count = 5,
)
cuda_py_tests(
@@ -144,12 +138,11 @@ cuda_py_tests(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
- shard_count = 5,
)
cuda_py_tests(
name = "linear_operator_util_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/linear_operator_util_test.py"],
additional_deps = [
":linalg_py",
@@ -160,7 +153,6 @@ cuda_py_tests(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
- shard_count = 5,
)
py_library(
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
index a06af336e7..f047f4b978 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
@@ -229,6 +229,29 @@ class MatmulWithBroadcastTest(test.TestCase):
self.assertAllEqual(expected, result)
+class MatrixAdjointTest(test.TestCase):
+
+ def testNonBatchMatrix(self):
+ a = [[1, 2, 3j], [4, 5, -6j]] # Shape (2, 3)
+ expected = [[1, 4], [2, 5], [-3j, 6j]] # Shape (3, 2)
+ with self.test_session():
+ a_adj = linear_operator_util.matrix_adjoint(a)
+ self.assertEqual((3, 2), a_adj.get_shape())
+ self.assertAllClose(expected, a_adj.eval())
+
+ def testBatchMatrix(self):
+ matrix_0 = [[1j, 2, 3], [4, 5, 6]]
+ matrix_0_a = [[-1j, 4], [2, 5], [3, 6]]
+ matrix_1 = [[11, 22, 33], [44, 55, 66j]]
+ matrix_1_a = [[11, 44], [22, 55], [33, -66j]]
+ batch_matrix = [matrix_0, matrix_1] # Shape (2, 2, 3)
+ expected_adj = [matrix_0_a, matrix_1_a] # Shape (2, 3, 2)
+ with self.test_session():
+ matrix_adj = linear_operator_util.matrix_adjoint(batch_matrix)
+ self.assertEqual((2, 3, 2), matrix_adj.get_shape())
+ self.assertAllEqual(expected_adj, matrix_adj.eval())
+
+
class DomainDimensionStubOperator(object):
def __init__(self, domain_dimension):
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
index a52a235677..9f8cb23169 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
@@ -289,6 +289,53 @@ def matmul_with_broadcast(a,
b_is_sparse=b_is_sparse)
+def matrix_adjoint(a, name="matrix_adjoint"):
+ """Transposes last two dimensions of tensor `a`, and takes complex conjugate.
+
+ If `a` is real valued, the result is equivalent to `matrix_transpose`.
+
+ For example:
+
+ ```python
+ # Matrix with no batch dimension.
+ # 'x' is [[1 2 3j]
+ # [4 5 -6j]]
+ tf.matrix_adjoint(x) ==> [[1 4]
+ [2 5]
+ [-3j 6j]]
+
+ # Matrix with two batch dimensions.
+ # x.shape is [1, 2, 3, 4]
+ # tf.matrix_adjoint(x) is shape [1, 2, 4, 3]
+ ```
+
+ Note that `tf.matmul` provides kwargs allowing for adjoint of arguments. This
+ is done with minimal cost, and is preferable to using this function. E.g.
+
+ ```
+ # Good! Adjoint is taken at minimal additional cost.
+ tf.matmul(matrix, b, adjoint_b=True)
+
+ # Inefficient!
+ tf.matmul(matrix, tf.matrix_adjoint(b))
+ ```
+
+ Args:
+ a: A `Tensor` with `rank >= 2`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A batch matrix `Tensor` with same `dtype` as `a`.
+
+ Raises:
+ ValueError: If `a` is determined statically to have `rank < 2`.
+ """
+ with ops.name_scope(name, values=[a]):
+ a = ops.convert_to_tensor(a, name="a")
+ a_transpose = array_ops.matrix_transpose(a)
+ return math_ops.conj(a_transpose)
+
+
def shape_tensor(shape, name=None):
"""Convert Tensor using default type, unless empty list or tuple."""
# Works just like random_ops._ShapeTensor.
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
new file mode 100644
index 0000000000..2ad0fd5310
--- /dev/null
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -0,0 +1,236 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A decoder that performs beam search.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.contrib.rnn import core_rnn_cell
+from tensorflow.contrib.seq2seq.python.ops import decoder
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.layers import base as layers_base
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.util import nest
+
+
+__all__ = [
+ "BeamSearchDecoderOutput",
+ "BeamSearchDecoderState",
+ "BeamSearchDecoder",
+]
+
+
+class BeamSearchDecoderOutput(
+ collections.namedtuple("BeamSearchDecoderOutput", ("rnn_output",))):
+ pass
+
+
+class BeamSearchDecoderState(
+ collections.namedtuple("BeamSearchDecoderState",
+ ("cell_state", "log_prob", "beam_ids"))):
+ pass
+
+
+class BeamSearchDecoder(decoder.Decoder):
+ """BeamSearch sampling decoder."""
+
+ def __init__(self, cell, embedding, start_tokens, end_token,
+ initial_state, beam_width, output_layer=None):
+ """Initialize BeamSearchDecoder.
+
+ Args:
+ cell: An `RNNCell` instance.
+ embedding: A callable that takes a vector tensor of `ids` (argmax ids),
+ or the `params` argument for `embedding_lookup`.
+ start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
+ end_token: `int32` scalar, the token that marks end of decoding.
+ initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
+ beam_width: Python integer, the number of beams
+ output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
+ `tf.layers.Dense`. Optional layer to apply to the RNN output prior
+ to storing the result or sampling.
+
+ Raises:
+ TypeError: if `cell` is not an instance of `RNNCell`,
+ or `output_layer` is not an instance of `tf.layers.Layer`.
+ ValueError: If `start_tokens` is not a vector or
+ `end_token` is not a scalar.
+ """
+ if not isinstance(cell, core_rnn_cell.RNNCell):
+ raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
+ if (output_layer is not None
+ and not isinstance(output_layer, layers_base._Layer)): # pylint: disable=protected-access
+ raise TypeError(
+ "output_layer must be a Layer, received: %s" % type(output_layer))
+ self._cell = cell
+ self._initial_cell_state = initial_state
+ self._output_layer = output_layer
+
+ if callable(embedding):
+ self._embedding_fn = embedding
+ else:
+ self._embedding_fn = (
+ lambda ids: embedding_ops.embedding_lookup(embedding, ids))
+
+ self._start_tokens = ops.convert_to_tensor(
+ start_tokens, dtype=dtypes.int32, name="start_tokens")
+ self._end_token = ops.convert_to_tensor(
+ end_token, dtype=dtypes.int32, name="end_token")
+ if self._start_tokens.get_shape().ndims != 1:
+ raise ValueError("start_tokens must be a vector")
+ self._batch_size = array_ops.size(start_tokens)
+ self._beam_width = beam_width
+ if self._end_token.get_shape().ndims != 0:
+ raise ValueError("end_token must be a scalar")
+ self._start_inputs = self._embedding_fn(self._start_tokens)
+
+ @property
+ def batch_size(self):
+ return self._batch_size
+
+ def _rnn_output_size(self):
+ size = self._cell.output_size
+ if self._output_layer is None:
+ return size
+ else:
+ # To use layer's compute_output_shape, we need to convert the
+ # RNNCell's output_size entries into shapes with an unknown
+ # batch size. We then pass this through the layer's
+ # compute_output_shape and read off all but the first (batch)
+ # dimensions to get the output size of the rnn with the layer
+ # applied to the top.
+ output_shape_with_unknown_batch = nest.map_structure(
+ lambda s: tensor_shape.TensorShape([None]).concatenate(s),
+ size)
+ layer_output_shape = self._output_layer._compute_output_shape( # pylint: disable=protected-access
+ output_shape_with_unknown_batch)
+ return nest.map_structure(lambda s: s[1:], layer_output_shape)
+
+ @property
+ def output_size(self):
+ # Return the cell output and the id
+ prepend_beam_width = (
+ lambda s: tensor_shape.TensorShape([self._beam_width]).concatenate(s))
+ return BeamSearchDecoderOutput(
+ rnn_output=nest.map_structure(
+ prepend_beam_width, self._rnn_output_size()))
+
+ @property
+ def output_dtype(self):
+ # Assume the dtype of the cell is the output_size structure
+ # containing the input_state's first component's dtype.
+ # Return that structure and int32 (the id)
+ dtype = nest.flatten(self._initial_cell_state)[0].dtype
+ return BeamSearchDecoderOutput(
+ rnn_output=nest.map_structure(lambda _: dtype, self._rnn_output_size()))
+
+ def initialize(self, name=None):
+ """Initialize the decoder.
+
+ Args:
+ name: Name scope for any created operations.
+
+ Returns:
+ `(finished, first_inputs, initial_state)`.
+ """
+ finished, first_inputs = self._finished, self._first_inputs
+
+ initial_state = BeamSearchDecoderState(
+ cell_state=self._initial_cell_state,
+ log_probs=array_ops.zeros(
+ [self.batch_size, self.beam_width],
+ dtype=nest.flatten(self._initial_cell_state)[0].dtype),
+ beam_ids=tensor_array_ops.TensorArray(
+ size=0, dynamic_size=True, dtype=dtypes.int32,
+ clear_after_read=False))
+
+ return (finished, first_inputs, initial_state)
+
+ def _merge_batch_beams(self, t):
+ t_static_shape = t.shape
+ t_shape = array_ops.shape(t)
+ static_batch_size = tensor_util.constant_value(self._batch_size)
+ batch_size_beam_width = (
+ None if static_batch_size is None
+ else static_batch_size * self._beam_width)
+ reshaped_t = array_ops.reshape(
+ t, array_ops.concat(
+ ([self._batch_size * self._beam_width], t_shape[2:]), 0))
+ reshaped_t.set_shape(
+ (tensor_shape.TensorShape([batch_size_beam_width])
+ .concatenate(t_static_shape[2:])))
+ return reshaped_t
+
+ def _split_batch_beams(self, t):
+ t_static_shape = t.shape
+ t_shape = array_ops.shape(t)
+ reshaped_t = array_ops.reshape(
+ t, array_ops.concat(
+ ([self._batch_size, self._beam_width], t_shape[1:]), 0))
+ static_batch_size = tensor_util.constant_value(self._batch_size)
+ reshaped_t.set_shape(
+ (tensor_shape.TensorShape([static_batch_size, self._beam_width])
+ .concatenate(t_static_shape[1:])))
+ return reshaped_t
+
+ def step(self, time, inputs, state, name=None):
+ """Perform a decoding step.
+
+ Args:
+ time: scalar `int32` tensor.
+ inputs: A (structure of) input tensors.
+ state: A (structure of) state tensors and TensorArrays.
+ name: Name scope for any created operations.
+
+ Returns:
+ `(outputs, next_state, next_inputs, finished)`.
+ """
+ with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
+ cell_state = state.cell_state
+ inputs = nest.map_structure(self._merge_batch_beams, inputs)
+ cell_state = nest.map_structure(self._merge_batch_beams, cell_state)
+ cell_outputs, next_cell_state = self._cell(inputs, cell_state)
+ cell_outputs = nest.map_structure(self._split_batch_beams, cell_outputs)
+ next_cell_state = nest.map_structure(self._split_batch_beams,
+ next_cell_state)
+
+ if self._output_layer is not None:
+ cell_outputs = self._output_layer(cell_outputs)
+
+ # TODO(cinjon): Calculate next_log_probs, next_beam_ids,
+ # finished, next_inputs, final_cell_state via beam search
+ # via self._embedding
+ # ....
+ next_beam_ids, next_log_probs, final_cell_state, next_inputs, finished = (
+ None, None, None, None, None)
+
+ beam_ids = state.beam_ids.write(time, next_beam_ids)
+
+ outputs = BeamSearchDecoderOutput(cell_outputs)
+ next_state = BeamSearchDecoderState(
+ log_probs=next_log_probs,
+ beam_ids=beam_ids,
+ cell_state=final_cell_state)
+
+ return (outputs, next_state, next_inputs, finished)
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index 1d2674af30..6338eb152e 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import rnn
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
@@ -38,34 +39,7 @@ from tensorflow.python.util import nest
__all__ = ["Decoder", "dynamic_decode"]
-def _transpose_batch_time(x):
- """Transpose the batch and time dimensions of a Tensor.
-
- Retains as much of the static shape information as possible.
-
- Args:
- x: A tensor of rank 2 or higher.
-
- Returns:
- x transposed along the first two dimensions.
-
- Raises:
- ValueError: if `x` is rank 1 or lower.
- """
- x_static_shape = x.get_shape()
- if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
- raise ValueError(
- "Expected input tensor %s to have rank at least 2, but saw shape: %s" %
- (x, x_static_shape))
- x_rank = array_ops.rank(x)
- x_t = array_ops.transpose(
- x, array_ops.concat(
- ([1, 0], math_ops.range(2, x_rank)), axis=0))
- x_t.set_shape(
- tensor_shape.TensorShape([
- x_static_shape[1].value, x_static_shape[0].value
- ]).concatenate(x_static_shape[2:]))
- return x_t
+_transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access
@six.add_metaclass(abc.ABCMeta)
diff --git a/tensorflow/contrib/session_bundle/gc_test.py b/tensorflow/contrib/session_bundle/gc_test.py
index 1a8ee93cca..8faf3ef3d4 100644
--- a/tensorflow/contrib/session_bundle/gc_test.py
+++ b/tensorflow/contrib/session_bundle/gc_test.py
@@ -29,10 +29,6 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
-def tearDownModule():
- gfile.DeleteRecursively(test.get_temp_dir())
-
-
class GcTest(test_util.TensorFlowTestCase):
def testLargestExportVersions(self):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ba761cd7c6..afcc7891b6 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -272,6 +272,7 @@ cc_library(
"lib/monitoring/sampler.h",
"lib/random/distribution_sampler.h",
"lib/random/philox_random.h",
+ "lib/random/random_distributions.h",
"lib/random/simple_philox.h",
"lib/strings/numbers.h",
"lib/strings/str_util.h",
@@ -383,6 +384,7 @@ tf_cuda_library(
"util/bcast.h",
"util/cuda_kernel_helper.h",
"util/device_name_utils.h",
+ "util/env_var.h",
"util/events_writer.h",
"util/example_proto_fast_parsing.h",
"util/example_proto_helper.h",
@@ -1535,7 +1537,10 @@ cc_library(
tf_cuda_library(
name = "direct_session_internal",
srcs = ["common_runtime/direct_session.cc"],
- hdrs = ["common_runtime/direct_session.h"],
+ hdrs = [
+ "common_runtime/direct_session.h",
+ "util/env_var.h",
+ ],
copts = tf_copts(),
cuda_deps = [
":gpu_tracer",
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index eda2be3e70..768c2f6f75 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -57,6 +57,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/env_var.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
@@ -242,6 +243,13 @@ DirectSession::DirectSession(const SessionOptions& options,
thread_pools_.push_back(GlobalThreadPool(options));
owns_thread_pools_ = false;
}
+ // The default value of sync_on_finish will be flipped soon and this
+ // environment variable will be removed as well.
+ Status status =
+ ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
+ if (!status.ok()) {
+ LOG(ERROR) << status.error_message();
+ }
// NOTE(mrry): We do not need to use a unique string for the session
// handle, because DirectSession owns its devices. This may change
// in future versions.
@@ -448,7 +456,7 @@ Status DirectSession::Run(const RunOptions& run_options,
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
- args.sync_on_finish = true;
+ args.sync_on_finish = sync_on_finish_;
const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
@@ -632,7 +640,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
- args.sync_on_finish = true;
+ args.sync_on_finish = sync_on_finish_;
if (options_.config.graph_options().build_cost_model()) {
run_state->collector.reset(new StepStatsCollector(nullptr));
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 1495648631..b9d22ac522 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -247,6 +247,8 @@ class DirectSession : public Session {
std::vector<thread::ThreadPool*> thread_pools_;
bool owns_thread_pools_ = false;
+ // If true, blocks until device has finished all queued operations in a step.
+ bool sync_on_finish_ = true;
// Schedules 'c' for execution on pool.
void SchedClosure(thread::ThreadPool* pool, std::function<void()> c);
diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h
index bbde0924c7..f23f9361eb 100644
--- a/tensorflow/core/common_runtime/shape_refiner.h
+++ b/tensorflow/core/common_runtime/shape_refiner.h
@@ -64,6 +64,10 @@ class ShapeRefiner {
return it->second.get();
}
+ // Getters and setters for graph_def_version_.
+ int32 graph_def_version() { return graph_def_version_; }
+ void set_graph_def_version(int32 version) { graph_def_version_ = version; }
+
private:
// Extracts the subgraph ending at 'node' that is statically
// computable and inserts into 'out_graph'. If statically computable,
@@ -100,7 +104,7 @@ class ShapeRefiner {
const Node* node, int dst_idx,
shape_inference::ShapeHandle* result);
- const int graph_def_version_;
+ int32 graph_def_version_;
const OpRegistryInterface* const ops_registry_;
// The lifetime of the tensors are bound to the runner, so it should be the
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 537d489aae..545ae867f6 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
@@ -48,6 +49,13 @@ GraphMgr::GraphMgr(const WorkerEnv* worker_env,
RendezvousMgrInterface* rendezvous_mgr)
: worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) {
CHECK(rendezvous_mgr) << "Rendezvous mgr was null";
+ // The default value of sync_on_finish will be flipped soon and this
+ // environment variable will be removed as well.
+ Status status =
+ ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
+ if (!status.ok()) {
+ LOG(ERROR) << status.error_message();
+ }
}
GraphMgr::~GraphMgr() {
@@ -486,7 +494,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
args.cancellation_manager = cancellation_manager;
args.stats_collector = collector;
args.step_container = step_container;
- args.sync_on_finish = true;
+ args.sync_on_finish = sync_on_finish_;
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, handle);
}
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index 4477a2764b..5f51d63857 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -137,6 +137,9 @@ class GraphMgr {
mutex mu_;
int64 next_id_ GUARDED_BY(mu_) = 0;
+ // If true, blocks until device has finished all queued operations in a step.
+ bool sync_on_finish_ = true;
+
// Table mapping graph handles to registered graphs.
//
// TODO(zhifengc): If the client does not call Deregister, we'll
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 4c87a453e2..d5e6e293d6 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -873,6 +873,13 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status RandomShape(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
+ return Status::OK();
+}
+
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape) {
// Validate ranks.
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 73509fb7fb..dc99e48adb 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -199,6 +199,9 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c);
// Tested by ops/math_ops_test.cc.
Status BroadcastBinaryOpShapeFn(InferenceContext* c);
+// Shape function for random operations.
+Status RandomShape(shape_inference::InferenceContext* c);
+
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 3626de58d6..3d913cdaf0 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -125,6 +125,23 @@ Status OpKernel::OutputRange(StringPiece output_name, int* start,
}
}
+Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const {
+ if (!IsLegacyVector(shape.shape())) {
+ return errors::InvalidArgument(
+ "shape must be a vector of {int32,int64}, got shape ",
+ shape.shape().DebugString());
+ }
+ if (shape.dtype() == DataType::DT_INT32) {
+ auto vec = shape.flat<int32>();
+ return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
+ } else if (shape.dtype() == DataType::DT_INT64) {
+ auto vec = shape.flat<int64>();
+ return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
+ } else {
+ return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
+ }
+}
+
void AsyncOpKernel::Compute(OpKernelContext* context) {
Notification n;
ComputeAsync(context, [&n]() { n.Notify(); });
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index d874b9087f..91e6a98304 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -151,6 +151,10 @@ class OpKernel {
return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0);
}
+ // Turn a shape Tensor into a TensorShape
+ // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
+ Status MakeShape(const Tensor& shape, TensorShape* out) const;
+
private:
const NodeDef def_;
const DataTypeVector input_types_;
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 449d8f55f5..a990dc2f04 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -239,8 +239,11 @@ string InferenceContext::DebugString() const {
ProtoDebugString(node_def_));
}
-Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
+Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
+ if (rank > kint32max) {
+ return errors::InvalidArgument("Rank cannot exceed kint32max");
+ }
const int32 existing = Rank(shape);
if (existing == rank) {
*out = shape;
@@ -261,8 +264,11 @@ Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
existing);
}
-Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank,
+Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
+ if (rank > kint32max) {
+ return errors::InvalidArgument("Rank cannot exceed kint32max");
+ }
const int32 existing = Rank(shape);
if (existing >= rank) {
*out = shape;
@@ -276,8 +282,11 @@ Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank,
" but is rank ", existing);
}
-Status InferenceContext::WithRankAtMost(ShapeHandle shape, int32 rank,
+Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
+ if (rank > kint32max) {
+ return errors::InvalidArgument("Rank cannot exceed kint32max");
+ }
const int32 existing = Rank(shape);
if (existing == kUnknownRank) {
return ReturnUnknownShape(out);
@@ -470,12 +479,12 @@ Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
return ReturnCreatedShape(dims, out);
}
-Status InferenceContext::ReplaceDim(ShapeHandle s, int dim_index_in,
+Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
DimensionHandle new_dim, ShapeHandle* out) {
if (!RankKnown(s)) {
return ReturnUnknownShape(out);
}
- int dim_index = dim_index_in;
+ int64 dim_index = dim_index_in;
if (dim_index < 0) {
dim_index = s->dims_.size() + dim_index;
}
@@ -510,7 +519,8 @@ ShapeHandle InferenceContext::UnknownShape() {
return shape_manager_.UnknownShape();
}
-ShapeHandle InferenceContext::UnknownShapeOfRank(int32 rank) {
+ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
+ CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
std::vector<DimensionHandle> dims(rank);
for (int32 i = 0; i < rank; ++i) {
dims[i] = UnknownDim();
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index b7f1725c5f..5e116884c6 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -194,7 +194,7 @@ class InferenceContext {
return s;
}
- ShapeHandle input(int idx) const { return inputs_[idx]; }
+ ShapeHandle input(int64 idx) const { return inputs_[idx]; }
Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
int num_inputs() const { return inputs_.size(); }
@@ -237,7 +237,7 @@ class InferenceContext {
// idx can be negative for an offset from end of dimensions.
// idx must be in the range [-1 * s.rank, s.rank).
- DimensionHandle Dim(ShapeHandle s, int32 idx) {
+ DimensionHandle Dim(ShapeHandle s, int64 idx) {
if (s->rank_ == kUnknownRank) {
return UnknownDim();
}
@@ -277,11 +277,11 @@ class InferenceContext {
// the shape with asserted rank in <*out>. Otherwise return an error.
//
// Note that <*out> may be set to <shape>.
- Status WithRank(ShapeHandle shape, int32 rank,
+ Status WithRank(ShapeHandle shape, int64 rank,
ShapeHandle* out) TF_MUST_USE_RESULT;
- Status WithRankAtLeast(ShapeHandle shape, int32 rank,
+ Status WithRankAtLeast(ShapeHandle shape, int64 rank,
ShapeHandle* out) TF_MUST_USE_RESULT;
- Status WithRankAtMost(ShapeHandle shape, int32 rank,
+ Status WithRankAtMost(ShapeHandle shape, int64 rank,
ShapeHandle* out) TF_MUST_USE_RESULT;
// If <dim> has value <value>, or its value is unknown, returns OK and returns
@@ -332,7 +332,7 @@ class InferenceContext {
// Returns in <out> the shape from replacing <s.dim[dim_index]> with
// <new_dim>.
- Status ReplaceDim(ShapeHandle s, int dim_index, DimensionHandle new_dim,
+ Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
ShapeHandle* out) TF_MUST_USE_RESULT;
// Returns a new shape with the given dims. The returned value is owned by
@@ -344,7 +344,7 @@ class InferenceContext {
ShapeHandle UnknownShape();
// Returns a shape with specified rank but unknown dims.
- ShapeHandle UnknownShapeOfRank(int32 rank);
+ ShapeHandle UnknownShapeOfRank(int64 rank);
// Returns a new shape of zero dimensions.
ShapeHandle Scalar();
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 6b3b5d3604..9d4a0a52f7 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -839,11 +839,6 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
Graph* g, ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors) {
- ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
- if (refiner == nullptr) {
- refiner = &default_refiner;
- }
-
if (!opts.return_tensors.empty()) {
if (return_tensors == nullptr) {
return errors::InvalidArgument(
@@ -857,6 +852,36 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
return_tensors->size(), ")");
}
}
+
+ ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
+ if (refiner == nullptr) {
+ refiner = &default_refiner;
+ } else {
+ // Log a warning if we are importing a GraphDef at an older
+ // producer version after already having added non-source/sink
+ // nodes to the graph in the past.
+ if (gdef.versions().producer() > 0 &&
+ gdef.versions().producer() < refiner->graph_def_version() &&
+ g->num_nodes() > 2) {
+ LOG(WARNING) << "Importing a graph with a lower producer version "
+ << gdef.versions().producer()
+ << " into an existing graph with producer version "
+ << refiner->graph_def_version() << ". Shape inference will "
+ << "have run different parts of the graph with different "
+ << "producer versions.";
+ }
+ }
+
+ // Set the graph def version of the refiner as the min of the
+ // current value and the version from the graph we are about to
+ // import.
+ //
+ // Note: to match Run() semantics, we should re-run shape inference
+ // on the entire graph if the producer version has changed. For now
+ // we log the warning above.
+ refiner->set_graph_def_version(
+ std::min(refiner->graph_def_version(), gdef.versions().producer()));
+
return GraphConstructor::Construct(opts, &gdef, g, refiner, return_tensors);
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index e20dabc891..e3b7f322cb 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -2271,5 +2271,176 @@ TEST_F(GraphConstructorTest, GraphDefVersionMergingDuringImport) {
EXPECT_EQ(3, graph_.versions().bad_consumers(2));
}
+TEST_F(GraphConstructorTest, ImportGraphDefProvidedShapeRefinerVersions) {
+ ImportGraphDefOptions opts;
+ // A valid graph at producer version 20, but one
+ // that would not import if the graph_def_version were 21.
+ string gdef_ascii = strings::StrCat(R"EOF(
+node {
+ name: "Sum/input"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 1
+ }
+ }
+ tensor_content: "\001\000\000\000\002\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "Sum/reduction_indices"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 1
+ }
+ }
+ tensor_content: "\000\000\000\000\001\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "Sum"
+ op: "Sum"
+ input: "Sum/input"
+ input: "Sum/reduction_indices"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+}
+versions {
+ producer: 20
+})EOF");
+
+ // Create a shape refiner with the latest TF_GRAPH_DEF_VERSION.
+ // Importing the graphdef with an existing refiner should
+ // make the refiner inherit the graphdef version from the
+ // passed in graphdef since it has a lower producer.
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
+ ExpectOK(gdef_ascii, opts, &refiner);
+
+ // Add another node with a higher producer
+ gdef_ascii = strings::StrCat(R"EOF(
+node {
+ name: "RandomConst"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 1
+ }
+ }
+ tensor_content: "\001\000\000\000\002\000\000\000"
+ }
+ }
+ }
+}
+versions {
+ producer: 21
+})EOF");
+
+ ExpectOK(gdef_ascii, opts, &refiner);
+ // Check that the refiner's graph def version is the lowest of
+ // the graph defs we have seen so far.
+ EXPECT_EQ(20, refiner.graph_def_version());
+
+ // Add another node with a lower producer
+ gdef_ascii = strings::StrCat(R"EOF(
+node {
+ name: "RandomConst2"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 1
+ }
+ }
+ tensor_content: "\001\000\000\000\002\000\000\000"
+ }
+ }
+ }
+}
+versions {
+ producer: 17
+})EOF");
+ ExpectOK(gdef_ascii, opts, &refiner);
+
+ // Check that the refiner's graph def version is the lowest of
+ // the graph defs we have seen so far.
+ EXPECT_EQ(17, refiner.graph_def_version());
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD
index c42eebae53..5d74d3d3b1 100644
--- a/tensorflow/core/grappler/BUILD
+++ b/tensorflow/core/grappler/BUILD
@@ -30,6 +30,16 @@ filegroup(
)
cc_library(
+ name = "op_types",
+ srcs = ["op_types.cc"],
+ hdrs = ["op_types.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
name = "utils",
srcs = ["utils.cc"],
hdrs = ["utils.h"],
@@ -88,6 +98,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":grappler_item",
+ ":op_types",
":utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib_internal",
diff --git a/tensorflow/core/grappler/devices.cc b/tensorflow/core/grappler/devices.cc
index d3fc9044d3..b318ac22d4 100644
--- a/tensorflow/core/grappler/devices.cc
+++ b/tensorflow/core/grappler/devices.cc
@@ -53,6 +53,22 @@ int GetNumAvailableGPUs() {
return num_eligible_gpus;
}
+int64 AvailableGPUMemory(int gpu_id) {
+#if GOOGLE_CUDA
+ // Look up the device, to see its attributes.
+ perftools::gputools::Platform* gpu_platform = GPUMachineManager();
+ CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount());
+ perftools::gputools::StreamExecutor* se =
+ gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
+ int64 total_memory, available_memory;
+ CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
+
+ return available_memory;
+#else
+ return 0;
+#endif
+}
+
int GetNumAvailableLogicalCPUCores() { return port::NumSchedulableCPUs(); }
} // end namespace grappler
diff --git a/tensorflow/core/grappler/devices.h b/tensorflow/core/grappler/devices.h
index 329e8e2e65..2d6c41888d 100644
--- a/tensorflow/core/grappler/devices.h
+++ b/tensorflow/core/grappler/devices.h
@@ -29,6 +29,10 @@ namespace grappler {
// than 8.
int GetNumAvailableGPUs();
+// Maximum amount of gpu memory available per gpu. gpu_id must be in the range
+// [0, num_available_gpu)
+int64 AvailableGPUMemory(int gpu_id);
+
// Get the number of logical CPU cores (aka hyperthreads) available.
int GetNumAvailableLogicalCPUCores();
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 7889b0e025..e37b908fc6 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variable.pb.h"
#include "tensorflow/core/grappler/inputs/utils.h"
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@@ -90,7 +91,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
node.clear_device();
}
- if (node.op() == "Placeholder" || node.op() == "PlaceholderV2") {
+ if (IsPlaceholder(node)) {
if (node.attr().count("dtype") == 0) {
LOG(ERROR) << "Unknown type for placeholder " << node.name()
<< ", skipping this input";
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
new file mode 100644
index 0000000000..33ef498db0
--- /dev/null
+++ b/tensorflow/core/grappler/op_types.cc
@@ -0,0 +1,27 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/op_types.h"
+
+namespace tensorflow {
+namespace grappler {
+
+bool IsPlaceholder(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "Placeholder" || op == "PlaceholderV2";
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
new file mode 100644
index 0000000000..30a3c91411
--- /dev/null
+++ b/tensorflow/core/grappler/op_types.h
@@ -0,0 +1,29 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_OP_TYPES_H_
+#define TENSORFLOW_GRAPPLER_OP_TYPES_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+
+bool IsPlaceholder(const NodeDef& node);
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_OP_TYPES_H_
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index bd96e2b33c..2ea150ce18 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -26,6 +26,40 @@ filegroup(
)
cc_library(
+ name = "auto_parallel",
+ srcs = ["auto_parallel.cc"],
+ hdrs = [
+ "auto_parallel.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:devices",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ ],
+)
+
+cc_test(
+ name = "auto_parallel_test",
+ srcs = ["auto_parallel_test.cc"],
+ deps = [
+ ":auto_parallel",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
+
+cc_library(
name = "constant_folding",
srcs = ["constant_folding.cc"],
hdrs = [
@@ -179,6 +213,7 @@ cc_library(
":constant_folding",
":graph_optimizer",
":layout_optimizer",
+ ":memory_optimizer",
":model_pruner",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc
new file mode 100644
index 0000000000..77ab178653
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc
@@ -0,0 +1,260 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/devices.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace grappler {
+const char kAutoParallelPrefix[] = "AutoParallel";
+
+NodeDef* AutoParallel::AddNodeDivConst() {
+ NodeDef* node = graph_.add_node();
+ node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const"));
+ node->set_op("Const");
+
+ AttrValue attr_data_type;
+ attr_data_type.set_type(DT_FLOAT);
+ node->mutable_attr()->insert({"dtype", attr_data_type});
+
+ AttrValue attr_tensor;
+ auto tensor = attr_tensor.mutable_tensor();
+ tensor->add_float_val(static_cast<float>(num_replicas_));
+ tensor->set_dtype(DT_FLOAT);
+ node->mutable_attr()->insert({"value", attr_tensor});
+ return node;
+}
+
+NodeDef* AutoParallel::AddNodeDiv(const string& name, const string& input_a,
+ const string& input_b) {
+ NodeDef* node = graph_.add_node();
+ node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-", name));
+ node->set_op("RealDiv");
+ node->add_input(input_a);
+ node->add_input(input_b);
+ AttrValue attr_type;
+ attr_type.set_type(DT_FLOAT);
+ node->mutable_attr()->insert({"T", attr_type});
+ return node;
+}
+
+NodeDef* AutoParallel::AddNodeControl(const string& name,
+ const std::set<string>& deps,
+ GraphDef* graph) {
+ NodeDef* node = graph->add_node();
+ node->set_name(name);
+ node->set_op("NoOp");
+ for (const auto& dep : deps) {
+ node->add_input(strings::StrCat("^", dep));
+ }
+ return node;
+}
+
+Status AutoParallel::Initialize(const GrapplerItem& item) {
+ num_gpus_ = GetNumAvailableGPUs();
+ LOG(INFO) << "Number of GPUs: " << num_gpus_;
+ item_ = &item;
+ graph_ = item.graph;
+ LOG(INFO) << "Original graph size: " << graph_.node_size();
+ if (item.fetch.empty()) {
+ return Status(error::INVALID_ARGUMENT, "No fetch nodes provided.");
+ }
+
+ if (item.MainVariables().empty()) {
+ return Status(error::INVALID_ARGUMENT, "No variables provided.");
+ }
+
+ for (const auto& init : item.init_ops) {
+ VLOG(1) << "Init node: " << init;
+ }
+
+ for (const auto& fetch : item.fetch) {
+ VLOG(1) << "Fetch node: " << fetch;
+ }
+
+ for (const auto& var : item.MainVariables()) {
+ VLOG(2) << "Variable: " << var->name();
+ }
+
+ std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
+ "ApplyProximalGradientDescent",
+ "ApplyAdadelta",
+ "ApplyAdagrad",
+ "ApplyProximalAdagrad",
+ "ApplyAdagradDA",
+ "ApplyFtrl",
+ "ApplyMomentum",
+ "ApplyAdam",
+ "ApplyRMSProp",
+ "ApplyCenteredRMSProp"};
+ const NodeDef* dequeue_node = nullptr;
+ for (int i = 0; i < graph_.node_size(); i++) {
+ all_nodes_.insert(
+ std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
+ if (graph_.node(i).op() == "QueueDequeueManyV2") {
+ dequeue_node = graph_.mutable_node(i);
+ }
+ if (apply_gradients_ops.find(graph_.node(i).op()) !=
+ apply_gradients_ops.end()) {
+ apply_gradients_nodes_.insert(graph_.node(i).name());
+ VLOG(2) << "Apply gradients node: " << graph_.node(i).name();
+ }
+ }
+
+ auto div_const_node = AddNodeDivConst();
+ all_nodes_.insert(std::make_pair(div_const_node->name(), div_const_node));
+ std::map<string, int> gradient_pos = {{"ApplyGradientDescent", 2},
+ {"ApplyProximalGradientDescent", 4},
+ {"ApplyAdadelta", 6},
+ {"ApplyAdagrad", 3},
+ {"ApplyProximalAdagrad", 5},
+ {"ApplyAdagradDA", 3},
+ {"ApplyFtrl", 3},
+ {"ApplyMomentum", 3},
+ {"ApplyAdam", 9},
+ {"ApplyRMSProp", 7},
+ {"ApplyCenteredRMSProp", 8}};
+ for (const auto& apply_gradient_node_name : apply_gradients_nodes_) {
+ auto apply_gradients_op = all_nodes_[apply_gradient_node_name]->op();
+ auto apply_gradients_node = all_nodes_[apply_gradient_node_name];
+
+ auto div_node = AddNodeDiv(
+ apply_gradient_node_name,
+ apply_gradients_node->input(gradient_pos[apply_gradients_op]),
+ div_const_node->name());
+ all_nodes_.insert(std::make_pair(div_node->name(), div_node));
+ *apply_gradients_node->mutable_input(gradient_pos[apply_gradients_op]) =
+ div_node->name();
+ }
+ LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size();
+
+ auto train_nodes = ComputeTransitiveFanin(graph_, item.fetch);
+ LOG(INFO) << "Number of training nodes: " << train_nodes.size();
+
+ std::vector<const NodeDef*> input_nodes;
+ if (dequeue_node) {
+ LOG(INFO) << "Dequeue node: " << dequeue_node->name();
+ input_nodes = ComputeTransitiveFanin(graph_, {dequeue_node->name()});
+ }
+ LOG(INFO) << "Number of input nodes: " << input_nodes.size();
+
+ std::set<string> dont_replicate_nodes;
+ for (const auto& variable : item.MainVariables()) {
+ dont_replicate_nodes.insert(variable->name());
+ }
+ // Don't replicate all input nodes, except the dequeue node.
+ for (const auto& input_node : input_nodes) {
+ if (input_node->name() != dequeue_node->name()) {
+ dont_replicate_nodes.insert(input_node->name());
+ }
+ }
+
+ for (const auto& node : train_nodes) {
+ if (dont_replicate_nodes.find(node->name()) == dont_replicate_nodes.end()) {
+ replica_nodes_.insert(node->name());
+ }
+ }
+ LOG(INFO) << "Number of replica nodes: " << replica_nodes_.size();
+
+ for (const auto& node : all_nodes_) {
+ if (replica_nodes_.find(node.first) == replica_nodes_.end()) {
+ shared_nodes_.insert(node.first);
+ }
+ }
+ LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size();
+ return Status::OK();
+}
+
+bool AutoParallel::NotSharedNode(const string& name) {
+ return shared_nodes_.find(name) == shared_nodes_.end();
+}
+
+void AutoParallel::AddSharedNodes(GraphDef* graph) {
+ string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", 0);
+ for (const auto& node : shared_nodes_) {
+ auto new_node = graph->add_node();
+ *new_node = *all_nodes_[node];
+ for (int i = 0; i < new_node->input_size(); i++) {
+ if (NotSharedNode(NodeName(new_node->input(i)))) {
+ string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
+ *new_node->mutable_input(i) = new_name;
+ }
+ }
+ }
+}
+
+void AutoParallel::AddOneReplica(GraphDef* graph, int number) {
+ string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", number);
+ for (const auto& node : replica_nodes_) {
+ auto new_node = graph->add_node();
+ *new_node = *all_nodes_[node];
+ if (NotSharedNode(new_node->name())) {
+ new_node->set_name(AddPrefixToNodeName(new_node->name(), prefix));
+ if (num_gpus_ > 0) {
+ new_node->set_device(strings::StrCat("/gpu:", number % num_gpus_));
+ }
+ for (int i = 0; i < new_node->input_size(); i++) {
+ if (NotSharedNode(NodeName(new_node->input(i)))) {
+ string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
+ *new_node->mutable_input(i) = new_name;
+ }
+ }
+ }
+ }
+}
+
+void AutoParallel::BuildGraph(GraphDef* graph) {
+ AddSharedNodes(graph);
+ for (int i = 0; i < num_replicas_; i++) {
+ AddOneReplica(graph, i);
+ }
+ std::set<string> fetches;
+ for (int i = 0; i < item_->fetch.size(); i++) {
+ for (int j = 0; j < num_replicas_; j++) {
+ string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
+ string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
+ fetches.insert(fetch);
+ }
+ }
+ string name_control =
+ strings::StrCat(kAutoParallelPrefix, "-Control-", "Fetch");
+ auto control = AddNodeControl(name_control, fetches, graph);
+
+ for (const auto& fetch : item_->fetch) {
+ AddNodeControl(fetch, {control->name()}, graph);
+ }
+ LOG(INFO) << "Parallelized graph size: " << graph->node_size();
+}
+
+Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ TF_RETURN_IF_ERROR(Initialize(item));
+ BuildGraph(output);
+ return Status::OK();
+}
+
+void AutoParallel::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // TODO(yaozhang): Add feedback.
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.h b/tensorflow/core/grappler/optimizers/auto_parallel.h
new file mode 100644
index 0000000000..cac0db2c23
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/auto_parallel.h
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
+#define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
+
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Automatically parallelize a graph by splitting in the batch dimension.
+class AutoParallel : public GraphOptimizer {
+ public:
+ AutoParallel(int num_replicas) : num_replicas_(num_replicas) {}
+ ~AutoParallel() override {}
+
+ string name() const override { return "autoparallel"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+
+ private:
+ GraphDef graph_;
+ std::map<string, NodeDef*> all_nodes_;
+ std::set<string> apply_gradients_nodes_;
+ std::set<string> replica_nodes_;
+ std::set<string> shared_nodes_;
+ const GrapplerItem* item_;
+ int num_replicas_;
+ int num_gpus_;
+ Status Initialize(const GrapplerItem& item);
+ NodeDef* AddNodeDivConst();
+ NodeDef* AddNodeDiv(const string& name, const string& input_a,
+ const string& input_b);
+ NodeDef* AddNodeControl(const string& name, const std::set<string>& deps,
+ GraphDef* graph);
+ bool NotSharedNode(const string& name);
+ void AddSharedNodes(GraphDef* graph);
+ void AddOneReplica(GraphDef* graph, int number);
+ void BuildGraph(GraphDef* graph);
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
new file mode 100644
index 0000000000..b7786ccd14
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
@@ -0,0 +1,125 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class AutoParallelTest : public ::testing::Test {};
+
+TEST_F(AutoParallelTest, SimpleParallel) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
+ Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
+ Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
+ Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
+ Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT});
+ auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue},
+ {constant_b}, {DT_FLOAT});
+ Output add = ops::AddN(s.WithOpName("add"), {constant_a, dequeue[0]});
+ Output learning_rate = ops::Const(s.WithOpName("learning_rate"), 0.01f, {1});
+ Output apply_gradient = ops::ApplyGradientDescent(
+ s.WithOpName("apply_gradient"), {var}, {learning_rate}, {add});
+
+ GrapplerItem item;
+ item.init_ops.push_back("assign");
+ item.fetch.push_back("apply_gradient");
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ AutoParallel parallel(2);
+ GraphDef output;
+ Status status = parallel.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_EQ(20, output.node_size());
+
+ const NodeDef& node_assign = output.node(0);
+ EXPECT_EQ("assign", node_assign.name());
+ EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_assign.input(1));
+
+ const NodeDef& node_constant_b = output.node(1);
+ EXPECT_EQ("constant_b", node_constant_b.name());
+
+ const NodeDef& node_fifo_queue = output.node(2);
+ EXPECT_EQ("fifo_queue", node_fifo_queue.name());
+
+ const NodeDef& node_var = output.node(3);
+ EXPECT_EQ("var", node_var.name());
+
+ const NodeDef& node_div_const0 = output.node(4);
+ EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-Const",
+ node_div_const0.name());
+
+ const NodeDef& node_div0 = output.node(5);
+ EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-apply_gradient",
+ node_div0.name());
+ const NodeDef& node_add0 = output.node(6);
+ EXPECT_EQ("AutoParallel-Replica-0-add", node_add0.name());
+
+ const NodeDef& node_gradient0 = output.node(7);
+ EXPECT_EQ("AutoParallel-Replica-0-apply_gradient", node_gradient0.name());
+
+ const NodeDef& node_constant_a0 = output.node(8);
+ EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_constant_a0.name());
+
+ const NodeDef& node_dequeue0 = output.node(9);
+ EXPECT_EQ("AutoParallel-Replica-0-dequeue", node_dequeue0.name());
+
+ const NodeDef& node_learning_rate0 = output.node(10);
+ EXPECT_EQ("AutoParallel-Replica-0-learning_rate", node_learning_rate0.name());
+
+ const NodeDef& node_div_const1 = output.node(11);
+ EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-Const",
+ node_div_const1.name());
+
+ const NodeDef& node_div1 = output.node(12);
+ EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-apply_gradient",
+ node_div1.name());
+
+ const NodeDef& node_add1 = output.node(13);
+ EXPECT_EQ("AutoParallel-Replica-1-add", node_add1.name());
+
+ const NodeDef& node_gradient1 = output.node(14);
+ EXPECT_EQ("AutoParallel-Replica-1-apply_gradient", node_gradient1.name());
+
+ const NodeDef& node_constant_a1 = output.node(15);
+ EXPECT_EQ("AutoParallel-Replica-1-constant_a", node_constant_a1.name());
+
+ const NodeDef& node_dequeue1 = output.node(16);
+ EXPECT_EQ("AutoParallel-Replica-1-dequeue", node_dequeue1.name());
+
+ const NodeDef& node_learning_rate1 = output.node(17);
+ EXPECT_EQ("AutoParallel-Replica-1-learning_rate", node_learning_rate1.name());
+
+ const NodeDef& node_fetch = output.node(18);
+ EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name());
+ EXPECT_EQ("^AutoParallel-Replica-0-apply_gradient", node_fetch.input(0));
+ EXPECT_EQ("^AutoParallel-Replica-1-apply_gradient", node_fetch.input(1));
+
+ const NodeDef& node_gradient = output.node(19);
+ EXPECT_EQ("apply_gradient", node_gradient.name());
+ EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 67ffa7a4b6..0fe9359b75 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/lib/core/status.h"
@@ -37,6 +38,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
if (optimizer == "layout") {
graph_optimizer.reset(new LayoutOptimizer());
}
+ if (optimizer == "memory") {
+ graph_optimizer.reset(new MemoryOptimizer());
+ }
return graph_optimizer;
}
@@ -55,8 +59,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
}
+ if (cfg_.memory_optimization() > 0) {
+ optimizers.push_back(
+ std::unique_ptr<GraphOptimizer>(new MemoryOptimizer()));
+ }
} else {
- std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout"};
+ std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout",
+ "memory"};
for (const auto& optimizer : cfg_.optimizers()) {
if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) {
optimizers.push_back(NewOptimizer(optimizer));
@@ -81,7 +90,6 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizer->Optimize(nullptr, optimized_item, optimized_graph));
}
}
-
// Copy the graph version.
*optimized_graph->mutable_versions() = item.graph.versions();
diff --git a/tensorflow/core/kernels/batchtospace_op.cc b/tensorflow/core/kernels/batchtospace_op.cc
index b24a834083..99b5d3daaa 100644
--- a/tensorflow/core/kernels/batchtospace_op.cc
+++ b/tensorflow/core/kernels/batchtospace_op.cc
@@ -97,6 +97,10 @@ static void BatchToSpaceOpCompute(OpKernelContext* context,
for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
block_shape_product *= block_shape[block_dim];
}
+ OP_REQUIRES(
+ context, block_shape_product > 0,
+ errors::InvalidArgument("Product of block sizes must be positive, got ",
+ block_shape_product));
const int64 orig_input_batch_size = orig_input_tensor.dim_size(0);
OP_REQUIRES(
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc
index caf73420ba..746fe63e2a 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op.cc
@@ -216,12 +216,14 @@ struct CropAndResize<CPUDevice, T> {
const float x_lerp = in_x - left_x_index;
for (int d = 0; d < depth; ++d) {
- const float top_left(image(b_in, top_y_index, left_x_index, d));
- const float top_right(image(b_in, top_y_index, right_x_index, d));
- const float bottom_left(
- image(b_in, bottom_y_index, left_x_index, d));
- const float bottom_right(
- image(b_in, bottom_y_index, right_x_index, d));
+ const float top_left(
+ static_cast<float>(image(b_in, top_y_index, left_x_index, d)));
+ const float top_right(
+ static_cast<float>(image(b_in, top_y_index, right_x_index, d)));
+ const float bottom_left(static_cast<float>(
+ image(b_in, bottom_y_index, left_x_index, d)));
+ const float bottom_right(static_cast<float>(
+ image(b_in, bottom_y_index, right_x_index, d)));
const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom =
bottom_left + (bottom_right - bottom_left) * x_lerp;
@@ -545,12 +547,14 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float x_lerp = in_x - left_x_index;
for (int d = 0; d < depth; ++d) {
- const float top_left(image(b_in, top_y_index, left_x_index, d));
- const float top_right(image(b_in, top_y_index, right_x_index, d));
- const float bottom_left(
- image(b_in, bottom_y_index, left_x_index, d));
- const float bottom_right(
- image(b_in, bottom_y_index, right_x_index, d));
+ const float top_left(
+ static_cast<float>(image(b_in, top_y_index, left_x_index, d)));
+ const float top_right(
+ static_cast<float>(image(b_in, top_y_index, right_x_index, d)));
+ const float bottom_left(static_cast<float>(
+ image(b_in, bottom_y_index, left_x_index, d)));
+ const float bottom_right(static_cast<float>(
+ image(b_in, bottom_y_index, right_x_index, d)));
// Compute the image gradient.
float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
x_lerp * (bottom_right - top_right);
@@ -606,18 +610,25 @@ inline void CheckValidBoxInd<CPUDevice>(
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>); \
\
- REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("image_size"), \
- CropAndResizeGradImageOp<CPUDevice, T>); \
- \
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
+
+#undef REGISTER_KERNEL
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("image_size"), \
+ CropAndResizeGradImageOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_KERNEL);
TF_CALL_float(REGISTER_KERNEL);
+TF_CALL_double(REGISTER_KERNEL);
#undef REGISTER_KERNEL
@@ -685,7 +696,7 @@ inline void CheckValidBoxInd<GPUDevice>(
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<GPUDevice, T>);
-TF_CALL_float(REGISTER_KERNEL);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
index 75146b28e6..254475db46 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
@@ -88,26 +88,26 @@ __global__ void CropAndResizeKernel(
const int right_x_index = ceilf(in_x);
const float x_lerp = in_x - left_x_index;
- const float top_left(
+ const float top_left(static_cast<float>(
image_ptr[((b_in * image_height + top_y_index) * image_width +
left_x_index) *
depth +
- d]);
- const float top_right(
+ d]));
+ const float top_right(static_cast<float>(
image_ptr[((b_in * image_height + top_y_index) * image_width +
right_x_index) *
depth +
- d]);
- const float bottom_left(
+ d]));
+ const float bottom_left(static_cast<float>(
image_ptr[((b_in * image_height + bottom_y_index) * image_width +
left_x_index) *
depth +
- d]);
- const float bottom_right(
+ d]));
+ const float bottom_right(static_cast<float>(
image_ptr[((b_in * image_height + bottom_y_index) * image_width +
right_x_index) *
depth +
- d]);
+ d]));
const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
crops_ptr[out_idx] = top + (bottom - top) * y_lerp;
@@ -258,26 +258,26 @@ __global__ void CropAndResizeBackpropBoxesKernel(
const int right_x_index = ceilf(in_x);
const float x_lerp = in_x - left_x_index;
- const float top_left =
+ const float top_left(static_cast<float>(
image_ptr[((b_in * image_height + top_y_index) * image_width +
left_x_index) *
depth +
- d];
- const float top_right =
+ d]));
+ const float top_right(static_cast<float>(
image_ptr[((b_in * image_height + top_y_index) * image_width +
right_x_index) *
depth +
- d];
- const float bottom_left =
+ d]));
+ const float bottom_left(static_cast<float>(
image_ptr[((b_in * image_height + bottom_y_index) * image_width +
left_x_index) *
depth +
- d];
- const float bottom_right =
+ d]));
+ const float bottom_right(static_cast<float>(
image_ptr[((b_in * image_height + bottom_y_index) * image_width +
right_x_index) *
depth +
- d];
+ d]));
// Compute the image gradient.
float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
@@ -436,7 +436,7 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
template struct CropAndResizeBackpropImage<GPUDevice, T>; \
template struct CropAndResizeBackpropBoxes<GPUDevice, T>;
-TF_CALL_float(DEFINE_GPU_SPECS);
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS
diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc
index 68e077e44d..3a7f180598 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_test.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
@@ -31,9 +32,10 @@ namespace tensorflow {
class CropAndResizeOpTest : public OpsTestBase {
protected:
+ template <typename T>
void MakeOp(float extrapolation_value) {
TF_EXPECT_OK(NodeDefBuilder("crop_and_resize_op", "CropAndResize")
- .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DataTypeToEnum<T>::value))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_INT32))
.Input(FakeInput(DT_INT32))
@@ -43,12 +45,33 @@ class CropAndResizeOpTest : public OpsTestBase {
}
};
-TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1) {
- MakeOp(0);
+#define REGISTER_TEST(T) \
+ TEST_F(CropAndResizeOpTest, TestCropAndResize##T) { \
+ MakeOp<T>(0); \
+ AddInputFromArray<T>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); \
+ AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1}); \
+ AddInputFromArray<int32>(TensorShape({1}), {0}); \
+ AddInputFromArray<int32>(TensorShape({2}), {1, 1}); \
+ TF_ASSERT_OK(RunOpKernel()); \
+ \
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1})); \
+ test::FillValues<float>(&expected, {2.5}); \
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0)); \
+ }
+
+REGISTER_TEST(float)
+REGISTER_TEST(double)
+REGISTER_TEST(int8)
+REGISTER_TEST(uint8)
+
+#undef REGISTER_TEST
+
+TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Uint8) {
+ MakeOp<uint8>(0);
// Input:
// 1, 2
// 3, 4
- AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<uint8>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({1}), {0});
AddInputFromArray<int32>(TensorShape({2}), {1, 1});
@@ -60,7 +83,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Flipped) {
- MakeOp(0);
+ MakeOp<float>(0);
// Input:
// 1, 2
// 3, 4
@@ -76,7 +99,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Flipped) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3) {
- MakeOp(0);
+ MakeOp<float>(0);
// Input:
// 1, 2
// 3, 4
@@ -97,7 +120,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Flipped) {
- MakeOp(0);
+ MakeOp<float>(0);
// Input:
// 1, 2
// 3, 4
@@ -118,7 +141,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Flipped) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2) {
- MakeOp(0);
+ MakeOp<float>(0);
// Input:
// 1, 2, 3
// 4, 5, 6
@@ -143,7 +166,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2Flipped) {
- MakeOp(0);
+ MakeOp<float>(0);
// Input:
// 1, 2, 3
// 4, 5, 6
@@ -169,7 +192,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2Flipped) {
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
const float v = -1;
- MakeOp(v);
+ MakeOp<float>(v);
// Input:
// 1, 2
// 3, 4
@@ -190,7 +213,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
- MakeOp(0);
+ MakeOp<float>(0);
// Input:
// 1, 2
// 3, 4
@@ -208,7 +231,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
}
TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
- MakeOp(0);
+ MakeOp<float>(0);
AddInputFromArray<float>(TensorShape({2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({1}), {0});
@@ -220,7 +243,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
- MakeOp(0);
+ MakeOp<float>(0);
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({2}), {0, 0});
@@ -233,7 +256,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
- MakeOp(0);
+ MakeOp<float>(0);
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({1}), {1});
diff --git a/tensorflow/core/kernels/gather_functor.cc b/tensorflow/core/kernels/gather_functor.cc
index be220d5c95..8ef027a1dd 100644
--- a/tensorflow/core/kernels/gather_functor.cc
+++ b/tensorflow/core/kernels/gather_functor.cc
@@ -38,6 +38,7 @@ namespace functor {
DECLARE_GPU_SPECS_INDEX(T, int64)
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+TF_CALL_complex64(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_INDEX
diff --git a/tensorflow/core/kernels/gather_functor_gpu.cu.cc b/tensorflow/core/kernels/gather_functor_gpu.cu.cc
index f1c1025078..456f4023a7 100644
--- a/tensorflow/core/kernels/gather_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/gather_functor_gpu.cu.cc
@@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice;
DEFINE_GPU_SPECS_INDEX(T, int64);
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
+TF_CALL_complex64(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS
#undef DEFINE_GPU_SPECS_INDEX
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index d8182218af..31af37693c 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -114,6 +114,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
+TF_CALL_complex64(REGISTER_GATHER_GPU);
#undef REGISTER_GATHER_GPU
diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc
index c340223aa1..23645dafad 100644
--- a/tensorflow/core/kernels/gather_op_test.cc
+++ b/tensorflow/core/kernels/gather_op_test.cc
@@ -40,9 +40,9 @@ namespace {
class GatherOpTest : public OpsTestBase {
protected:
- void MakeOp(DataType index_type) {
+ void MakeOp(DataType data_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "Gather")
- .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(data_type))
.Input(FakeInput(index_type))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
@@ -50,7 +50,7 @@ class GatherOpTest : public OpsTestBase {
};
TEST_F(GatherOpTest, ScalarIndices) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4});
@@ -63,8 +63,26 @@ TEST_F(GatherOpTest, ScalarIndices) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
+TEST_F(GatherOpTest, ScalarIndices_Complex) {
+ MakeOp(DT_COMPLEX64, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<std::complex<float>>(
+ TensorShape({5}), {std::complex<float>(0, 10), std::complex<float>(1, 11),
+ std::complex<float>(2, 12), std::complex<float>(3, 13),
+ std::complex<float>(4, 14)});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_COMPLEX64, TensorShape({}));
+ test::FillValues<std::complex<float>>(&expected,
+ {std::complex<float>(3, 13)});
+ test::ExpectTensorEqual<std::complex<float>>(expected, *GetOutput(0));
+}
+
TEST_F(GatherOpTest, Simple_TwoD32) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
@@ -79,7 +97,7 @@ TEST_F(GatherOpTest, Simple_TwoD32) {
}
TEST_F(GatherOpTest, ZeroSize_TwoD32) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 0}), {});
@@ -92,7 +110,7 @@ TEST_F(GatherOpTest, ZeroSize_TwoD32) {
}
TEST_F(GatherOpTest, Simple_TwoD64) {
- MakeOp(DT_INT64);
+ MakeOp(DT_FLOAT, DT_INT64);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
@@ -107,7 +125,7 @@ TEST_F(GatherOpTest, Simple_TwoD64) {
}
TEST_F(GatherOpTest, HighRank) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({4}), {0, 1, 2, 3});
@@ -121,7 +139,7 @@ TEST_F(GatherOpTest, HighRank) {
}
TEST_F(GatherOpTest, Error_IndexOutOfRange) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
index c5d5657492..a383cc8199 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
@@ -295,6 +295,83 @@ static void RunFusedGraph(const GraphDef& fused_graph_def) {
reinterpret_cast<const float*>(output_tensor.flat<float>().data()));
}
+static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
+ const GraphTransferInfo& gfi1) {
+ LOG(INFO) << "(1) node count: " << gfi1.node_info_size() << ", "
+ << gfi1.const_node_info_size();
+
+ // 1. check node_info
+ ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size());
+ for (int i = 0; i < gfi0.node_info_size(); ++i) {
+ const GraphTransferInfo::NodeInfo& ni0 = gfi0.node_info(i);
+ const GraphTransferInfo::NodeInfo& ni1 = gfi1.node_info(i);
+ EXPECT_EQ(ni0.DebugString(), ni1.DebugString());
+ EXPECT_EQ(ni0.ByteSize(), ni1.ByteSize());
+ }
+
+ // 2. check const_node_info
+ ASSERT_EQ(gfi0.const_node_info_size(), gfi1.const_node_info_size());
+ for (int i = 0; i < gfi0.const_node_info_size(); ++i) {
+ const GraphTransferInfo::ConstNodeInfo& cni0 = gfi0.const_node_info(i);
+ const GraphTransferInfo::ConstNodeInfo& cni1 = gfi1.const_node_info(i);
+ ASSERT_EQ(cni0.shape_size(), cni1.shape_size());
+ for (int j = 0; j < cni0.shape_size(); ++j) {
+ EXPECT_EQ(cni0.shape(j), cni1.shape(j));
+ }
+ EXPECT_EQ(cni0.ByteSize(), cni1.ByteSize());
+ EXPECT_EQ(cni0.DebugString(), cni1.DebugString());
+ }
+
+ // 3. check node_input_info
+ ASSERT_EQ(gfi0.node_input_info_size(), gfi1.node_input_info_size());
+ for (int i = 0; i < gfi0.node_input_info_size(); ++i) {
+ const GraphTransferInfo::NodeInputInfo& nii0 = gfi0.node_input_info(i);
+ const GraphTransferInfo::NodeInputInfo& nii1 = gfi1.node_input_info(i);
+ EXPECT_EQ(nii0.ByteSize(), nii1.ByteSize());
+ EXPECT_EQ(nii0.DebugString(), nii1.DebugString());
+ }
+
+ // 4. check node_output_info
+ ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size());
+ for (int i = 0; i < gfi0.node_output_info_size(); ++i) {
+ const GraphTransferInfo::NodeOutputInfo& noi0 = gfi0.node_output_info(i);
+ const GraphTransferInfo::NodeOutputInfo& noi1 = gfi1.node_output_info(i);
+ ASSERT_EQ(noi0.max_byte_size_size(), noi1.max_byte_size_size());
+ for (int j = 0; j < noi0.max_byte_size_size(); ++j) {
+ EXPECT_EQ(noi0.max_byte_size(j), noi1.max_byte_size(j));
+ }
+ EXPECT_EQ(noi0.ByteSize(), noi1.ByteSize());
+ EXPECT_EQ(noi0.DebugString(), noi1.DebugString());
+ }
+
+ // 5. check graph_input_node_info
+ ASSERT_EQ(gfi0.graph_input_node_info_size(),
+ gfi1.graph_input_node_info_size());
+ for (int i = 0; i < gfi0.graph_input_node_info_size(); ++i) {
+ const GraphTransferInfo::GraphInputNodeInfo& gini0 =
+ gfi0.graph_input_node_info(i);
+ const GraphTransferInfo::GraphInputNodeInfo& gini1 =
+ gfi0.graph_input_node_info(i);
+ EXPECT_EQ(gini0.ByteSize(), gini1.ByteSize());
+ EXPECT_EQ(gini0.DebugString(), gini1.DebugString());
+ }
+
+ // 6. check graph_output_node_info
+ ASSERT_EQ(gfi0.graph_output_node_info_size(),
+ gfi1.graph_output_node_info_size());
+ for (int i = 0; i < gfi0.graph_output_node_info_size(); ++i) {
+ const GraphTransferInfo::GraphOutputNodeInfo& goni0 =
+ gfi0.graph_output_node_info(i);
+ const GraphTransferInfo::GraphOutputNodeInfo& goni1 =
+ gfi0.graph_output_node_info(i);
+ EXPECT_EQ(goni0.ByteSize(), goni1.ByteSize());
+ EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
+ }
+
+ // 7. check destination
+ EXPECT_EQ(gfi0.destination(), gfi1.destination());
+}
+
// CAVEAT: This test only runs when you specify hexagon library using
// makefile.
// CAVEAT: This test is disabled by default because hexagon can keep only
@@ -450,34 +527,22 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
prof1.Stop();
prof1.DumpStatistics("Estiame shape by shape inference");
- LOG(INFO) << "(1) node count: " << gfi1.node_info_size() << ", "
- << gfi1.const_node_info_size();
+ CompareGraphTransferInfo(gfi0, gfi1);
- ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size());
+ const RemoteFusedGraphExecuteInfo ei0 =
+ BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi0);
+ const RemoteFusedGraphExecuteInfo ei1 =
+ BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi1);
- ASSERT_EQ(gt0.GetGraphTransferInfo().const_node_info_size(),
- gt1.GetGraphTransferInfo().const_node_info_size());
+ GraphTransferInfo rgfi0;
+ rgfi0.ParseFromString(ei0.serialized_executor_parameters());
+ GraphTransferInfo rgfi1;
+ rgfi1.ParseFromString(ei1.serialized_executor_parameters());
- for (int i = 0; i < gfi0.const_node_info_size(); ++i) {
- const GraphTransferInfo::ConstNodeInfo& ni0 = gfi0.const_node_info(i);
- const GraphTransferInfo::ConstNodeInfo& ni1 = gfi1.const_node_info(i);
- ASSERT_EQ(ni0.shape_size(), ni1.shape_size());
- for (int j = 0; j < ni0.shape_size(); ++j) {
- EXPECT_EQ(ni0.shape(j), ni1.shape(j));
- }
- }
-
- ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size());
- for (int i = 0; i < gfi0.node_output_info_size(); ++i) {
- const GraphTransferInfo::NodeOutputInfo& no0 = gfi0.node_output_info(i);
- const GraphTransferInfo::NodeOutputInfo& no1 = gfi1.node_output_info(i);
- ASSERT_EQ(no0.max_byte_size_size(), no1.max_byte_size_size());
- for (int j = 0; j < no0.max_byte_size_size(); ++j) {
- EXPECT_EQ(no0.max_byte_size(j), no1.max_byte_size(j));
- }
- }
+ CompareGraphTransferInfo(rgfi0, rgfi1);
+ CompareGraphTransferInfo(gfi0, rgfi0);
+ CompareGraphTransferInfo(gfi1, rgfi1);
}
-
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc
index 851d87b15b..ad9200e948 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc
@@ -174,6 +174,7 @@ const std::unordered_map<string, SupportedOpType> OP_NAME_TO_SOC_OP_TYPE_MAP{
{"Placeholder", SupportedOpType::NOP},
{"RequantizationRange", SupportedOpType::REQUANTIZATION_RANGE_32},
{"Requantize", SupportedOpType::REQUANTIZE_32_TO_8},
+ {"QuantizedReshape", SupportedOpType::QUANTIZED_RESHAPE},
};
/* static */ const IGraphTransferOpsDefinitions&
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index eb590280c9..6cb56797bf 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -587,8 +587,8 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {2}, 0, tensor_out.shape(), &output));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, tensor_out.shape(), &output));
PoolParameters params{context, ksize_, stride_,
padding_, data_format_, tensor_in.shape()};
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
index 32b210ecb7..e3a57d2f28 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -70,7 +70,7 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
- dtype maxval = -FLT_MAX;
+ dtype maxval = Eigen::NumTraits<dtype>::lowest();
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * channels * height * width;
for (int h = hstart; h < hend; ++h) {
@@ -312,9 +312,6 @@ __global__ void MaxPoolGradBackwardNoMaskNHWC(
// bottom_offset: the pre-computed per-image offset of the maxpool output.
// This is equal to Hout*Wout*C.
// bottom_diff: the gradient of the gradient w.r.t. output.
-// This function relies on CudaAtomicAdd to avoid race conditions. Also, before
-// the kernel is run, you will need to make sure that bottom_diff is filled with
-// zero first.
template <typename dtype>
__global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
const int64* mask, const int top_offset,
@@ -357,12 +354,12 @@ bool MaxPoolBackwardNoMask<T>::operator()(
const int stride_w, const int pad_t, const int pad_l, const T* top_diff,
T* bottom_diff, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
- const int bottom_size = batch * channels * height * width;
- const int top_size = batch * channels * pooled_height * pooled_width;
+ const int bottom_size = batch * channels * height * width;
SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff);
+ const int top_size = batch * channels * pooled_height * pooled_width;
MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) /
kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(
diff --git a/tensorflow/core/kernels/quantize_op.cc b/tensorflow/core/kernels/quantize_op.cc
index 7b34c32ceb..f649287fc1 100644
--- a/tensorflow/core/kernels/quantize_op.cc
+++ b/tensorflow/core/kernels/quantize_op.cc
@@ -86,6 +86,7 @@ class QuantizeV2Op : public OpKernel {
fabsf(input_max_range))) /
100.0f;
max_range = std::max(input_max_range, min_range + epsilon);
+ max_range = std::max(0.0f, max_range);
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
diff --git a/tensorflow/core/kernels/quantize_op_test.cc b/tensorflow/core/kernels/quantize_op_test.cc
index 41996852f1..48bde3b497 100644
--- a/tensorflow/core/kernels/quantize_op_test.cc
+++ b/tensorflow/core/kernels/quantize_op_test.cc
@@ -132,6 +132,50 @@ TEST_F(QuantizedOpTest, QuantizeV2EqualRange) {
EXPECT_LT(0.0f, output_max);
}
+TEST_F(QuantizedOpTest, QuantizeV2MovesMinToIncludeZero) {
+ TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("T", DataTypeToEnum<quint8>::v())
+ .Attr("mode", "MIN_FIRST")
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ AddInputFromArray<float>(TensorShape({3}), {0.1, 0.2, 0.3});
+ AddInputFromArray<float>(TensorShape({1}), {0.1});
+ AddInputFromArray<float>(TensorShape({1}), {0.3});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_QUINT8, TensorShape({3}));
+ test::FillValues<quint8>(&expected, {85, 170, 255});
+ test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
+ const float output_min = GetOutput(1)->flat<float>()(0);
+ const float output_max = GetOutput(2)->flat<float>()(0);
+ EXPECT_NEAR(0.0f, output_min, 1e-5f);
+ EXPECT_NEAR(0.3f, output_max, 1e-5f);
+}
+
+TEST_F(QuantizedOpTest, QuantizeV2MovesMaxToIncludeZero) {
+ TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("T", DataTypeToEnum<quint8>::v())
+ .Attr("mode", "MIN_FIRST")
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ AddInputFromArray<float>(TensorShape({3}), {-0.1, -0.2, -0.3});
+ AddInputFromArray<float>(TensorShape({1}), {-0.3});
+ AddInputFromArray<float>(TensorShape({1}), {-0.1});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_QUINT8, TensorShape({3}));
+ test::FillValues<quint8>(&expected, {170, 85, 0});
+ test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
+ const float output_min = GetOutput(1)->flat<float>()(0);
+ const float output_max = GetOutput(2)->flat<float>()(0);
+ EXPECT_NEAR(-0.3f, output_min, 1e-5f);
+ EXPECT_NEAR(0.0f, output_max, 1e-5f);
+}
+
TEST_F(QuantizedOpTest, Dequantize) {
TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "Dequantize")
.Input(FakeInput(DT_QUINT8))
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 3063fedac8..80b1be8d4c 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -178,27 +178,9 @@ namespace {
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
int index, Tensor** output) {
- if (!ctx->op_kernel().IsLegacyVector(shape.shape())) {
- return errors::InvalidArgument(
- "shape must be a vector of {int32,int64}, got shape ",
- shape.shape().DebugString());
- }
- if (shape.dtype() == DataType::DT_INT32) {
- auto vec = shape.flat<int32>();
- TensorShape tensor_shape;
- TF_RETURN_IF_ERROR(
- TensorShapeUtils::MakeShape(vec.data(), vec.size(), &tensor_shape));
- TF_RETURN_IF_ERROR(ctx->allocate_output(index, tensor_shape, output));
- } else if (shape.dtype() == DataType::DT_INT64) {
- auto vec = shape.flat<int64>();
- TensorShape tensor_shape;
- TF_RETURN_IF_ERROR(
- TensorShapeUtils::MakeShape(vec.data(), vec.size(), &tensor_shape));
- TF_RETURN_IF_ERROR(ctx->allocate_output(index, tensor_shape, output));
- } else {
- return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
- }
- return Status::OK();
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(ctx->op_kernel().MakeShape(shape, &tensor_shape));
+ return ctx->allocate_output(index, tensor_shape, output);
}
// For now, use the same interface as RandomOp, so we can choose either one
@@ -465,6 +447,12 @@ class RandomGammaOp : public OpKernel {
#define REGISTER(TYPE) \
template struct functor::FillPhiloxRandom< \
CPUDevice, random::UniformDistribution<random::PhiloxRandom, TYPE> >; \
+ template struct functor::FillPhiloxRandom< \
+ CPUDevice, random::NormalDistribution<random::PhiloxRandom, TYPE> >; \
+ template struct functor::FillPhiloxRandom< \
+ CPUDevice, \
+ random::TruncatedNormalDistribution< \
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >; \
REGISTER_KERNEL_BUILDER( \
Name("RandomUniform") \
.Device(DEVICE_CPU) \
diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc
index 553a4a7f93..66123e47c6 100644
--- a/tensorflow/core/kernels/random_poisson_op.cc
+++ b/tensorflow/core/kernels/random_poisson_op.cc
@@ -291,33 +291,15 @@ class RandomPoissonOp : public OpKernel {
const Tensor& shape_t = ctx->input(0);
const Tensor& rate_t = ctx->input(1);
- OP_REQUIRES(ctx,
- TensorShapeUtils::IsVector(shape_t.shape()) &&
- (shape_t.dtype() == DataType::DT_INT32 ||
- shape_t.dtype() == DataType::DT_INT64),
- errors::InvalidArgument(
- "shape must be a vector of {int32,int64}, got shape: ",
- shape_t.DebugString()));
TensorShape samples_shape;
- if (shape_t.dtype() == DataType::DT_INT32) {
- auto vec = shape_t.flat<int32>();
- OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
- &samples_shape));
- } else if (shape_t.dtype() == DataType::DT_INT64) {
- auto vec = shape_t.flat<int64>();
- OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
- &samples_shape));
- }
+ OP_REQUIRES_OK(ctx, MakeShape(shape_t, &samples_shape));
const int64 num_samples = samples_shape.num_elements();
- OP_REQUIRES(ctx, num_samples > 0,
- errors::InvalidArgument(
- "Input shape should have non-zero element count, got: ",
- num_samples));
samples_shape.AppendShape(rate_t.shape());
// Allocate output samples.
Tensor* samples_t = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
+ if (num_samples == 0) return;
const auto rate_flat = rate_t.flat<T>().data();
const int64 num_rate = rate_t.NumElements();
diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc
index 2e09956578..c665bc5b03 100644
--- a/tensorflow/core/kernels/save_restore_v2_ops.cc
+++ b/tensorflow/core/kernels/save_restore_v2_ops.cc
@@ -47,8 +47,9 @@ void ValidateInputs(bool is_save_op, OpKernelContext* context,
context, prefix.NumElements() == 1,
errors::InvalidArgument("Input prefix should have a single element, got ",
prefix.NumElements(), " instead."));
- OP_REQUIRES(context, TensorShapeUtils::IsVector(tensor_names.shape()) &&
- TensorShapeUtils::IsVector(shape_and_slices.shape()),
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsVector(tensor_names.shape()) &&
+ TensorShapeUtils::IsVector(shape_and_slices.shape()),
errors::InvalidArgument(
"Input tensor_names and shape_and_slices "
"should be an 1-D tensors, got ",
@@ -105,6 +106,7 @@ class SaveV2 : public OpKernel {
const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
BundleWriter writer(Env::Default(), prefix_string);
+ OP_REQUIRES_OK(context, writer.status());
VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;
for (int i = 0; i < num_tensors; ++i) {
diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc
index 3815716ccd..c513683918 100644
--- a/tensorflow/core/kernels/spacetobatch_op.cc
+++ b/tensorflow/core/kernels/spacetobatch_op.cc
@@ -100,6 +100,10 @@ void SpaceToBatchOpCompute(OpKernelContext* context,
for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
block_shape_product *= block_shape[block_dim];
}
+ OP_REQUIRES(
+ context, block_shape_product > 0,
+ errors::InvalidArgument("Product of block sizes must be positive, got ",
+ block_shape_product));
const int internal_block_dims =
block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h
index 1fec5a3b44..b2adb4462b 100644
--- a/tensorflow/core/lib/random/philox_random.h
+++ b/tensorflow/core/lib/random/philox_random.h
@@ -101,12 +101,15 @@ class Array {
// 2. PhiloxRandom is compilable by gcc and nvcc.
class PhiloxRandom {
public:
- typedef Array<uint32, 4> ResultType;
- typedef uint32 ResultElementType;
+ using ResultType = Array<uint32, 4>;
+ using ResultElementType = uint32;
// The number of elements that will be returned.
static const int kResultElementCount = 4;
// Cost of generation of a single element (in cycles).
static const int kElementCost = 10;
+ // The type for the 64-bit key stored in the form of two 32-bit uint
+ // that are used in the diffusion process.
+ using Key = Array<uint32, 2>;
PHILOX_DEVICE_INLINE
PhiloxRandom() {}
@@ -125,6 +128,9 @@ class PhiloxRandom {
counter_[3] = static_cast<uint32>(seed_hi >> 32);
}
+ PHILOX_DEVICE_INLINE
+ PhiloxRandom(ResultType counter, Key key) : counter_(counter), key_(key) {}
+
// Skip the specified number of samples of 128-bits in the current stream.
PHILOX_DEVICE_INLINE
void Skip(uint64 count) {
@@ -178,10 +184,6 @@ class PhiloxRandom {
}
private:
- // The type for the 64-bit key stored in the form of two 32-bit uint
- // that are used in the diffusion process.
- typedef Array<uint32, 2> Key;
-
// We use the same constants as recommended by the original paper.
static const uint32 kPhiloxW32A = 0x9E3779B9;
static const uint32 kPhiloxW32B = 0xBB67AE85;
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index e81490c498..e2e07a4bf1 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -41,10 +41,10 @@ Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
}
template <typename T>
-std::vector<int64> AsInt64(const Tensor* tensor, int num_elements) {
+std::vector<int64> AsInt64(const Tensor* tensor, int64 num_elements) {
std::vector<int64> ret(num_elements);
auto data = tensor->vec<T>();
- for (int i = 0; i < num_elements; ++i) {
+ for (int64 i = 0; i < num_elements; ++i) {
ret[i] = data(i);
}
return ret;
@@ -52,11 +52,11 @@ std::vector<int64> AsInt64(const Tensor* tensor, int num_elements) {
template <typename T>
Status PadKnown(InferenceContext* c, ShapeHandle input,
- const Tensor* paddings_t, int32 num_dims) {
+ const Tensor* paddings_t, int64 num_dims) {
// paddings_t is known.
std::vector<DimensionHandle> dims(num_dims);
auto paddings_data = paddings_t->matrix<T>();
- for (int i = 0; i < num_dims; ++i) {
+ for (int64 i = 0; i < num_dims; ++i) {
const T pad0 = paddings_data(i, 0);
const T pad1 = paddings_data(i, 1);
if (pad0 < 0 || pad1 < 0) {
@@ -1244,9 +1244,12 @@ REGISTER_OP("_ParallelConcatStart")
.Attr("dtype: type")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
- ShapeHandle out;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
- c->set_output(0, out);
+ TensorShapeProto shape_proto;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_proto));
+ ShapeHandle output_shape;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeProto(shape_proto, &output_shape));
+ c->set_output(0, output_shape);
return Status::OK();
})
.Doc(R"doc(
@@ -2644,10 +2647,10 @@ output: The padded tensor.
namespace {
template <typename T>
Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
- const Tensor* paddings_t, int32 input_rank) {
+ const Tensor* paddings_t, int64 input_rank) {
auto paddings_data = paddings_t->matrix<T>();
std::vector<DimensionHandle> dims(input_rank);
- for (int i = 0; i < input_rank; ++i) {
+ for (int64 i = 0; i < input_rank; ++i) {
const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
if (pad0 < 0 || pad1 < 0) {
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index bc99fb09e5..adb1320fc7 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -1626,4 +1626,16 @@ TEST(ArrayOpsTest, QuantizedConcat_ShapeFn) {
// Note that other cases of concat are covered in the Concat tests.
}
+TEST(StateOpsTest, _ParallelConcatStart_ShapeFn) {
+ ShapeInferenceTestOp op("_ParallelConcatStart");
+ TensorShape shape({1, 2, 3});
+ TensorShapeProto shape_proto;
+ shape.AsProto(&shape_proto);
+ TF_ASSERT_OK(NodeDefBuilder("test", "_ParallelConcatStart")
+ .Attr("shape", shape_proto)
+ .Attr("dtype", DT_FLOAT)
+ .Finalize(&op.node_def));
+ INFER_OK(op, "", "[1,2,3]");
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 10b5df91f1..7e7d499f88 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -120,7 +120,7 @@ REGISTER_OP("DynamicStitch")
TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
ShapeHandle extra_shape = c->UnknownShape();
- for (int i = 0; i < num_partitions; ++i) {
+ for (int64 i = 0; i < num_partitions; ++i) {
ShapeHandle indices_shape = c->input(i);
ShapeHandle data_shape = c->input(i + num_partitions);
if (!c->RankKnown(indices_shape)) {
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 7b2da9d8e6..392ac32010 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -23,17 +23,6 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-namespace {
-
-Status RandomShape(InferenceContext* c) {
- ShapeHandle out;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
- c->set_output(0, out);
- return Status::OK();
-}
-
-} // namepsace
-
REGISTER_OP("RandomUniform")
.Input("shape: T")
.SetIsStateful()
@@ -42,7 +31,7 @@ REGISTER_OP("RandomUniform")
.Attr("seed2: int = 0")
.Attr("dtype: {half,float,double}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random values from a uniform distribution.
@@ -69,7 +58,7 @@ REGISTER_OP("RandomUniformInt")
.Attr("seed2: int = 0")
.Attr("Tout: {int32, int64}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random integers from a uniform distribution.
@@ -100,7 +89,7 @@ REGISTER_OP("RandomStandardNormal")
.Attr("seed2: int = 0")
.Attr("dtype: {half,float,double}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random values from a normal distribution.
@@ -128,7 +117,7 @@ REGISTER_OP("ParameterizedTruncatedNormal")
.Attr("seed2: int = 0")
.Attr("dtype: {half,float,double}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random values from a normal distribution. The parameters may each be a
scalar which applies to the entire output, or a vector of length shape[0] which
@@ -158,7 +147,7 @@ REGISTER_OP("TruncatedNormal")
.Attr("seed2: int = 0")
.Attr("dtype: {half,float,double}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random values from a truncated normal distribution.
diff --git a/tensorflow/core/ops/set_ops.cc b/tensorflow/core/ops/set_ops.cc
index fad7007207..85d1335dcf 100644
--- a/tensorflow/core/ops/set_ops.cc
+++ b/tensorflow/core/ops/set_ops.cc
@@ -235,7 +235,7 @@ REGISTER_OP("SparseToSparseSetOperation")
DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0);
DimensionHandle output_rank_dim;
if (c->ValueKnown(input0_rank_dim)) {
- const int32 input0_rank = c->Value(input0_rank_dim);
+ const int64 input0_rank = c->Value(input0_rank_dim);
if (input0_rank < 2) {
return errors::InvalidArgument("Input 0, expected rank >= 2, got ",
input0_rank, ".");
@@ -244,7 +244,7 @@ REGISTER_OP("SparseToSparseSetOperation")
c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim));
output_rank_dim = input0_rank_dim;
} else if (c->ValueKnown(input1_rank_dim)) {
- const int32 input1_rank = c->Value(input1_rank_dim);
+ const int64 input1_rank = c->Value(input1_rank_dim);
if (input1_rank < 2) {
return errors::InvalidArgument("Input 1, expected rank >= 2, got ",
input1_rank, ".");
diff --git a/tensorflow/core/platform/cpu_info.cc b/tensorflow/core/platform/cpu_info.cc
index 9edf2de64c..906826e6f8 100644
--- a/tensorflow/core/platform/cpu_info.cc
+++ b/tensorflow/core/platform/cpu_info.cc
@@ -68,7 +68,7 @@ int GetXCR0EAX() {
// Structure for basic CPUID info
class CPUIDInfo {
-public:
+ public:
CPUIDInfo()
: have_adx_(0),
have_aes_(0),
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 6e9eff6225..63821cb55e 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -10,6 +10,15 @@ message RewriterConfig {
bool optimize_tensor_layout = 1;
bool disable_model_pruning = 2;
bool constant_folding = 3;
+
+ enum MemOptType {
+ // Fully disabled
+ NO_MEM_OPT = 0;
+ // Driven by manual annotations
+ MANUAL = 1;
+ }
+ MemOptType memory_optimization = 4;
+
// If non-empty, will use this as an alternative way to specify a list of
// optimizations to turn on and the order of the optimizations.
repeated string optimizers = 100;
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 8bb4ca8ff8..8a3f6c587e 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -128,6 +128,28 @@ __device__ __host__ inline T ldg(const T* address) {
#endif
}
+template <>
+__device__ __host__ inline std::complex<float> ldg(
+ const std::complex<float>* address) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
+ float2 mem = __ldg(reinterpret_cast<const float2*>(address));
+ return std::complex<float>(mem.x, mem.y);
+#else
+ return *address;
+#endif
+}
+
+template <>
+__device__ __host__ inline std::complex<double> ldg(
+ const std::complex<double>* address) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
+ double2 mem = __ldg(reinterpret_cast<const double2*>(address));
+ return std::complex<double>(mem.x, mem.y);
+#else
+ return *address;
+#endif
+}
+
// CUDA provides atomic ops, but not for all types. We provide wrappers
// for some ops and provide implementation for all reasonable types.
#define CUDA_ATOMIC_WRAPPER(op, T) \
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index b8989b2c3e..80a910e689 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -249,8 +249,10 @@ BundleWriter::BundleWriter(Env* env, StringPiece prefix)
random::New64())),
out_(nullptr),
size_(0) {
- status_ =
- env_->CreateDir(io::Dirname(prefix_).ToString()); // Ignores errors.
+ status_ = env_->CreateDir(io::Dirname(prefix_).ToString());
+ if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
+ return;
+ }
const string filename = DataFilename(prefix_, 0, 1);
std::unique_ptr<WritableFile> wrapper;
status_ = env_->NewWritableFile(tmp_data_path_, &wrapper);
@@ -264,9 +266,9 @@ BundleWriter::BundleWriter(Env* env, StringPiece prefix)
BundleWriter::~BundleWriter() { CHECK(out_ == nullptr); }
Status BundleWriter::Add(StringPiece key, const Tensor& val) {
+ if (!status_.ok()) return status_;
CHECK_NE(key, kHeaderEntryKey);
const string key_string = key.ToString();
- if (!status_.ok()) return status_;
if (entries_.find(key_string) != entries_.end()) {
status_ = errors::InvalidArgument("Adding duplicate key: ", key);
return status_;
@@ -301,14 +303,14 @@ Status BundleWriter::AddSlice(StringPiece full_tensor_key,
const TensorShape& full_tensor_shape,
const TensorSlice& slice_spec,
const Tensor& slice_tensor) {
+ if (!status_.ok()) return status_;
+ CHECK_NE(full_tensor_key, kHeaderEntryKey);
+
// If just a singleton full slice, use the regular Add() to be more efficient.
if (IsFullSlice(slice_spec, full_tensor_shape)) {
return Add(full_tensor_key, slice_tensor);
}
- CHECK_NE(full_tensor_key, kHeaderEntryKey);
- if (!status_.ok()) return status_;
-
// Inserts/updates the full tensor's metadata entry.
//
// In the case of a sharded save, MergeBundles() is responsible for merging
@@ -516,7 +518,8 @@ Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
// Merges all metadata tables.
// TODO(zhifengc): KeyValue sorter if it becomes too big.
MergeState merge;
- env->CreateDir(io::Dirname(merged_prefix).ToString()).IgnoreError();
+ Status status = env->CreateDir(io::Dirname(merged_prefix).ToString());
+ if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
for (int i = 0; i < prefixes.size(); ++i) {
TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
}
@@ -534,7 +537,6 @@ Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
std::unique_ptr<WritableFile> merged_metadata;
TF_RETURN_IF_ERROR(
env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
- Status status;
{
table::TableBuilder builder(table::Options(), merged_metadata.get());
// Header entry.
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
index bca3910f59..676bfe4df6 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
@@ -100,6 +100,10 @@ extern const int kTensorBundleVersion;
extern const char* const kHeaderEntryKey;
// Builds a string-string table of tensor names to BundleEntryProto (metadata).
+//
+// On construction, attempts to create a directory given by the dirname of
+// "prefix", so "status()" must be checked before calling any member functions.
+//
// All threads accessing the same BundleWriter must synchronize.
class BundleWriter {
public:
diff --git a/tensorflow/docs_src/get_started/get_started.md b/tensorflow/docs_src/get_started/get_started.md
index 6ae61b43a0..6116c7d87f 100644
--- a/tensorflow/docs_src/get_started/get_started.md
+++ b/tensorflow/docs_src/get_started/get_started.md
@@ -323,6 +323,10 @@ When run, it produces
W: [-0.9999969] b: [ 0.99999082] loss: 5.69997e-11
```
+Notice that the loss is a very small number (close to zero). If you run this
+program your loss will not be exactly the same, because the model is initialized
+with random values.
+
This more complicated program can still be visualized in TensorBoard
![TensorBoard final model visualization](../images/getting_started_final.png)
diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md
index 6f442e6e0c..10eebf6f42 100644
--- a/tensorflow/docs_src/programmers_guide/debugger.md
+++ b/tensorflow/docs_src/programmers_guide/debugger.md
@@ -130,6 +130,8 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at
| `lo -r hidden/Relu:0` | List the recipients of the output of the node `hidden/Relu`, recursively—i.e., the output recipient tree. |
| `lt -n softmax.*` | List all dumped tensors whose names match the regular-expression pattern `softmax.*`. |
| `lt -t MatMul` | List all dumped tensors whose node type is `MatMul`. |
+| `ls` | List all Python source files responsible for constructing the nodes (and tensors) in the current graph. |
+| `ls -n softmax.*` | List Python source files responsible for constructing the nodes whose names match the pattern `softmax.*`. |
| `ps /path/to/source.py` | Print the Python source file source.py, with the lines annotated with the ops created at each of them, respectively. |
| `ps -t /path/to/source.py` | Same as the command above, but perform annotation using dumped Tensors, instead of ops. |
| `ps -b 30 /path/to/source.py` | Annotate source.py beginning at line 30. |
diff --git a/tensorflow/docs_src/tutorials/recurrent.md b/tensorflow/docs_src/tutorials/recurrent.md
index a1c0532f5a..8cc6cf15ef 100644
--- a/tensorflow/docs_src/tutorials/recurrent.md
+++ b/tensorflow/docs_src/tutorials/recurrent.md
@@ -173,15 +173,22 @@ final_state = state
## Run the Code
-Start by cloning the [TensorFlow models repo](https://github.com/tensorflow/models) from GitHub.
-You'll also need to download the PTB dataset, as discussed at the beginning of
-this tutorial; we'll assume the dataset is located in `/tmp/simple-examples/data`.
+Before running the code, download the PTB dataset, as discussed at the beginning
+of this tutorial. Then, extract the PTB dataset underneath your home directory
+as follows:
-Run the following commands:
+```bsh
+tar xvfz simple-examples.tgz -C $HOME
+```
+_(Note: On Windows, you may need to use
+[other tools](https://wiki.haskell.org/How_to_unpack_a_tar_file_in_Windows).)_
+
+Now, clone the [TensorFlow models repo](https://github.com/tensorflow/models)
+from GitHub. Run the following commands:
-```bash
+```bsh
cd models/tutorials/rnn/ptb
-python ptb_word_lm.py --data_path=/tmp/simple-examples/data/ --model=small
+python ptb_word_lm.py --data_path=$HOME/simple-examples/data/ --model=small
```
There are 3 supported model configurations in the tutorial code: "small",
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 6900ac9a4f..5b50df3ed3 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1094,8 +1094,9 @@ class BaseSession(SessionInterface):
if tensors_to_delete:
feeds = {}
fetches = []
- for tensor_handle in tensors_to_delete:
+ for deleter_key, tensor_handle in enumerate(tensors_to_delete):
holder, deleter = session_ops._get_handle_deleter(self.graph,
+ deleter_key,
tensor_handle)
feeds[holder] = tensor_handle
fetches.append(deleter)
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index 99c154bd99..930eb5f283 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -375,6 +375,8 @@ Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
PyObject* fields = PyList_New(1);
PyList_SetItem(fields, 0, field);
int convert_result = PyArray_DescrConverter(fields, descr);
+ Py_CLEAR(field);
+ Py_CLEAR(fields);
if (convert_result != 1) {
return errors::Internal("Failed to create numpy array description for ",
"TF_RESOURCE-type tensor");
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 56dd7ceba5..0b87660e5d 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -185,6 +185,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":base_ui",
+ ":cli_shared",
":command_parser",
":curses_widgets",
":debugger_cli_common",
@@ -375,6 +376,7 @@ py_test(
":debug_utils",
":source_utils",
"//tensorflow/python:client",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py
index 0c8004e254..95d3f3f249 100644
--- a/tensorflow/python/debug/cli/analyzer_cli.py
+++ b/tensorflow/python/debug/cli/analyzer_cli.py
@@ -345,6 +345,25 @@ class DebugAnalyzer(object):
help="Print source beginning at line number (1-based.)")
self._arg_parsers["print_source"] = ap
+ # Parser for list_source.
+ ap = argparse.ArgumentParser(
+ description="List source files responsible for constructing nodes and "
+ "tensors present in the run().",
+ usage=argparse.SUPPRESS)
+ ap.add_argument(
+ "-p",
+ "--path_filter",
+ type=str,
+ default="",
+ help="Regular expression filter for file path.")
+ ap.add_argument(
+ "-n",
+ "--node_name_filter",
+ type=str,
+ default="",
+ help="Regular expression filter for node name.")
+ self._arg_parsers["list_source"] = ap
+
# TODO(cais): Implement list_nodes.
def add_tensor_filter(self, filter_name, filter_callable):
@@ -979,6 +998,15 @@ class DebugAnalyzer(object):
return output
+ def _reconstruct_print_source_command(self,
+ parsed,
+ line_begin_decrease=0,
+ max_elements_per_line_increase=0):
+ return "ps %s %s -b %d -m %d" % (
+ parsed.source_file_path, "-t" if parsed.tensors else "",
+ max(parsed.line_begin - line_begin_decrease, 1),
+ parsed.max_elements_per_line + max_elements_per_line_increase)
+
def print_source(self, args, screen_info=None):
"""Print the content of a source file."""
del screen_info # Unused.
@@ -1000,12 +1028,20 @@ class DebugAnalyzer(object):
labeled_source_lines = []
if parsed.line_begin > 1:
- labeled_source_lines.append(
- RL("(... Omitted %d source lines ...)" % (parsed.line_begin - 1),
- "bold"))
+ omitted_info_line = RL(
+ "(... Omitted %d source lines ...) " % (parsed.line_begin - 1),
+ "bold")
+ omitted_info_line += RL(
+ "+5",
+ debugger_cli_common.MenuItem(
+ None,
+ self._reconstruct_print_source_command(
+ parsed, line_begin_decrease=5)))
+ labeled_source_lines.append(omitted_info_line)
for i, line in enumerate(source_lines[parsed.line_begin - 1:]):
- annotated_line = RL("L%d" % (i + parsed.line_begin), "yellow")
+ annotated_line = RL("L%d" % (i + parsed.line_begin),
+ cli_shared.COLOR_YELLOW)
annotated_line += " " * (line_num_width - len(annotated_line))
annotated_line += line
labeled_source_lines.append(annotated_line)
@@ -1014,11 +1050,17 @@ class DebugAnalyzer(object):
sorted_elements = sorted(source_annotation[i + parsed.line_begin])
for k, element in enumerate(sorted_elements):
if k >= parsed.max_elements_per_line:
- labeled_source_lines.append(
- " (... Omitted %d of %d %s ...)" % (
- len(sorted_elements) - parsed.max_elements_per_line,
- len(sorted_elements),
- "tensor(s)" if parsed.tensors else "op(s)"))
+ omitted_info_line = RL(" (... Omitted %d of %d %s ...) " % (
+ len(sorted_elements) - parsed.max_elements_per_line,
+ len(sorted_elements),
+ "tensor(s)" if parsed.tensors else "op(s)"))
+ omitted_info_line += RL(
+ "+5",
+ debugger_cli_common.MenuItem(
+ None,
+ self._reconstruct_print_source_command(
+ parsed, max_elements_per_line_increase=5)))
+ labeled_source_lines.append(omitted_info_line)
break
label = RL(" " * 4)
@@ -1026,7 +1068,7 @@ class DebugAnalyzer(object):
debug_data.get_node_name(element)):
attribute = debugger_cli_common.MenuItem("", "pt %s" % element)
else:
- attribute = "blue"
+ attribute = cli_shared.COLOR_BLUE
label += RL(element, attribute)
labeled_source_lines.append(label)
@@ -1036,6 +1078,105 @@ class DebugAnalyzer(object):
_add_main_menu(output, node_name=None)
return output
+ def _make_source_table(self, source_list, is_tf_py_library):
+ """Make a table summarizing the source files that create nodes and tensors.
+
+ Args:
+ source_list: List of source files and related information as a list of
+ tuples (file_path, is_tf_library, num_nodes, num_tensors, num_dumps,
+ first_line).
+ is_tf_py_library: (`bool`) whether this table is for files that belong
+ to the TensorFlow Python library.
+
+ Returns:
+ The table as a `debugger_cli_common.RichTextLines` object.
+ """
+ path_head = "Source file path"
+ num_nodes_head = "#(nodes)"
+ num_tensors_head = "#(tensors)"
+ num_dumps_head = "#(tensor dumps)"
+
+ if is_tf_py_library:
+ # Use color to mark files that are guessed to belong to TensorFlow Python
+ # library.
+ color = cli_shared.COLOR_GRAY
+ lines = [RL("TensorFlow Python library file(s):", color)]
+ else:
+ color = cli_shared.COLOR_WHITE
+ lines = [RL("File(s) outside TensorFlow Python library:", color)]
+
+ if not source_list:
+ lines.append(RL("[No files.]"))
+ lines.append(RL())
+ return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
+
+ path_column_width = max(
+ max([len(item[0]) for item in source_list]), len(path_head)) + 1
+ num_nodes_column_width = max(
+ max([len(str(item[2])) for item in source_list]),
+ len(num_nodes_head)) + 1
+ num_tensors_column_width = max(
+ max([len(str(item[3])) for item in source_list]),
+ len(num_tensors_head)) + 1
+
+ head = RL(path_head + " " * (path_column_width - len(path_head)), color)
+ head += RL(num_nodes_head + " " * (
+ num_nodes_column_width - len(num_nodes_head)), color)
+ head += RL(num_tensors_head + " " * (
+ num_tensors_column_width - len(num_tensors_head)), color)
+ head += RL(num_dumps_head, color)
+
+ lines.append(head)
+
+ for item in source_list:
+ path_attributes = [debugger_cli_common.MenuItem(
+ None, "ps %s -b %d" % (item[0], item[5])), color]
+
+ line = RL(item[0], path_attributes)
+ line += " " * (path_column_width - len(line))
+ line += RL(
+ str(item[2]) + " " * (num_nodes_column_width - len(str(item[2]))),
+ color)
+ line += RL(
+ str(item[3]) + " " * (num_tensors_column_width - len(str(item[3]))),
+ color)
+ line += RL(str(item[4]), color)
+ lines.append(line)
+ lines.append(RL())
+
+ return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
+
+ def list_source(self, args, screen_info=None):
+ """List Python source files that constructed nodes and tensors."""
+ del screen_info # Unused.
+
+ parsed = self._arg_parsers["list_source"].parse_args(args)
+ source_list = source_utils.list_source_files_against_dump(
+ self._debug_dump,
+ path_regex_whitelist=parsed.path_filter,
+ node_name_regex_whitelist=parsed.node_name_filter)
+
+ top_lines = [
+ RL("List of source files that created nodes in this run", "bold")]
+ if parsed.path_filter:
+ top_lines.append(
+ RL("File path regex filter: \"%s\"" % parsed.path_filter))
+ if parsed.node_name_filter:
+ top_lines.append(
+ RL("Node name regex filter: \"%s\"" % parsed.node_name_filter))
+ top_lines.append(RL())
+ output = debugger_cli_common.rich_text_lines_from_rich_line_list(top_lines)
+ if not source_list:
+ output.append("[No source file information.]")
+ return output
+
+ output.extend(self._make_source_table(
+ [item for item in source_list if not item[1]], False))
+ output.extend(self._make_source_table(
+ [item for item in source_list if item[1]], True))
+ _add_main_menu(output, node_name=None)
+ return output
+
def _list_inputs_or_outputs(self,
recursive,
node_name,
@@ -1395,6 +1536,11 @@ def create_analyzer_ui(debug_dump,
analyzer.print_source,
analyzer.get_help("print_source"),
prefix_aliases=["ps"])
+ cli.register_command_handler(
+ "list_source",
+ analyzer.list_source,
+ analyzer.get_help("list_source"),
+ prefix_aliases=["ls"])
dumped_tensor_names = []
for datum in debug_dump.dumped_tensor_data:
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index bb2d72e2e4..185d395126 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -28,6 +28,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.cli import analyzer_cli
+from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.lib import debug_data
@@ -569,6 +570,11 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
cls._analyzer.print_source,
cls._analyzer.get_help("print_source"),
prefix_aliases=["ps"])
+ cls._registry.register_command_handler(
+ "list_source",
+ cls._analyzer.list_source,
+ cls._analyzer.get_help("list_source"),
+ prefix_aliases=["ls"])
@classmethod
def tearDownClass(cls):
@@ -906,7 +912,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
["ERROR: There is no node named \"bar\" in the partition graphs"],
out.lines)
# Check color indicating error.
- self.assertEqual({0: [(0, 59, "red")]}, out.font_attr_segs)
+ self.assertEqual({0: [(0, 59, cli_shared.COLOR_RED)]}, out.font_attr_segs)
check_main_menu(self, out, list_tensors_enabled=True)
def testPrintTensor(self):
@@ -1172,7 +1178,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
out.font_attr_segs[index + 1][0][2].content)
# simple_mul_add/u/Assign is not used in this run because the Variable has
# already been initialized.
- self.assertEqual("blue", out.font_attr_segs[index + 2][0][2])
+ self.assertEqual(cli_shared.COLOR_BLUE, out.font_attr_segs[index + 2][0][2])
self.assertEqual("pt simple_mul_add/u/read",
out.font_attr_segs[index + 3][0][2].content)
@@ -1234,6 +1240,12 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
screen_info={"cols": 80})
self.assertIn("Omitted 2 source lines", out.lines[0])
+ self.assertTrue(out.lines[0].endswith("+5"))
+ expand_lines_command = out.font_attr_segs[0][-1][2].content
+ self.assertStartsWith(expand_lines_command,
+ "ps %s " % self._curr_file_path)
+ self.assertIn("-b 1", expand_lines_command)
+
self.assertIsNone(self._findSourceLine(out, 1))
self.assertIsNone(self._findSourceLine(out, 2))
self.assertIsNotNone(self._findSourceLine(out, 3))
@@ -1250,7 +1262,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
out.font_attr_segs[index + 1][0][2].content)
# simple_mul_add/u/Assign is not used in this run because the Variable has
# already been initialized.
- self.assertEqual("blue", out.font_attr_segs[index + 2][0][2])
+ self.assertEqual(cli_shared.COLOR_BLUE, out.font_attr_segs[index + 2][0][2])
self.assertEqual("pt simple_mul_add/u/read",
out.font_attr_segs[index + 3][0][2].content)
@@ -1266,10 +1278,81 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
["L%d u = variables.Variable(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
- " (... Omitted 2 of 3 op(s) ...)"],
+ " (... Omitted 2 of 3 op(s) ...) +5"],
out.lines[index : index + 3])
self.assertEqual("pt simple_mul_add/u",
out.font_attr_segs[index + 1][0][2].content)
+ more_elements_command = out.font_attr_segs[index + 2][-1][2].content
+ self.assertStartsWith(more_elements_command,
+ "ps %s " % self._curr_file_path)
+ self.assertIn(" -m 6", more_elements_command)
+
+ def testListSourceWorks(self):
+ self._debug_dump.set_python_graph(self._sess.graph)
+ out = self._registry.dispatch_command("list_source", [])
+
+ non_tf_lib_files_start = [
+ i for i in xrange(len(out.lines))
+ if out.lines[i].startswith("Source file path")][0] + 1
+ non_tf_lib_files_end = [
+ i for i in xrange(len(out.lines))
+ if out.lines[i].startswith("TensorFlow Python library file(s):")][0] - 1
+ non_tf_lib_files = [
+ line.split(" ")[0] for line
+ in out.lines[non_tf_lib_files_start : non_tf_lib_files_end]]
+ self.assertIn(self._curr_file_path, non_tf_lib_files)
+
+ # Check that the TF library files are marked with special color attribute.
+ for i in xrange(non_tf_lib_files_end + 1, len(out.lines)):
+ if not out.lines[i]:
+ continue
+ for attr_seg in out.font_attr_segs[i]:
+ self.assertTrue(cli_shared.COLOR_GRAY in attr_seg[2] or
+ attr_seg[2] == cli_shared.COLOR_GRAY)
+
+ def testListSourceWithNodeNameFilterWithMatchesWorks(self):
+ self._debug_dump.set_python_graph(self._sess.graph)
+ out = self._registry.dispatch_command("list_source", ["-n", ".*/read"])
+
+ self.assertStartsWith(out.lines[1], "Node name regex filter: \".*/read\"")
+
+ non_tf_lib_files_start = [
+ i for i in xrange(len(out.lines))
+ if out.lines[i].startswith("Source file path")][0] + 1
+ non_tf_lib_files_end = [
+ i for i in xrange(len(out.lines))
+ if out.lines[i].startswith("TensorFlow Python library file(s):")][0] - 1
+ non_tf_lib_files = [
+ line.split(" ")[0] for line
+ in out.lines[non_tf_lib_files_start : non_tf_lib_files_end]]
+ self.assertIn(self._curr_file_path, non_tf_lib_files)
+
+ # Check that the TF library files are marked with special color attribute.
+ for i in xrange(non_tf_lib_files_end + 1, len(out.lines)):
+ if not out.lines[i]:
+ continue
+ for attr_seg in out.font_attr_segs[i]:
+ self.assertTrue(cli_shared.COLOR_GRAY in attr_seg[2] or
+ attr_seg[2] == cli_shared.COLOR_GRAY)
+
+ def testListSourceWithNodeNameFilterWithNoMatchesWorks(self):
+ self._debug_dump.set_python_graph(self._sess.graph)
+ out = self._registry.dispatch_command("list_source", ["-n", "^$"])
+
+ self.assertEqual([
+ "List of source files that created nodes in this run",
+ "Node name regex filter: \"^$\"", "",
+ "[No source file information.]"], out.lines)
+
+ def testListSourceWithPathAndNodeNameFiltersWorks(self):
+ self._debug_dump.set_python_graph(self._sess.graph)
+ out = self._registry.dispatch_command(
+ "list_source", ["-p", self._curr_file_path, "-n", ".*read"])
+
+ self.assertEqual([
+ "List of source files that created nodes in this run",
+ "File path regex filter: \"%s\"" % self._curr_file_path,
+ "Node name regex filter: \".*read\"", ""], out.lines[:4])
class AnalyzerCLIPrintLargeTensorTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py
index b195347950..8ff0916761 100644
--- a/tensorflow/python/debug/cli/cli_shared.py
+++ b/tensorflow/python/debug/cli/cli_shared.py
@@ -32,6 +32,16 @@ RL = debugger_cli_common.RichLine
# when printing the value of the tensor.
DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000
+COLOR_BLACK = "black"
+COLOR_BLUE = "blue"
+COLOR_CYAN = "cyan"
+COLOR_GRAY = "gray"
+COLOR_GREEN = "green"
+COLOR_MAGENTA = "magenta"
+COLOR_RED = "red"
+COLOR_WHITE = "white"
+COLOR_YELLOW = "yellow"
+
def bytes_to_readable_str(num_bytes, include_b=False):
"""Generate a human-readable string representing number of bytes.
@@ -154,7 +164,7 @@ def error(msg):
"""
return debugger_cli_common.rich_text_lines_from_rich_line_list([
- RL("ERROR: " + msg, "red")])
+ RL("ERROR: " + msg, COLOR_RED)])
def _get_fetch_name(fetch):
diff --git a/tensorflow/python/debug/cli/curses_ui.py b/tensorflow/python/debug/cli/curses_ui.py
index d8d3bce3de..b7549b406b 100644
--- a/tensorflow/python/debug/cli/curses_ui.py
+++ b/tensorflow/python/debug/cli/curses_ui.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import curses
from curses import textpad
+import os
import signal
import sys
import threading
@@ -27,6 +28,7 @@ import threading
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.debug.cli import base_ui
+from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import curses_widgets
from tensorflow.python.debug.cli import debugger_cli_common
@@ -42,6 +44,9 @@ _SCROLL_HOME = "home"
_SCROLL_END = "end"
_SCROLL_TO_LINE_INDEX = "scroll_to_line_index"
+_COLOR_READY_COLORTERMS = ["gnome-terminal", "xfce4-terminal"]
+_COLOR_ENABLED_TERM = "xterm-256color"
+
def _get_command_from_line_attr_segs(mouse_x, attr_segs):
"""Attempt to extract command from the attribute segments of a line.
@@ -77,7 +82,7 @@ class ScrollBar(object):
event in the screen region it occupies.
"""
- BASE_ATTR = "black_on_white"
+ BASE_ATTR = cli_shared.COLOR_BLACK + "_on_" + cli_shared.COLOR_WHITE
def __init__(self,
min_x,
@@ -225,27 +230,36 @@ class CursesUI(base_ui.BaseUI):
}
_FOREGROUND_COLORS = {
- "white": curses.COLOR_WHITE,
- "red": curses.COLOR_RED,
- "green": curses.COLOR_GREEN,
- "yellow": curses.COLOR_YELLOW,
- "blue": curses.COLOR_BLUE,
- "cyan": curses.COLOR_CYAN,
- "magenta": curses.COLOR_MAGENTA,
- "black": curses.COLOR_BLACK,
+ cli_shared.COLOR_WHITE: curses.COLOR_WHITE,
+ cli_shared.COLOR_RED: curses.COLOR_RED,
+ cli_shared.COLOR_GREEN: curses.COLOR_GREEN,
+ cli_shared.COLOR_YELLOW: curses.COLOR_YELLOW,
+ cli_shared.COLOR_BLUE: curses.COLOR_BLUE,
+ cli_shared.COLOR_CYAN: curses.COLOR_CYAN,
+ cli_shared.COLOR_MAGENTA: curses.COLOR_MAGENTA,
+ cli_shared.COLOR_BLACK: curses.COLOR_BLACK,
}
_BACKGROUND_COLORS = {
- "white": curses.COLOR_WHITE,
- "black": curses.COLOR_BLACK,
+ "transparent": -1,
+ cli_shared.COLOR_WHITE: curses.COLOR_WHITE,
+ cli_shared.COLOR_BLACK: curses.COLOR_BLACK,
}
# Font attribute for search and highlighting.
- _SEARCH_HIGHLIGHT_FONT_ATTR = "black_on_white"
- _ARRAY_INDICES_COLOR_PAIR = "black_on_white"
- _ERROR_TOAST_COLOR_PAIR = "red_on_white"
- _INFO_TOAST_COLOR_PAIR = "blue_on_white"
- _STATUS_BAR_COLOR_PAIR = "black_on_white"
- _UI_WAIT_COLOR_PAIR = "magenta_on_white"
+ _SEARCH_HIGHLIGHT_FONT_ATTR = (
+ cli_shared.COLOR_BLACK + "_on_" + cli_shared.COLOR_WHITE)
+ _ARRAY_INDICES_COLOR_PAIR = (
+ cli_shared.COLOR_BLACK + "_on_" + cli_shared.COLOR_WHITE)
+ _ERROR_TOAST_COLOR_PAIR = (
+ cli_shared.COLOR_RED + "_on_" + cli_shared.COLOR_WHITE)
+ _INFO_TOAST_COLOR_PAIR = (
+ cli_shared.COLOR_BLUE + "_on_" + cli_shared.COLOR_WHITE)
+ _STATUS_BAR_COLOR_PAIR = (
+ cli_shared.COLOR_BLACK + "_on_" + cli_shared.COLOR_WHITE)
+ _UI_WAIT_COLOR_PAIR = (
+ cli_shared.COLOR_MAGENTA + "_on_" + cli_shared.COLOR_WHITE)
+ _NAVIGATION_WARNING_COLOR_PAIR = (
+ cli_shared.COLOR_RED + "_on_" + cli_shared.COLOR_WHITE)
_UI_WAIT_MESSAGE = "Processing..."
@@ -370,29 +384,43 @@ class CursesUI(base_ui.BaseUI):
Creates curses stdscr and initialize the color pairs for display.
"""
-
+ # If the terminal type is color-ready, enable it.
+ if os.getenv("COLORTERM") in _COLOR_READY_COLORTERMS:
+ os.environ["TERM"] = _COLOR_ENABLED_TERM
self._stdscr = curses.initscr()
self._command_window = None
+ self._screen_color_init()
- # Prepare color pairs.
+ def _screen_color_init(self):
+ """Initialization of screen colors."""
curses.start_color()
-
+ curses.use_default_colors()
self._color_pairs = {}
color_index = 0
+ # Prepare color pairs.
for fg_color in self._FOREGROUND_COLORS:
for bg_color in self._BACKGROUND_COLORS:
-
color_index += 1
curses.init_pair(color_index, self._FOREGROUND_COLORS[fg_color],
self._BACKGROUND_COLORS[bg_color])
color_name = fg_color
- if bg_color != "black":
+ if bg_color != "transparent":
color_name += "_on_" + bg_color
self._color_pairs[color_name] = curses.color_pair(color_index)
+ # Try getting color(s) available only under 256-color support.
+ try:
+ color_index += 1
+ curses.init_pair(color_index, 245, -1)
+ self._color_pairs[cli_shared.COLOR_GRAY] = curses.color_pair(color_index)
+ except curses.error:
+ # Use fall-back color(s):
+ self._color_pairs[cli_shared.COLOR_GRAY] = (
+ self._color_pairs[cli_shared.COLOR_GREEN])
+
# A_BOLD or A_BLINK is not really a "color". But place it here for
# convenience.
self._color_pairs["bold"] = curses.A_BOLD
@@ -400,7 +428,7 @@ class CursesUI(base_ui.BaseUI):
self._color_pairs["underline"] = curses.A_UNDERLINE
# Default color pair to use when a specified color pair does not exist.
- self._default_color_pair = self._color_pairs["white"]
+ self._default_color_pair = self._color_pairs[cli_shared.COLOR_WHITE]
def _screen_launch(self, enable_mouse_on_start):
"""Launch the curses screen."""
@@ -588,7 +616,7 @@ class CursesUI(base_ui.BaseUI):
scroll_position = item.scroll_position
else:
self._toast("At the LATEST in navigation history!",
- color="red_on_white")
+ color=self._NAVIGATION_WARNING_COLOR_PAIR)
return
else:
if self._nav_history.can_go_back():
@@ -596,7 +624,7 @@ class CursesUI(base_ui.BaseUI):
scroll_position = item.scroll_position
else:
self._toast("At the OLDEST in navigation history!",
- color="red_on_white")
+ color=self._NAVIGATION_WARNING_COLOR_PAIR)
return
self._display_output(item.screen_output)
@@ -959,7 +987,7 @@ class CursesUI(base_ui.BaseUI):
self._curr_wrapped_output.lines.append("Output cut off at %d lines!" %
self.max_output_lines)
self._curr_wrapped_output.font_attr_segs[self.max_output_lines] = [
- (0, len(output.lines[-1]), "magenta")
+ (0, len(output.lines[-1]), cli_shared.COLOR_MAGENTA)
]
self._display_nav_bar()
@@ -1518,7 +1546,9 @@ class CursesUI(base_ui.BaseUI):
pad, _, _ = self._display_lines(
debugger_cli_common.RichTextLines(
- message, font_attr_segs={0: [(0, len(message), color or "white")]}),
+ message,
+ font_attr_segs={
+ 0: [(0, len(message), color or cli_shared.COLOR_WHITE)]}),
0)
right_end = min(len(message), self._max_x - 2)
diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py
index aee0849832..94eb2754da 100644
--- a/tensorflow/python/debug/cli/stepper_cli.py
+++ b/tensorflow/python/debug/cli/stepper_cli.py
@@ -68,19 +68,19 @@ class NodeStepperCLI(object):
_UPDATED_ATTRIBUTE = "bold"
_STATE_COLORS = {
- STATE_CONT: "green",
- STATE_DIRTY_VARIABLE: "magenta",
- STATE_DUMPED_INTERMEDIATE: "blue",
- STATE_OVERRIDDEN: "yellow",
- STATE_IS_PLACEHOLDER: "cyan",
- STATE_UNFEEDABLE: "red",
+ STATE_CONT: cli_shared.COLOR_GREEN,
+ STATE_DIRTY_VARIABLE: cli_shared.COLOR_MAGENTA,
+ STATE_DUMPED_INTERMEDIATE: cli_shared.COLOR_BLUE,
+ STATE_OVERRIDDEN: cli_shared.COLOR_YELLOW,
+ STATE_IS_PLACEHOLDER: cli_shared.COLOR_CYAN,
+ STATE_UNFEEDABLE: cli_shared.COLOR_RED,
}
_FEED_COLORS = {
- stepper.NodeStepper.FEED_TYPE_CLIENT: "white",
- stepper.NodeStepper.FEED_TYPE_HANDLE: "green",
- stepper.NodeStepper.FEED_TYPE_OVERRIDE: "yellow",
- stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE: "blue",
+ stepper.NodeStepper.FEED_TYPE_CLIENT: cli_shared.COLOR_WHITE,
+ stepper.NodeStepper.FEED_TYPE_HANDLE: cli_shared.COLOR_GREEN,
+ stepper.NodeStepper.FEED_TYPE_OVERRIDE: cli_shared.COLOR_YELLOW,
+ stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE: cli_shared.COLOR_BLUE,
}
def __init__(self, node_stepper):
diff --git a/tensorflow/python/debug/lib/source_utils.py b/tensorflow/python/debug/lib/source_utils.py
index cc949932cb..b8a5daf860 100644
--- a/tensorflow/python/debug/lib/source_utils.py
+++ b/tensorflow/python/debug/lib/source_utils.py
@@ -18,13 +18,47 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import os
+import re
+
+_TENSORFLOW_BASEDIR = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.dirname(
+ os.path.normpath(os.path.abspath(__file__))))))
def _convert_watch_key_to_tensor_name(watch_key):
return watch_key[:watch_key.rfind(":")]
+def _guess_is_tensorflow_py_library(py_file_path):
+ """Guess whether a Python source file is a part of the tensorflow library.
+
+ Special cases:
+ 1) Returns False for unit-test files in the library (*_test.py),
+ 2) Returns False for files under python/debug/examples.
+
+ Args:
+ py_file_path: full path of the Python source file in question.
+
+ Returns:
+ (`bool`) Whether the file is a part of the tensorflow library.
+
+ Raises:
+ ValueError: if py_file_path does not end with ".py".
+ """
+
+ if not py_file_path.endswith(".py"):
+ raise ValueError(
+ "Input file path (%s) is not a Python source file." % py_file_path)
+ py_file_path = os.path.normpath(os.path.abspath(py_file_path))
+
+ return (py_file_path.startswith(_TENSORFLOW_BASEDIR) and
+ not py_file_path.endswith("_test.py") and
+ not os.path.dirname(py_file_path).endswith(
+ os.path.normpath("python/debug/examples")))
+
+
def annotate_source(dump,
source_file_path,
do_dumped_tensors=False,
@@ -61,21 +95,16 @@ def annotate_source(dump,
raise ValueError("Cannot perform source annotation due to a lack of set "
"Python graph in the dump object")
- source_file_path = os.path.normpath(source_file_path)
+ source_file_path = os.path.normpath(os.path.abspath(source_file_path))
line_to_op_names = {}
for op in py_graph.get_operations():
- try:
- traceback = dump.node_traceback(op.name)
- except KeyError:
- pass
-
- for file_path, line_number, _, _ in reversed(traceback):
+ for file_path, line_number, _, _ in reversed(dump.node_traceback(op.name)):
if (min_line is not None and line_number < min_line or
max_line is not None and line_number >= max_line):
continue
- if os.path.normpath(file_path) != source_file_path:
+ if os.path.normpath(os.path.abspath(file_path)) != source_file_path:
continue
if do_dumped_tensors:
@@ -95,3 +124,103 @@ def annotate_source(dump,
break
return line_to_op_names
+
+
+def list_source_files_against_dump(dump,
+ path_regex_whitelist=None,
+ node_name_regex_whitelist=None):
+ """Generate a list of source files with information regarding ops and tensors.
+
+ Args:
+ dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
+ has been loaded.
+ path_regex_whitelist: A regular-expression filter for source file path.
+ node_name_regex_whitelist: A regular-expression filter for node names.
+
+ Returns:
+ A list of tuples regarding the Python source files involved in constructing
+ the ops and tensors contained in `dump`. Each tuple is:
+ (source_file_path, is_tf_library, num_nodes, num_tensors, num_dumps,
+ first_line)
+
+ is_tf_library: (`bool`) A guess of whether the file belongs to the
+ TensorFlow Python library.
+ num_nodes: How many nodes were created by lines of this source file.
+ These include nodes with dumps and those without.
+ num_tensors: How many Tensors were created by lines of this source file.
+ These include Tensors with dumps and those without.
+ num_dumps: How many debug Tensor dumps were from nodes (and Tensors)
+ that were created by this source file.
+ first_line: The first line number (1-based) that created any nodes or
+ Tensors in this source file.
+
+ The list is sorted by ascending order of source_file_path.
+
+ Raises:
+ ValueError: If the dump object does not have a Python graph set.
+ """
+
+ py_graph = dump.python_graph
+ if not py_graph:
+ raise ValueError("Cannot generate source list due to a lack of set "
+ "Python graph in the dump object")
+
+ path_to_node_names = collections.defaultdict(set)
+ path_to_tensor_names = collections.defaultdict(set)
+ path_to_first_line = {}
+ tensor_name_to_num_dumps = {}
+
+ path_regex = (re.compile(path_regex_whitelist)
+ if path_regex_whitelist else None)
+ node_name_regex = (re.compile(node_name_regex_whitelist)
+ if node_name_regex_whitelist else None)
+
+ to_skip_file_paths = set()
+ for op in py_graph.get_operations():
+ if node_name_regex and not node_name_regex.match(op.name):
+ continue
+
+ for file_path, line_number, _, _ in dump.node_traceback(op.name):
+ file_path = os.path.normpath(os.path.abspath(file_path))
+ if (file_path in to_skip_file_paths or
+ path_regex and not path_regex.match(file_path) or
+ not os.path.isfile(file_path)):
+ to_skip_file_paths.add(file_path)
+ continue
+
+ path_to_node_names[file_path].add(op.name)
+ if file_path in path_to_first_line:
+ if path_to_first_line[file_path] > line_number:
+ path_to_first_line[file_path] = line_number
+ else:
+ path_to_first_line[file_path] = line_number
+
+ for output_tensor in op.outputs:
+ tensor_name = output_tensor.name
+ path_to_tensor_names[file_path].add(tensor_name)
+
+ watch_keys = dump.debug_watch_keys(op.name)
+ for watch_key in watch_keys:
+ node_name, output_slot, debug_op = watch_key.split(":")
+ tensor_name = "%s:%s" % (node_name, output_slot)
+ if tensor_name not in tensor_name_to_num_dumps:
+ tensor_name_to_num_dumps[tensor_name] = len(
+ dump.get_tensors(node_name, int(output_slot), debug_op))
+
+ path_to_num_dumps = {}
+ for path in path_to_tensor_names:
+ path_to_num_dumps[path] = sum(
+ tensor_name_to_num_dumps.get(tensor_name, 0)
+ for tensor_name in path_to_tensor_names[path])
+
+ output = []
+ for file_path in path_to_node_names:
+ output.append((
+ file_path,
+ _guess_is_tensorflow_py_library(file_path),
+ len(path_to_node_names.get(file_path, {})),
+ len(path_to_tensor_names.get(file_path, {})),
+ path_to_num_dumps.get(file_path, 0),
+ path_to_first_line[file_path]))
+
+ return sorted(output, key=lambda x: x[0])
diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py
index 5d28bff207..138c75de31 100644
--- a/tensorflow/python/debug/lib/source_utils_test.py
+++ b/tensorflow/python/debug/lib/source_utils_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.debug.lib import source_utils
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -42,6 +43,37 @@ def line_number_above():
return inspect.stack()[1][2] - 1
+class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.curr_file_path = os.path.normpath(os.path.abspath(__file__))
+
+ def tearDown(self):
+ ops.reset_default_graph()
+
+ def testGuessedBaseDirIsProbablyCorrect(self):
+ self.assertEqual(
+ "tensorflow", os.path.basename(source_utils._TENSORFLOW_BASEDIR))
+
+ def testUnitTestFileReturnsFalse(self):
+ self.assertFalse(source_utils._guess_is_tensorflow_py_library(
+ self.curr_file_path))
+
+ def _disabledtestSourceUtilModuleReturnsTrue(self):
+ self.assertTrue(source_utils._guess_is_tensorflow_py_library(
+ source_utils.__file__))
+
+ def testFileInPythonKernelsPathReturnsTrue(self):
+ x = constant_op.constant(42.0, name="x")
+ self.assertTrue(source_utils._guess_is_tensorflow_py_library(
+ x.op.traceback[-1][0]))
+
+ def testNonPythonFileRaisesException(self):
+ with self.assertRaisesRegexp(ValueError, r"is not a Python source file"):
+ source_utils._guess_is_tensorflow_py_library(
+ os.path.join(os.path.dirname(self.curr_file_path), "foo.cc"))
+
+
class SourceHelperTest(test_util.TensorFlowTestCase):
def createAndRunGraphHelper(self):
@@ -199,5 +231,131 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
os.remove(unrelated_source_path)
+class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
+
+ def createAndRunGraphWithWhileLoop(self):
+ """Create and run a TensorFlow Graph with a while loop to generate dumps."""
+
+ self.dump_root = self.get_temp_dir()
+ self.curr_file_path = os.path.abspath(
+ inspect.getfile(inspect.currentframe()))
+
+ # Run a simple TF graph to generate some debug dumps that can be used in
+ # source annotation.
+ with session.Session() as sess:
+ loop_body = lambda i: math_ops.add(i, 2)
+ self.traceback_first_line = line_number_above()
+
+ loop_cond = lambda i: math_ops.less(i, 16)
+
+ i = constant_op.constant(10, name="i")
+ loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ debug_utils.watch_graph(
+ run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(loop, options=run_options, run_metadata=run_metadata)
+
+ self.dump = debug_data.DebugDumpDir(
+ self.dump_root, partition_graphs=run_metadata.partition_graphs)
+ self.dump.set_python_graph(sess.graph)
+
+ def setUp(self):
+ self.createAndRunGraphWithWhileLoop()
+
+ def tearDown(self):
+ if os.path.isdir(self.dump_root):
+ shutil.rmtree(self.dump_root)
+ ops.reset_default_graph()
+
+ def testGenerateSourceList(self):
+ source_list = source_utils.list_source_files_against_dump(self.dump)
+
+ # Assert that the file paths are sorted and unique.
+ file_paths = [item[0] for item in source_list]
+ self.assertEqual(sorted(file_paths), file_paths)
+ self.assertEqual(len(set(file_paths)), len(file_paths))
+
+ # Assert that each item of source_list has length 6.
+ for item in source_list:
+ self.assertTrue(isinstance(item, tuple))
+ self.assertEqual(6, len(item))
+
+ # The while loop body should have executed 3 times. The following table
+ # lists the tensors and how many times each of them is dumped.
+ # Tensor name # of times dumped:
+ # i:0 1
+ # while/Enter:0 1
+ # while/Merge:0 4
+ # while/Merge:1 4
+ # while/Less/y:0 4
+ # while/Less:0 4
+ # while/LoopCond:0 4
+ # while/Switch:0 1
+ # while/Swtich:1 3
+ # while/Identity:0 3
+ # while/Add/y:0 3
+ # while/Add:0 3
+ # while/NextIteration:0 3
+ # while/Exit:0 1
+ # ----------------------------
+ # (Total) 39
+ #
+ # The total number of nodes is 12.
+ # The total number of tensors is 14 (2 of the nodes have 2 outputs:
+ # while/Merge, while/Switch).
+
+ _, is_tf_py_library, num_nodes, num_tensors, num_dumps, first_line = (
+ source_list[file_paths.index(self.curr_file_path)])
+ self.assertFalse(is_tf_py_library)
+ self.assertEqual(12, num_nodes)
+ self.assertEqual(14, num_tensors)
+ self.assertEqual(39, num_dumps)
+ self.assertEqual(self.traceback_first_line, first_line)
+
+ def testGenerateSourceListWithNodeNameFilter(self):
+ source_list = source_utils.list_source_files_against_dump(
+ self.dump, node_name_regex_whitelist=r"while/Add.*")
+
+ # Assert that the file paths are sorted.
+ file_paths = [item[0] for item in source_list]
+ self.assertEqual(sorted(file_paths), file_paths)
+ self.assertEqual(len(set(file_paths)), len(file_paths))
+
+ # Assert that each item of source_list has length 4.
+ for item in source_list:
+ self.assertTrue(isinstance(item, tuple))
+ self.assertEqual(6, len(item))
+
+ # Due to the node-name filtering the result should only contain 2 nodes
+ # and 2 tensors. The total number of dumped tensors should be 6:
+ # while/Add/y:0 3
+ # while/Add:0 3
+ _, is_tf_py_library, num_nodes, num_tensors, num_dumps, _ = (
+ source_list[file_paths.index(self.curr_file_path)])
+ self.assertFalse(is_tf_py_library)
+ self.assertEqual(2, num_nodes)
+ self.assertEqual(2, num_tensors)
+ self.assertEqual(6, num_dumps)
+
+ def testGenerateSourceListWithPathRegexFilter(self):
+ curr_file_basename = os.path.basename(self.curr_file_path)
+ source_list = source_utils.list_source_files_against_dump(
+ self.dump,
+ path_regex_whitelist=(
+ ".*" + curr_file_basename.replace(".", "\\.") + "$"))
+
+ self.assertEqual(1, len(source_list))
+ (file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps,
+ first_line) = source_list[0]
+ self.assertEqual(self.curr_file_path, file_path)
+ self.assertFalse(is_tf_py_library)
+ self.assertEqual(12, num_nodes)
+ self.assertEqual(14, num_tensors)
+ self.assertEqual(39, num_dumps)
+ self.assertEqual(self.traceback_first_line, first_line)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 616b7ae49b..f1471a515f 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -108,18 +108,19 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":checkpoint_utils",
":export",
":model_fn",
":run_config",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client",
"//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:metrics",
"//tensorflow/python:platform",
+ "//tensorflow/python:random_seed",
"//tensorflow/python:summary",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python/saved_model:builder",
"//tensorflow/python/saved_model:tag_constants",
],
@@ -131,20 +132,31 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":estimator",
+ ":export",
":model_fn",
":numpy_io",
":run_config",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:init_ops",
"//tensorflow/python:layers",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
"//tensorflow/python:saver_test_utils",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
"//tensorflow/python/ops/losses",
"//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:tag_constants",
],
)
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 36918af552..80c5bbf684 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -141,6 +141,11 @@ class Estimator(object):
logging.info('Using config: %s', str(vars(self._config)))
+ if self._config.session_config is None:
+ self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ else:
+ self._session_config = self._config.session_config
+
self._device_fn = _get_replica_device_setter(self._config)
if model_fn is None:
@@ -317,7 +322,7 @@ class Estimator(object):
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
scaffold=estimator_spec.scaffold,
- config=config_pb2.ConfigProto(allow_soft_placement=True)),
+ config=self._session_config),
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
preds_evaluated = mon_sess.run(predictions)
@@ -552,7 +557,8 @@ class Estimator(object):
training.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
- defer_build=True))
+ defer_build=True,
+ save_relative_paths=True))
chief_hooks = []
if (self._config.save_checkpoints_secs or
@@ -579,7 +585,7 @@ class Estimator(object):
chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
- config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess:
+ config=self._session_config) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
@@ -634,7 +640,7 @@ class Estimator(object):
eval_ops=update_op,
final_ops=eval_dict,
hooks=hooks,
- config=config_pb2.ConfigProto(allow_soft_placement=True))
+ config=self._session_config)
_write_dict_to_summary(
output_dir=eval_dir,
@@ -643,12 +649,6 @@ class Estimator(object):
return eval_results
- def _verify_default_metric_key(self, metric_key, eval_dict):
- if metric_key in six.iterkeys(eval_dict):
- raise ValueError(
- 'Metric with name `%s` is not allowed, because Estimator '
- 'already defines a default metric with the same name.' % metric_key)
-
def _check_hooks_type(hooks):
"""Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index a1659156a6..84813073d3 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -23,6 +23,8 @@ import tempfile
import numpy as np
+from google.protobuf import text_format
+
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
@@ -34,6 +36,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import layers
+from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
@@ -48,6 +51,7 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.training import checkpoint_state_pb2
from tensorflow.python.training import saver
from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import session_run_hook
@@ -236,6 +240,40 @@ class EstimatorTrainTest(test.TestCase):
self.assertEqual(
5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
+ def test_checkpoint_contains_relative_paths(self):
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(
+ model_dir=tmpdir,
+ model_fn=model_fn_global_step_incrementer)
+ est.train(dummy_input_fn, steps=5)
+
+ checkpoint_file_content = file_io.read_file_to_string(
+ os.path.join(tmpdir, 'checkpoint'))
+ ckpt = checkpoint_state_pb2.CheckpointState()
+ text_format.Merge(checkpoint_file_content, ckpt)
+ self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
+ self.assertAllEqual(
+ ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
+
+ def test_train_save_copy_reload(self):
+ tmpdir = tempfile.mkdtemp()
+ model_dir1 = os.path.join(tmpdir, 'model_dir1')
+ est1 = estimator.Estimator(
+ model_dir=model_dir1,
+ model_fn=model_fn_global_step_incrementer)
+ est1.train(dummy_input_fn, steps=5)
+
+ model_dir2 = os.path.join(tmpdir, 'model_dir2')
+ os.renames(model_dir1, model_dir2)
+ est2 = estimator.Estimator(
+ model_dir=model_dir2,
+ model_fn=model_fn_global_step_incrementer)
+ self.assertEqual(
+ 5, estimator._load_global_step_from_checkpoint_dir(est2.model_dir))
+ est2.train(dummy_input_fn, steps=5)
+ self.assertEqual(
+ 10, estimator._load_global_step_from_checkpoint_dir(est2.model_dir))
+
def test_steps0_raises_error(self):
est = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops)
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index c6e6c60991..79b55c6853 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -73,6 +73,10 @@ class RunConfig(object):
return 600
@property
+ def session_config(self):
+ return None
+
+ @property
def save_checkpoints_steps(self):
return None
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index dac8d58b35..1f161e59cd 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -31,61 +31,80 @@ from tensorflow.python.platform import test
class GatherTest(test.TestCase):
use_gpu = False
+ def _buildParams(self, data, dtype):
+ data = data.astype(dtype.as_numpy_dtype)
+ # For complex types, add an index-dependent imaginary component so we can
+ # tell we got the right value.
+ if dtype.is_complex:
+ return data + 10j * data
+ return data
+
def testScalar1D(self):
with self.test_session(use_gpu=self.use_gpu):
- params = constant_op.constant([0, 1, 2, 3, 7, 5])
- indices = constant_op.constant(4)
- gather_t = array_ops.gather(params, indices)
- gather_val = gather_t.eval()
- self.assertAllEqual(7, gather_val)
- self.assertEqual([], gather_t.get_shape())
+ data = np.array([0, 1, 2, 3, 7, 5])
+ for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128):
+ params_np = self._buildParams(data, dtype)
+ params = constant_op.constant(params_np)
+ indices = constant_op.constant(4)
+ gather_t = array_ops.gather(params, indices)
+ gather_val = gather_t.eval()
+ self.assertAllEqual(params_np[4], gather_val)
+ self.assertEqual([], gather_t.get_shape())
def testScalar2D(self):
with self.test_session(use_gpu=self.use_gpu):
- params = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
- [9, 10, 11], [12, 13, 14]])
- indices = constant_op.constant(2)
- gather_t = array_ops.gather(params, indices)
- gather_val = gather_t.eval()
- self.assertAllEqual([6, 7, 8], gather_val)
- self.assertEqual([3], gather_t.get_shape())
+ data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14]])
+ for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128):
+ params_np = self._buildParams(data, dtype)
+ params = constant_op.constant(params_np)
+ indices = constant_op.constant(2)
+ gather_t = array_ops.gather(params, indices)
+ gather_val = gather_t.eval()
+ self.assertAllEqual(params_np[2], gather_val)
+ self.assertEqual([3], gather_t.get_shape())
def testSimpleTwoD32(self):
with self.test_session(use_gpu=self.use_gpu):
- params = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
- [9, 10, 11], [12, 13, 14]])
- indices = constant_op.constant([0, 4, 0, 2])
- gather_t = array_ops.gather(params, indices)
- gather_val = gather_t.eval()
- self.assertAllEqual([[0, 1, 2], [12, 13, 14], [0, 1, 2], [6, 7, 8]],
- gather_val)
- self.assertEqual([4, 3], gather_t.get_shape())
+ data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14]])
+ for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128):
+ params_np = self._buildParams(data, dtype)
+ params = constant_op.constant(params_np)
+ indices = constant_op.constant([0, 4, 0, 2])
+ gather_t = array_ops.gather(params, indices)
+ gather_val = gather_t.eval()
+ self.assertAllEqual(params_np[[0, 4, 0, 2]], gather_val)
+ self.assertEqual([4, 3], gather_t.get_shape())
def testHigherRank(self):
np.random.seed(1)
# We check that scalar and empty shapes work as well
for shape in (7, 0), (4, 3, 2):
for indices_shape in (), (0,), (3, 0), (3, 5):
- params = np.random.randn(*shape)
- indices = np.random.randint(shape[0], size=indices_shape)
- with self.test_session(use_gpu=self.use_gpu):
- tf_params = constant_op.constant(params)
- tf_indices = constant_op.constant(indices)
- gather = array_ops.gather(tf_params, tf_indices)
- self.assertAllEqual(params[indices], gather.eval())
- self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
- # Test gradients
- gather_grad = np.random.randn(*gather.get_shape().as_list())
- params_grad, indices_grad = gradients_impl.gradients(
- gather, [tf_params, tf_indices], gather_grad)
- self.assertEqual(indices_grad, None)
- self.assertEqual(type(params_grad), ops.IndexedSlices)
- params_grad = ops.convert_to_tensor(params_grad)
- correct_params_grad = np.zeros(shape)
- for i, g in zip(indices.flat,
- gather_grad.reshape((indices.size,) + shape[1:])):
- correct_params_grad[i] += g
- self.assertAllClose(correct_params_grad, params_grad.eval())
+ for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128):
+ params = self._buildParams(np.random.randn(*shape), dtype)
+ indices = np.random.randint(shape[0], size=indices_shape)
+ with self.test_session(use_gpu=self.use_gpu):
+ tf_params = constant_op.constant(params)
+ tf_indices = constant_op.constant(indices)
+ gather = array_ops.gather(tf_params, tf_indices)
+ self.assertAllEqual(params[indices], gather.eval())
+ self.assertEqual(indices.shape + params.shape[1:],
+ gather.get_shape())
+ # Test gradients
+ gather_grad = np.random.randn(*gather.get_shape().as_list()).astype(
+ dtype.as_numpy_dtype)
+ params_grad, indices_grad = gradients_impl.gradients(
+ gather, [tf_params, tf_indices], gather_grad)
+ self.assertEqual(indices_grad, None)
+ self.assertEqual(type(params_grad), ops.IndexedSlices)
+ params_grad = ops.convert_to_tensor(params_grad)
+ correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
+ for i, g in zip(indices.flat,
+ gather_grad.reshape((indices.size,) + shape[1:])):
+ correct_params_grad[i] += g
+ self.assertAllClose(correct_params_grad, params_grad.eval())
def testUnknownIndices(self):
params = constant_op.constant([[0, 1, 2]])
@@ -103,7 +122,7 @@ class GatherTest(test.TestCase):
def testEmptySlices(self):
with self.test_session(use_gpu=self.use_gpu):
- for dtype in np.float32, np.float64:
+ for dtype in np.float32, np.float64, np.complex64, np.complex128:
for itype in np.int32, np.int64:
params = np.zeros((7, 0), dtype=dtype)
indices = np.array([3, 4], dtype=itype)
diff --git a/tensorflow/python/kernel_tests/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg_ops_test.py
index ff299e6511..153d4ab662 100644
--- a/tensorflow/python/kernel_tests/linalg_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg_ops_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tensorflow.python.ops.special_math_ops."""
+"""Tests for tensorflow.python.ops.linalg_ops."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 8659382834..c998f57da7 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -69,6 +69,19 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: A string, the name of the layer.
+ renorm: Whether to use Batch Renormalization
+ (https://arxiv.org/abs/1702.03275). This adds extra variables during
+ training. The inference is the same for either value of this parameter.
+ renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
+ scalar `Tensors` used to clip the renorm correction. The correction
+ `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
+ `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
+ dmax are set to inf, 0, inf, respectively.
+ renorm_momentum: Momentum used to update the moving means and standard
+ deviations with renorm. Unlike `momentum`, this affects training
+ and should be neither too small (which would add noise) nor too large
+ (which would give stale estimates). Note that `momentum` is still applied
+ to get the means and variances for inference.
"""
def __init__(self,
@@ -85,6 +98,9 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
gamma_regularizer=None,
trainable=True,
name=None,
+ renorm=False,
+ renorm_clipping=None,
+ renorm_momentum=0.99,
**kwargs):
super(BatchNormalization, self).__init__(
name=name, trainable=trainable, **kwargs)
@@ -99,6 +115,15 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
self.moving_variance_initializer = moving_variance_initializer
self.beta_regularizer = beta_regularizer
self.gamma_regularizer = gamma_regularizer
+ self.renorm = renorm
+ if renorm:
+ renorm_clipping = renorm_clipping or {}
+ keys = ['rmax', 'rmin', 'dmax']
+ if set(renorm_clipping) - set(keys):
+ raise ValueError('renorm_clipping %s contains keys not in %s' %
+ (renorm_clipping, keys))
+ self.renorm_clipping = renorm_clipping
+ self.renorm_momentum = renorm_momentum
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
@@ -148,9 +173,90 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
shape=(param_dim,),
initializer=self.moving_variance_initializer,
trainable=False)
+ if self.renorm:
+ # Create variables to maintain the moving mean and standard deviation.
+ # These are used in training and thus are different from the moving
+ # averages above. The renorm variables are colocated with moving_mean
+ # and moving_variance.
+ # NOTE: below, the outer `with device` block causes the current device
+ # stack to be cleared. The nested ones use a `lambda` to set the desired
+ # device and ignore any devices that may be set by the custom getter.
+ def _renorm_variable(name, shape):
+ var = vs.get_variable(name,
+ shape=shape,
+ initializer=init_ops.zeros_initializer(),
+ trainable=False)
+ return var
+ with ops.device(None):
+ with ops.device(lambda _: self.moving_mean.device):
+ self.renorm_mean = _renorm_variable('renorm_mean', (param_dim,))
+ self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
+ # We initialize renorm_stddev to 0, and maintain the (0-initialized)
+ # renorm_stddev_weight. This allows us to (1) mix the average
+ # stddev with the minibatch stddev early in training, and (2) compute
+ # the unbiased average stddev by dividing renorm_stddev by the weight.
+ with ops.device(lambda _: self.moving_variance.device):
+ self.renorm_stddev = _renorm_variable('renorm_stddev', (param_dim,))
+ self.renorm_stddev_weight = _renorm_variable(
+ 'renorm_stddev_weight', ())
finally:
vs.get_variable_scope().set_partitioner(partitioner)
+ def _renorm_correction_and_moments(self, mean, variance, training):
+ """Returns the correction and update values for renorm."""
+ stddev = math_ops.sqrt(variance + self.epsilon)
+ # Compute the average mean and standard deviation, as if they were
+ # initialized with this batch's moments.
+ mixed_renorm_mean = (self.renorm_mean +
+ (1. - self.renorm_mean_weight) * mean)
+ mixed_renorm_stddev = (self.renorm_stddev +
+ (1. - self.renorm_stddev_weight) * stddev)
+ # Compute the corrections for batch renorm.
+ r = stddev / mixed_renorm_stddev
+ d = (mean - mixed_renorm_mean) / mixed_renorm_stddev
+ # Ensure the corrections use pre-update moving averages.
+ with ops.control_dependencies([r, d]):
+ mean = array_ops.identity(mean)
+ stddev = array_ops.identity(stddev)
+ rmin, rmax, dmax = [self.renorm_clipping.get(key)
+ for key in ['rmin', 'rmax', 'dmax']]
+ if rmin is not None:
+ r = math_ops.maximum(r, rmin)
+ if rmax is not None:
+ r = math_ops.minimum(r, rmax)
+ if dmax is not None:
+ d = math_ops.maximum(d, -dmax)
+ d = math_ops.minimum(d, dmax)
+ # When not training, use r=1, d=0, and decay=1 meaning no updates.
+ r = _smart_select(training, lambda: r, lambda: array_ops.ones_like(r))
+ d = _smart_select(training, lambda: d, lambda: array_ops.zeros_like(d))
+ decay = _smart_select(training, lambda: self.renorm_momentum, lambda: 1.)
+ def _update_renorm_variable(var, weight, value):
+ """Updates a moving average and weight, returns the unbiased value."""
+ # Update the variables without zero debiasing. The debiasing will be
+ # accomplished by dividing the exponential moving average by the weight.
+ # For example, after a single update, the moving average would be
+ # (1-decay) * value. and the weight will be 1-decay, with their ratio
+ # giving value.
+ new_var = moving_averages.assign_moving_average(
+ var, value, decay, zero_debias=False)
+ new_weight = moving_averages.assign_moving_average(
+ weight, 1., decay, zero_debias=False)
+ return new_var / new_weight
+
+ with ops.colocate_with(self.moving_mean):
+ new_mean = _update_renorm_variable(self.renorm_mean,
+ self.renorm_mean_weight,
+ mean)
+ with ops.colocate_with(self.moving_variance):
+ new_stddev = _update_renorm_variable(self.renorm_stddev,
+ self.renorm_stddev_weight,
+ stddev)
+ # Make sqrt(moving_variance + epsilon) = new_stddev.
+ new_variance = math_ops.square(new_stddev) - self.epsilon
+
+ return (r, d, new_mean, new_variance)
+
def call(self, inputs, training=False):
# First, compute the axes along which to reduce the mean / variance,
# as well as the broadcast shape to be used for all parameters.
@@ -164,82 +270,66 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
# Determines whether broadcasting is needed.
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
+ scale, offset = self.gamma, self.beta
+
# Determine a boolean value for `training`: could be True, False, or None.
training_value = utils.constant_value(training)
-
- if needs_broadcasting:
- # In this case we must explictly broadcast all parameters.
- if self.center:
- broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
- else:
- broadcast_beta = None
- if self.scale:
- broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape)
- else:
- broadcast_gamma = None
-
if training_value is not False:
- if needs_broadcasting:
- broadcast_mean, broadcast_variance = nn.moments(
- inputs, reduction_axes, keep_dims=True)
- mean = array_ops.reshape(broadcast_mean, [-1])
- variance = array_ops.reshape(broadcast_variance, [-1])
+ # Some of the computations here are not necessary when training==False
+ # but not a constant. However, this makes the code simpler.
+ mean, variance = nn.moments(inputs, reduction_axes)
+ if self.renorm:
+ r, d, new_mean, new_variance = self._renorm_correction_and_moments(
+ mean, variance, training)
+ # When training, the normalized values (say, x) will be transformed as
+ # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
+ # = x * (r * gamma) + (d * gamma + beta) with renorm.
+ scale = array_ops.stop_gradient(r, name='renorm_r')
+ offset = array_ops.stop_gradient(d, name='renorm_d')
+ if self.gamma is not None:
+ scale *= self.gamma
+ offset *= self.gamma
+ if self.beta is not None:
+ offset += self.beta
else:
- mean, variance = nn.moments(inputs, reduction_axes)
+ new_mean, new_variance = mean, variance
+
+ # Update moving averages when training, and prevent updates otherwise.
+ decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
+ mean_update = moving_averages.assign_moving_average(
+ self.moving_mean, new_mean, decay, zero_debias=False)
+ variance_update = moving_averages.assign_moving_average(
+ self.moving_variance, new_variance, decay, zero_debias=False)
- # Prepare updates if necessary.
if not self.updates:
- mean_update = moving_averages.assign_moving_average(
- self.moving_mean, mean, self.momentum, zero_debias=False)
- variance_update = moving_averages.assign_moving_average(
- self.moving_variance, variance, self.momentum, zero_debias=False)
# In the future this should be refactored into a self.add_update
# methods in order to allow for instance-based BN layer sharing
# across unrelated input streams (e.g. like in Keras).
self.updates.append(mean_update)
self.updates.append(variance_update)
- # Normalize batch. We do this inside separate functions for training
- # and inference so as to avoid evaluating both branches.
- def normalize_in_test():
- if needs_broadcasting:
- broadcast_moving_mean = array_ops.reshape(self.moving_mean,
- broadcast_shape)
- broadcast_moving_variance = array_ops.reshape(self.moving_variance,
- broadcast_shape)
- return nn.batch_normalization(inputs,
- broadcast_moving_mean,
- broadcast_moving_variance,
- broadcast_beta,
- broadcast_gamma,
- self.epsilon)
- else:
- return nn.batch_normalization(inputs,
- self.moving_mean,
- self.moving_variance,
- self.beta if self.center else None,
- self.gamma if self.scale else None,
- self.epsilon)
-
- def normalize_in_training():
- if needs_broadcasting:
- return nn.batch_normalization(inputs,
- broadcast_mean,
- broadcast_variance,
- broadcast_beta,
- broadcast_gamma,
- self.epsilon)
- else:
- return nn.batch_normalization(inputs,
- mean,
- variance,
- self.beta if self.center else None,
- self.gamma if self.scale else None,
- self.epsilon)
+ mean = _smart_select(training,
+ lambda: mean,
+ lambda: self.moving_mean)
+ variance = _smart_select(training,
+ lambda: variance,
+ lambda: self.moving_variance)
+
+ else:
+ mean, variance = self.moving_mean, self.moving_variance
- return utils.smart_cond(training,
- normalize_in_training,
- normalize_in_test)
+ def _broadcast(v):
+ if needs_broadcasting and v is not None:
+ # In this case we must explictly broadcast all parameters.
+ return array_ops.reshape(v, broadcast_shape)
+ return v
+
+ return nn.batch_normalization(inputs,
+ _broadcast(mean),
+ _broadcast(variance),
+ _broadcast(offset),
+ _broadcast(scale),
+ self.epsilon)
def batch_normalization(inputs,
@@ -257,7 +347,10 @@ def batch_normalization(inputs,
training=False,
trainable=True,
name=None,
- reuse=None):
+ reuse=None,
+ renorm=False,
+ renorm_clipping=None,
+ renorm_momentum=0.99):
"""Functional interface for the batch normalization layer.
Reference: http://arxiv.org/abs/1502.03167
@@ -294,6 +387,19 @@ def batch_normalization(inputs,
name: String, the name of the layer.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
+ renorm: Whether to use Batch Renormalization
+ (https://arxiv.org/abs/1702.03275). This adds extra variables during
+ training. The inference is the same for either value of this parameter.
+ renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
+ scalar `Tensors` used to clip the renorm correction. The correction
+ `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
+ `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
+ dmax are set to inf, 0, inf, respectively.
+ renorm_momentum: Momentum used to update the moving means and standard
+ deviations with renorm. Unlike `momentum`, this affects training
+ and should be neither too small (which would add noise) nor too large
+ (which would give stale estimates). Note that `momentum` is still applied
+ to get the means and variances for inference.
Returns:
Output tensor.
@@ -311,6 +417,9 @@ def batch_normalization(inputs,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
trainable=trainable,
+ renorm=renorm,
+ renorm_clipping=renorm_clipping,
+ renorm_momentum=renorm_momentum,
name=name,
_reuse=reuse,
_scope=name)
@@ -321,3 +430,39 @@ def batch_normalization(inputs,
BatchNorm = BatchNormalization
batch_norm = batch_normalization
+
+
+# Helper function
+
+
+def _smart_select(pred, fn_then, fn_else):
+ """Selects fn_then() or fn_else() based on the value of pred.
+
+ The purpose of this function is the same as `utils.smart_cond`. However, at
+ the moment there is a bug (b/36297356) that seems to kick in only when
+ `smart_cond` delegates to `tf.cond`, which sometimes results in the training
+ hanging when using parameter servers. This function will output the result
+ of `fn_then` or `fn_else` if `pred` is known at graph construction time.
+ Otherwise, it will use `tf.where` which will result in some redundant work
+ (both branches will be computed but only one selected). However, the tensors
+ involved will usually be small (means and variances in batchnorm), so the
+ cost will be small and will not be incurred at all if `pred` is a constant.
+
+ Args:
+ pred: A boolean scalar `Tensor`.
+ fn_then: A callable to use when pred==True.
+ fn_else: A callable to use when pred==False.
+
+ Returns:
+ A `Tensor` whose value is fn_then() or fn_else() based on the value of pred.
+ """
+ pred_value = utils.constant_value(pred)
+ if pred_value:
+ return fn_then()
+ elif pred_value is False:
+ return fn_else()
+ t_then = array_ops.expand_dims(fn_then(), 0)
+ t_else = array_ops.expand_dims(fn_else(), 0)
+ pred = array_ops.reshape(pred, [1])
+ result = array_ops.where(pred, t_then, t_else)
+ return array_ops.squeeze(result, [0])
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index 91b7cb6f48..0f82f73ea4 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import normalization as normalization_layers
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
@@ -513,6 +514,64 @@ class BNTest(test.TestCase):
_ = bn.apply(inputs, training=training)
self.assertEqual(len(bn.losses), 1)
+ def testRenorm(self):
+ shape = (4, 3)
+ xt = array_ops.placeholder(dtypes.float32, shape)
+ momentum = 0.99
+ renorm_momentum = 0.8
+ rmax = 1.1
+ rmin = 0.9
+ dmax = 0.1
+ gamma = 2.
+ beta = 3.
+ epsilon = 0.001
+ bn = normalization_layers.BatchNormalization(
+ axis=1,
+ gamma_initializer=init_ops.constant_initializer(gamma),
+ beta_initializer=init_ops.constant_initializer(beta),
+ epsilon=epsilon,
+ momentum=momentum,
+ renorm=True,
+ renorm_clipping={'rmax': rmax, 'rmin': rmin, 'dmax': dmax},
+ renorm_momentum=renorm_momentum)
+ training = array_ops.placeholder(dtypes.bool)
+ yt = bn.apply(xt, training=training)
+
+ moving_mean = 0.
+ moving_variance = 1.
+ renorm_mean = renorm_stddev = 0.
+ renorm_weight = 0.
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ for _ in range(5):
+ x = np.random.random(shape)
+
+ mean = x.mean(0)
+ stddev = np.sqrt(x.var(0) + epsilon)
+ adj_mean = renorm_mean + (1. - renorm_weight) * mean
+ adj_stddev = renorm_stddev + (1. - renorm_weight) * stddev
+ r = (stddev / adj_stddev).clip(rmin, rmax)
+ d = ((mean - adj_mean) / adj_stddev).clip(-dmax, dmax)
+ y_train = ((x - mean) / stddev * r + d) * gamma + beta
+ renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum)
+ renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum)
+ renorm_weight += (1. - renorm_weight) * (1. - renorm_momentum)
+ moving_mean += (renorm_mean / renorm_weight -
+ moving_mean) * (1. - momentum)
+ moving_variance += ((renorm_stddev / renorm_weight) ** 2 - epsilon -
+ moving_variance) * (1. - momentum)
+
+ y_test = ((x - moving_mean) / (moving_variance + epsilon) ** 0.5 *
+ gamma) + beta
+
+ yt_val_train, _, _ = sess.run([yt] + bn.updates,
+ feed_dict={xt: x, training: True})
+ yt_val_test, _, _ = sess.run([yt] + bn.updates,
+ feed_dict={xt: x, training: False})
+
+ self.assertAllClose(y_train, yt_val_train, atol=1e-5)
+ self.assertAllClose(y_test, yt_val_test, atol=1e-5)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 60057b9ab1..45efc51d5c 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -196,7 +196,7 @@ def broadcast_dynamic_shape(shape_x, shape_y):
Args:
shape_x: A rank 1 integer `Tensor`, representing the shape of x.
- shape_y: A rank 1 integer `Tensor`, representing the shape of x.
+ shape_y: A rank 1 integer `Tensor`, representing the shape of y.
Returns:
A rank 1 integer `Tensor` representing the broadcasted shape.
"""
@@ -1292,6 +1292,17 @@ def matrix_transpose(a, name="matrix_transpose"):
# tf.matrix_transpose(x) is shape [1, 2, 4, 3]
```
+ Note that `tf.matmul` provides kwargs allowing for transpose of arguments.
+ This is done with minimal cost, and is preferable to using this function. E.g.
+
+ ```
+ # Good! Transpose is taken at minimal additional cost.
+ tf.matmul(matrix, b, transpose_b=True)
+
+ # Inefficient!
+ tf.matmul(matrix, tf.matrix_transpose(b))
+ ```
+
Args:
a: A `Tensor` with `rank >= 2`.
name: A name for the operation (optional).
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 99d29a3719..66ccedf546 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -375,7 +375,9 @@ def with_space_to_batch(
input_shape_list = input.get_shape().as_list()
input_spatial_shape = [input_shape_list[i] for i in spatial_dims]
if input_spatial_shape is None or None in input_spatial_shape:
- input_spatial_shape = array_ops.gather(array_ops.shape(input), spatial_dims)
+ input_shape_tensor = array_ops.shape(input)
+ input_spatial_shape = array_ops.stack(
+ [input_shape_tensor[i] for i in spatial_dims])
paddings, crops = array_ops.required_space_to_batch_paddings(
input_shape=input_spatial_shape,
@@ -2021,7 +2023,7 @@ def top_k(input, k=1, sorted=True, name=None):
def conv1d(value, filters, stride, padding,
use_cudnn_on_gpu=None, data_format=None,
name=None):
- """Computes a 1-D convolution given 3-D input and filter tensors.
+ r"""Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape
[batch, in_width, in_channels]
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 86e0cae27a..77f0468c01 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -197,8 +197,10 @@ class ResourceVariable(object):
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
self._handle, self._initial_value, name=n)
with ops.name_scope("Read"), ops.colocate_with(self._handle):
- value = gen_resource_variable_ops.read_variable_op(
- self._handle, dtype=self._dtype)
+ # Manually assign reads to the handle's device to avoid log messages.
+ with ops.device(self._handle.device):
+ value = gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
self._graph_element = value
if caching_device is not None:
# Variables may be created in a tf.device() or ops.colocate_with()
@@ -276,8 +278,9 @@ class ResourceVariable(object):
"""A cached operation which reads the value of this variable."""
if self._cached_value is not None:
return self._cached_value
- return gen_resource_variable_ops.read_variable_op(
- self._handle, dtype=self._dtype)
+ with ops.device(self._handle.device):
+ return gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
def _as_graph_element(self):
"""Conversion function for Graph.as_graph_element()."""
@@ -318,8 +321,9 @@ class ResourceVariable(object):
the read operation.
"""
with ops.name_scope("Read"):
- value = gen_resource_variable_ops.read_variable_op(
- self._handle, dtype=self._dtype)
+ with ops.device(self._handle.device):
+ value = gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
# Return an identity so it can get placed on whatever device the context
# specifies instead of the device where the variable is.
return array_ops.identity(value)
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 162b13ec21..1051478a7f 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -37,6 +37,36 @@ _state_size_with_prefix = rnn_cell_impl._state_size_with_prefix
# pylint: enable=protected-access
+def _transpose_batch_time(x):
+ """Transpose the batch and time dimensions of a Tensor.
+
+ Retains as much of the static shape information as possible.
+
+ Args:
+ x: A tensor of rank 2 or higher.
+
+ Returns:
+ x transposed along the first two dimensions.
+
+ Raises:
+ ValueError: if `x` is rank 1 or lower.
+ """
+ x_static_shape = x.get_shape()
+ if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
+ raise ValueError(
+ "Expected input tensor %s to have rank at least 2, but saw shape: %s" %
+ (x, x_static_shape))
+ x_rank = array_ops.rank(x)
+ x_t = array_ops.transpose(
+ x, array_ops.concat(
+ ([1, 0], math_ops.range(2, x_rank)), axis=0))
+ x_t.set_shape(
+ tensor_shape.TensorShape([
+ x_static_shape[1].value, x_static_shape[0].value
+ ]).concatenate(x_static_shape[2:]))
+ return x_t
+
+
def _infer_state_dtype(explicit_dtype, state):
"""Infer the dtype of an RNN state.
@@ -492,8 +522,8 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
if not time_major:
# (B,T,D) => (T,B,D)
- flat_input = tuple(array_ops.transpose(input_, [1, 0, 2])
- for input_ in flat_input)
+ flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
+ flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
parallel_iterations = parallel_iterations or 32
if sequence_length is not None:
@@ -556,11 +586,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# to shape [batch, time, depth]
if not time_major:
# (T,B,D) => (B,T,D)
- flat_output = nest.flatten(outputs)
- flat_output = [array_ops.transpose(output, [1, 0, 2])
- for output in flat_output]
- outputs = nest.pack_sequence_as(
- structure=outputs, flat_sequence=flat_output)
+ outputs = nest.map_structure(_transpose_batch_time, outputs)
return (outputs, final_state)
@@ -1003,34 +1029,20 @@ def raw_rnn(cell, loop_fn,
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
- current_flat = nest.flatten(current)
- candidate_flat = nest.flatten(candidate)
- # pylint: disable=g-long-lambda,cell-var-from-loop
- result_flat = [
- _on_device(
- lambda: array_ops.where(
- elements_finished, current_i, candidate_i),
- device=candidate_i.op.device)
- for (current_i, candidate_i) in zip(current_flat, candidate_flat)]
- # pylint: enable=g-long-lambda,cell-var-from-loop
- return nest.pack_sequence_as(
- structure=current, flat_sequence=result_flat)
+ def copy_fn(cur_i, cand_i):
+ return _on_device(
+ lambda: array_ops.where(elements_finished, cur_i, cand_i),
+ device=cand_i.op.device)
+ return nest.map_structure(copy_fn, current, candidate)
emit_output = _copy_some_through(zero_emit, emit_output)
next_state = _copy_some_through(state, next_state)
- emit_output_flat = nest.flatten(emit_output)
- emit_ta_flat = nest.flatten(emit_ta)
+ emit_ta = nest.map_structure(
+ lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
elements_finished = math_ops.logical_or(elements_finished, next_finished)
- emit_ta_flat = [
- ta.write(time, emit)
- for (ta, emit) in zip(emit_ta_flat, emit_output_flat)]
-
- emit_ta = nest.pack_sequence_as(
- structure=emit_structure, flat_sequence=emit_ta_flat)
-
return (next_time, elements_finished, next_input,
emit_ta, next_state, loop_state)
diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py
index 0a06982ad7..3d038cfd8a 100644
--- a/tensorflow/python/ops/session_ops.py
+++ b/tensorflow/python/ops/session_ops.py
@@ -116,7 +116,7 @@ class TensorHandle(object):
raise TypeError("Persistent tensor %s may have already been deleted."
% self.handle)
self._auto_gc_enabled = False
- holder, deleter = _get_handle_deleter(self._session.graph, self._handle)
+ holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle)
self._session.run(deleter, feed_dict={holder: self.handle})
def get_raw_handle(self):
@@ -142,11 +142,6 @@ class TensorHandle(object):
return handle_parts[0] + ";" + handle_parts[-1]
@staticmethod
- def _get_deleter_key(handle):
- """The graph key for deleter."""
- return str(handle).split(";")[-1]
-
- @staticmethod
def _get_mover_key(feeder, handle):
"""The graph key for mover."""
return feeder.op.name + ";" + TensorHandle._get_reader_key(handle)
@@ -302,10 +297,9 @@ def _get_handle_mover(graph, feeder, handle):
return result
-def _get_handle_deleter(graph, handle):
+def _get_handle_deleter(graph, deleter_key, handle):
"""Return a deletion subgraph for this handle."""
- graph_key = TensorHandle._get_deleter_key(handle)
- result = graph._handle_deleters.get(graph_key)
+ result = graph._handle_deleters.get(deleter_key)
if result is None:
# Create deleter if we haven't done it.
handle_device = TensorHandle._get_device_name(handle)
@@ -313,5 +307,5 @@ def _get_handle_deleter(graph, handle):
holder = array_ops.placeholder(dtypes.string)
deleter = gen_data_flow_ops._delete_session_tensor(holder)
result = (holder, deleter)
- graph._handle_deleters[graph_key] = result
+ graph._handle_deleters[deleter_key] = result
return result
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 2a64cb7b70..f46f56cbb7 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -243,8 +243,9 @@ def assign_add(ref, value, use_locking=None, name=None):
def assign(ref, value, validate_shape=None, use_locking=None, name=None):
"""Update 'ref' by assigning 'value' to it.
- This operation outputs "ref" after the assignment is done.
- This makes it easier to chain operations that need to use the reset value.
+ This operation outputs a Tensor that holds the new value of 'ref' after
+ the value has been assigned. This makes it easier to chain operations
+ that need to use the reset value.
Args:
ref: A mutable `Tensor`.
@@ -261,8 +262,8 @@ def assign(ref, value, validate_shape=None, use_locking=None, name=None):
name: A name for the operation (optional).
Returns:
- Same as "ref". Returned as a convenience for operations that want
- to use the new value after the variable has been reset.
+ A `Tensor` that will hold the new value of 'ref' after
+ the assignment has completed.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.assign(
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 19c5d3c3ea..b3745fa4e6 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -974,6 +974,8 @@ class VariableScope(object):
partitioner = self._partitioner
if dtype is None:
dtype = self._dtype
+ if use_resource is None:
+ use_resource = self._use_resource
if self._custom_getter is not None:
raise ValueError(
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 111461f784..4b0ef50df5 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import optimizer
@@ -154,7 +155,7 @@ class AdamOptimizer(optimizer.Optimizer):
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
grad, use_locking=self._use_locking)
- def _apply_sparse(self, grad, var):
+ def _apply_sparse_shared(self, grad, var, indices, scatter_add):
beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
@@ -164,23 +165,39 @@ class AdamOptimizer(optimizer.Optimizer):
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
- m_scaled_g_values = grad.values * (1 - beta1_t)
+ m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t,
use_locking=self._use_locking)
- m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values,
- use_locking=self._use_locking)
+ with ops.control_dependencies([m_t]):
+ m_t = scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
- v_scaled_g_values = (grad.values * grad.values) * (1 - beta2_t)
- v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
- v_t = state_ops.scatter_add(v_t, grad.indices, v_scaled_g_values,
- use_locking=self._use_locking)
+ v_scaled_g_values = (grad * grad) * (1 - beta2_t)
+ v_t = state_ops.assign(v, v * beta2_t)
+ with ops.control_dependencies([v_t]):
+ v_t = scatter_add(v, indices, v_scaled_g_values)
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(var,
lr * m_t / (v_sqrt + epsilon_t),
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
+ def _apply_sparse(self, grad, var):
+ return self._apply_sparse_shared(
+ grad.values, var, grad.indices,
+ lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
+ x, i, v, use_locking=self._use_locking))
+
+ def _resource_scatter_add(self, x, i, v):
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(
+ x.handle, i, v)]):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ return self._apply_sparse_shared(
+ grad, var, indices, self._resource_scatter_add)
+
def _finish(self, update_ops, name_scope):
# Update the power accumulators.
with ops.control_dependencies(update_ops):
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index 00ff5d9b9d..62b171e234 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -52,7 +52,7 @@ def adam_update_numpy(param,
class AdamOptimizerTest(test.TestCase):
- def testSparse(self):
+ def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
# Initialize variables for numpy implementation.
@@ -62,8 +62,12 @@ class AdamOptimizerTest(test.TestCase):
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np),
@@ -95,6 +99,12 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ def testSparse(self):
+ self.doTestSparse(use_resource=False)
+
+ def testResourceSparse(self):
+ self.doTestSparse(use_resource=True)
+
def testSparseDevicePlacement(self):
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(force_gpu=test.is_gpu_available()):
diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py
index 7f403f4927..85ee10379a 100644
--- a/tensorflow/python/training/device_setter.py
+++ b/tensorflow/python/training/device_setter.py
@@ -198,7 +198,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
if ps_ops is None:
# TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be
# placed in the parameter server.
- ps_ops = ["Variable", "VariableV2"]
+ ps_ops = ["Variable", "VariableV2", "VarHandleOp"]
if not merge_devices:
logging.warning(
diff --git a/tensorflow/python/training/device_setter_test.py b/tensorflow/python/training/device_setter_test.py
index e05f0f6a1c..bc29e0d21c 100644
--- a/tensorflow/python/training/device_setter_test.py
+++ b/tensorflow/python/training/device_setter_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
@@ -46,6 +47,12 @@ class DeviceSetterTest(test.TestCase):
self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
self.assertDeviceEqual("/job:worker/cpu:0", a.device)
+ def testResource(self):
+ with ops.device(
+ device_setter.replica_device_setter(cluster=self._cluster_spec)):
+ v = resource_variable_ops.ResourceVariable([1, 2])
+ self.assertDeviceEqual("/job:ps/task:0", v.device)
+
def testPS2TasksWithClusterSpecClass(self):
with ops.device(
device_setter.replica_device_setter(cluster=self._cluster_spec)):
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index cf8692eda1..6d6128d207 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -252,7 +252,7 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
save_summaries_secs=None,
config=None,
stop_grace_period_secs=120,
- log_step_count_steps=10000):
+ log_step_count_steps=100):
"""Creates a `MonitoredSession` for training.
For a chief, this utility sets proper session initializer/restorer. It also
diff --git a/tensorflow/tensorboard/backend/application.py b/tensorflow/tensorboard/backend/application.py
index 005d183039..974762822f 100644
--- a/tensorflow/tensorboard/backend/application.py
+++ b/tensorflow/tensorboard/backend/application.py
@@ -61,6 +61,7 @@ DATA_PREFIX = '/data'
LOGDIR_ROUTE = '/logdir'
RUNS_ROUTE = '/runs'
PLUGIN_PREFIX = '/plugin'
+PLUGINS_LISTING_ROUTE = '/plugins_listing'
SCALARS_ROUTE = '/' + event_accumulator.SCALARS
IMAGES_ROUTE = '/' + event_accumulator.IMAGES
AUDIO_ROUTE = '/' + event_accumulator.AUDIO
@@ -152,30 +153,34 @@ class TensorBoardWSGIApp(object):
reload_multiplexer(self._multiplexer, path_to_run)
self.data_applications = {
- DATA_PREFIX + LOGDIR_ROUTE:
- self._serve_logdir,
- DATA_PREFIX + SCALARS_ROUTE:
- self._serve_scalars,
+ '/app.js':
+ self._serve_js,
+ DATA_PREFIX + AUDIO_ROUTE:
+ self._serve_audio,
+ DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE:
+ self._serve_compressed_histograms,
DATA_PREFIX + GRAPH_ROUTE:
self._serve_graph,
- DATA_PREFIX + RUN_METADATA_ROUTE:
- self._serve_run_metadata,
DATA_PREFIX + HISTOGRAMS_ROUTE:
self._serve_histograms,
- DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE:
- self._serve_compressed_histograms,
DATA_PREFIX + IMAGES_ROUTE:
self._serve_images,
- DATA_PREFIX + INDIVIDUAL_IMAGE_ROUTE:
- self._serve_image,
- DATA_PREFIX + AUDIO_ROUTE:
- self._serve_audio,
DATA_PREFIX + INDIVIDUAL_AUDIO_ROUTE:
self._serve_individual_audio,
+ DATA_PREFIX + INDIVIDUAL_IMAGE_ROUTE:
+ self._serve_image,
+ DATA_PREFIX + LOGDIR_ROUTE:
+ self._serve_logdir,
+ # TODO(chizeng): Delete this RPC once we have skylark rules that obviate
+ # the need for the frontend to determine which plugins are active.
+ DATA_PREFIX + PLUGINS_LISTING_ROUTE:
+ self._serve_plugins_listing,
+ DATA_PREFIX + RUN_METADATA_ROUTE:
+ self._serve_run_metadata,
DATA_PREFIX + RUNS_ROUTE:
self._serve_runs,
- '/app.js':
- self._serve_js
+ DATA_PREFIX + SCALARS_ROUTE:
+ self._serve_scalars,
}
# Serve the routes from the registered plugins using their name as the route
@@ -489,6 +494,21 @@ class TensorBoardWSGIApp(object):
return query_string
@wrappers.Request.application
+ def _serve_plugins_listing(self, request):
+ """Serves an object mapping plugin name to whether it is enabled.
+
+ Args:
+ request: The werkzeug.Request object.
+
+ Returns:
+ A werkzeug.Response object.
+ """
+ return http_util.Respond(
+ request,
+ {plugin.plugin_name: plugin.is_active() for plugin in self._plugins},
+ 'application/json')
+
+ @wrappers.Request.application
def _serve_runs(self, request):
"""WSGI app serving a JSON object about runs and tags.
diff --git a/tensorflow/tensorboard/backend/application_test.py b/tensorflow/tensorboard/backend/application_test.py
index 454ba63e75..002709cd5b 100644
--- a/tensorflow/tensorboard/backend/application_test.py
+++ b/tensorflow/tensorboard/backend/application_test.py
@@ -51,6 +51,40 @@ from tensorflow.tensorboard.backend.event_processing import event_multiplexer
from tensorflow.tensorboard.plugins import base_plugin
+class FakePlugin(base_plugin.TBPlugin):
+ """A plugin with no functionality."""
+
+ def __init__(self, plugin_name, is_active_value):
+ """Constructs a fake plugin.
+
+ Args:
+ plugin_name: The name of this plugin.
+ is_active_value: Whether the plugin is active.
+ """
+ self.plugin_name = plugin_name
+ self._is_active_value = is_active_value
+
+ def get_plugin_apps(self, multiplexer, logdir):
+ """Returns a mapping from routes to handlers offered by this plugin.
+
+ Args:
+ multiplexer: The event multiplexer.
+ logdir: The path to the directory containing logs.
+
+ Returns:
+ An empty dict. This plugin offers no routes.
+ """
+ return {}
+
+ def is_active(self):
+ """Returns whether this plugin is active.
+
+ Returns:
+ A boolean. Whether this plugin is active.
+ """
+ return self._is_active_value
+
+
class TensorboardServerTest(test.TestCase):
_only_use_meta_graph = False # Server data contains only a GraphDef
@@ -62,7 +96,10 @@ class TensorboardServerTest(test.TestCase):
multiplexer = event_multiplexer.EventMultiplexer(
size_guidance=application.DEFAULT_SIZE_GUIDANCE,
purge_orphaned_data=True)
- plugins = []
+ plugins = [
+ FakePlugin(plugin_name='foo', is_active_value=True),
+ FakePlugin(plugin_name='bar', is_active_value=False)
+ ]
app = application.TensorBoardWSGIApp(
self.temp_dir, plugins, multiplexer, reload_interval=0)
try:
@@ -124,6 +161,12 @@ class TensorboardServerTest(test.TestCase):
parsed_object = self._getJson('/data/logdir')
self.assertEqual(parsed_object, {'logdir': self.temp_dir})
+ def testPluginsListing(self):
+ """Test the format of the data/plugins_listing endpoint."""
+ parsed_object = self._getJson('/data/plugins_listing')
+ # Plugin foo is active. Plugin bar is not.
+ self.assertEqual(parsed_object, {'foo': True, 'bar': False})
+
def testRuns(self):
"""Test the format of the /data/runs endpoint."""
run_json = self._getJson('/data/runs')
@@ -484,29 +527,21 @@ class TensorboardSimpleServerConstructionTest(test.TestCase):
class TensorBoardApplcationConstructionTest(test.TestCase):
def testExceptions(self):
-
- class UnnamedPlugin(base_plugin.TBPlugin):
-
- def get_plugin_apps(self):
- pass
-
- class MockPlugin(UnnamedPlugin):
- plugin_name = 'mock'
-
- class OtherMockPlugin(UnnamedPlugin):
- plugin_name = 'mock'
-
logdir = '/fake/foo'
multiplexer = event_multiplexer.EventMultiplexer()
# Fails if there is an unnamed plugin
with self.assertRaises(ValueError):
- plugins = [UnnamedPlugin()]
+ # This plugin lacks a name.
+ plugins = [FakePlugin(plugin_name=None, is_active_value=True)]
application.TensorBoardWSGIApp(logdir, plugins, multiplexer, 0)
# Fails if there are two plugins with same name
with self.assertRaises(ValueError):
- plugins = [MockPlugin(), OtherMockPlugin()]
+ plugins = [
+ FakePlugin(plugin_name='foo', is_active_value=True),
+ FakePlugin(plugin_name='foo', is_active_value=True),
+ ]
application.TensorBoardWSGIApp(logdir, plugins, multiplexer, 0)
diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator.py
index beba28da06..d5a91bbb6a 100644
--- a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py
+++ b/tensorflow/tensorboard/backend/event_processing/event_accumulator.py
@@ -438,6 +438,14 @@ class EventAccumulator(object):
"""
return self._health_pills.Items(node_name)
+ def GetOpsWithHealthPills(self):
+ """Determines which ops have at least 1 health pill event.
+
+ Returns:
+ A list of names of ops with at least 1 health pill event.
+ """
+ return self._health_pills.Keys()
+
def Graph(self):
"""Return the graph definition, if there is one.
diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py
index 38a8cd915f..3734e470b6 100644
--- a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py
+++ b/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py
@@ -297,8 +297,6 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
acc = ea.EventAccumulator(gen)
gen.AddHealthPill(13371337, 41, 'Add', 0, range(1, 13))
gen.AddHealthPill(13381338, 42, 'Add', 1, range(42, 54))
-
- acc = ea.EventAccumulator(gen)
acc.Reload()
# Retrieve the health pills for each node name.
@@ -321,6 +319,14 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
value=range(42, 54)),
gotten_events[1])
+ def testGetOpsWithHealthPills(self):
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddHealthPill(13371337, 41, 'Add', 0, range(1, 13))
+ gen.AddHealthPill(13381338, 42, 'MatMul', 1, range(42, 54))
+ acc.Reload()
+ self.assertItemsEqual(['Add', 'MatMul'], acc.GetOpsWithHealthPills())
+
def testHistograms(self):
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
diff --git a/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py b/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py
index bbf958820a..08e6dbb57d 100644
--- a/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py
+++ b/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py
@@ -287,6 +287,21 @@ class EventMultiplexer(object):
accumulator = self._GetAccumulator(run)
return accumulator.HealthPills(node_name)
+ def GetOpsWithHealthPills(self, run):
+ """Determines which ops have at least 1 health pill event for a given run.
+
+ Args:
+ run: The name of the run.
+
+ Raises:
+ KeyError: If the run is not found, or the node name is not available for
+ the given run.
+
+ Returns:
+ The list of names of ops with health pill events.
+ """
+ return self._GetAccumulator(run).GetOpsWithHealthPills()
+
def Graph(self, run):
"""Retrieve the graph associated with the provided run.
diff --git a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py b/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py
index ed5cac4014..ded1856d7e 100644
--- a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py
+++ b/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import os
import os.path
import shutil
@@ -45,10 +46,16 @@ def _CreateCleanDirectory(path):
class _FakeAccumulator(object):
- def __init__(self, path):
+ def __init__(self, path, health_pill_mapping=None):
+ """Constructs a fake accumulator with some fake events.
+
+ Args:
+ path: The path for the run that this accumulator is for.
+ health_pill_mapping: An optional mapping from Op to health pill strings.
+ """
self._path = path
self.reload_called = False
- self._node_names_to_health_pills = {'Add': ['hp1', 'hp2']}
+ self._node_names_to_health_pills = health_pill_mapping or {}
def Tags(self):
return {event_accumulator.IMAGES: ['im1', 'im2'],
@@ -74,6 +81,9 @@ class _FakeAccumulator(object):
health_pills = self._node_names_to_health_pills[node_name]
return [self._path + '/' + health_pill for health_pill in health_pills]
+ def GetOpsWithHealthPills(self):
+ return self._node_names_to_health_pills.keys()
+
def Histograms(self, tag_name):
return self._TagHelper(tag_name, event_accumulator.HISTOGRAMS)
@@ -93,14 +103,13 @@ class _FakeAccumulator(object):
self.reload_called = True
-# pylint: disable=unused-argument
-def _GetFakeAccumulator(
- path,
- size_guidance=None,
- compression_bps=None,
- purge_orphaned_data=None):
- return _FakeAccumulator(path)
-# pylint: enable=unused-argument
+def _GetFakeAccumulator(path,
+ size_guidance=None,
+ compression_bps=None,
+ purge_orphaned_data=None,
+ health_pill_mapping=None):
+ del size_guidance, compression_bps, purge_orphaned_data # Unused.
+ return _FakeAccumulator(path, health_pill_mapping=health_pill_mapping)
class EventMultiplexerTest(test_util.TensorFlowTestCase):
@@ -141,9 +150,27 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase):
self.assertEqual(run1_expected, run1_actual)
def testHealthPills(self):
+ self.stubs.Set(event_accumulator, 'EventAccumulator',
+ functools.partial(
+ _GetFakeAccumulator,
+ health_pill_mapping={'Add': ['hp1', 'hp2']}))
x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
self.assertEqual(['path1/hp1', 'path1/hp2'], x.HealthPills('run1', 'Add'))
+ def testGetOpsWithHealthPillsWhenHealthPillsAreNotAvailable(self):
+ # The event accumulator lacks health pills for the run.
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ self.assertItemsEqual([], x.GetOpsWithHealthPills('run1'))
+
+ def testGetOpsWithHealthPillsWhenHealthPillsAreAvailable(self):
+ # The event accumulator has health pills for the run.
+ self.stubs.Set(event_accumulator, 'EventAccumulator',
+ functools.partial(
+ _GetFakeAccumulator,
+ health_pill_mapping={'Add': ['hp1', 'hp2']}))
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ self.assertItemsEqual(['Add'], x.GetOpsWithHealthPills('run1'))
+
def testExceptions(self):
x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
with self.assertRaises(KeyError):
diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html
index dbc1dc5c5f..c90efac1d6 100644
--- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html
+++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html
@@ -34,10 +34,9 @@ Display a warning when there is no data found.
and pass the graph either via the constructor, or by calling its
<code>add_graph()</code> method.
You may want to check out the
- <a href="https://www.tensorflow.org/versions/master/how_tos/graph_viz/index.html">
+ <a href="https://www.tensorflow.org/get_started/graph_viz">
graph visualizer tutorial
- </a>
- .
+ </a>.
</p>
</template>
<template is="dom-if" if="[[_isProjector(dataType)]]">
@@ -53,7 +52,7 @@ Display a warning when there is no data found.
<li>
You are not saving any checkpoint. To save your model,
create a
- <a href="https://www.tensorflow.org/versions/master/api_docs/python/state_ops.html#Saver">
+ <a href="https://www.tensorflow.org/api_docs/python/tf/train/Saver">
<code>tf.train.Saver</code>
</a>
and save your model periodically
@@ -86,7 +85,7 @@ Display a warning when there is no data found.
README
</a>
and perhaps the
- <a href="https://www.tensorflow.org/versions/master/how_tos/summaries_and_tensorboard/index.html">
+ <a href="https://www.tensorflow.org/get_started/summaries_and_tensorboard">
TensorBoard tutorial
</a>.
</p>
diff --git a/tensorflow/tensorboard/http_api.md b/tensorflow/tensorboard/http_api.md
index 16c2f95ae1..00aeb6353e 100644
--- a/tensorflow/tensorboard/http_api.md
+++ b/tensorflow/tensorboard/http_api.md
@@ -36,6 +36,13 @@ Returns a JSON object with a key "logdir" that maps to the `logdir` argument
The `logdir` argument is the path of the directory that contains events files.
+## `data/plugins_listing`
+
+Returns a dict mapping from plugin name to a boolean indicating whether the
+plugin is active. A plugin might be inactive, for instance, if it lacks relevant
+data. Every plugin has a key. This route helps the frontend avoid issuing
+requests to an inactive plugin - the routes of an inactive plugin do not work.
+
## `data/runs`
Returns a dictionary mapping from `run name` (quoted string) to dictionaries
diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json
index ca6a9e89ce..5dcf2f21e9 100644
--- a/tensorflow/tensorboard/package.json
+++ b/tensorflow/tensorboard/package.json
@@ -30,7 +30,7 @@
"merge2": "~0.3.6",
"minimist": "~1.2.0",
"tsify": "^0.14.8",
- "typescript": "2.1.5",
+ "typescript": "2.2.2",
"typings": "1.4.0",
"vinyl-source-stream": "^1.1.0",
"vulcanize": "^1.14.0",
diff --git a/tensorflow/tensorboard/plugins/base_plugin.py b/tensorflow/tensorboard/plugins/base_plugin.py
index 8b1560cf8a..259046dfb4 100644
--- a/tensorflow/tensorboard/plugins/base_plugin.py
+++ b/tensorflow/tensorboard/plugins/base_plugin.py
@@ -51,3 +51,15 @@ class TBPlugin(object):
A dict mapping route paths to WSGI applications.
"""
raise NotImplementedError()
+
+ @abstractmethod
+ def is_active(self):
+ """Determines whether this plugin is active.
+
+ A plugin may not be active for instance if it lacks relevant data. If a
+ plugin is inactive, the frontend may avoid issuing requests to its routes.
+
+ Returns:
+ A boolean value. Whether this plugin is active.
+ """
+ raise NotImplementedError()
diff --git a/tensorflow/tensorboard/plugins/debugger/debugger_plugin.py b/tensorflow/tensorboard/plugins/debugger/debugger_plugin.py
index cfa8f68187..5d34bb91db 100644
--- a/tensorflow/tensorboard/plugins/debugger/debugger_plugin.py
+++ b/tensorflow/tensorboard/plugins/debugger/debugger_plugin.py
@@ -82,6 +82,21 @@ class DebuggerPlugin(base_plugin.TBPlugin):
_HEALTH_PILLS_ROUTE: self._serve_health_pills_handler,
}
+ def is_active(self):
+ """Determines whether this plugin is active.
+
+ This plugin is active if any health pills information is present for any
+ run. This method must be called only after get_plugin_apps has been called.
+
+ Returns:
+ A boolean. Whether this plugin is active.
+ """
+ for run_name in self._event_multiplexer.Runs():
+ if self._event_multiplexer.GetOpsWithHealthPills(run_name):
+ return True
+
+ return False
+
@wrappers.Request.application
def _serve_health_pills_handler(self, request):
"""A (wrapped) werkzeug handler for serving health pills.
diff --git a/tensorflow/tensorboard/plugins/debugger/debugger_plugin_test.py b/tensorflow/tensorboard/plugins/debugger/debugger_plugin_test.py
index 2c9135fd27..f1cc2e06da 100644
--- a/tensorflow/tensorboard/plugins/debugger/debugger_plugin_test.py
+++ b/tensorflow/tensorboard/plugins/debugger/debugger_plugin_test.py
@@ -146,6 +146,19 @@ class DebuggerPluginTest(test.TestCase):
self.assertIn('/health_pills', apps)
self.assertIsInstance(apps['/health_pills'], collections.Callable)
+ def testHealthPillsPluginIsActive(self):
+ self.plugin.get_plugin_apps(self.multiplexer, self.log_dir)
+
+ # The multiplexer has sampled health pills.
+ self.assertTrue(self.plugin.is_active())
+
+ def testHealthPillsPluginIsInactive(self):
+ self.plugin.get_plugin_apps(
+ event_multiplexer.EventMultiplexer({}), self.log_dir)
+
+ # The multiplexer lacks sampled health pills.
+ self.assertFalse(self.plugin.is_active())
+
def testRequestHealthPillsForRunFoo(self):
"""Tests that the plugin produces health pills for a specified run."""
response = self.server.post(
diff --git a/tensorflow/tensorboard/plugins/projector/projector_plugin.py b/tensorflow/tensorboard/plugins/projector/projector_plugin.py
index 32ebb78e42..001c6e1e35 100644
--- a/tensorflow/tensorboard/plugins/projector/projector_plugin.py
+++ b/tensorflow/tensorboard/plugins/projector/projector_plugin.py
@@ -45,6 +45,8 @@ from tensorflow.tensorboard.plugins.projector import projector_config_pb2
_PLUGIN_PREFIX_ROUTE = 'projector'
PROJECTOR_FILENAME = 'projector_config.pbtxt'
+_PLUGIN_NAME = 'org_tensorflow_tensorboard_projector'
+_PLUGINS_DIR = 'plugins'
# HTTP routes.
CONFIG_ROUTE = '/info'
@@ -112,7 +114,7 @@ class EmbeddingMetadata(object):
class ProjectorPluginAsset(plugin_asset.PluginAsset):
"""Provides a registry for assets needed by the Projector plugin."""
- plugin_name = 'org_tensorflow_tensorboard_projector'
+ plugin_name = _PLUGIN_NAME
def __init__(self):
self._config = projector_config_pb2.ProjectorConfig()
@@ -259,12 +261,20 @@ def _read_tensor_file(fpath):
return np.array(tensor, dtype='float32')
+def _assets_dir_to_logdir(assets_dir):
+ sub_path = os.path.sep + _PLUGINS_DIR + os.path.sep
+ if sub_path in assets_dir:
+ two_parents_up = os.pardir + os.path.sep + os.pardir
+ return os.path.abspath(os.path.join(assets_dir, two_parents_up))
+ return assets_dir
+
+
def _latest_checkpoints_changed(configs, run_path_pairs):
"""Returns true if the latest checkpoint has changed in any of the runs."""
- for run_name, logdir in run_path_pairs:
+ for run_name, assets_dir in run_path_pairs:
if run_name not in configs:
config = projector_config_pb2.ProjectorConfig()
- config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
+ config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME)
if file_io.file_exists(config_fpath):
file_content = file_io.read_file_to_string(config_fpath)
text_format.Merge(file_content, config)
@@ -272,6 +282,7 @@ def _latest_checkpoints_changed(configs, run_path_pairs):
config = configs[run_name]
# See if you can find a checkpoint file in the logdir.
+ logdir = _assets_dir_to_logdir(assets_dir)
ckpt_path = _find_latest_checkpoint(logdir)
if not ckpt_path:
continue
@@ -302,6 +313,12 @@ def _parse_positive_int_param(request, param_name):
return -1
+def _rel_to_abs_asset_path(fpath, config_fpath):
+ if not os.path.isabs(fpath):
+ return os.path.join(os.path.dirname(config_fpath), fpath)
+ return fpath
+
+
class ProjectorPlugin(TBPlugin):
"""Embedding projector."""
@@ -314,8 +331,10 @@ class ProjectorPlugin(TBPlugin):
self.logdir = None
self._configs = None
self.old_num_run_paths = None
+ self.multiplexer = None
def get_plugin_apps(self, multiplexer, logdir):
+ self.multiplexer = multiplexer
self.run_paths = multiplexer.RunPaths()
self.logdir = logdir
self._handlers = {
@@ -328,10 +347,21 @@ class ProjectorPlugin(TBPlugin):
}
return self._handlers
+ def is_active(self):
+ """Determines whether this plugin is active.
+
+ This plugin is only active if any run has an embedding.
+
+ Returns:
+ A boolean. Whether this plugin is active.
+ """
+ return bool(self.configs)
+
@property
def configs(self):
"""Returns a map of run paths to `ProjectorConfig` protos."""
run_path_pairs = list(self.run_paths.items())
+ self._append_plugin_asset_directories(run_path_pairs)
# If there are no summary event files, the projector should still work,
# treating the `logdir` as the model checkpoint directory.
if not run_path_pairs:
@@ -359,7 +389,9 @@ class ProjectorPlugin(TBPlugin):
embedding.tensor_name = embedding.tensor_name[:-2]
# Find the size of embeddings associated with a tensors file.
if embedding.tensor_path and not embedding.tensor_shape:
- tensor = _read_tensor_file(embedding.tensor_path)
+ fpath = _rel_to_abs_asset_path(embedding.tensor_path,
+ self.config_fpaths[run])
+ tensor = _read_tensor_file(fpath)
embedding.tensor_shape.extend([len(tensor), len(tensor[0])])
reader = self._get_reader_for_run(run)
@@ -397,13 +429,12 @@ class ProjectorPlugin(TBPlugin):
"""Reads and returns the projector config files in every run directory."""
configs = {}
config_fpaths = {}
- for run_name, logdir in run_path_pairs:
+ for run_name, assets_dir in run_path_pairs:
config = projector_config_pb2.ProjectorConfig()
- config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
+ config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME)
if file_io.file_exists(config_fpath):
file_content = file_io.read_file_to_string(config_fpath)
text_format.Merge(file_content, config)
-
has_tensor_files = False
for embedding in config.embeddings:
if embedding.tensor_path:
@@ -412,6 +443,7 @@ class ProjectorPlugin(TBPlugin):
if not config.model_checkpoint_path:
# See if you can find a checkpoint file in the logdir.
+ logdir = _assets_dir_to_logdir(assets_dir)
ckpt_path = _find_latest_checkpoint(logdir)
if not ckpt_path and not has_tensor_files:
continue
@@ -421,7 +453,7 @@ class ProjectorPlugin(TBPlugin):
# Sanity check for the checkpoint file.
if (config.model_checkpoint_path and
not checkpoint_exists(config.model_checkpoint_path)):
- logging.warning('Checkpoint file %s not found',
+ logging.warning('Checkpoint file "%s" not found',
config.model_checkpoint_path)
continue
configs[run_name] = config
@@ -438,7 +470,7 @@ class ProjectorPlugin(TBPlugin):
try:
reader = NewCheckpointReader(config.model_checkpoint_path)
except Exception: # pylint: disable=broad-except
- logging.warning('Failed reading %s', config.model_checkpoint_path)
+ logging.warning('Failed reading "%s"', config.model_checkpoint_path)
self.readers[run] = reader
return reader
@@ -469,6 +501,12 @@ class ProjectorPlugin(TBPlugin):
return info
return None
+ def _append_plugin_asset_directories(self, run_path_pairs):
+ for run in self.multiplexer.PluginAssets(_PLUGIN_NAME):
+ assets_dir = os.path.join(self.run_paths[run], _PLUGINS_DIR, _PLUGIN_NAME)
+ assets_path_pair = (run, os.path.abspath(assets_dir))
+ run_path_pairs.append(assets_path_pair)
+
@wrappers.Request.application
def _serve_runs(self, request):
"""Returns a list of runs that have embeddings."""
@@ -481,7 +519,7 @@ class ProjectorPlugin(TBPlugin):
return Respond(request, 'query parameter "run" is required', 'text/plain',
400)
if run not in self.configs:
- return Respond(request, 'Unknown run: %s' % run, 'text/plain', 400)
+ return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400)
config = self.configs[run]
return Respond(request,
@@ -505,17 +543,19 @@ class ProjectorPlugin(TBPlugin):
'text/plain', 400)
if run not in self.configs:
- return Respond(request, 'Unknown run: %s' % run, 'text/plain', 400)
+ return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400)
config = self.configs[run]
fpath = self._get_metadata_file_for_tensor(name, config)
if not fpath:
return Respond(
request,
- 'No metadata file found for tensor %s in the config file %s' %
+ 'No metadata file found for tensor "%s" in the config file "%s"' %
(name, self.config_fpaths[run]), 'text/plain', 400)
+ fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run])
if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
- return Respond(request, '%s is not a file' % fpath, 'text/plain', 400)
+ return Respond(request, '"%s" not found, or is not a file' % fpath,
+ 'text/plain', 400)
num_header_rows = 0
with file_io.FileIO(fpath, 'r') as f:
@@ -548,26 +588,24 @@ class ProjectorPlugin(TBPlugin):
'text/plain', 400)
if run not in self.configs:
- return Respond(request, 'Unknown run: %s' % run, 'text/plain', 400)
+ return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400)
- reader = self._get_reader_for_run(run)
config = self.configs[run]
- if reader is None:
- # See if there is a tensor file in the config.
- embedding = self._get_embedding(name, config)
- if not embedding or not embedding.tensor_path:
+ # See if there is a tensor file in the config.
+ embedding = self._get_embedding(name, config)
+ if embedding and embedding.tensor_path:
+ fpath = _rel_to_abs_asset_path(embedding.tensor_path,
+ self.config_fpaths[run])
+ if not file_io.file_exists(fpath):
return Respond(request,
- 'Tensor %s has no tensor_path in the config' % name,
+ 'Tensor file "%s" does not exist' % fpath,
'text/plain', 400)
- if not file_io.file_exists(embedding.tensor_path):
- return Respond(request,
- 'Tensor file %s does not exist' % embedding.tensor_path,
- 'text/plain', 400)
- tensor = _read_tensor_file(embedding.tensor_path)
+ tensor = _read_tensor_file(fpath)
else:
- if not reader.has_tensor(name):
- return Respond(request, 'Tensor %s not found in checkpoint dir %s' %
+ reader = self._get_reader_for_run(run)
+ if not reader or not reader.has_tensor(name):
+ return Respond(request, 'Tensor "%s" not found in checkpoint dir "%s"' %
(name, config.model_checkpoint_path), 'text/plain', 400)
try:
tensor = reader.get_tensor(name)
@@ -595,17 +633,19 @@ class ProjectorPlugin(TBPlugin):
'text/plain', 400)
if run not in self.configs:
- return Respond(request, 'Unknown run: %s' % run, 'text/plain', 400)
+ return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400)
config = self.configs[run]
fpath = self._get_bookmarks_file_for_tensor(name, config)
if not fpath:
return Respond(
request,
- 'No bookmarks file found for tensor %s in the config file %s' %
+ 'No bookmarks file found for tensor "%s" in the config file "%s"' %
(name, self.config_fpaths[run]), 'text/plain', 400)
+ fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run])
if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
- return Respond(request, '%s is not a file' % fpath, 'text/plain', 400)
+ return Respond(request, '"%s" not found, or is not a file' % fpath,
+ 'text/plain', 400)
bookmarks_json = None
with file_io.FileIO(fpath, 'rb') as f:
@@ -625,7 +665,7 @@ class ProjectorPlugin(TBPlugin):
'text/plain', 400)
if run not in self.configs:
- return Respond(request, 'Unknown run: %s' % run, 'text/plain', 400)
+ return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400)
config = self.configs[run]
embedding_info = self._get_embedding(name, config)
@@ -633,12 +673,13 @@ class ProjectorPlugin(TBPlugin):
if not embedding_info or not embedding_info.sprite.image_path:
return Respond(
request,
- 'No sprite image file found for tensor %s in the config file %s' %
+ 'No sprite image file found for tensor "%s" in the config file "%s"' %
(name, self.config_fpaths[run]), 'text/plain', 400)
fpath = os.path.expanduser(embedding_info.sprite.image_path)
+ fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run])
if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
- return Respond(request, '%s does not exist or is directory' % fpath,
+ return Respond(request, '"%s" does not exist or is directory' % fpath,
'text/plain', 400)
f = file_io.FileIO(fpath, 'rb')
encoded_image_string = f.read()
diff --git a/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py b/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py
index 5679eff4a3..9e2e7159d8 100644
--- a/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py
+++ b/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py
@@ -55,7 +55,7 @@ class ProjectorAppTest(test.TestCase):
self._GenerateProjectorTestData()
self._SetupWSGIApp()
run_json = self._GetJson('/data/plugin/projector/runs')
- self.assertEqual(run_json, ['.'])
+ self.assertTrue(run_json)
def testRunsWithNoCheckpoint(self):
self._SetupWSGIApp()
@@ -73,6 +73,19 @@ class ProjectorAppTest(test.TestCase):
run_json = self._GetJson('/data/plugin/projector/runs')
self.assertEqual(run_json, [])
+ def testRunsWithInvalidModelCheckpointPathInConfig(self):
+ config_path = os.path.join(self.log_dir, 'projector_config.pbtxt')
+ config = projector_config_pb2.ProjectorConfig()
+ config.model_checkpoint_path = 'does_not_exist'
+ embedding = config.embeddings.add()
+ embedding.tensor_name = 'var1'
+ with gfile.GFile(config_path, 'w') as f:
+ f.write(text_format.MessageToString(config))
+ self._SetupWSGIApp()
+
+ run_json = self._GetJson('/data/plugin/projector/runs')
+ self.assertEqual(run_json, [])
+
def testInfoWithValidCheckpoint(self):
self._GenerateProjectorTestData()
self._SetupWSGIApp()
@@ -80,7 +93,8 @@ class ProjectorAppTest(test.TestCase):
info_json = self._GetJson('/data/plugin/projector/info?run=.')
self.assertItemsEqual(info_json['embeddings'], [{
'tensorShape': [1, 2],
- 'tensorName': 'var1'
+ 'tensorName': 'var1',
+ 'bookmarksPath': 'bookmarks.json'
}, {
'tensorShape': [10, 10],
'tensorName': 'var2'
@@ -95,17 +109,286 @@ class ProjectorAppTest(test.TestCase):
url = '/data/plugin/projector/tensor?run=.&name=var1'
tensor_bytes = self._Get(url).data
- tensor = np.reshape(np.fromstring(tensor_bytes, dtype='float32'), [1, 2])
- expected_tensor = np.array([[6, 6]], dtype='float32')
+ expected_tensor = np.array([[6, 6]], dtype=np.float32)
+ self._AssertTensorResponse(tensor_bytes, expected_tensor)
+
+ def testBookmarksRequestMissingRunAndName(self):
+ self._GenerateProjectorTestData()
+ self._SetupWSGIApp()
+
+ url = '/data/plugin/projector/bookmarks'
+ self.assertEqual(self._Get(url).status_code, 400)
+
+ def testBookmarksRequestMissingName(self):
+ self._GenerateProjectorTestData()
+ self._SetupWSGIApp()
+
+ url = '/data/plugin/projector/bookmarks?run=.'
+ self.assertEqual(self._Get(url).status_code, 400)
+
+ def testBookmarksRequestMissingRun(self):
+ self._GenerateProjectorTestData()
+ self._SetupWSGIApp()
+
+ url = '/data/plugin/projector/bookmarks?name=var1'
+ self.assertEqual(self._Get(url).status_code, 400)
+
+ def testBookmarksUnknownRun(self):
+ self._GenerateProjectorTestData()
+ self._SetupWSGIApp()
+
+ url = '/data/plugin/projector/bookmarks?run=unknown&name=var1'
+ self.assertEqual(self._Get(url).status_code, 400)
+
+ def testBookmarksUnknownName(self):
+ self._GenerateProjectorTestData()
+ self._SetupWSGIApp()
+
+ url = '/data/plugin/projector/bookmarks?run=.&name=unknown'
+ self.assertEqual(self._Get(url).status_code, 400)
+
+ def testBookmarks(self):
+ self._GenerateProjectorTestData()
+ self._SetupWSGIApp()
+
+ url = '/data/plugin/projector/bookmarks?run=.&name=var1'
+ bookmark = self._GetJson(url)
+ self.assertEqual(bookmark, {'a': 'b'})
+
+ def testEndpointsNoAssets(self):
+ g = ops.Graph()
+ with g.as_default():
+ plugin_asset.get_plugin_asset(projector_plugin.ProjectorPluginAsset)
+
+ fw = writer.FileWriter(self.log_dir, graph=g)
+ fw.close()
+
+ self._SetupWSGIApp()
+ run_json = self._GetJson('/data/plugin/projector/runs')
+ self.assertEqual(run_json, [])
+
+ def testEndpointsMetadataForVariableAssets(self):
+ self._GenerateProjectorTestData()
+ g = ops.Graph()
+ with g.as_default():
+ manager = plugin_asset.get_plugin_asset(
+ projector_plugin.ProjectorPluginAsset)
+
+ metadata = projector_plugin.EmbeddingMetadata(3)
+ metadata.add_column('labels', ['a', 'b', 'c'])
+ manager.add_metadata_for_embedding_variable('test', metadata)
+
+ fw = writer.FileWriter(self.log_dir, graph=g)
+ fw.close()
+
+ self._SetupWSGIApp()
+ run_json = self._GetJson('/data/plugin/projector/runs')
+ self.assertTrue(run_json)
+
+ run = run_json[0]
+ metedata_query = '/data/plugin/projector/metadata?run=%s&name=test' % run
+ metadata_tsv = self._Get(metedata_query).data
+ self.assertEqual(metadata_tsv, b'a\nb\nc\n')
+
+ unk_tensor_query = '/data/plugin/projector/tensor?run=%s&name=test' % run
+ response = self._Get(unk_tensor_query)
+ self.assertEqual(response.status_code, 400)
+
+ expected_tensor = np.array([[6, 6]], dtype=np.float32)
+ tensor_query = '/data/plugin/projector/tensor?run=%s&name=var1' % run
+ tensor_bytes = self._Get(tensor_query).data
+ self._AssertTensorResponse(tensor_bytes, expected_tensor)
+
+ def testEndpointsMetadataForVariableAssetsButNoCheckpoint(self):
+ g = ops.Graph()
+ with g.as_default():
+ manager = plugin_asset.get_plugin_asset(
+ projector_plugin.ProjectorPluginAsset)
+
+ metadata = projector_plugin.EmbeddingMetadata(3)
+ metadata.add_column('labels', ['a', 'b', 'c'])
+ manager.add_metadata_for_embedding_variable('test', metadata)
+
+ fw = writer.FileWriter(self.log_dir, graph=g)
+ fw.close()
+
+ self._SetupWSGIApp()
+ run_json = self._GetJson('/data/plugin/projector/runs')
+ self.assertEqual(run_json, [])
+
+ def testEndpointsTensorAndMetadataAssets(self):
+ g = ops.Graph()
+ with g.as_default():
+ manager = plugin_asset.get_plugin_asset(
+ projector_plugin.ProjectorPluginAsset)
+
+ metadata = projector_plugin.EmbeddingMetadata(3)
+ metadata.add_column('labels', ['a', 'b', 'c'])
+ manager.add_metadata_for_embedding_variable('test', metadata)
+ expected_tensor = np.array([[1, 2], [3, 4], [5, 6]])
+ image1 = np.array([[[1, 2, 3], [4, 5, 6]],
+ [[7, 8, 9], [10, 11, 12]]])
+ image2 = np.array([[[10, 20, 30], [40, 50, 60]],
+ [[70, 80, 90], [100, 110, 120]]])
+ manager.add_embedding('emb', expected_tensor, metadata, [image1, image2],
+ [2, 2])
+
+ fw = writer.FileWriter(self.log_dir, graph=g)
+ fw.close()
+
+ self._SetupWSGIApp()
+ run_json = self._GetJson('/data/plugin/projector/runs')
+ self.assertTrue(run_json)
+
+ run = run_json[0]
+ metadata_query = '/data/plugin/projector/metadata?run=%s&name=emb' % run
+ metadata_tsv = self._Get(metadata_query).data
+ self.assertEqual(metadata_tsv, b'a\nb\nc\n')
+
+ unk_metadata_query = '/data/plugin/projector/metadata?run=%s&name=q' % run
+ response = self._Get(unk_metadata_query)
+ self.assertEqual(response.status_code, 400)
+
+ tensor_query = '/data/plugin/projector/tensor?run=%s&name=emb' % run
+ tensor_bytes = self._Get(tensor_query).data
+ self._AssertTensorResponse(tensor_bytes, expected_tensor)
+
+ unk_tensor_query = '/data/plugin/projector/tensor?run=%s&name=var1' % run
+ response = self._Get(unk_tensor_query)
+ self.assertEqual(response.status_code, 400)
+
+ image_query = '/data/plugin/projector/sprite_image?run=%s&name=emb' % run
+ image_bytes = self._Get(image_query).data
+ with ops.Graph().as_default():
+ s = session.Session()
+ image_array = image_ops.decode_png(image_bytes).eval(session=s).tolist()
+ expected_sprite_image = [
+ [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]],
+ [[7, 8, 9], [10, 11, 12], [70, 80, 90], [100, 110, 120]],
+ [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
+ [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]
+ ]
+ self.assertEqual(image_array, expected_sprite_image)
+
+ def testSpriteImageRequestMissingRunAndName(self):
+ self._SetupWSGIApp()
+ q = '/data/plugin/projector/sprite_image'
+ response = self._Get(q)
+ self.assertEqual(response.status_code, 400)
+
+ def testSpriteImageRequestMissingName(self):
+ self._SetupWSGIApp()
+ q = '/data/plugin/projector/sprite_image?run=.'
+ response = self._Get(q)
+ self.assertEqual(response.status_code, 400)
+
+ def testSpriteImageRequestMissingRun(self):
+ self._SetupWSGIApp()
+ q = '/data/plugin/projector/sprite_image?name=emb'
+ response = self._Get(q)
+ self.assertEqual(response.status_code, 400)
+
+ def testSpriteImageUnknownRun(self):
+ self._GenerateProjectorTestData()
+ g = ops.Graph()
+ with g.as_default():
+ manager = plugin_asset.get_plugin_asset(
+ projector_plugin.ProjectorPluginAsset)
+ image1 = np.array([[[1, 2, 3], [4, 5, 6]],
+ [[7, 8, 9], [10, 11, 12]]])
+ image2 = np.array([[[10, 20, 30], [40, 50, 60]],
+ [[70, 80, 90], [100, 110, 120]]])
+ manager.add_metadata_for_embedding_variable('var1',
+ thumbnails=[image1, image2],
+ thumbnail_dim=[2, 2])
+ fw = writer.FileWriter(self.log_dir, graph=g)
+ fw.close()
+ self._SetupWSGIApp()
+
+ q = '/data/plugin/projector/sprite_image?run=unknown&name=var1'
+ response = self._Get(q)
+ self.assertEqual(response.status_code, 400)
+
+ def testSpriteImageUnknownName(self):
+ self._GenerateProjectorTestData()
+ g = ops.Graph()
+ with g.as_default():
+ manager = plugin_asset.get_plugin_asset(
+ projector_plugin.ProjectorPluginAsset)
+ image1 = np.array([[[1, 2, 3], [4, 5, 6]],
+ [[7, 8, 9], [10, 11, 12]]])
+ image2 = np.array([[[10, 20, 30], [40, 50, 60]],
+ [[70, 80, 90], [100, 110, 120]]])
+ manager.add_metadata_for_embedding_variable('var1',
+ thumbnails=[image1, image2],
+ thumbnail_dim=[2, 2])
+ fw = writer.FileWriter(self.log_dir, graph=g)
+ fw.close()
+ self._SetupWSGIApp()
+ q = '/data/plugin/projector/sprite_image?run=.&name=unknown'
+ response = self._Get(q)
+ self.assertEqual(response.status_code, 400)
+
+ def testEndpointsComboTensorAssetsAndCheckpoint(self):
+ self._GenerateProjectorTestData()
+ g = ops.Graph()
+ with g.as_default():
+ manager = plugin_asset.get_plugin_asset(
+ projector_plugin.ProjectorPluginAsset)
+
+ metadata = projector_plugin.EmbeddingMetadata(3)
+ metadata.add_column('labels', ['a', 'b', 'c'])
+ manager.add_metadata_for_embedding_variable('var1', metadata)
+
+ new_tensor_values = np.array([[1, 2], [3, 4], [5, 6]])
+ manager.add_embedding('new_tensor', new_tensor_values)
+
+ fw = writer.FileWriter(self.log_dir, graph=g)
+ fw.close()
+
+ self._SetupWSGIApp()
+ run_json = self._GetJson('/data/plugin/projector/runs')
+ self.assertTrue(run_json)
+
+ run = run_json[0]
+ var1_values = np.array([[6, 6]], dtype=np.float32)
+ var1_tensor_query = '/data/plugin/projector/tensor?run=%s&name=var1' % run
+ tensor_bytes = self._Get(var1_tensor_query).data
+ self._AssertTensorResponse(tensor_bytes, var1_values)
+
+ metadata_query = '/data/plugin/projector/metadata?run=%s&name=var1' % run
+ metadata_tsv = self._Get(metadata_query).data
+ self.assertEqual(metadata_tsv, b'a\nb\nc\n')
+
+ tensor_query = '/data/plugin/projector/tensor?run=%s&name=new_tensor' % run
+ tensor_bytes = self._Get(tensor_query).data
+ self._AssertTensorResponse(tensor_bytes, new_tensor_values)
+
+ def _AssertTensorResponse(self, tensor_bytes, expected_tensor):
+ tensor = np.reshape(np.fromstring(tensor_bytes, dtype=np.float32),
+ expected_tensor.shape)
self.assertTrue(np.array_equal(tensor, expected_tensor))
+ def testPluginIsActive(self):
+ self._GenerateProjectorTestData()
+ self._SetupWSGIApp()
+
+ # Embedding data is available.
+ self.assertTrue(self.plugin.is_active())
+
+ def testPluginIsNotActive(self):
+ self._SetupWSGIApp()
+
+ # Embedding data is not available.
+ self.assertFalse(self.plugin.is_active())
+
def _SetupWSGIApp(self):
multiplexer = event_multiplexer.EventMultiplexer(
size_guidance=application.DEFAULT_SIZE_GUIDANCE,
purge_orphaned_data=True)
- plugin = projector_plugin.ProjectorPlugin()
+ self.plugin = projector_plugin.ProjectorPlugin()
wsgi_app = application.TensorBoardWSGIApp(
- self.log_dir, [plugin], multiplexer, reload_interval=0)
+ self.log_dir, [self.plugin], multiplexer, reload_interval=0)
self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse)
def _Get(self, path):
@@ -124,6 +407,11 @@ class ProjectorAppTest(test.TestCase):
embedding = config.embeddings.add()
# Add an embedding by its canonical tensor name.
embedding.tensor_name = 'var1:0'
+
+ with gfile.GFile(os.path.join(self.log_dir, 'bookmarks.json'), 'w') as f:
+ f.write('{"a": "b"}')
+ embedding.bookmarks_path = 'bookmarks.json'
+
config_pbtxt = text_format.MessageToString(config)
with gfile.GFile(config_path, 'w') as f:
f.write(config_pbtxt)
@@ -342,6 +630,30 @@ class ProjectorPluginAssetTest(test.TestCase):
'test', np.array([[1], [2], [3]]), thumbnails=thumbnails,
thumbnail_dim=[4])
+ def testAddEmbeddingThumbnailListHasNoEntries(self):
+ manager = plugin_asset.get_plugin_asset(
+ projector_plugin.ProjectorPluginAsset)
+
+ with self.assertRaises(ValueError):
+ manager.add_embedding('test', np.array([[1]]), thumbnails=[],
+ thumbnail_dim=[1, 1])
+
+ def testAddEmbeddingThumbnailListNotOfRank4(self):
+ manager = plugin_asset.get_plugin_asset(
+ projector_plugin.ProjectorPluginAsset)
+
+ with self.assertRaises(ValueError):
+ manager.add_embedding('test2', np.array([[1]]),
+ thumbnails=np.array([[1]]), thumbnail_dim=[1, 1])
+
+ def testAddEmbeddingThumbnailListEntriesNot3DTensors(self):
+ manager = plugin_asset.get_plugin_asset(
+ projector_plugin.ProjectorPluginAsset)
+
+ with self.assertRaises(ValueError):
+ manager.add_embedding('test3', np.array([[1]]), thumbnails=[[1, 2, 3]],
+ thumbnail_dim=[1, 1])
+
def testAddEmbeddingWithMetadataOfIncorrectLength(self):
manager = plugin_asset.get_plugin_asset(
projector_plugin.ProjectorPluginAsset)
@@ -392,8 +704,8 @@ class ProjectorPluginAssetTest(test.TestCase):
with ops.Graph().as_default() as g:
plugin_asset.get_plugin_asset(projector_plugin.ProjectorPluginAsset)
- fw = writer.FileWriter(logdir)
- fw.add_graph(g)
+ fw = writer.FileWriter(logdir, graph=g)
+ fw.close()
with gfile.Open(os.path.join(plugin_dir, 'projector_config.pbtxt')) as f:
content = f.read()
@@ -405,8 +717,8 @@ class ProjectorPluginAssetTest(test.TestCase):
projector_plugin.ProjectorPluginAsset.plugin_name)
with ops.Graph().as_default() as g:
- fw = writer.FileWriter(logdir)
- fw.add_graph(g)
+ fw = writer.FileWriter(logdir, graph=g)
+ fw.close()
self.assertFalse(
gfile.Exists(plugin_dir),
diff --git a/tensorflow/tensorboard/plugins/text/text_plugin.py b/tensorflow/tensorboard/plugins/text/text_plugin.py
index b337ce2ad0..427a761d1e 100644
--- a/tensorflow/tensorboard/plugins/text/text_plugin.py
+++ b/tensorflow/tensorboard/plugins/text/text_plugin.py
@@ -292,3 +292,13 @@ class TextPlugin(base_plugin.TBPlugin):
RUNS_ROUTE: self.runs_route,
TEXT_ROUTE: self.text_route,
}
+
+ def is_active(self):
+ """Determines whether this plugin is active.
+
+ This plugin is only active if TensorBoard sampled any text summaries.
+
+ Returns:
+ Whether this plugin is active.
+ """
+ return bool(self.index_impl())
diff --git a/tensorflow/tensorboard/plugins/text/text_plugin_test.py b/tensorflow/tensorboard/plugins/text/text_plugin_test.py
index 846995b9a9..91dca289ce 100644
--- a/tensorflow/tensorboard/plugins/text/text_plugin_test.py
+++ b/tensorflow/tensorboard/plugins/text/text_plugin_test.py
@@ -390,6 +390,20 @@ class TextPluginTest(test.TestCase):
</table>""")
self.assertEqual(convert(d3), d3_expected)
+ def testPluginIsActive(self):
+ plugin = text_plugin.TextPlugin()
+ multiplexer = event_multiplexer.EventMultiplexer()
+ plugin.get_plugin_apps(event_multiplexer.EventMultiplexer(), None)
+
+ # The plugin is inactive because text summaries are not available.
+ self.assertFalse(plugin.is_active())
+
+ multiplexer.AddRunsFromDirectory(self.logdir)
+ multiplexer.Reload()
+
+ # The plugin is active because text summaries are available.
+ self.assertTrue(self.plugin.is_active())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 156f7b13bd..ddffabd8cb 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -41,76 +41,82 @@ def tf_android_core_proto_headers(core_proto_sources_relative):
])
+# Sanitize a dependency so that it works correctly from code that includes
+# TensorFlow as a submodule.
+def clean_dep(dep):
+ return str(Label(dep))
+
+
def if_android_x86(a):
return select({
- str(Label("//tensorflow:android_x86")): a,
- str(Label("//tensorflow:android_x86_64")): a,
+ clean_dep("//tensorflow:android_x86"): a,
+ clean_dep("//tensorflow:android_x86_64"): a,
"//conditions:default": [],
})
def if_android_arm(a):
return select({
- str(Label("//tensorflow:android_arm")): a,
+ clean_dep("//tensorflow:android_arm"): a,
"//conditions:default": [],
})
def if_android_arm64(a):
return select({
- str(Label("//tensorflow:android_arm64")): a,
+ clean_dep("//tensorflow:android_arm64"): a,
"//conditions:default": [],
})
def if_not_android(a):
return select({
- str(Label("//tensorflow:android")): [],
+ clean_dep("//tensorflow:android"): [],
"//conditions:default": a,
})
def if_android(a):
return select({
- str(Label("//tensorflow:android")): a,
+ clean_dep("//tensorflow:android"): a,
"//conditions:default": [],
})
def if_ios(a):
return select({
- str(Label("//tensorflow:ios")): a,
+ clean_dep("//tensorflow:ios"): a,
"//conditions:default": [],
})
def if_mobile(a):
return select({
- str(Label("//tensorflow:android")): a,
- str(Label("//tensorflow:ios")): a,
+ clean_dep("//tensorflow:android"): a,
+ clean_dep("//tensorflow:ios"): a,
"//conditions:default": [],
})
def if_not_mobile(a):
return select({
- str(Label("//tensorflow:android")): [],
- str(Label("//tensorflow:ios")): [],
+ clean_dep("//tensorflow:android"): [],
+ clean_dep("//tensorflow:ios"): [],
"//conditions:default": a,
})
def if_not_windows(a):
return select({
- str(Label("//tensorflow:windows")): [],
+ clean_dep("//tensorflow:windows"): [],
"//conditions:default": a,
})
def if_x86(a):
return select({
- str(Label("//tensorflow:linux_x86_64")): a,
- str(Label("//tensorflow:windows")): a,
+ clean_dep("//tensorflow:linux_x86_64"): a,
+ clean_dep("//tensorflow:windows"): a,
"//conditions:default": [],
})
@@ -124,13 +130,13 @@ def tf_copts():
"-fno-exceptions",
] + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_android_arm(
["-mfpu=neon"]) + if_x86(["-msse3"]) + select({
- "//tensorflow:android": [
+ clean_dep("//tensorflow:android"): [
"-std=c++11",
"-DTF_LEAN_BINARY",
"-O2",
],
- "//tensorflow:darwin": [],
- "//tensorflow:windows": [
+ clean_dep("//tensorflow:darwin"): [],
+ clean_dep("//tensorflow:windows"): [
"/DLANG_CXX11",
"/D__VERSION__=\\\"MSVC\\\"",
"/DPLATFORM_WINDOWS",
@@ -138,7 +144,7 @@ def tf_copts():
"/DEIGEN_HAS_C99_MATH",
"/DTENSORFLOW_USE_EIGEN_THREADPOOL",
],
- "//tensorflow:ios": ["-std=c++11"],
+ clean_dep("//tensorflow:ios"): ["-std=c++11"],
"//conditions:default": ["-pthread"]
}))
@@ -166,7 +172,7 @@ def tf_gen_op_libs(op_lib_names, deps=None):
name=n + "_op_lib",
copts=tf_copts(),
srcs=["ops/" + n + ".cc"],
- deps=deps + ["//tensorflow/core:framework"],
+ deps=deps + [clean_dep("//tensorflow/core:framework")],
visibility=["//visibility:public"],
alwayslink=1,
linkstatic=1,)
@@ -175,7 +181,7 @@ def tf_gen_op_libs(op_lib_names, deps=None):
def tf_gen_op_wrapper_cc(name,
out_ops_file,
pkg="",
- op_gen="//tensorflow/cc:cc_op_gen_main",
+ op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"),
deps=None,
override_file=None,
include_internal_ops=0):
@@ -242,11 +248,11 @@ def tf_gen_op_wrappers_cc(name,
other_hdrs=[],
pkg="",
deps=[
- str(Label("//tensorflow/cc:ops")),
- str(Label("//tensorflow/cc:scope")),
- str(Label("//tensorflow/cc:const_op")),
+ clean_dep("//tensorflow/cc:ops"),
+ clean_dep("//tensorflow/cc:scope"),
+ clean_dep("//tensorflow/cc:const_op"),
],
- op_gen=str(Label("//tensorflow/cc:cc_op_gen_main")),
+ op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"),
override_file=None,
include_internal_ops=0,
visibility=None):
@@ -272,12 +278,12 @@ def tf_gen_op_wrappers_cc(name,
srcs=subsrcs,
hdrs=subhdrs,
deps=deps + if_not_android([
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
+ clean_dep("//tensorflow/core:core_cpu"),
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/core:lib"),
+ clean_dep("//tensorflow/core:protos_all_cc"),
]) + if_android([
- "//tensorflow/core:android_tensorflow_lib",
+ clean_dep("//tensorflow/core:android_tensorflow_lib"),
]),
copts=tf_copts(),
alwayslink=1,
@@ -287,16 +293,16 @@ def tf_gen_op_wrappers_cc(name,
srcs=internalsrcs,
hdrs=internalhdrs,
deps=deps + if_not_android([
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
+ clean_dep("//tensorflow/core:core_cpu"),
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/core:lib"),
+ clean_dep("//tensorflow/core:protos_all_cc"),
]) + if_android([
- "//tensorflow/core:android_tensorflow_lib",
+ clean_dep("//tensorflow/core:android_tensorflow_lib"),
]),
copts=tf_copts(),
alwayslink=1,
- visibility=["//tensorflow:internal"])
+ visibility=[clean_dep("//tensorflow:internal")])
# Invoke this rule in .../tensorflow/python to build the wrapper library.
@@ -318,10 +324,10 @@ def tf_gen_op_wrapper_py(name,
copts=tf_copts(),
linkstatic=1, # Faster to link this one-time-use binary dynamically
deps=([
- "//tensorflow/core:framework",
- "//tensorflow/python:python_op_gen_main"
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/python:python_op_gen_main")
] + deps),
- visibility=["//tensorflow:internal"],)
+ visibility=[clean_dep("//tensorflow:internal")],)
# Invoke the previous cc_binary to generate a python file.
if not out:
@@ -363,7 +369,7 @@ def tf_gen_op_wrapper_py(name,
srcs_version="PY2AND3",
visibility=visibility,
deps=[
- "//tensorflow/python:framework_for_generated_wrappers_v2",
+ clean_dep("//tensorflow/python:framework_for_generated_wrappers_v2"),
],)
@@ -439,7 +445,7 @@ def tf_cuda_cc_test(name,
name=name,
srcs=srcs,
suffix="_gpu",
- deps=deps + if_cuda(["//tensorflow/core:gpu_runtime"]),
+ deps=deps + if_cuda([clean_dep("//tensorflow/core:gpu_runtime")]),
linkstatic=if_cuda(1, 0),
tags=tags + tf_cuda_tests_tags(),
data=data,
@@ -547,8 +553,8 @@ def tf_gpu_kernel_library(srcs,
hdrs=hdrs,
copts=copts,
deps=deps + if_cuda([
- "//tensorflow/core:cuda",
- "//tensorflow/core:gpu_lib",
+ clean_dep("//tensorflow/core:cuda"),
+ clean_dep("//tensorflow/core:gpu_lib"),
]),
alwayslink=1,
**kwargs)
@@ -579,7 +585,7 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=None, **kwargs):
native.cc_library(
deps=deps + if_cuda(cuda_deps + [
- "//tensorflow/core:cuda",
+ clean_dep("//tensorflow/core:cuda"),
"@local_config_cuda//cuda:cuda_headers"
]),
copts=copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]),
@@ -634,7 +640,7 @@ def tf_kernel_library(name,
hdrs = hdrs + native.glob(
[prefix + "*.h"], exclude=[prefix + "*test*", prefix + "*.cu.h"])
- cuda_deps = [str(Label("//tensorflow/core:gpu_lib"))]
+ cuda_deps = [clean_dep("//tensorflow/core:gpu_lib")]
if gpu_srcs:
for gpu_src in gpu_srcs:
if gpu_src.endswith(".cc") and not gpu_src.endswith(".cu.cc"):
@@ -810,8 +816,8 @@ def cc_header_only_library(name, deps=[], **kwargs):
def tf_custom_op_library_additional_deps():
return [
"@protobuf//:protobuf_headers",
- str(Label("//third_party/eigen3")),
- str(Label("//tensorflow/core:framework_headers_lib")),
+ clean_dep("//third_party/eigen3"),
+ clean_dep("//tensorflow/core:framework_headers_lib"),
]
@@ -871,7 +877,7 @@ check_deps = rule(
# implementations of custom ops and kernels.
def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[]):
cuda_deps = [
- str(Label("//tensorflow/core:stream_executor_headers_lib")),
+ clean_dep("//tensorflow/core:stream_executor_headers_lib"),
"@local_config_cuda//cuda:cudart_static",
]
deps = deps + tf_custom_op_library_additional_deps()
@@ -888,8 +894,8 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[]):
name=name + "_check_deps",
deps=deps + if_cuda(cuda_deps),
disallowed_deps=[
- "//tensorflow/core:framework",
- "//tensorflow/core:lib"
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/core:lib")
])
native.cc_binary(
@@ -903,7 +909,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[]):
"//conditions:default": [
"-lm",
],
- "//tensorflow:darwin": [],
+ clean_dep("//tensorflow:darwin"): [],
}),)
@@ -956,21 +962,21 @@ def tf_py_wrap_cc(name,
extra_linkopts = select({
"@local_config_cuda//cuda:darwin": [
"-Wl,-exported_symbols_list",
- str(Label("//tensorflow:tf_exported_symbols.lds"))
+ clean_dep("//tensorflow:tf_exported_symbols.lds")
],
- str(Label("//tensorflow:windows")): [],
+ clean_dep("//tensorflow:windows"): [],
"//conditions:default": [
"-Wl,--version-script",
- "//tensorflow:tf_version_script.lds"
+ clean_dep("//tensorflow:tf_version_script.lds")
]
})
extra_deps += select({
"@local_config_cuda//cuda:darwin": [
- "//tensorflow:tf_exported_symbols.lds"
+ clean_dep("//tensorflow:tf_exported_symbols.lds")
],
- "//tensorflow:windows": [],
+ clean_dep("//tensorflow:windows"): [],
"//conditions:default": [
- "//tensorflow:tf_version_script.lds"
+ clean_dep("//tensorflow:tf_version_script.lds")
]
})
@@ -994,7 +1000,7 @@ def tf_py_wrap_cc(name,
srcs=[":" + name + ".py"],
srcs_version="PY2AND3",
data=select({
- "//tensorflow:windows": [":" + cc_library_pyd_name],
+ clean_dep("//tensorflow:windows"): [":" + cc_library_pyd_name],
"//conditions:default": [":" + cc_library_name],
}))
@@ -1003,7 +1009,7 @@ def py_test(deps=[], **kwargs):
native.py_test(
deps=select({
"//conditions:default": deps,
- "//tensorflow:no_tensorflow_py_deps": []
+ clean_dep("//tensorflow:no_tensorflow_py_deps"): []
}),
**kwargs)
@@ -1028,15 +1034,15 @@ def tf_py_test(name,
main=main,
args=args,
tags=tags,
- visibility=[str(Label("//tensorflow:internal"))],
+ visibility=[clean_dep("//tensorflow:internal")],
shard_count=shard_count,
data=data,
deps=select({
"//conditions:default": [
- "//tensorflow/python:extra_py_tests_deps",
- "//tensorflow/python:gradient_checker",
+ clean_dep("//tensorflow/python:extra_py_tests_deps"),
+ clean_dep("//tensorflow/python:gradient_checker"),
] + additional_deps,
- "//tensorflow:no_tensorflow_py_deps": []
+ clean_dep("//tensorflow:no_tensorflow_py_deps"): []
}),
flaky=flaky,
srcs_version="PY2AND3")
@@ -1153,13 +1159,13 @@ def tf_generate_proto_text_sources(name, srcs_relative_dir, srcs):
out_srcs = [p.replace(".proto", ".pb_text.cc") for p in srcs]
native.genrule(
name=name,
- srcs=srcs + ["//tensorflow/tools/proto_text:placeholder.txt"],
+ srcs=srcs + [clean_dep("//tensorflow/tools/proto_text:placeholder.txt")],
outs=out_hdrs + out_srcs,
cmd=
"$(location //tensorflow/tools/proto_text:gen_proto_text_functions) "
+ "$(@D) " + srcs_relative_dir + " $(SRCS)",
tools=[
- "//tensorflow/tools/proto_text:gen_proto_text_functions"
+ clean_dep("//tensorflow/tools/proto_text:gen_proto_text_functions")
],)
return struct(hdrs=out_hdrs, srcs=out_srcs)
@@ -1173,15 +1179,15 @@ def tf_version_info_genrule():
native.genrule(
name="version_info_gen",
srcs=[
- "//tensorflow/tools/git:gen/spec.json",
- "//tensorflow/tools/git:gen/head",
- "//tensorflow/tools/git:gen/branch_ref",
+ clean_dep("//tensorflow/tools/git:gen/spec.json"),
+ clean_dep("//tensorflow/tools/git:gen/head"),
+ clean_dep("//tensorflow/tools/git:gen/branch_ref"),
],
outs=["util/version_info.cc"],
cmd=
"$(location //tensorflow/tools/git:gen_git_source.py) --generate $(SRCS) \"$@\"",
local=1,
- tools=["//tensorflow/tools/git:gen_git_source.py"],)
+ tools=[clean_dep("//tensorflow/tools/git:gen_git_source.py")],)
def cc_library_with_android_deps(deps,
diff --git a/tensorflow/tools/dist_test/server/Dockerfile.test b/tensorflow/tools/dist_test/server/Dockerfile.test
index 3cd3d5206d..908af8af9b 100644
--- a/tensorflow/tools/dist_test/server/Dockerfile.test
+++ b/tensorflow/tools/dist_test/server/Dockerfile.test
@@ -52,13 +52,13 @@ ADD . /var/tf-k8s
# Download MNIST data for tests
RUN mkdir -p /tmp/mnist-data
RUN curl -o /tmp/mnist-data/train-labels-idx1-ubyte.gz \
- http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
+ https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz
RUN curl -o /tmp/mnist-data/train-images-idx3-ubyte.gz \
- http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
+ https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz
RUN curl -o /tmp/mnist-data/t10k-labels-idx1-ubyte.gz \
- http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
+ https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz
RUN curl -o /tmp/mnist-data/t10k-images-idx3-ubyte.gz \
- http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
+ https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz
# Download Census data for Wide & Deep test
RUN mkdir -p /tmp/census-data
diff --git a/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb b/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb
index b35b14df1f..c9f2b1ab9e 100644
--- a/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb
+++ b/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb
@@ -134,7 +134,7 @@
"import os\n",
"from six.moves.urllib.request import urlretrieve\n",
"\n",
- "SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'\n",
+ "SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'\n",
"WORK_DIRECTORY = \"/tmp/mnist-data\"\n",
"\n",
"def maybe_download(filename):\n",
diff --git a/tensorflow/tools/graph_transforms/quantize_weights.cc b/tensorflow/tools/graph_transforms/quantize_weights.cc
index e6f1498224..66d800f0da 100644
--- a/tensorflow/tools/graph_transforms/quantize_weights.cc
+++ b/tensorflow/tools/graph_transforms/quantize_weights.cc
@@ -70,6 +70,10 @@ Status QuantizeWeights(const GraphDef& input_graph_def,
min = std::min(min, value);
max = std::max(max, value);
}
+ // Make sure the quantization range includes 0.0f. Not all quantized
+ // Ops behave properly if 0.0f is not in the range.
+ min = std::min(min, 0.0f);
+ max = std::max(0.0f, max);
// min_value == max_value is a tricky case. It can occur for general
// tensors, and of course for scalars. The quantized ops cannot deal
// with this case, so we set max_value to something else.
diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc
index cd5feed358..e1a105bdd3 100644
--- a/tensorflow/tools/graph_transforms/quantize_weights_test.cc
+++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc
@@ -35,51 +35,46 @@ Status QuantizeWeights(const GraphDef& input_graph_def,
class QuantizeWeightsTest : public ::testing::Test {
protected:
- void TestQuantizeWeights() {
+ void BuildGraphDef(const TensorShape& input_shape,
+ std::initializer_list<float> input_values,
+ const TensorShape& weight_shape,
+ std::initializer_list<float> weight_values,
+ GraphDef* original_graph_def) {
auto root = tensorflow::Scope::NewRootScope();
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
- Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
- test::FillValues<float>(
- &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
- -5.0f, -3.0f, -6.0f});
+ Tensor input_data(DT_FLOAT, input_shape);
+ test::FillValues<float>(&input_data, input_values);
Output input_op =
- Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+ ops::Const(root.WithOpName("input_op"), Input::Initializer(input_data));
- Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 10}));
- test::FillValues<float>(
- &weights_data,
- {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
- 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
- 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
- 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
- Output weights_op =
- Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+ Tensor weights_data(DT_FLOAT, weight_shape);
+ test::FillValues<float>(&weights_data, weight_values);
+ Output weights_op = ops::Const(root.WithOpName("weights_op"),
+ Input::Initializer(weights_data));
- Output conv_op = Conv2D(root.WithOpName("output"), input_op, weights_op,
- {1, 1, 1, 1}, "VALID");
+ Output conv_op = ops::Conv2D(root.WithOpName("output"), input_op,
+ weights_op, {1, 1, 1, 1}, "VALID");
- GraphDef original_graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+ TF_ASSERT_OK(root.ToGraphDef(original_graph_def));
+ }
- std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
- TF_ASSERT_OK(original_session->Create(original_graph_def));
- std::vector<Tensor> original_outputs;
- TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+ void TestQuantizeWeights() {
+ GraphDef original_graph_def;
+ BuildGraphDef({1, 1, 6, 2},
+ {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f},
+ {1, 2, 2, 10},
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
+ 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
+ 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
+ 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f},
+ &original_graph_def);
GraphDef quantized_graph_def;
TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}},
&quantized_graph_def));
- std::unique_ptr<Session> quantized_session(NewSession(SessionOptions()));
- TF_ASSERT_OK(quantized_session->Create(quantized_graph_def));
- std::vector<Tensor> quantized_outputs;
- TF_ASSERT_OK(
- quantized_session->Run({}, {"output"}, {}, &quantized_outputs));
-
- test::ExpectTensorNear<float>(original_outputs[0], quantized_outputs[0],
- 0.5);
-
+ // Verify the structure of the quantized graph.
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(quantized_graph_def, &node_lookup);
EXPECT_EQ(1, node_lookup.count("input_op"));
@@ -94,10 +89,69 @@ class QuantizeWeightsTest : public ::testing::Test {
const NodeDef* q_weights_const = node_lookup.at(weights_const_name);
EXPECT_EQ("Const", q_weights_const->op());
EXPECT_EQ(DT_QUINT8, q_weights_const->attr().at("dtype").type());
+
+ // Run the the original graph.
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+
+ // Run the the quantized graph.
+ std::unique_ptr<Session> quantized_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(quantized_session->Create(quantized_graph_def));
+ std::vector<Tensor> quantized_outputs;
+ TF_ASSERT_OK(
+ quantized_session->Run({}, {"output"}, {}, &quantized_outputs));
+
+ // Compare the results
+ test::ExpectTensorNear<float>(original_outputs[0], quantized_outputs[0],
+ 0.5);
}
};
TEST_F(QuantizeWeightsTest, TestQuantizeWeights) { TestQuantizeWeights(); }
+TEST_F(QuantizeWeightsTest, RangesAlwaysIncludeZero) {
+ GraphDef original_graph_def;
+ BuildGraphDef({1, 1, 4, 4},
+ {-1.0f, -4.0f, -2.0f, -5.0f, -1.0f, -4.0f, -2.0f, -5.0f, -1.0f,
+ -4.0f, -2.0f, -5.0f, -1.0f, -4.0f, -2.0f, -5.0f},
+ {1, 2, 2, 10},
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
+ 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
+ 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
+ 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f},
+ &original_graph_def);
+ GraphDef quantized_graph_def;
+ TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}},
+ &quantized_graph_def));
+
+ std::map<string, const NodeDef*> node_lookup;
+ MapNamesToNodes(quantized_graph_def, &node_lookup);
+
+ auto expected_tensor = [](float value) {
+ Tensor tensor(DT_FLOAT, TensorShape({}));
+ test::FillValues<float>(&tensor, {value});
+ return tensor;
+ };
+ auto existing_tensor = [&node_lookup](string op) {
+ const NodeDef* node_def = node_lookup.at(op);
+ CHECK(node_def);
+ return GetNodeTensorAttr(*node_def, "value");
+ };
+
+ // The max of input_op is moved from -1.0 to 0.0.
+ test::ExpectTensorNear<float>(
+ expected_tensor(-5.0), existing_tensor("input_op_quantized_min"), 1e-5);
+ test::ExpectTensorNear<float>(
+ expected_tensor(0.0), existing_tensor("input_op_quantized_max"), 1e-5);
+
+ // The min of weights_op is moved from 0.1 to 0.0.
+ test::ExpectTensorNear<float>(
+ expected_tensor(0.0), existing_tensor("weights_op_quantized_min"), 1e-5);
+ test::ExpectTensorNear<float>(
+ expected_tensor(4.0), existing_tensor("weights_op_quantized_max"), 1e-5);
+}
+
} // namespace graph_transforms
} // namespace tensorflow
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index d0ef1d32fc..e5e005fa93 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -320,24 +320,22 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
# and com_google_protobuf_cc to enable proto_library support in bazel.
# Unfortunately there is no way to alias http_archives at the moment.
native.http_archive(
- name = "com_google_protobuf",
- urls = [
+ name="com_google_protobuf",
+ urls=[
"http://bazel-mirror.storage.googleapis.com/github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz",
"https://github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz",
],
- sha256 = "e5d3d4e227a0f7afb8745df049bbd4d55474b158ca5aaa2a0e31099af24be1d0",
- strip_prefix = "protobuf-2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a",
- )
+ sha256="e5d3d4e227a0f7afb8745df049bbd4d55474b158ca5aaa2a0e31099af24be1d0",
+ strip_prefix="protobuf-2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a",)
native.http_archive(
- name = "com_google_protobuf_cc",
- urls = [
+ name="com_google_protobuf_cc",
+ urls=[
"http://bazel-mirror.storage.googleapis.com/github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz",
"https://github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz",
],
- sha256 = "e5d3d4e227a0f7afb8745df049bbd4d55474b158ca5aaa2a0e31099af24be1d0",
- strip_prefix = "protobuf-2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a",
- )
+ sha256="e5d3d4e227a0f7afb8745df049bbd4d55474b158ca5aaa2a0e31099af24be1d0",
+ strip_prefix="protobuf-2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a",)
native.new_http_archive(
name="gmock_archive",
@@ -358,10 +356,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
actual="@gmock_archive//:gtest_main",)
native.git_repository(
- name = "com_github_gflags_gflags",
- commit = "f8a0efe03aa69b3336d8e228b37d4ccb17324b88",
- remote = "https://github.com/gflags/gflags.git",
- )
+ name="com_github_gflags_gflags",
+ commit="f8a0efe03aa69b3336d8e228b37d4ccb17324b88",
+ remote="https://github.com/gflags/gflags.git",)
native.bind(
name="python_headers",
@@ -613,13 +610,13 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
name="com_microsoft_typescript",
licenses=["notice"], # Apache 2.0
sha256_urls={
- "e3d9e320a2cae99be4aaa37953961a48323cdf16ba9aa2557a44d69571cd9b8d": [
- "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.1.6/lib/tsc.js",
- "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.1.6/lib/tsc.js",
+ "43a7c763fe024d5add8d5365e5a7981f4a359ba5bf86481f545a0db8f60d48cc": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/tsc.js",
+ "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/tsc.js",
],
- "f189cebe96eb76b238c6e364e72d4b0324e699f83eeae5deac23506cb3764fc6": [
- "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.1.6/lib/lib.es6.d.ts",
- "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.1.6/lib/lib.es6.d.ts",
+ "aecec1e47a3b3d872e214cb9adb82b30d6bd0471ea0aad7311ad81428566627c": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/lib.es6.d.ts",
+ "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/lib.es6.d.ts",
],
},
extra_build_file_content="\n".join([
diff --git a/third_party/fft2d/BUILD b/third_party/fft2d/BUILD
new file mode 100644
index 0000000000..93ea06e81b
--- /dev/null
+++ b/third_party/fft2d/BUILD
@@ -0,0 +1,30 @@
+# Headers for 2D Fast Fourier Transform package
+# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html
+# This is a separate package because the original downloaded archive doesn't
+# contain any header files.
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+# Unrestricted use; can only distribute original package.
+# See fft/readme.txt
+licenses(["notice"])
+
+exports_files(["LICENSE"])
+
+cc_library(
+ name = "fft2d_headers",
+ srcs = ["fft.h"],
+)
+
+objc_library(
+ name = "fft2d_headersd_ios",
+ srcs = ["fft.h"],
+)
+
+# Export the source code so that it could be compiled for Andoid native apps.
+filegroup(
+ name = "fft2d_headers_srcs",
+ srcs = ["fft.h"],
+)
diff --git a/third_party/fft2d/LICENSE b/third_party/fft2d/LICENSE
new file mode 100644
index 0000000000..2bd85506a8
--- /dev/null
+++ b/third_party/fft2d/LICENSE
@@ -0,0 +1,3 @@
+Copyright(C) 1997,2001 Takuya OOURA (email: ooura@kurims.kyoto-u.ac.jp).
+You may use, copy, modify this code for any purpose and
+without fee. You may distribute this ORIGINAL package.
diff --git a/third_party/fft2d/fft.h b/third_party/fft2d/fft.h
new file mode 100644
index 0000000000..252cc01fec
--- /dev/null
+++ b/third_party/fft2d/fft.h
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Declarations for 1D FFT routines in third_party/fft2d/fft.
+
+#ifndef THIRD_PARTY_FFT2D_FFT_H__
+#define THIRD_PARTY_FFT2D_FFT_H__
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+extern void cdft(int, int, double *, int *, double *);
+extern void rdft(int, int, double *, int *, double *);
+extern void ddct(int, int, double *, int *, double *);
+extern void ddst(int, int, double *, int *, double *);
+extern void dfct(int, double *, double *, int *, double *);
+extern void dfst(int, double *, double *, int *, double *);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // THIRD_PARTY_FFT2D_FFT_H__
diff --git a/third_party/fft2d/fft2d.BUILD b/third_party/fft2d/fft2d.BUILD
new file mode 100644
index 0000000000..3dbd36aec0
--- /dev/null
+++ b/third_party/fft2d/fft2d.BUILD
@@ -0,0 +1,36 @@
+# 2D Fast Fourier Transform package
+# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+# Unrestricted use; can only distribute original package.
+licenses(["notice"])
+
+exports_files(["fft/readme.txt"])
+
+FFT2D_SRCS = [
+ "fft/fftsg.c",
+]
+
+# This is the main 2D FFT library. The 2D FFTs in this library call
+# 1D FFTs. In addition, fast DCTs are provided for the special case
+# of 8x8 and 16x16. This code in this library is referred to as
+# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html.
+cc_library(
+ name = "fft2d",
+ srcs = FFT2D_SRCS,
+ linkopts = ["-lm"],
+)
+
+objc_library(
+ name = "fft2d_ios",
+ srcs = FFT2D_SRCS,
+)
+
+# Export the source code so that it could be compiled for Andoid native apps.
+filegroup(
+ name = "fft2d_srcs",
+ srcs = FFT2D_SRCS,
+)